1 | //===- JumpTableToSwitch.cpp ----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/Transforms/Scalar/JumpTableToSwitch.h" |
10 | #include "llvm/ADT/SmallVector.h" |
11 | #include "llvm/Analysis/ConstantFolding.h" |
12 | #include "llvm/Analysis/DomTreeUpdater.h" |
13 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
14 | #include "llvm/Analysis/PostDominators.h" |
15 | #include "llvm/IR/IRBuilder.h" |
16 | #include "llvm/Support/CommandLine.h" |
17 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
18 | |
19 | using namespace llvm; |
20 | |
21 | static cl::opt<unsigned> |
22 | JumpTableSizeThreshold("jump-table-to-switch-size-threshold" , cl::Hidden, |
23 | cl::desc("Only split jump tables with size less or " |
24 | "equal than JumpTableSizeThreshold." ), |
25 | cl::init(Val: 10)); |
26 | |
27 | // TODO: Consider adding a cost model for profitability analysis of this |
28 | // transformation. Currently we replace a jump table with a switch if all the |
29 | // functions in the jump table are smaller than the provided threshold. |
30 | static cl::opt<unsigned> FunctionSizeThreshold( |
31 | "jump-table-to-switch-function-size-threshold" , cl::Hidden, |
32 | cl::desc("Only split jump tables containing functions whose sizes are less " |
33 | "or equal than this threshold." ), |
34 | cl::init(Val: 50)); |
35 | |
36 | #define DEBUG_TYPE "jump-table-to-switch" |
37 | |
38 | namespace { |
39 | struct JumpTableTy { |
40 | Value *Index; |
41 | SmallVector<Function *, 10> Funcs; |
42 | }; |
43 | } // anonymous namespace |
44 | |
45 | static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP, |
46 | PointerType *PtrTy) { |
47 | Constant *Ptr = dyn_cast<Constant>(Val: GEP->getPointerOperand()); |
48 | if (!Ptr) |
49 | return std::nullopt; |
50 | |
51 | GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: Ptr); |
52 | if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) |
53 | return std::nullopt; |
54 | |
55 | Function &F = *GEP->getParent()->getParent(); |
56 | const DataLayout &DL = F.getDataLayout(); |
57 | const unsigned BitWidth = |
58 | DL.getIndexSizeInBits(AS: GEP->getPointerAddressSpace()); |
59 | MapVector<Value *, APInt> VariableOffsets; |
60 | APInt ConstantOffset(BitWidth, 0); |
61 | if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) |
62 | return std::nullopt; |
63 | if (VariableOffsets.size() != 1) |
64 | return std::nullopt; |
65 | // TODO: consider supporting more general patterns |
66 | if (!ConstantOffset.isZero()) |
67 | return std::nullopt; |
68 | APInt StrideBytes = VariableOffsets.front().second; |
69 | const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(Ty: GV->getValueType()); |
70 | if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0) |
71 | return std::nullopt; |
72 | const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue(); |
73 | if (N > JumpTableSizeThreshold) |
74 | return std::nullopt; |
75 | |
76 | JumpTableTy JumpTable; |
77 | JumpTable.Index = VariableOffsets.front().first; |
78 | JumpTable.Funcs.reserve(N); |
79 | for (uint64_t Index = 0; Index < N; ++Index) { |
80 | // ConstantOffset is zero. |
81 | APInt Offset = Index * StrideBytes; |
82 | Constant *C = |
83 | ConstantFoldLoadFromConst(C: GV->getInitializer(), Ty: PtrTy, Offset, DL); |
84 | auto *Func = dyn_cast_or_null<Function>(Val: C); |
85 | if (!Func || Func->isDeclaration() || |
86 | Func->getInstructionCount() > FunctionSizeThreshold) |
87 | return std::nullopt; |
88 | JumpTable.Funcs.push_back(Elt: Func); |
89 | } |
90 | return JumpTable; |
91 | } |
92 | |
93 | static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT, |
94 | DomTreeUpdater &DTU, |
95 | OptimizationRemarkEmitter &ORE) { |
96 | const bool IsVoid = CB->getType() == Type::getVoidTy(C&: CB->getContext()); |
97 | |
98 | SmallVector<DominatorTree::UpdateType, 8> DTUpdates; |
99 | BasicBlock *BB = CB->getParent(); |
100 | BasicBlock *Tail = SplitBlock(Old: BB, SplitPt: CB, DTU: &DTU, LI: nullptr, MSSAU: nullptr, |
101 | BBName: BB->getName() + Twine(".tail" )); |
102 | DTUpdates.push_back(Elt: {DominatorTree::Delete, BB, Tail}); |
103 | BB->getTerminator()->eraseFromParent(); |
104 | |
105 | Function &F = *BB->getParent(); |
106 | BasicBlock *BBUnreachable = BasicBlock::Create( |
107 | Context&: F.getContext(), Name: "default.switch.case.unreachable" , Parent: &F, InsertBefore: Tail); |
108 | IRBuilder<> BuilderUnreachable(BBUnreachable); |
109 | BuilderUnreachable.CreateUnreachable(); |
110 | |
111 | IRBuilder<> Builder(BB); |
112 | SwitchInst *Switch = Builder.CreateSwitch(V: JT.Index, Dest: BBUnreachable); |
113 | DTUpdates.push_back(Elt: {DominatorTree::Insert, BB, BBUnreachable}); |
114 | |
115 | IRBuilder<> BuilderTail(CB); |
116 | PHINode *PHI = |
117 | IsVoid ? nullptr : BuilderTail.CreatePHI(Ty: CB->getType(), NumReservedValues: JT.Funcs.size()); |
118 | |
119 | for (auto [Index, Func] : llvm::enumerate(First: JT.Funcs)) { |
120 | BasicBlock *B = BasicBlock::Create(Context&: Func->getContext(), |
121 | Name: "call." + Twine(Index), Parent: &F, InsertBefore: Tail); |
122 | DTUpdates.push_back(Elt: {DominatorTree::Insert, BB, B}); |
123 | DTUpdates.push_back(Elt: {DominatorTree::Insert, B, Tail}); |
124 | |
125 | CallBase *Call = cast<CallBase>(Val: CB->clone()); |
126 | Call->setCalledFunction(Func); |
127 | Call->insertInto(ParentBB: B, It: B->end()); |
128 | Switch->addCase( |
129 | OnVal: cast<ConstantInt>(Val: ConstantInt::get(Ty: JT.Index->getType(), V: Index)), Dest: B); |
130 | BranchInst::Create(IfTrue: Tail, InsertBefore: B); |
131 | if (PHI) |
132 | PHI->addIncoming(V: Call, BB: B); |
133 | } |
134 | DTU.applyUpdates(Updates: DTUpdates); |
135 | ORE.emit(RemarkBuilder: [&]() { |
136 | return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch" , CB) |
137 | << "expanded indirect call into switch" ; |
138 | }); |
139 | if (PHI) |
140 | CB->replaceAllUsesWith(V: PHI); |
141 | CB->eraseFromParent(); |
142 | return Tail; |
143 | } |
144 | |
145 | PreservedAnalyses JumpTableToSwitchPass::run(Function &F, |
146 | FunctionAnalysisManager &AM) { |
147 | OptimizationRemarkEmitter &ORE = |
148 | AM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: F); |
149 | DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(IR&: F); |
150 | PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(IR&: F); |
151 | DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy); |
152 | bool Changed = false; |
153 | for (BasicBlock &BB : make_early_inc_range(Range&: F)) { |
154 | BasicBlock *CurrentBB = &BB; |
155 | while (CurrentBB) { |
156 | BasicBlock *SplittedOutTail = nullptr; |
157 | for (Instruction &I : make_early_inc_range(Range&: *CurrentBB)) { |
158 | auto *Call = dyn_cast<CallInst>(Val: &I); |
159 | if (!Call || Call->getCalledFunction() || Call->isMustTailCall()) |
160 | continue; |
161 | auto *L = dyn_cast<LoadInst>(Val: Call->getCalledOperand()); |
162 | // Skip atomic or volatile loads. |
163 | if (!L || !L->isSimple()) |
164 | continue; |
165 | auto *GEP = dyn_cast<GetElementPtrInst>(Val: L->getPointerOperand()); |
166 | if (!GEP) |
167 | continue; |
168 | auto *PtrTy = dyn_cast<PointerType>(Val: L->getType()); |
169 | assert(PtrTy && "call operand must be a pointer" ); |
170 | std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy); |
171 | if (!JumpTable) |
172 | continue; |
173 | SplittedOutTail = expandToSwitch(CB: Call, JT: *JumpTable, DTU, ORE); |
174 | Changed = true; |
175 | break; |
176 | } |
177 | CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr; |
178 | } |
179 | } |
180 | |
181 | if (!Changed) |
182 | return PreservedAnalyses::all(); |
183 | |
184 | PreservedAnalyses PA; |
185 | if (DT) |
186 | PA.preserve<DominatorTreeAnalysis>(); |
187 | if (PDT) |
188 | PA.preserve<PostDominatorTreeAnalysis>(); |
189 | return PA; |
190 | } |
191 | |