1//===- CodeExtractor.cpp - Pull code region into a new function -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the interface to tear out a code region, such as an
10// individual loop or a parallel section, into a new function, replacing it with
11// a call to the new function.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Transforms/Utils/CodeExtractor.h"
16#include "llvm/ADT/ArrayRef.h"
17#include "llvm/ADT/DenseMap.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/ADT/SmallPtrSet.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Analysis/AssumptionCache.h"
23#include "llvm/Analysis/BlockFrequencyInfo.h"
24#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
25#include "llvm/Analysis/BranchProbabilityInfo.h"
26#include "llvm/IR/Argument.h"
27#include "llvm/IR/Attributes.h"
28#include "llvm/IR/BasicBlock.h"
29#include "llvm/IR/CFG.h"
30#include "llvm/IR/Constant.h"
31#include "llvm/IR/Constants.h"
32#include "llvm/IR/DIBuilder.h"
33#include "llvm/IR/DataLayout.h"
34#include "llvm/IR/DebugInfo.h"
35#include "llvm/IR/DebugInfoMetadata.h"
36#include "llvm/IR/DerivedTypes.h"
37#include "llvm/IR/Dominators.h"
38#include "llvm/IR/Function.h"
39#include "llvm/IR/GlobalValue.h"
40#include "llvm/IR/InstIterator.h"
41#include "llvm/IR/InstrTypes.h"
42#include "llvm/IR/Instruction.h"
43#include "llvm/IR/Instructions.h"
44#include "llvm/IR/IntrinsicInst.h"
45#include "llvm/IR/Intrinsics.h"
46#include "llvm/IR/LLVMContext.h"
47#include "llvm/IR/MDBuilder.h"
48#include "llvm/IR/Module.h"
49#include "llvm/IR/PatternMatch.h"
50#include "llvm/IR/Type.h"
51#include "llvm/IR/User.h"
52#include "llvm/IR/Value.h"
53#include "llvm/IR/Verifier.h"
54#include "llvm/Support/BlockFrequency.h"
55#include "llvm/Support/BranchProbability.h"
56#include "llvm/Support/Casting.h"
57#include "llvm/Support/CommandLine.h"
58#include "llvm/Support/Debug.h"
59#include "llvm/Support/ErrorHandling.h"
60#include "llvm/Support/raw_ostream.h"
61#include "llvm/Transforms/Utils/BasicBlockUtils.h"
62#include <cassert>
63#include <cstdint>
64#include <iterator>
65#include <map>
66#include <utility>
67#include <vector>
68
69using namespace llvm;
70using namespace llvm::PatternMatch;
71using ProfileCount = Function::ProfileCount;
72
73#define DEBUG_TYPE "code-extractor"
74
75// Provide a command-line option to aggregate function arguments into a struct
76// for functions produced by the code extractor. This is useful when converting
77// extracted functions to pthread-based code, as only one argument (void*) can
78// be passed in to pthread_create().
79static cl::opt<bool>
80AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
81 cl::desc("Aggregate arguments to code-extracted functions"));
82
83/// Test whether a block is valid for extraction.
84static bool isBlockValidForExtraction(const BasicBlock &BB,
85 const SetVector<BasicBlock *> &Result,
86 bool AllowVarArgs, bool AllowAlloca) {
87 // taking the address of a basic block moved to another function is illegal
88 if (BB.hasAddressTaken())
89 return false;
90
91 // don't hoist code that uses another basicblock address, as it's likely to
92 // lead to unexpected behavior, like cross-function jumps
93 SmallPtrSet<User const *, 16> Visited;
94 SmallVector<User const *, 16> ToVisit(llvm::make_pointer_range(Range: BB));
95
96 while (!ToVisit.empty()) {
97 User const *Curr = ToVisit.pop_back_val();
98 if (!Visited.insert(Ptr: Curr).second)
99 continue;
100 if (isa<BlockAddress const>(Val: Curr))
101 return false; // even a reference to self is likely to be not compatible
102
103 if (isa<Instruction>(Val: Curr) && cast<Instruction>(Val: Curr)->getParent() != &BB)
104 continue;
105
106 for (auto const &U : Curr->operands()) {
107 if (auto *UU = dyn_cast<User>(Val: U))
108 ToVisit.push_back(Elt: UU);
109 }
110 }
111
112 // If explicitly requested, allow vastart and alloca. For invoke instructions
113 // verify that extraction is valid.
114 for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
115 if (isa<AllocaInst>(Val: I)) {
116 if (!AllowAlloca)
117 return false;
118 continue;
119 }
120
121 if (const auto *II = dyn_cast<InvokeInst>(Val&: I)) {
122 // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
123 // must be a part of the subgraph which is being extracted.
124 if (auto *UBB = II->getUnwindDest())
125 if (!Result.count(key: UBB))
126 return false;
127 continue;
128 }
129
130 // All catch handlers of a catchswitch instruction as well as the unwind
131 // destination must be in the subgraph.
132 if (const auto *CSI = dyn_cast<CatchSwitchInst>(Val&: I)) {
133 if (auto *UBB = CSI->getUnwindDest())
134 if (!Result.count(key: UBB))
135 return false;
136 for (const auto *HBB : CSI->handlers())
137 if (!Result.count(key: const_cast<BasicBlock*>(HBB)))
138 return false;
139 continue;
140 }
141
142 // Make sure that entire catch handler is within subgraph. It is sufficient
143 // to check that catch return's block is in the list.
144 if (const auto *CPI = dyn_cast<CatchPadInst>(Val&: I)) {
145 for (const auto *U : CPI->users())
146 if (const auto *CRI = dyn_cast<CatchReturnInst>(Val: U))
147 if (!Result.count(key: const_cast<BasicBlock*>(CRI->getParent())))
148 return false;
149 continue;
150 }
151
152 // And do similar checks for cleanup handler - the entire handler must be
153 // in subgraph which is going to be extracted. For cleanup return should
154 // additionally check that the unwind destination is also in the subgraph.
155 if (const auto *CPI = dyn_cast<CleanupPadInst>(Val&: I)) {
156 for (const auto *U : CPI->users())
157 if (const auto *CRI = dyn_cast<CleanupReturnInst>(Val: U))
158 if (!Result.count(key: const_cast<BasicBlock*>(CRI->getParent())))
159 return false;
160 continue;
161 }
162 if (const auto *CRI = dyn_cast<CleanupReturnInst>(Val&: I)) {
163 if (auto *UBB = CRI->getUnwindDest())
164 if (!Result.count(key: UBB))
165 return false;
166 continue;
167 }
168
169 if (const CallInst *CI = dyn_cast<CallInst>(Val&: I)) {
170 // musttail calls have several restrictions, generally enforcing matching
171 // calling conventions between the caller parent and musttail callee.
172 // We can't usually honor them, because the extracted function has a
173 // different signature altogether, taking inputs/outputs and returning
174 // a control-flow identifier rather than the actual return value.
175 if (CI->isMustTailCall())
176 return false;
177
178 if (const Function *F = CI->getCalledFunction()) {
179 auto IID = F->getIntrinsicID();
180 if (IID == Intrinsic::vastart) {
181 if (AllowVarArgs)
182 continue;
183 else
184 return false;
185 }
186
187 // Currently, we miscompile outlined copies of eh_typid_for. There are
188 // proposals for fixing this in llvm.org/PR39545.
189 if (IID == Intrinsic::eh_typeid_for)
190 return false;
191 }
192 }
193 }
194
195 return true;
196}
197
198/// Build a set of blocks to extract if the input blocks are viable.
199static SetVector<BasicBlock *>
200buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
201 bool AllowVarArgs, bool AllowAlloca) {
202 assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
203 SetVector<BasicBlock *> Result;
204
205 // Loop over the blocks, adding them to our set-vector, and aborting with an
206 // empty set if we encounter invalid blocks.
207 for (BasicBlock *BB : BBs) {
208 // If this block is dead, don't process it.
209 if (DT && !DT->isReachableFromEntry(A: BB))
210 continue;
211
212 if (!Result.insert(X: BB))
213 llvm_unreachable("Repeated basic blocks in extraction input");
214 }
215
216 LLVM_DEBUG(dbgs() << "Region front block: " << Result.front()->getName()
217 << '\n');
218
219 for (auto *BB : Result) {
220 if (!isBlockValidForExtraction(BB: *BB, Result, AllowVarArgs, AllowAlloca))
221 return {};
222
223 // Make sure that the first block is not a landing pad.
224 if (BB == Result.front()) {
225 if (BB->isEHPad()) {
226 LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
227 return {};
228 }
229 continue;
230 }
231
232 // All blocks other than the first must not have predecessors outside of
233 // the subgraph which is being extracted.
234 for (auto *PBB : predecessors(BB))
235 if (!Result.count(key: PBB)) {
236 LLVM_DEBUG(dbgs() << "No blocks in this region may have entries from "
237 "outside the region except for the first block!\n"
238 << "Problematic source BB: " << BB->getName() << "\n"
239 << "Problematic destination BB: " << PBB->getName()
240 << "\n");
241 return {};
242 }
243 }
244
245 return Result;
246}
247
248/// isAlignmentPreservedForAddrCast - Return true if the cast operation
249/// for specified target preserves original alignment
250static bool isAlignmentPreservedForAddrCast(const Triple &TargetTriple) {
251 switch (TargetTriple.getArch()) {
252 case Triple::ArchType::amdgcn:
253 case Triple::ArchType::r600:
254 return true;
255 // TODO: Add other architectures for which we are certain that alignment
256 // is preserved during address space cast operations.
257 default:
258 return false;
259 }
260 return false;
261}
262
263CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
264 bool AggregateArgs, BlockFrequencyInfo *BFI,
265 BranchProbabilityInfo *BPI, AssumptionCache *AC,
266 bool AllowVarArgs, bool AllowAlloca,
267 BasicBlock *AllocationBlock, std::string Suffix,
268 bool ArgsInZeroAddressSpace)
269 : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
270 BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
271 AllowVarArgs(AllowVarArgs),
272 Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
273 Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
274
275/// definedInRegion - Return true if the specified value is defined in the
276/// extracted region.
277static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) {
278 if (Instruction *I = dyn_cast<Instruction>(Val: V))
279 if (Blocks.count(key: I->getParent()))
280 return true;
281 return false;
282}
283
284/// definedInCaller - Return true if the specified value is defined in the
285/// function being code extracted, but not in the region being extracted.
286/// These values must be passed in as live-ins to the function.
287static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
288 if (isa<Argument>(Val: V)) return true;
289 if (Instruction *I = dyn_cast<Instruction>(Val: V))
290 if (!Blocks.count(key: I->getParent()))
291 return true;
292 return false;
293}
294
295static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
296 BasicBlock *CommonExitBlock = nullptr;
297 auto hasNonCommonExitSucc = [&](BasicBlock *Block) {
298 for (auto *Succ : successors(BB: Block)) {
299 // Internal edges, ok.
300 if (Blocks.count(key: Succ))
301 continue;
302 if (!CommonExitBlock) {
303 CommonExitBlock = Succ;
304 continue;
305 }
306 if (CommonExitBlock != Succ)
307 return true;
308 }
309 return false;
310 };
311
312 if (any_of(Range: Blocks, P: hasNonCommonExitSucc))
313 return nullptr;
314
315 return CommonExitBlock;
316}
317
318CodeExtractorAnalysisCache::CodeExtractorAnalysisCache(Function &F) {
319 for (BasicBlock &BB : F) {
320 for (Instruction &II : BB.instructionsWithoutDebug())
321 if (auto *AI = dyn_cast<AllocaInst>(Val: &II))
322 Allocas.push_back(Elt: AI);
323
324 findSideEffectInfoForBlock(BB);
325 }
326}
327
328void CodeExtractorAnalysisCache::findSideEffectInfoForBlock(BasicBlock &BB) {
329 for (Instruction &II : BB.instructionsWithoutDebug()) {
330 unsigned Opcode = II.getOpcode();
331 Value *MemAddr = nullptr;
332 switch (Opcode) {
333 case Instruction::Store:
334 case Instruction::Load: {
335 if (Opcode == Instruction::Store) {
336 StoreInst *SI = cast<StoreInst>(Val: &II);
337 MemAddr = SI->getPointerOperand();
338 } else {
339 LoadInst *LI = cast<LoadInst>(Val: &II);
340 MemAddr = LI->getPointerOperand();
341 }
342 // Global variable can not be aliased with locals.
343 if (isa<Constant>(Val: MemAddr))
344 break;
345 Value *Base = MemAddr->stripInBoundsConstantOffsets();
346 if (!isa<AllocaInst>(Val: Base)) {
347 SideEffectingBlocks.insert(V: &BB);
348 return;
349 }
350 BaseMemAddrs[&BB].insert(V: Base);
351 break;
352 }
353 default: {
354 IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(Val: &II);
355 if (IntrInst) {
356 if (IntrInst->isLifetimeStartOrEnd())
357 break;
358 SideEffectingBlocks.insert(V: &BB);
359 return;
360 }
361 // Treat all the other cases conservatively if it has side effects.
362 if (II.mayHaveSideEffects()) {
363 SideEffectingBlocks.insert(V: &BB);
364 return;
365 }
366 }
367 }
368 }
369}
370
371bool CodeExtractorAnalysisCache::doesBlockContainClobberOfAddr(
372 BasicBlock &BB, AllocaInst *Addr) const {
373 if (SideEffectingBlocks.count(V: &BB))
374 return true;
375 auto It = BaseMemAddrs.find(Val: &BB);
376 if (It != BaseMemAddrs.end())
377 return It->second.count(V: Addr);
378 return false;
379}
380
381bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
382 const CodeExtractorAnalysisCache &CEAC, Instruction *Addr) const {
383 AllocaInst *AI = cast<AllocaInst>(Val: Addr->stripInBoundsConstantOffsets());
384 Function *Func = (*Blocks.begin())->getParent();
385 for (BasicBlock &BB : *Func) {
386 if (Blocks.count(key: &BB))
387 continue;
388 if (CEAC.doesBlockContainClobberOfAddr(BB, Addr: AI))
389 return false;
390 }
391 return true;
392}
393
394BasicBlock *
395CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
396 BasicBlock *SinglePredFromOutlineRegion = nullptr;
397 assert(!Blocks.count(CommonExitBlock) &&
398 "Expect a block outside the region!");
399 for (auto *Pred : predecessors(BB: CommonExitBlock)) {
400 if (!Blocks.count(key: Pred))
401 continue;
402 if (!SinglePredFromOutlineRegion) {
403 SinglePredFromOutlineRegion = Pred;
404 } else if (SinglePredFromOutlineRegion != Pred) {
405 SinglePredFromOutlineRegion = nullptr;
406 break;
407 }
408 }
409
410 if (SinglePredFromOutlineRegion)
411 return SinglePredFromOutlineRegion;
412
413#ifndef NDEBUG
414 auto getFirstPHI = [](BasicBlock *BB) {
415 BasicBlock::iterator I = BB->begin();
416 PHINode *FirstPhi = nullptr;
417 while (I != BB->end()) {
418 PHINode *Phi = dyn_cast<PHINode>(I);
419 if (!Phi)
420 break;
421 if (!FirstPhi) {
422 FirstPhi = Phi;
423 break;
424 }
425 }
426 return FirstPhi;
427 };
428 // If there are any phi nodes, the single pred either exists or has already
429 // be created before code extraction.
430 assert(!getFirstPHI(CommonExitBlock) && "Phi not expected");
431#endif
432
433 BasicBlock *NewExitBlock =
434 CommonExitBlock->splitBasicBlock(I: CommonExitBlock->getFirstNonPHIIt());
435
436 for (BasicBlock *Pred :
437 llvm::make_early_inc_range(Range: predecessors(BB: CommonExitBlock))) {
438 if (Blocks.count(key: Pred))
439 continue;
440 Pred->getTerminator()->replaceUsesOfWith(From: CommonExitBlock, To: NewExitBlock);
441 }
442 // Now add the old exit block to the outline region.
443 Blocks.insert(X: CommonExitBlock);
444 return CommonExitBlock;
445}
446
447// Find the pair of life time markers for address 'Addr' that are either
448// defined inside the outline region or can legally be shrinkwrapped into the
449// outline region. If there are not other untracked uses of the address, return
450// the pair of markers if found; otherwise return a pair of nullptr.
451CodeExtractor::LifetimeMarkerInfo
452CodeExtractor::getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
453 Instruction *Addr,
454 BasicBlock *ExitBlock) const {
455 LifetimeMarkerInfo Info;
456
457 for (User *U : Addr->users()) {
458 IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(Val: U);
459 if (IntrInst) {
460 // We don't model addresses with multiple start/end markers, but the
461 // markers do not need to be in the region.
462 if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
463 if (Info.LifeStart)
464 return {};
465 Info.LifeStart = IntrInst;
466 continue;
467 }
468 if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
469 if (Info.LifeEnd)
470 return {};
471 Info.LifeEnd = IntrInst;
472 continue;
473 }
474 }
475 // Find untracked uses of the address, bail.
476 if (!definedInRegion(Blocks, V: U))
477 return {};
478 }
479
480 if (!Info.LifeStart || !Info.LifeEnd)
481 return {};
482
483 Info.SinkLifeStart = !definedInRegion(Blocks, V: Info.LifeStart);
484 Info.HoistLifeEnd = !definedInRegion(Blocks, V: Info.LifeEnd);
485 // Do legality check.
486 if ((Info.SinkLifeStart || Info.HoistLifeEnd) &&
487 !isLegalToShrinkwrapLifetimeMarkers(CEAC, Addr))
488 return {};
489
490 // Check to see if we have a place to do hoisting, if not, bail.
491 if (Info.HoistLifeEnd && !ExitBlock)
492 return {};
493
494 return Info;
495}
496
497void CodeExtractor::findAllocas(const CodeExtractorAnalysisCache &CEAC,
498 ValueSet &SinkCands, ValueSet &HoistCands,
499 BasicBlock *&ExitBlock) const {
500 Function *Func = (*Blocks.begin())->getParent();
501 ExitBlock = getCommonExitBlock(Blocks);
502
503 auto moveOrIgnoreLifetimeMarkers =
504 [&](const LifetimeMarkerInfo &LMI) -> bool {
505 if (!LMI.LifeStart)
506 return false;
507 if (LMI.SinkLifeStart) {
508 LLVM_DEBUG(dbgs() << "Sinking lifetime.start: " << *LMI.LifeStart
509 << "\n");
510 SinkCands.insert(X: LMI.LifeStart);
511 }
512 if (LMI.HoistLifeEnd) {
513 LLVM_DEBUG(dbgs() << "Hoisting lifetime.end: " << *LMI.LifeEnd << "\n");
514 HoistCands.insert(X: LMI.LifeEnd);
515 }
516 return true;
517 };
518
519 // Look up allocas in the original function in CodeExtractorAnalysisCache, as
520 // this is much faster than walking all the instructions.
521 for (AllocaInst *AI : CEAC.getAllocas()) {
522 BasicBlock *BB = AI->getParent();
523 if (Blocks.count(key: BB))
524 continue;
525
526 // As a prior call to extractCodeRegion() may have shrinkwrapped the alloca,
527 // check whether it is actually still in the original function.
528 Function *AIFunc = BB->getParent();
529 if (AIFunc != Func)
530 continue;
531
532 LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(CEAC, Addr: AI, ExitBlock);
533 bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo);
534 if (Moved) {
535 LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n");
536 SinkCands.insert(X: AI);
537 continue;
538 }
539
540 // Find bitcasts in the outlined region that have lifetime marker users
541 // outside that region. Replace the lifetime marker use with an
542 // outside region bitcast to avoid unnecessary alloca/reload instructions
543 // and extra lifetime markers.
544 SmallVector<Instruction *, 2> LifetimeBitcastUsers;
545 for (User *U : AI->users()) {
546 if (!definedInRegion(Blocks, V: U))
547 continue;
548
549 if (U->stripInBoundsConstantOffsets() != AI)
550 continue;
551
552 Instruction *Bitcast = cast<Instruction>(Val: U);
553 for (User *BU : Bitcast->users()) {
554 auto *IntrInst = dyn_cast<LifetimeIntrinsic>(Val: BU);
555 if (!IntrInst)
556 continue;
557
558 if (definedInRegion(Blocks, V: IntrInst))
559 continue;
560
561 LLVM_DEBUG(dbgs() << "Replace use of extracted region bitcast"
562 << *Bitcast << " in out-of-region lifetime marker "
563 << *IntrInst << "\n");
564 LifetimeBitcastUsers.push_back(Elt: IntrInst);
565 }
566 }
567
568 for (Instruction *I : LifetimeBitcastUsers) {
569 Module *M = AIFunc->getParent();
570 LLVMContext &Ctx = M->getContext();
571 auto *Int8PtrTy = PointerType::getUnqual(C&: Ctx);
572 CastInst *CastI =
573 CastInst::CreatePointerCast(S: AI, Ty: Int8PtrTy, Name: "lt.cast", InsertBefore: I->getIterator());
574 I->replaceUsesOfWith(From: I->getOperand(i: 1), To: CastI);
575 }
576
577 // Follow any bitcasts.
578 SmallVector<Instruction *, 2> Bitcasts;
579 SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo;
580 for (User *U : AI->users()) {
581 if (U->stripInBoundsConstantOffsets() == AI) {
582 Instruction *Bitcast = cast<Instruction>(Val: U);
583 LifetimeMarkerInfo LMI = getLifetimeMarkers(CEAC, Addr: Bitcast, ExitBlock);
584 if (LMI.LifeStart) {
585 Bitcasts.push_back(Elt: Bitcast);
586 BitcastLifetimeInfo.push_back(Elt: LMI);
587 continue;
588 }
589 }
590
591 // Found unknown use of AI.
592 if (!definedInRegion(Blocks, V: U)) {
593 Bitcasts.clear();
594 break;
595 }
596 }
597
598 // Either no bitcasts reference the alloca or there are unknown uses.
599 if (Bitcasts.empty())
600 continue;
601
602 LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n");
603 SinkCands.insert(X: AI);
604 for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) {
605 Instruction *BitcastAddr = Bitcasts[I];
606 const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I];
607 assert(LMI.LifeStart &&
608 "Unsafe to sink bitcast without lifetime markers");
609 moveOrIgnoreLifetimeMarkers(LMI);
610 if (!definedInRegion(Blocks, V: BitcastAddr)) {
611 LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr
612 << "\n");
613 SinkCands.insert(X: BitcastAddr);
614 }
615 }
616 }
617}
618
619bool CodeExtractor::isEligible() const {
620 if (Blocks.empty())
621 return false;
622 BasicBlock *Header = *Blocks.begin();
623 Function *F = Header->getParent();
624
625 // For functions with varargs, check that varargs handling is only done in the
626 // outlined function, i.e vastart and vaend are only used in outlined blocks.
627 if (AllowVarArgs && F->getFunctionType()->isVarArg()) {
628 auto containsVarArgIntrinsic = [](const Instruction &I) {
629 if (const CallInst *CI = dyn_cast<CallInst>(Val: &I))
630 if (const Function *Callee = CI->getCalledFunction())
631 return Callee->getIntrinsicID() == Intrinsic::vastart ||
632 Callee->getIntrinsicID() == Intrinsic::vaend;
633 return false;
634 };
635
636 for (auto &BB : *F) {
637 if (Blocks.count(key: &BB))
638 continue;
639 if (llvm::any_of(Range&: BB, P: containsVarArgIntrinsic))
640 return false;
641 }
642 }
643 // stacksave as input implies stackrestore in the outlined function.
644 // This can confuse prolog epilog insertion phase.
645 // stacksave's uses must not cross outlined function.
646 for (BasicBlock *BB : Blocks) {
647 for (Instruction &I : *BB) {
648 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: &I);
649 if (!II)
650 continue;
651 bool IsSave = II->getIntrinsicID() == Intrinsic::stacksave;
652 bool IsRestore = II->getIntrinsicID() == Intrinsic::stackrestore;
653 if (IsSave && any_of(Range: II->users(), P: [&Blks = this->Blocks](User *U) {
654 return !definedInRegion(Blocks: Blks, V: U);
655 }))
656 return false;
657 if (IsRestore && !definedInRegion(Blocks, V: II->getArgOperand(i: 0)))
658 return false;
659 }
660 }
661 return true;
662}
663
664void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
665 const ValueSet &SinkCands,
666 bool CollectGlobalInputs) const {
667 for (BasicBlock *BB : Blocks) {
668 // If a used value is defined outside the region, it's an input. If an
669 // instruction is used outside the region, it's an output.
670 for (Instruction &II : *BB) {
671 for (auto &OI : II.operands()) {
672 Value *V = OI;
673 if (!SinkCands.count(key: V) &&
674 (definedInCaller(Blocks, V) ||
675 (CollectGlobalInputs && llvm::isa<llvm::GlobalVariable>(Val: V))))
676 Inputs.insert(X: V);
677 }
678
679 for (User *U : II.users())
680 if (!definedInRegion(Blocks, V: U)) {
681 Outputs.insert(X: &II);
682 break;
683 }
684 }
685 }
686}
687
688/// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
689/// of the region, we need to split the entry block of the region so that the
690/// PHI node is easier to deal with.
691void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) {
692 unsigned NumPredsFromRegion = 0;
693 unsigned NumPredsOutsideRegion = 0;
694
695 if (Header != &Header->getParent()->getEntryBlock()) {
696 PHINode *PN = dyn_cast<PHINode>(Val: Header->begin());
697 if (!PN) return; // No PHI nodes.
698
699 // If the header node contains any PHI nodes, check to see if there is more
700 // than one entry from outside the region. If so, we need to sever the
701 // header block into two.
702 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
703 if (Blocks.count(key: PN->getIncomingBlock(i)))
704 ++NumPredsFromRegion;
705 else
706 ++NumPredsOutsideRegion;
707
708 // If there is one (or fewer) predecessor from outside the region, we don't
709 // need to do anything special.
710 if (NumPredsOutsideRegion <= 1) return;
711 }
712
713 // Otherwise, we need to split the header block into two pieces: one
714 // containing PHI nodes merging values from outside of the region, and a
715 // second that contains all of the code for the block and merges back any
716 // incoming values from inside of the region.
717 BasicBlock *NewBB = SplitBlock(Old: Header, SplitPt: Header->getFirstNonPHIIt(), DT);
718
719 // We only want to code extract the second block now, and it becomes the new
720 // header of the region.
721 BasicBlock *OldPred = Header;
722 Blocks.remove(X: OldPred);
723 Blocks.insert(X: NewBB);
724 Header = NewBB;
725
726 // Okay, now we need to adjust the PHI nodes and any branches from within the
727 // region to go to the new header block instead of the old header block.
728 if (NumPredsFromRegion) {
729 PHINode *PN = cast<PHINode>(Val: OldPred->begin());
730 // Loop over all of the predecessors of OldPred that are in the region,
731 // changing them to branch to NewBB instead.
732 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
733 if (Blocks.count(key: PN->getIncomingBlock(i))) {
734 Instruction *TI = PN->getIncomingBlock(i)->getTerminator();
735 TI->replaceUsesOfWith(From: OldPred, To: NewBB);
736 }
737
738 // Okay, everything within the region is now branching to the right block, we
739 // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
740 BasicBlock::iterator AfterPHIs;
741 for (AfterPHIs = OldPred->begin(); isa<PHINode>(Val: AfterPHIs); ++AfterPHIs) {
742 PHINode *PN = cast<PHINode>(Val&: AfterPHIs);
743 // Create a new PHI node in the new region, which has an incoming value
744 // from OldPred of PN.
745 PHINode *NewPN = PHINode::Create(Ty: PN->getType(), NumReservedValues: 1 + NumPredsFromRegion,
746 NameStr: PN->getName() + ".ce");
747 NewPN->insertBefore(InsertPos: NewBB->begin());
748 PN->replaceAllUsesWith(V: NewPN);
749 NewPN->addIncoming(V: PN, BB: OldPred);
750
751 // Loop over all of the incoming value in PN, moving them to NewPN if they
752 // are from the extracted region.
753 for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) {
754 if (Blocks.count(key: PN->getIncomingBlock(i))) {
755 NewPN->addIncoming(V: PN->getIncomingValue(i), BB: PN->getIncomingBlock(i));
756 PN->removeIncomingValue(Idx: i);
757 --i;
758 }
759 }
760 }
761 }
762}
763
764/// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
765/// outlined region, we split these PHIs on two: one with inputs from region
766/// and other with remaining incoming blocks; then first PHIs are placed in
767/// outlined region.
768void CodeExtractor::severSplitPHINodesOfExits() {
769 for (BasicBlock *ExitBB : ExtractedFuncRetVals) {
770 BasicBlock *NewBB = nullptr;
771
772 for (PHINode &PN : ExitBB->phis()) {
773 // Find all incoming values from the outlining region.
774 SmallVector<unsigned, 2> IncomingVals;
775 for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i)
776 if (Blocks.count(key: PN.getIncomingBlock(i)))
777 IncomingVals.push_back(Elt: i);
778
779 // Do not process PHI if there is one (or fewer) predecessor from region.
780 // If PHI has exactly one predecessor from region, only this one incoming
781 // will be replaced on codeRepl block, so it should be safe to skip PHI.
782 if (IncomingVals.size() <= 1)
783 continue;
784
785 // Create block for new PHIs and add it to the list of outlined if it
786 // wasn't done before.
787 if (!NewBB) {
788 NewBB = BasicBlock::Create(Context&: ExitBB->getContext(),
789 Name: ExitBB->getName() + ".split",
790 Parent: ExitBB->getParent(), InsertBefore: ExitBB);
791 SmallVector<BasicBlock *, 4> Preds(predecessors(BB: ExitBB));
792 for (BasicBlock *PredBB : Preds)
793 if (Blocks.count(key: PredBB))
794 PredBB->getTerminator()->replaceUsesOfWith(From: ExitBB, To: NewBB);
795 BranchInst::Create(IfTrue: ExitBB, InsertBefore: NewBB);
796 Blocks.insert(X: NewBB);
797 }
798
799 // Split this PHI.
800 PHINode *NewPN = PHINode::Create(Ty: PN.getType(), NumReservedValues: IncomingVals.size(),
801 NameStr: PN.getName() + ".ce");
802 NewPN->insertBefore(InsertPos: NewBB->getFirstNonPHIIt());
803 for (unsigned i : IncomingVals)
804 NewPN->addIncoming(V: PN.getIncomingValue(i), BB: PN.getIncomingBlock(i));
805 for (unsigned i : reverse(C&: IncomingVals))
806 PN.removeIncomingValue(Idx: i, DeletePHIIfEmpty: false);
807 PN.addIncoming(V: NewPN, BB: NewBB);
808 }
809 }
810}
811
812void CodeExtractor::splitReturnBlocks() {
813 for (BasicBlock *Block : Blocks)
814 if (ReturnInst *RI = dyn_cast<ReturnInst>(Val: Block->getTerminator())) {
815 BasicBlock *New =
816 Block->splitBasicBlock(I: RI->getIterator(), BBName: Block->getName() + ".ret");
817 if (DT) {
818 // Old dominates New. New node dominates all other nodes dominated
819 // by Old.
820 DomTreeNode *OldNode = DT->getNode(BB: Block);
821 SmallVector<DomTreeNode *, 8> Children(OldNode->begin(),
822 OldNode->end());
823
824 DomTreeNode *NewNode = DT->addNewBlock(BB: New, DomBB: Block);
825
826 for (DomTreeNode *I : Children)
827 DT->changeImmediateDominator(N: I, NewIDom: NewNode);
828 }
829 }
830}
831
832Function *CodeExtractor::constructFunctionDeclaration(
833 const ValueSet &inputs, const ValueSet &outputs, BlockFrequency EntryFreq,
834 const Twine &Name, ValueSet &StructValues, StructType *&StructTy) {
835 LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
836 LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
837
838 Function *oldFunction = Blocks.front()->getParent();
839 Module *M = Blocks.front()->getModule();
840
841 // Assemble the function's parameter lists.
842 std::vector<Type *> ParamTy;
843 std::vector<Type *> AggParamTy;
844 const DataLayout &DL = M->getDataLayout();
845
846 // Add the types of the input values to the function's argument list
847 for (Value *value : inputs) {
848 LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
849 if (AggregateArgs && !ExcludeArgsFromAggregate.contains(key: value)) {
850 AggParamTy.push_back(x: value->getType());
851 StructValues.insert(X: value);
852 } else
853 ParamTy.push_back(x: value->getType());
854 }
855
856 // Add the types of the output values to the function's argument list.
857 for (Value *output : outputs) {
858 LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
859 if (AggregateArgs && !ExcludeArgsFromAggregate.contains(key: output)) {
860 AggParamTy.push_back(x: output->getType());
861 StructValues.insert(X: output);
862 } else
863 ParamTy.push_back(
864 x: PointerType::get(C&: output->getContext(), AddressSpace: DL.getAllocaAddrSpace()));
865 }
866
867 assert(
868 (ParamTy.size() + AggParamTy.size()) ==
869 (inputs.size() + outputs.size()) &&
870 "Number of scalar and aggregate params does not match inputs, outputs");
871 assert((StructValues.empty() || AggregateArgs) &&
872 "Expeced StructValues only with AggregateArgs set");
873
874 // Concatenate scalar and aggregate params in ParamTy.
875 if (!AggParamTy.empty()) {
876 StructTy = StructType::get(Context&: M->getContext(), Elements: AggParamTy);
877 ParamTy.push_back(x: PointerType::get(
878 C&: M->getContext(), AddressSpace: ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace()));
879 }
880
881 Type *RetTy = getSwitchType();
882 LLVM_DEBUG({
883 dbgs() << "Function type: " << *RetTy << " f(";
884 for (Type *i : ParamTy)
885 dbgs() << *i << ", ";
886 dbgs() << ")\n";
887 });
888
889 FunctionType *funcType = FunctionType::get(
890 Result: RetTy, Params: ParamTy, isVarArg: AllowVarArgs && oldFunction->isVarArg());
891
892 // Create the new function
893 Function *newFunction =
894 Function::Create(Ty: funcType, Linkage: GlobalValue::InternalLinkage,
895 AddrSpace: oldFunction->getAddressSpace(), N: Name, M);
896
897 // Propagate personality info to the new function if there is one.
898 if (oldFunction->hasPersonalityFn())
899 newFunction->setPersonalityFn(oldFunction->getPersonalityFn());
900
901 // Inherit all of the target dependent attributes and white-listed
902 // target independent attributes.
903 // (e.g. If the extracted region contains a call to an x86.sse
904 // instruction we need to make sure that the extracted region has the
905 // "target-features" attribute allowing it to be lowered.
906 // FIXME: This should be changed to check to see if a specific
907 // attribute can not be inherited.
908 for (const auto &Attr : oldFunction->getAttributes().getFnAttrs()) {
909 if (Attr.isStringAttribute()) {
910 if (Attr.getKindAsString() == "thunk")
911 continue;
912 } else
913 switch (Attr.getKindAsEnum()) {
914 // Those attributes cannot be propagated safely. Explicitly list them
915 // here so we get a warning if new attributes are added.
916 case Attribute::AllocSize:
917 case Attribute::Builtin:
918 case Attribute::Convergent:
919 case Attribute::JumpTable:
920 case Attribute::Naked:
921 case Attribute::NoBuiltin:
922 case Attribute::NoMerge:
923 case Attribute::NoReturn:
924 case Attribute::NoSync:
925 case Attribute::ReturnsTwice:
926 case Attribute::Speculatable:
927 case Attribute::StackAlignment:
928 case Attribute::WillReturn:
929 case Attribute::AllocKind:
930 case Attribute::PresplitCoroutine:
931 case Attribute::Memory:
932 case Attribute::NoFPClass:
933 case Attribute::CoroDestroyOnlyWhenComplete:
934 case Attribute::CoroElideSafe:
935 case Attribute::NoDivergenceSource:
936 continue;
937 // Those attributes should be safe to propagate to the extracted function.
938 case Attribute::AlwaysInline:
939 case Attribute::Cold:
940 case Attribute::DisableSanitizerInstrumentation:
941 case Attribute::FnRetThunkExtern:
942 case Attribute::Hot:
943 case Attribute::HybridPatchable:
944 case Attribute::NoRecurse:
945 case Attribute::InlineHint:
946 case Attribute::MinSize:
947 case Attribute::NoCallback:
948 case Attribute::NoDuplicate:
949 case Attribute::NoFree:
950 case Attribute::NoImplicitFloat:
951 case Attribute::NoInline:
952 case Attribute::NonLazyBind:
953 case Attribute::NoRedZone:
954 case Attribute::NoUnwind:
955 case Attribute::NoSanitizeBounds:
956 case Attribute::NoSanitizeCoverage:
957 case Attribute::NullPointerIsValid:
958 case Attribute::OptimizeForDebugging:
959 case Attribute::OptForFuzzing:
960 case Attribute::OptimizeNone:
961 case Attribute::OptimizeForSize:
962 case Attribute::SafeStack:
963 case Attribute::ShadowCallStack:
964 case Attribute::SanitizeAddress:
965 case Attribute::SanitizeMemory:
966 case Attribute::SanitizeNumericalStability:
967 case Attribute::SanitizeThread:
968 case Attribute::SanitizeType:
969 case Attribute::SanitizeHWAddress:
970 case Attribute::SanitizeMemTag:
971 case Attribute::SanitizeRealtime:
972 case Attribute::SanitizeRealtimeBlocking:
973 case Attribute::SpeculativeLoadHardening:
974 case Attribute::StackProtect:
975 case Attribute::StackProtectReq:
976 case Attribute::StackProtectStrong:
977 case Attribute::StrictFP:
978 case Attribute::UWTable:
979 case Attribute::VScaleRange:
980 case Attribute::NoCfCheck:
981 case Attribute::MustProgress:
982 case Attribute::NoProfile:
983 case Attribute::SkipProfile:
984 break;
985 // These attributes cannot be applied to functions.
986 case Attribute::Alignment:
987 case Attribute::AllocatedPointer:
988 case Attribute::AllocAlign:
989 case Attribute::ByVal:
990 case Attribute::Captures:
991 case Attribute::Dereferenceable:
992 case Attribute::DereferenceableOrNull:
993 case Attribute::ElementType:
994 case Attribute::InAlloca:
995 case Attribute::InReg:
996 case Attribute::Nest:
997 case Attribute::NoAlias:
998 case Attribute::NoUndef:
999 case Attribute::NonNull:
1000 case Attribute::Preallocated:
1001 case Attribute::ReadNone:
1002 case Attribute::ReadOnly:
1003 case Attribute::Returned:
1004 case Attribute::SExt:
1005 case Attribute::StructRet:
1006 case Attribute::SwiftError:
1007 case Attribute::SwiftSelf:
1008 case Attribute::SwiftAsync:
1009 case Attribute::ZExt:
1010 case Attribute::ImmArg:
1011 case Attribute::ByRef:
1012 case Attribute::WriteOnly:
1013 case Attribute::Writable:
1014 case Attribute::DeadOnUnwind:
1015 case Attribute::Range:
1016 case Attribute::Initializes:
1017 case Attribute::NoExt:
1018 // These are not really attributes.
1019 case Attribute::None:
1020 case Attribute::EndAttrKinds:
1021 case Attribute::EmptyKey:
1022 case Attribute::TombstoneKey:
1023 case Attribute::DeadOnReturn:
1024 llvm_unreachable("Not a function attribute");
1025 }
1026
1027 newFunction->addFnAttr(Attr);
1028 }
1029
1030 // Create scalar and aggregate iterators to name all of the arguments we
1031 // inserted.
1032 Function::arg_iterator ScalarAI = newFunction->arg_begin();
1033
1034 // Set names and attributes for input and output arguments.
1035 ScalarAI = newFunction->arg_begin();
1036 for (Value *input : inputs) {
1037 if (StructValues.contains(key: input))
1038 continue;
1039
1040 ScalarAI->setName(input->getName());
1041 if (input->isSwiftError())
1042 newFunction->addParamAttr(ArgNo: ScalarAI - newFunction->arg_begin(),
1043 Kind: Attribute::SwiftError);
1044 ++ScalarAI;
1045 }
1046 for (Value *output : outputs) {
1047 if (StructValues.contains(key: output))
1048 continue;
1049
1050 ScalarAI->setName(output->getName() + ".out");
1051 ++ScalarAI;
1052 }
1053
1054 // Update the entry count of the function.
1055 if (BFI) {
1056 auto Count = BFI->getProfileCountFromFreq(Freq: EntryFreq);
1057 if (Count.has_value())
1058 newFunction->setEntryCount(
1059 Count: ProfileCount(*Count, Function::PCT_Real)); // FIXME
1060 }
1061
1062 return newFunction;
1063}
1064
1065/// If the original function has debug info, we have to add a debug location
1066/// to the new branch instruction from the artificial entry block.
1067/// We use the debug location of the first instruction in the extracted
1068/// blocks, as there is no other equivalent line in the source code.
1069static void applyFirstDebugLoc(Function *oldFunction,
1070 ArrayRef<BasicBlock *> Blocks,
1071 Instruction *BranchI) {
1072 if (oldFunction->getSubprogram()) {
1073 any_of(Range&: Blocks, P: [&BranchI](const BasicBlock *BB) {
1074 return any_of(Range: *BB, P: [&BranchI](const Instruction &I) {
1075 if (!I.getDebugLoc())
1076 return false;
1077 BranchI->setDebugLoc(I.getDebugLoc());
1078 return true;
1079 });
1080 });
1081 }
1082}
1083
1084/// Erase lifetime.start markers which reference inputs to the extraction
1085/// region, and insert the referenced memory into \p LifetimesStart.
1086///
1087/// The extraction region is defined by a set of blocks (\p Blocks), and a set
1088/// of allocas which will be moved from the caller function into the extracted
1089/// function (\p SunkAllocas).
1090static void eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks,
1091 const SetVector<Value *> &SunkAllocas,
1092 SetVector<Value *> &LifetimesStart) {
1093 for (BasicBlock *BB : Blocks) {
1094 for (Instruction &I : llvm::make_early_inc_range(Range&: *BB)) {
1095 auto *II = dyn_cast<LifetimeIntrinsic>(Val: &I);
1096 if (!II)
1097 continue;
1098
1099 // Get the memory operand of the lifetime marker. If the underlying
1100 // object is a sunk alloca, or is otherwise defined in the extraction
1101 // region, the lifetime marker must not be erased.
1102 Value *Mem = II->getOperand(i_nocapture: 1)->stripInBoundsOffsets();
1103 if (SunkAllocas.count(key: Mem) || definedInRegion(Blocks, V: Mem))
1104 continue;
1105
1106 if (II->getIntrinsicID() == Intrinsic::lifetime_start)
1107 LifetimesStart.insert(X: Mem);
1108 II->eraseFromParent();
1109 }
1110 }
1111}
1112
1113/// Insert lifetime start/end markers surrounding the call to the new function
1114/// for objects defined in the caller.
1115static void insertLifetimeMarkersSurroundingCall(
1116 Module *M, ArrayRef<Value *> LifetimesStart, ArrayRef<Value *> LifetimesEnd,
1117 CallInst *TheCall) {
1118 LLVMContext &Ctx = M->getContext();
1119 auto NegativeOne = ConstantInt::getSigned(Ty: Type::getInt64Ty(C&: Ctx), V: -1);
1120 Instruction *Term = TheCall->getParent()->getTerminator();
1121
1122 // Emit lifetime markers for the pointers given in \p Objects. Insert the
1123 // markers before the call if \p InsertBefore, and after the call otherwise.
1124 auto insertMarkers = [&](Intrinsic::ID MarkerFunc, ArrayRef<Value *> Objects,
1125 bool InsertBefore) {
1126 for (Value *Mem : Objects) {
1127 assert((!isa<Instruction>(Mem) || cast<Instruction>(Mem)->getFunction() ==
1128 TheCall->getFunction()) &&
1129 "Input memory not defined in original function");
1130
1131 Function *Func =
1132 Intrinsic::getOrInsertDeclaration(M, id: MarkerFunc, Tys: Mem->getType());
1133 auto Marker = CallInst::Create(Func, Args: {NegativeOne, Mem});
1134 if (InsertBefore)
1135 Marker->insertBefore(InsertPos: TheCall->getIterator());
1136 else
1137 Marker->insertBefore(InsertPos: Term->getIterator());
1138 }
1139 };
1140
1141 if (!LifetimesStart.empty()) {
1142 insertMarkers(Intrinsic::lifetime_start, LifetimesStart,
1143 /*InsertBefore=*/true);
1144 }
1145
1146 if (!LifetimesEnd.empty()) {
1147 insertMarkers(Intrinsic::lifetime_end, LifetimesEnd,
1148 /*InsertBefore=*/false);
1149 }
1150}
1151
1152void CodeExtractor::moveCodeToFunction(Function *newFunction) {
1153 auto newFuncIt = newFunction->begin();
1154 for (BasicBlock *Block : Blocks) {
1155 // Delete the basic block from the old function, and the list of blocks
1156 Block->removeFromParent();
1157
1158 // Insert this basic block into the new function
1159 // Insert the original blocks after the entry block created
1160 // for the new function. The entry block may be followed
1161 // by a set of exit blocks at this point, but these exit
1162 // blocks better be placed at the end of the new function.
1163 newFuncIt = newFunction->insert(Position: std::next(x: newFuncIt), BB: Block);
1164 }
1165}
1166
1167void CodeExtractor::calculateNewCallTerminatorWeights(
1168 BasicBlock *CodeReplacer,
1169 const DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
1170 BranchProbabilityInfo *BPI) {
1171 using Distribution = BlockFrequencyInfoImplBase::Distribution;
1172 using BlockNode = BlockFrequencyInfoImplBase::BlockNode;
1173
1174 // Update the branch weights for the exit block.
1175 Instruction *TI = CodeReplacer->getTerminator();
1176 SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
1177
1178 // Block Frequency distribution with dummy node.
1179 Distribution BranchDist;
1180
1181 SmallVector<BranchProbability, 4> EdgeProbabilities(
1182 TI->getNumSuccessors(), BranchProbability::getUnknown());
1183
1184 // Add each of the frequencies of the successors.
1185 for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
1186 BlockNode ExitNode(i);
1187 uint64_t ExitFreq = ExitWeights.lookup(Val: TI->getSuccessor(Idx: i)).getFrequency();
1188 if (ExitFreq != 0)
1189 BranchDist.addExit(Node: ExitNode, Amount: ExitFreq);
1190 else
1191 EdgeProbabilities[i] = BranchProbability::getZero();
1192 }
1193
1194 // Check for no total weight.
1195 if (BranchDist.Total == 0) {
1196 BPI->setEdgeProbability(Src: CodeReplacer, Probs: EdgeProbabilities);
1197 return;
1198 }
1199
1200 // Normalize the distribution so that they can fit in unsigned.
1201 BranchDist.normalize();
1202
1203 // Create normalized branch weights and set the metadata.
1204 for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
1205 const auto &Weight = BranchDist.Weights[I];
1206
1207 // Get the weight and update the current BFI.
1208 BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
1209 BranchProbability BP(Weight.Amount, BranchDist.Total);
1210 EdgeProbabilities[Weight.TargetNode.Index] = BP;
1211 }
1212 BPI->setEdgeProbability(Src: CodeReplacer, Probs: EdgeProbabilities);
1213 TI->setMetadata(
1214 KindID: LLVMContext::MD_prof,
1215 Node: MDBuilder(TI->getContext()).createBranchWeights(Weights: BranchWeights));
1216}
1217
1218/// Erase debug info intrinsics which refer to values in \p F but aren't in
1219/// \p F.
1220static void eraseDebugIntrinsicsWithNonLocalRefs(Function &F) {
1221 for (Instruction &I : instructions(F)) {
1222 SmallVector<DbgVariableIntrinsic *, 4> DbgUsers;
1223 SmallVector<DbgVariableRecord *, 4> DbgVariableRecords;
1224 findDbgUsers(DbgInsts&: DbgUsers, V: &I, DbgVariableRecords: &DbgVariableRecords);
1225 for (DbgVariableIntrinsic *DVI : DbgUsers)
1226 if (DVI->getFunction() != &F)
1227 DVI->eraseFromParent();
1228 for (DbgVariableRecord *DVR : DbgVariableRecords)
1229 if (DVR->getFunction() != &F)
1230 DVR->eraseFromParent();
1231 }
1232}
1233
1234/// Fix up the debug info in the old and new functions. Following changes are
1235/// done.
1236/// 1. If a debug record points to a value that has been replaced, update the
1237/// record to use the new value.
1238/// 2. If an Input value that has been replaced was used as a location of a
1239/// debug record in the Parent function, then materealize a similar record in
1240/// the new function.
1241/// 3. Point line locations and debug intrinsics to the new subprogram scope
1242/// 4. Remove intrinsics which point to values outside of the new function.
1243static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc,
1244 CallInst &TheCall,
1245 const SetVector<Value *> &Inputs,
1246 ArrayRef<Value *> NewValues) {
1247 DISubprogram *OldSP = OldFunc.getSubprogram();
1248 LLVMContext &Ctx = OldFunc.getContext();
1249
1250 if (!OldSP) {
1251 // Erase any debug info the new function contains.
1252 stripDebugInfo(F&: NewFunc);
1253 // Make sure the old function doesn't contain any non-local metadata refs.
1254 eraseDebugIntrinsicsWithNonLocalRefs(F&: NewFunc);
1255 return;
1256 }
1257
1258 // Create a subprogram for the new function. Leave out a description of the
1259 // function arguments, as the parameters don't correspond to anything at the
1260 // source level.
1261 assert(OldSP->getUnit() && "Missing compile unit for subprogram");
1262 DIBuilder DIB(*OldFunc.getParent(), /*AllowUnresolved=*/false,
1263 OldSP->getUnit());
1264 auto SPType = DIB.createSubroutineType(ParameterTypes: DIB.getOrCreateTypeArray(Elements: {}));
1265 DISubprogram::DISPFlags SPFlags = DISubprogram::SPFlagDefinition |
1266 DISubprogram::SPFlagOptimized |
1267 DISubprogram::SPFlagLocalToUnit;
1268 auto NewSP = DIB.createFunction(
1269 Scope: OldSP->getUnit(), Name: NewFunc.getName(), LinkageName: NewFunc.getName(), File: OldSP->getFile(),
1270 /*LineNo=*/0, Ty: SPType, /*ScopeLine=*/0, Flags: DINode::FlagZero, SPFlags);
1271 NewFunc.setSubprogram(NewSP);
1272
1273 auto UpdateOrInsertDebugRecord = [&](auto *DR, Value *OldLoc, Value *NewLoc,
1274 DIExpression *Expr, bool Declare) {
1275 if (DR->getParent()->getParent() == &NewFunc) {
1276 DR->replaceVariableLocationOp(OldLoc, NewLoc);
1277 return;
1278 }
1279 if (Declare) {
1280 DIB.insertDeclare(NewLoc, DR->getVariable(), Expr, DR->getDebugLoc(),
1281 &NewFunc.getEntryBlock());
1282 return;
1283 }
1284 DIB.insertDbgValueIntrinsic(
1285 Val: NewLoc, VarInfo: DR->getVariable(), Expr, DL: DR->getDebugLoc(),
1286 InsertPt: NewFunc.getEntryBlock().getTerminator()->getIterator());
1287 };
1288 for (auto [Input, NewVal] : zip_equal(t: Inputs, u&: NewValues)) {
1289 SmallVector<DbgVariableIntrinsic *, 1> DbgUsers;
1290 SmallVector<DbgVariableRecord *, 1> DPUsers;
1291 findDbgUsers(DbgInsts&: DbgUsers, V: Input, DbgVariableRecords: &DPUsers);
1292 DIExpression *Expr = DIB.createExpression();
1293
1294 // Iterate the debud users of the Input values. If they are in the extracted
1295 // function then update their location with the new value. If they are in
1296 // the parent function then create a similar debug record.
1297 for (auto *DVI : DbgUsers)
1298 UpdateOrInsertDebugRecord(DVI, Input, NewVal, Expr,
1299 isa<DbgDeclareInst>(Val: DVI));
1300 for (auto *DVR : DPUsers)
1301 UpdateOrInsertDebugRecord(DVR, Input, NewVal, Expr, DVR->isDbgDeclare());
1302 }
1303
1304 auto IsInvalidLocation = [&NewFunc](Value *Location) {
1305 // Location is invalid if it isn't a constant, an instruction or an
1306 // argument, or is an instruction/argument but isn't in the new function.
1307 if (!Location || (!isa<Constant>(Val: Location) && !isa<Argument>(Val: Location) &&
1308 !isa<Instruction>(Val: Location)))
1309 return true;
1310
1311 if (Argument *Arg = dyn_cast<Argument>(Val: Location))
1312 return Arg->getParent() != &NewFunc;
1313 if (Instruction *LocationInst = dyn_cast<Instruction>(Val: Location))
1314 return LocationInst->getFunction() != &NewFunc;
1315 return false;
1316 };
1317
1318 // Debug intrinsics in the new function need to be updated in one of two
1319 // ways:
1320 // 1) They need to be deleted, because they describe a value in the old
1321 // function.
1322 // 2) They need to point to fresh metadata, e.g. because they currently
1323 // point to a variable in the wrong scope.
1324 SmallDenseMap<DINode *, DINode *> RemappedMetadata;
1325 SmallVector<DbgVariableRecord *, 4> DVRsToDelete;
1326 DenseMap<const MDNode *, MDNode *> Cache;
1327
1328 auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar) {
1329 DINode *&NewVar = RemappedMetadata[OldVar];
1330 if (!NewVar) {
1331 DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram(
1332 RootScope&: *OldVar->getScope(), NewSP&: *NewSP, Ctx, Cache);
1333 NewVar = DIB.createAutoVariable(
1334 Scope: NewScope, Name: OldVar->getName(), File: OldVar->getFile(), LineNo: OldVar->getLine(),
1335 Ty: OldVar->getType(), /*AlwaysPreserve=*/false, Flags: DINode::FlagZero,
1336 AlignInBits: OldVar->getAlignInBits());
1337 }
1338 return cast<DILocalVariable>(Val: NewVar);
1339 };
1340
1341 auto UpdateDbgLabel = [&](auto *LabelRecord) {
1342 // Point the label record to a fresh label within the new function if
1343 // the record was not inlined from some other function.
1344 if (LabelRecord->getDebugLoc().getInlinedAt())
1345 return;
1346 DILabel *OldLabel = LabelRecord->getLabel();
1347 DINode *&NewLabel = RemappedMetadata[OldLabel];
1348 if (!NewLabel) {
1349 DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram(
1350 RootScope&: *OldLabel->getScope(), NewSP&: *NewSP, Ctx, Cache);
1351 NewLabel =
1352 DILabel::get(Context&: Ctx, Scope: NewScope, Name: OldLabel->getName(), File: OldLabel->getFile(),
1353 Line: OldLabel->getLine(), Column: OldLabel->getColumn(),
1354 IsArtificial: OldLabel->isArtificial(), CoroSuspendIdx: OldLabel->getCoroSuspendIdx());
1355 }
1356 LabelRecord->setLabel(cast<DILabel>(Val: NewLabel));
1357 };
1358
1359 auto UpdateDbgRecordsOnInst = [&](Instruction &I) -> void {
1360 for (DbgRecord &DR : I.getDbgRecordRange()) {
1361 if (DbgLabelRecord *DLR = dyn_cast<DbgLabelRecord>(Val: &DR)) {
1362 UpdateDbgLabel(DLR);
1363 continue;
1364 }
1365
1366 DbgVariableRecord &DVR = cast<DbgVariableRecord>(Val&: DR);
1367 // If any of the used locations are invalid, delete the record.
1368 if (any_of(Range: DVR.location_ops(), P: IsInvalidLocation)) {
1369 DVRsToDelete.push_back(Elt: &DVR);
1370 continue;
1371 }
1372
1373 // DbgAssign intrinsics have an extra Value argument:
1374 if (DVR.isDbgAssign() && IsInvalidLocation(DVR.getAddress())) {
1375 DVRsToDelete.push_back(Elt: &DVR);
1376 continue;
1377 }
1378
1379 // If the variable was in the scope of the old function, i.e. it was not
1380 // inlined, point the intrinsic to a fresh variable within the new
1381 // function.
1382 if (!DVR.getDebugLoc().getInlinedAt())
1383 DVR.setVariable(GetUpdatedDIVariable(DVR.getVariable()));
1384 }
1385 };
1386
1387 for (Instruction &I : instructions(F&: NewFunc))
1388 UpdateDbgRecordsOnInst(I);
1389
1390 for (auto *DVR : DVRsToDelete)
1391 DVR->getMarker()->MarkedInstr->dropOneDbgRecord(I: DVR);
1392 DIB.finalizeSubprogram(SP: NewSP);
1393
1394 // Fix up the scope information attached to the line locations and the
1395 // debug assignment metadata in the new function.
1396 DenseMap<DIAssignID *, DIAssignID *> AssignmentIDMap;
1397 for (Instruction &I : instructions(F&: NewFunc)) {
1398 if (const DebugLoc &DL = I.getDebugLoc())
1399 I.setDebugLoc(
1400 DebugLoc::replaceInlinedAtSubprogram(DL, NewSP&: *NewSP, Ctx, Cache));
1401 for (DbgRecord &DR : I.getDbgRecordRange())
1402 DR.setDebugLoc(DebugLoc::replaceInlinedAtSubprogram(DL: DR.getDebugLoc(),
1403 NewSP&: *NewSP, Ctx, Cache));
1404
1405 // Loop info metadata may contain line locations. Fix them up.
1406 auto updateLoopInfoLoc = [&Ctx, &Cache, NewSP](Metadata *MD) -> Metadata * {
1407 if (auto *Loc = dyn_cast_or_null<DILocation>(Val: MD))
1408 return DebugLoc::replaceInlinedAtSubprogram(DL: Loc, NewSP&: *NewSP, Ctx, Cache);
1409 return MD;
1410 };
1411 updateLoopMetadataDebugLocations(I, Updater: updateLoopInfoLoc);
1412 at::remapAssignID(Map&: AssignmentIDMap, I);
1413 }
1414 if (!TheCall.getDebugLoc())
1415 TheCall.setDebugLoc(DILocation::get(Context&: Ctx, Line: 0, Column: 0, Scope: OldSP));
1416
1417 eraseDebugIntrinsicsWithNonLocalRefs(F&: NewFunc);
1418}
1419
1420Function *
1421CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) {
1422 ValueSet Inputs, Outputs;
1423 return extractCodeRegion(CEAC, Inputs, Outputs);
1424}
1425
1426Function *
1427CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
1428 ValueSet &inputs, ValueSet &outputs) {
1429 if (!isEligible())
1430 return nullptr;
1431
1432 // Assumption: this is a single-entry code region, and the header is the first
1433 // block in the region.
1434 BasicBlock *header = *Blocks.begin();
1435 Function *oldFunction = header->getParent();
1436
1437 normalizeCFGForExtraction(header);
1438
1439 // Remove @llvm.assume calls that will be moved to the new function from the
1440 // old function's assumption cache.
1441 for (BasicBlock *Block : Blocks) {
1442 for (Instruction &I : llvm::make_early_inc_range(Range&: *Block)) {
1443 if (auto *AI = dyn_cast<AssumeInst>(Val: &I)) {
1444 if (AC)
1445 AC->unregisterAssumption(CI: AI);
1446 AI->eraseFromParent();
1447 }
1448 }
1449 }
1450
1451 ValueSet SinkingCands, HoistingCands;
1452 BasicBlock *CommonExit = nullptr;
1453 findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
1454 assert(HoistingCands.empty() || CommonExit);
1455
1456 // Find inputs to, outputs from the code region.
1457 findInputsOutputs(Inputs&: inputs, Outputs&: outputs, SinkCands: SinkingCands);
1458
1459 // Collect objects which are inputs to the extraction region and also
1460 // referenced by lifetime start markers within it. The effects of these
1461 // markers must be replicated in the calling function to prevent the stack
1462 // coloring pass from merging slots which store input objects.
1463 ValueSet LifetimesStart;
1464 eraseLifetimeMarkersOnInputs(Blocks, SunkAllocas: SinkingCands, LifetimesStart);
1465
1466 if (!HoistingCands.empty()) {
1467 auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExitBlock: CommonExit);
1468 Instruction *TI = HoistToBlock->getTerminator();
1469 for (auto *II : HoistingCands)
1470 cast<Instruction>(Val: II)->moveBefore(InsertPos: TI->getIterator());
1471 computeExtractedFuncRetVals();
1472 }
1473
1474 // CFG/ExitBlocks must not change hereafter
1475
1476 // Calculate the entry frequency of the new function before we change the root
1477 // block.
1478 BlockFrequency EntryFreq;
1479 DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
1480 if (BFI) {
1481 assert(BPI && "Both BPI and BFI are required to preserve profile info");
1482 for (BasicBlock *Pred : predecessors(BB: header)) {
1483 if (Blocks.count(key: Pred))
1484 continue;
1485 EntryFreq +=
1486 BFI->getBlockFreq(BB: Pred) * BPI->getEdgeProbability(Src: Pred, Dst: header);
1487 }
1488
1489 for (BasicBlock *Succ : ExtractedFuncRetVals) {
1490 for (BasicBlock *Block : predecessors(BB: Succ)) {
1491 if (!Blocks.count(key: Block))
1492 continue;
1493
1494 // Update the branch weight for this successor.
1495 BlockFrequency &BF = ExitWeights[Succ];
1496 BF += BFI->getBlockFreq(BB: Block) * BPI->getEdgeProbability(Src: Block, Dst: Succ);
1497 }
1498 }
1499 }
1500
1501 // Determine position for the replacement code. Do so before header is moved
1502 // to the new function.
1503 BasicBlock *ReplIP = header;
1504 while (ReplIP && Blocks.count(key: ReplIP))
1505 ReplIP = ReplIP->getNextNode();
1506
1507 // Construct new function based on inputs/outputs & add allocas for all defs.
1508 std::string SuffixToUse =
1509 Suffix.empty()
1510 ? (header->getName().empty() ? "extracted" : header->getName().str())
1511 : Suffix;
1512
1513 ValueSet StructValues;
1514 StructType *StructTy = nullptr;
1515 Function *newFunction = constructFunctionDeclaration(
1516 inputs, outputs, EntryFreq, Name: oldFunction->getName() + "." + SuffixToUse,
1517 StructValues, StructTy);
1518 SmallVector<Value *> NewValues;
1519
1520 emitFunctionBody(inputs, outputs, StructValues, newFunction, StructArgTy: StructTy, header,
1521 SinkingCands, NewValues);
1522
1523 std::vector<Value *> Reloads;
1524 CallInst *TheCall = emitReplacerCall(
1525 inputs, outputs, StructValues, newFunction, StructArgTy: StructTy, oldFunction, ReplIP,
1526 EntryFreq, LifetimesStart: LifetimesStart.getArrayRef(), Reloads);
1527
1528 insertReplacerCall(oldFunction, header, codeReplacer: TheCall->getParent(), outputs,
1529 Reloads, ExitWeights);
1530
1531 fixupDebugInfoPostExtraction(OldFunc&: *oldFunction, NewFunc&: *newFunction, TheCall&: *TheCall, Inputs: inputs,
1532 NewValues);
1533
1534 LLVM_DEBUG(llvm::dbgs() << "After extractCodeRegion - newFunction:\n");
1535 LLVM_DEBUG(newFunction->dump());
1536 LLVM_DEBUG(llvm::dbgs() << "After extractCodeRegion - oldFunction:\n");
1537 LLVM_DEBUG(oldFunction->dump());
1538 LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, *newFunction, AC))
1539 report_fatal_error("Stale Asumption cache for old Function!"));
1540 return newFunction;
1541}
1542
1543void CodeExtractor::normalizeCFGForExtraction(BasicBlock *&header) {
1544 // If we have any return instructions in the region, split those blocks so
1545 // that the return is not in the region.
1546 splitReturnBlocks();
1547
1548 // If we have to split PHI nodes of the entry or exit blocks, do so now.
1549 severSplitPHINodesOfEntry(Header&: header);
1550
1551 // If a PHI in an exit block has multiple incoming values from the outlined
1552 // region, create a new PHI for those values within the region such that only
1553 // PHI itself becomes an output value, not each of its incoming values
1554 // individually.
1555 computeExtractedFuncRetVals();
1556 severSplitPHINodesOfExits();
1557}
1558
1559void CodeExtractor::computeExtractedFuncRetVals() {
1560 ExtractedFuncRetVals.clear();
1561
1562 SmallPtrSet<BasicBlock *, 2> ExitBlocks;
1563 for (BasicBlock *Block : Blocks) {
1564 for (BasicBlock *Succ : successors(BB: Block)) {
1565 if (Blocks.count(key: Succ))
1566 continue;
1567
1568 bool IsNew = ExitBlocks.insert(Ptr: Succ).second;
1569 if (IsNew)
1570 ExtractedFuncRetVals.push_back(Elt: Succ);
1571 }
1572 }
1573}
1574
1575Type *CodeExtractor::getSwitchType() {
1576 LLVMContext &Context = Blocks.front()->getContext();
1577
1578 assert(ExtractedFuncRetVals.size() < 0xffff &&
1579 "too many exit blocks for switch");
1580 switch (ExtractedFuncRetVals.size()) {
1581 case 0:
1582 case 1:
1583 return Type::getVoidTy(C&: Context);
1584 case 2:
1585 // Conditional branch, return a bool
1586 return Type::getInt1Ty(C&: Context);
1587 default:
1588 return Type::getInt16Ty(C&: Context);
1589 }
1590}
1591
1592void CodeExtractor::emitFunctionBody(
1593 const ValueSet &inputs, const ValueSet &outputs,
1594 const ValueSet &StructValues, Function *newFunction,
1595 StructType *StructArgTy, BasicBlock *header, const ValueSet &SinkingCands,
1596 SmallVectorImpl<Value *> &NewValues) {
1597 Function *oldFunction = header->getParent();
1598 LLVMContext &Context = oldFunction->getContext();
1599
1600 // The new function needs a root node because other nodes can branch to the
1601 // head of the region, but the entry node of a function cannot have preds.
1602 BasicBlock *newFuncRoot =
1603 BasicBlock::Create(Context, Name: "newFuncRoot", Parent: newFunction);
1604
1605 // Now sink all instructions which only have non-phi uses inside the region.
1606 // Group the allocas at the start of the block, so that any bitcast uses of
1607 // the allocas are well-defined.
1608 for (auto *II : SinkingCands) {
1609 if (!isa<AllocaInst>(Val: II)) {
1610 cast<Instruction>(Val: II)->moveBefore(BB&: *newFuncRoot,
1611 I: newFuncRoot->getFirstInsertionPt());
1612 }
1613 }
1614 for (auto *II : SinkingCands) {
1615 if (auto *AI = dyn_cast<AllocaInst>(Val: II)) {
1616 AI->moveBefore(BB&: *newFuncRoot, I: newFuncRoot->getFirstInsertionPt());
1617 }
1618 }
1619
1620 Function::arg_iterator ScalarAI = newFunction->arg_begin();
1621 Argument *AggArg = StructValues.empty()
1622 ? nullptr
1623 : newFunction->getArg(i: newFunction->arg_size() - 1);
1624
1625 // Rewrite all users of the inputs in the extracted region to use the
1626 // arguments (or appropriate addressing into struct) instead.
1627 for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
1628 Value *RewriteVal;
1629 if (StructValues.contains(key: inputs[i])) {
1630 Value *Idx[2];
1631 Idx[0] = Constant::getNullValue(Ty: Type::getInt32Ty(C&: header->getContext()));
1632 Idx[1] = ConstantInt::get(Ty: Type::getInt32Ty(C&: header->getContext()), V: aggIdx);
1633 GetElementPtrInst *GEP = GetElementPtrInst::Create(
1634 PointeeType: StructArgTy, Ptr: AggArg, IdxList: Idx, NameStr: "gep_" + inputs[i]->getName(), InsertBefore: newFuncRoot);
1635 LoadInst *LoadGEP =
1636 new LoadInst(StructArgTy->getElementType(N: aggIdx), GEP,
1637 "loadgep_" + inputs[i]->getName(), newFuncRoot);
1638 // If we load pointer, we can add optional !align metadata
1639 // The existence of the !align metadata on the instruction tells
1640 // the optimizer that the value loaded is known to be aligned to
1641 // a boundary specified by the integer value in the metadata node.
1642 // Example:
1643 // %res = load ptr, ptr %input, align 8, !align !align_md_node
1644 // ^ ^
1645 // | |
1646 // alignment of %input address |
1647 // |
1648 // alignment of %res object
1649 if (StructArgTy->getElementType(N: aggIdx)->isPointerTy()) {
1650 unsigned AlignmentValue;
1651 const Triple &TargetTriple =
1652 newFunction->getParent()->getTargetTriple();
1653 const DataLayout &DL = header->getDataLayout();
1654 // Pointers without casting can provide more information about
1655 // alignment. Use pointers without casts if given target preserves
1656 // alignment information for cast the operation.
1657 if (isAlignmentPreservedForAddrCast(TargetTriple))
1658 AlignmentValue =
1659 inputs[i]->stripPointerCasts()->getPointerAlignment(DL).value();
1660 else
1661 AlignmentValue = inputs[i]->getPointerAlignment(DL).value();
1662 MDBuilder MDB(header->getContext());
1663 LoadGEP->setMetadata(
1664 KindID: LLVMContext::MD_align,
1665 Node: MDNode::get(
1666 Context&: header->getContext(),
1667 MDs: MDB.createConstant(C: ConstantInt::get(
1668 Ty: Type::getInt64Ty(C&: header->getContext()), V: AlignmentValue))));
1669 }
1670 RewriteVal = LoadGEP;
1671 ++aggIdx;
1672 } else
1673 RewriteVal = &*ScalarAI++;
1674
1675 NewValues.push_back(Elt: RewriteVal);
1676 }
1677
1678 moveCodeToFunction(newFunction);
1679
1680 for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
1681 Value *RewriteVal = NewValues[i];
1682
1683 std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
1684 for (User *use : Users)
1685 if (Instruction *inst = dyn_cast<Instruction>(Val: use))
1686 if (Blocks.count(key: inst->getParent()))
1687 inst->replaceUsesOfWith(From: inputs[i], To: RewriteVal);
1688 }
1689
1690 // Since there may be multiple exits from the original region, make the new
1691 // function return an unsigned, switch on that number. This loop iterates
1692 // over all of the blocks in the extracted region, updating any terminator
1693 // instructions in the to-be-extracted region that branch to blocks that are
1694 // not in the region to be extracted.
1695 std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
1696
1697 // Iterate over the previously collected targets, and create new blocks inside
1698 // the function to branch to.
1699 for (auto P : enumerate(First&: ExtractedFuncRetVals)) {
1700 BasicBlock *OldTarget = P.value();
1701 size_t SuccNum = P.index();
1702
1703 BasicBlock *NewTarget = BasicBlock::Create(
1704 Context, Name: OldTarget->getName() + ".exitStub", Parent: newFunction);
1705 ExitBlockMap[OldTarget] = NewTarget;
1706
1707 Value *brVal = nullptr;
1708 Type *RetTy = getSwitchType();
1709 assert(ExtractedFuncRetVals.size() < 0xffff &&
1710 "too many exit blocks for switch");
1711 switch (ExtractedFuncRetVals.size()) {
1712 case 0:
1713 case 1:
1714 // No value needed.
1715 break;
1716 case 2: // Conditional branch, return a bool
1717 brVal = ConstantInt::get(Ty: RetTy, V: !SuccNum);
1718 break;
1719 default:
1720 brVal = ConstantInt::get(Ty: RetTy, V: SuccNum);
1721 break;
1722 }
1723
1724 ReturnInst::Create(C&: Context, retVal: brVal, InsertBefore: NewTarget);
1725 }
1726
1727 for (BasicBlock *Block : Blocks) {
1728 Instruction *TI = Block->getTerminator();
1729 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
1730 if (Blocks.count(key: TI->getSuccessor(Idx: i)))
1731 continue;
1732 BasicBlock *OldTarget = TI->getSuccessor(Idx: i);
1733 // add a new basic block which returns the appropriate value
1734 BasicBlock *NewTarget = ExitBlockMap[OldTarget];
1735 assert(NewTarget && "Unknown target block!");
1736
1737 // rewrite the original branch instruction with this new target
1738 TI->setSuccessor(Idx: i, BB: NewTarget);
1739 }
1740 }
1741
1742 // Loop over all of the PHI nodes in the header and exit blocks, and change
1743 // any references to the old incoming edge to be the new incoming edge.
1744 for (BasicBlock::iterator I = header->begin(); isa<PHINode>(Val: I); ++I) {
1745 PHINode *PN = cast<PHINode>(Val&: I);
1746 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
1747 if (!Blocks.count(key: PN->getIncomingBlock(i)))
1748 PN->setIncomingBlock(i, BB: newFuncRoot);
1749 }
1750
1751 // Connect newFunction entry block to new header.
1752 BranchInst *BranchI = BranchInst::Create(IfTrue: header, InsertBefore: newFuncRoot);
1753 applyFirstDebugLoc(oldFunction, Blocks: Blocks.getArrayRef(), BranchI);
1754
1755 // Store the arguments right after the definition of output value.
1756 // This should be proceeded after creating exit stubs to be ensure that invoke
1757 // result restore will be placed in the outlined function.
1758 ScalarAI = newFunction->arg_begin();
1759 unsigned AggIdx = 0;
1760
1761 for (Value *Input : inputs) {
1762 if (StructValues.contains(key: Input))
1763 ++AggIdx;
1764 else
1765 ++ScalarAI;
1766 }
1767
1768 for (Value *Output : outputs) {
1769 // Find proper insertion point.
1770 // In case Output is an invoke, we insert the store at the beginning in the
1771 // 'normal destination' BB. Otherwise we insert the store right after
1772 // Output.
1773 BasicBlock::iterator InsertPt;
1774 if (auto *InvokeI = dyn_cast<InvokeInst>(Val: Output))
1775 InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
1776 else if (auto *Phi = dyn_cast<PHINode>(Val: Output))
1777 InsertPt = Phi->getParent()->getFirstInsertionPt();
1778 else if (auto *OutI = dyn_cast<Instruction>(Val: Output))
1779 InsertPt = std::next(x: OutI->getIterator());
1780 else {
1781 // Globals don't need to be updated, just advance to the next argument.
1782 if (StructValues.contains(key: Output))
1783 ++AggIdx;
1784 else
1785 ++ScalarAI;
1786 continue;
1787 }
1788
1789 assert((InsertPt->getFunction() == newFunction ||
1790 Blocks.count(InsertPt->getParent())) &&
1791 "InsertPt should be in new function");
1792
1793 if (StructValues.contains(key: Output)) {
1794 assert(AggArg && "Number of aggregate output arguments should match "
1795 "the number of defined values");
1796 Value *Idx[2];
1797 Idx[0] = Constant::getNullValue(Ty: Type::getInt32Ty(C&: Context));
1798 Idx[1] = ConstantInt::get(Ty: Type::getInt32Ty(C&: Context), V: AggIdx);
1799 GetElementPtrInst *GEP = GetElementPtrInst::Create(
1800 PointeeType: StructArgTy, Ptr: AggArg, IdxList: Idx, NameStr: "gep_" + Output->getName(), InsertBefore: InsertPt);
1801 new StoreInst(Output, GEP, InsertPt);
1802 ++AggIdx;
1803 } else {
1804 assert(ScalarAI != newFunction->arg_end() &&
1805 "Number of scalar output arguments should match "
1806 "the number of defined values");
1807 new StoreInst(Output, &*ScalarAI, InsertPt);
1808 ++ScalarAI;
1809 }
1810 }
1811
1812 if (ExtractedFuncRetVals.empty()) {
1813 // Mark the new function `noreturn` if applicable. Terminators which resume
1814 // exception propagation are treated as returning instructions. This is to
1815 // avoid inserting traps after calls to outlined functions which unwind.
1816 if (none_of(Range&: Blocks, P: [](const BasicBlock *BB) {
1817 const Instruction *Term = BB->getTerminator();
1818 return isa<ReturnInst>(Val: Term) || isa<ResumeInst>(Val: Term);
1819 }))
1820 newFunction->setDoesNotReturn();
1821 }
1822}
1823
1824CallInst *CodeExtractor::emitReplacerCall(
1825 const ValueSet &inputs, const ValueSet &outputs,
1826 const ValueSet &StructValues, Function *newFunction,
1827 StructType *StructArgTy, Function *oldFunction, BasicBlock *ReplIP,
1828 BlockFrequency EntryFreq, ArrayRef<Value *> LifetimesStart,
1829 std::vector<Value *> &Reloads) {
1830 LLVMContext &Context = oldFunction->getContext();
1831 Module *M = oldFunction->getParent();
1832 const DataLayout &DL = M->getDataLayout();
1833
1834 // This takes place of the original loop
1835 BasicBlock *codeReplacer =
1836 BasicBlock::Create(Context, Name: "codeRepl", Parent: oldFunction, InsertBefore: ReplIP);
1837 if (AllocationBlock)
1838 assert(AllocationBlock->getParent() == oldFunction &&
1839 "AllocationBlock is not in the same function");
1840 BasicBlock *AllocaBlock =
1841 AllocationBlock ? AllocationBlock : &oldFunction->getEntryBlock();
1842
1843 // Update the entry count of the function.
1844 if (BFI)
1845 BFI->setBlockFreq(BB: codeReplacer, Freq: EntryFreq);
1846
1847 std::vector<Value *> params;
1848
1849 // Add inputs as params, or to be filled into the struct
1850 for (Value *input : inputs) {
1851 if (StructValues.contains(key: input))
1852 continue;
1853
1854 params.push_back(x: input);
1855 }
1856
1857 // Create allocas for the outputs
1858 std::vector<Value *> ReloadOutputs;
1859 for (Value *output : outputs) {
1860 if (StructValues.contains(key: output))
1861 continue;
1862
1863 AllocaInst *alloca = new AllocaInst(
1864 output->getType(), DL.getAllocaAddrSpace(), nullptr,
1865 output->getName() + ".loc", AllocaBlock->getFirstInsertionPt());
1866 params.push_back(x: alloca);
1867 ReloadOutputs.push_back(x: alloca);
1868 }
1869
1870 AllocaInst *Struct = nullptr;
1871 if (!StructValues.empty()) {
1872 Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
1873 "structArg", AllocaBlock->getFirstInsertionPt());
1874 if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
1875 auto *StructSpaceCast = new AddrSpaceCastInst(
1876 Struct, PointerType ::get(C&: Context, AddressSpace: 0), "structArg.ascast");
1877 StructSpaceCast->insertAfter(InsertPos: Struct->getIterator());
1878 params.push_back(x: StructSpaceCast);
1879 } else {
1880 params.push_back(x: Struct);
1881 }
1882
1883 unsigned AggIdx = 0;
1884 for (Value *input : inputs) {
1885 if (!StructValues.contains(key: input))
1886 continue;
1887
1888 Value *Idx[2];
1889 Idx[0] = Constant::getNullValue(Ty: Type::getInt32Ty(C&: Context));
1890 Idx[1] = ConstantInt::get(Ty: Type::getInt32Ty(C&: Context), V: AggIdx);
1891 GetElementPtrInst *GEP = GetElementPtrInst::Create(
1892 PointeeType: StructArgTy, Ptr: Struct, IdxList: Idx, NameStr: "gep_" + input->getName());
1893 GEP->insertInto(ParentBB: codeReplacer, It: codeReplacer->end());
1894 new StoreInst(input, GEP, codeReplacer);
1895
1896 ++AggIdx;
1897 }
1898 }
1899
1900 // Emit the call to the function
1901 CallInst *call = CallInst::Create(
1902 Func: newFunction, Args: params, NameStr: ExtractedFuncRetVals.size() > 1 ? "targetBlock" : "",
1903 InsertBefore: codeReplacer);
1904
1905 // Set swifterror parameter attributes.
1906 unsigned ParamIdx = 0;
1907 unsigned AggIdx = 0;
1908 for (auto input : inputs) {
1909 if (StructValues.contains(key: input)) {
1910 ++AggIdx;
1911 } else {
1912 if (input->isSwiftError())
1913 call->addParamAttr(ArgNo: ParamIdx, Kind: Attribute::SwiftError);
1914 ++ParamIdx;
1915 }
1916 }
1917
1918 // Add debug location to the new call, if the original function has debug
1919 // info. In that case, the terminator of the entry block of the extracted
1920 // function contains the first debug location of the extracted function,
1921 // set in extractCodeRegion.
1922 if (codeReplacer->getParent()->getSubprogram()) {
1923 if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc())
1924 call->setDebugLoc(DL);
1925 }
1926
1927 // Reload the outputs passed in by reference, use the struct if output is in
1928 // the aggregate or reload from the scalar argument.
1929 for (unsigned i = 0, e = outputs.size(), scalarIdx = 0; i != e; ++i) {
1930 Value *Output = nullptr;
1931 if (StructValues.contains(key: outputs[i])) {
1932 Value *Idx[2];
1933 Idx[0] = Constant::getNullValue(Ty: Type::getInt32Ty(C&: Context));
1934 Idx[1] = ConstantInt::get(Ty: Type::getInt32Ty(C&: Context), V: AggIdx);
1935 GetElementPtrInst *GEP = GetElementPtrInst::Create(
1936 PointeeType: StructArgTy, Ptr: Struct, IdxList: Idx, NameStr: "gep_reload_" + outputs[i]->getName());
1937 GEP->insertInto(ParentBB: codeReplacer, It: codeReplacer->end());
1938 Output = GEP;
1939 ++AggIdx;
1940 } else {
1941 Output = ReloadOutputs[scalarIdx];
1942 ++scalarIdx;
1943 }
1944 LoadInst *load =
1945 new LoadInst(outputs[i]->getType(), Output,
1946 outputs[i]->getName() + ".reload", codeReplacer);
1947 Reloads.push_back(x: load);
1948 }
1949
1950 // Now we can emit a switch statement using the call as a value.
1951 SwitchInst *TheSwitch =
1952 SwitchInst::Create(Value: Constant::getNullValue(Ty: Type::getInt16Ty(C&: Context)),
1953 Default: codeReplacer, NumCases: 0, InsertBefore: codeReplacer);
1954 for (auto P : enumerate(First&: ExtractedFuncRetVals)) {
1955 BasicBlock *OldTarget = P.value();
1956 size_t SuccNum = P.index();
1957
1958 TheSwitch->addCase(OnVal: ConstantInt::get(Ty: Type::getInt16Ty(C&: Context), V: SuccNum),
1959 Dest: OldTarget);
1960 }
1961
1962 // Now that we've done the deed, simplify the switch instruction.
1963 Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
1964 switch (ExtractedFuncRetVals.size()) {
1965 case 0:
1966 // There are no successors (the block containing the switch itself), which
1967 // means that previously this was the last part of the function, and hence
1968 // this should be rewritten as a `ret` or `unreachable`.
1969 if (newFunction->doesNotReturn()) {
1970 // If fn is no return, end with an unreachable terminator.
1971 (void)new UnreachableInst(Context, TheSwitch->getIterator());
1972 } else if (OldFnRetTy->isVoidTy()) {
1973 // We have no return value.
1974 ReturnInst::Create(C&: Context, retVal: nullptr,
1975 InsertBefore: TheSwitch->getIterator()); // Return void
1976 } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
1977 // return what we have
1978 ReturnInst::Create(C&: Context, retVal: TheSwitch->getCondition(),
1979 InsertBefore: TheSwitch->getIterator());
1980 } else {
1981 // Otherwise we must have code extracted an unwind or something, just
1982 // return whatever we want.
1983 ReturnInst::Create(C&: Context, retVal: Constant::getNullValue(Ty: OldFnRetTy),
1984 InsertBefore: TheSwitch->getIterator());
1985 }
1986
1987 TheSwitch->eraseFromParent();
1988 break;
1989 case 1:
1990 // Only a single destination, change the switch into an unconditional
1991 // branch.
1992 BranchInst::Create(IfTrue: TheSwitch->getSuccessor(idx: 1), InsertBefore: TheSwitch->getIterator());
1993 TheSwitch->eraseFromParent();
1994 break;
1995 case 2:
1996 // Only two destinations, convert to a condition branch.
1997 // Remark: This also swaps the target branches:
1998 // 0 -> false -> getSuccessor(2); 1 -> true -> getSuccessor(1)
1999 BranchInst::Create(IfTrue: TheSwitch->getSuccessor(idx: 1), IfFalse: TheSwitch->getSuccessor(idx: 2),
2000 Cond: call, InsertBefore: TheSwitch->getIterator());
2001 TheSwitch->eraseFromParent();
2002 break;
2003 default:
2004 // Otherwise, make the default destination of the switch instruction be one
2005 // of the other successors.
2006 TheSwitch->setCondition(call);
2007 TheSwitch->setDefaultDest(
2008 TheSwitch->getSuccessor(idx: ExtractedFuncRetVals.size()));
2009 // Remove redundant case
2010 TheSwitch->removeCase(
2011 I: SwitchInst::CaseIt(TheSwitch, ExtractedFuncRetVals.size() - 1));
2012 break;
2013 }
2014
2015 // Insert lifetime markers around the reloads of any output values. The
2016 // allocas output values are stored in are only in-use in the codeRepl block.
2017 insertLifetimeMarkersSurroundingCall(M, LifetimesStart: ReloadOutputs, LifetimesEnd: ReloadOutputs, TheCall: call);
2018
2019 // Replicate the effects of any lifetime start/end markers which referenced
2020 // input objects in the extraction region by placing markers around the call.
2021 insertLifetimeMarkersSurroundingCall(M: oldFunction->getParent(), LifetimesStart,
2022 LifetimesEnd: {}, TheCall: call);
2023
2024 return call;
2025}
2026
2027void CodeExtractor::insertReplacerCall(
2028 Function *oldFunction, BasicBlock *header, BasicBlock *codeReplacer,
2029 const ValueSet &outputs, ArrayRef<Value *> Reloads,
2030 const DenseMap<BasicBlock *, BlockFrequency> &ExitWeights) {
2031
2032 // Rewrite branches to basic blocks outside of the loop to new dummy blocks
2033 // within the new function. This must be done before we lose track of which
2034 // blocks were originally in the code region.
2035 std::vector<User *> Users(header->user_begin(), header->user_end());
2036 for (auto &U : Users)
2037 // The BasicBlock which contains the branch is not in the region
2038 // modify the branch target to a new block
2039 if (Instruction *I = dyn_cast<Instruction>(Val: U))
2040 if (I->isTerminator() && I->getFunction() == oldFunction &&
2041 !Blocks.count(key: I->getParent()))
2042 I->replaceUsesOfWith(From: header, To: codeReplacer);
2043
2044 // When moving the code region it is sufficient to replace all uses to the
2045 // extracted function values. Since the original definition's block
2046 // dominated its use, it will also be dominated by codeReplacer's switch
2047 // which joined multiple exit blocks.
2048 for (BasicBlock *ExitBB : ExtractedFuncRetVals)
2049 for (PHINode &PN : ExitBB->phis()) {
2050 Value *IncomingCodeReplacerVal = nullptr;
2051 for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) {
2052 // Ignore incoming values from outside of the extracted region.
2053 if (!Blocks.count(key: PN.getIncomingBlock(i)))
2054 continue;
2055
2056 // Ensure that there is only one incoming value from codeReplacer.
2057 if (!IncomingCodeReplacerVal) {
2058 PN.setIncomingBlock(i, BB: codeReplacer);
2059 IncomingCodeReplacerVal = PN.getIncomingValue(i);
2060 } else
2061 assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) &&
2062 "PHI has two incompatbile incoming values from codeRepl");
2063 }
2064 }
2065
2066 for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
2067 Value *load = Reloads[i];
2068 std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end());
2069 for (User *U : Users) {
2070 Instruction *inst = cast<Instruction>(Val: U);
2071 if (inst->getParent()->getParent() == oldFunction)
2072 inst->replaceUsesOfWith(From: outputs[i], To: load);
2073 }
2074 }
2075
2076 // Update the branch weights for the exit block.
2077 if (BFI && ExtractedFuncRetVals.size() > 1)
2078 calculateNewCallTerminatorWeights(CodeReplacer: codeReplacer, ExitWeights, BPI);
2079}
2080
2081bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
2082 const Function &NewFunc,
2083 AssumptionCache *AC) {
2084 for (auto AssumeVH : AC->assumptions()) {
2085 auto *I = dyn_cast_or_null<CallInst>(Val&: AssumeVH);
2086 if (!I)
2087 continue;
2088
2089 // There shouldn't be any llvm.assume intrinsics in the new function.
2090 if (I->getFunction() != &OldFunc)
2091 return true;
2092
2093 // There shouldn't be any stale affected values in the assumption cache
2094 // that were previously in the old function, but that have now been moved
2095 // to the new function.
2096 for (auto AffectedValVH : AC->assumptionsFor(V: I->getOperand(i_nocapture: 0))) {
2097 auto *AffectedCI = dyn_cast_or_null<CallInst>(Val&: AffectedValVH);
2098 if (!AffectedCI)
2099 continue;
2100 if (AffectedCI->getFunction() != &OldFunc)
2101 return true;
2102 auto *AssumedInst = cast<Instruction>(Val: AffectedCI->getOperand(i_nocapture: 0));
2103 if (AssumedInst->getFunction() != &OldFunc)
2104 return true;
2105 }
2106 }
2107 return false;
2108}
2109
2110void CodeExtractor::excludeArgFromAggregate(Value *Arg) {
2111 ExcludeArgsFromAggregate.insert(X: Arg);
2112}
2113