1//===- JumpTableToSwitch.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
10#include "llvm/ADT/SmallVector.h"
11#include "llvm/Analysis/ConstantFolding.h"
12#include "llvm/Analysis/DomTreeUpdater.h"
13#include "llvm/Analysis/OptimizationRemarkEmitter.h"
14#include "llvm/Analysis/PostDominators.h"
15#include "llvm/IR/IRBuilder.h"
16#include "llvm/Support/CommandLine.h"
17#include "llvm/Transforms/Utils/BasicBlockUtils.h"
18
19using namespace llvm;
20
21static cl::opt<unsigned>
22 JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
23 cl::desc("Only split jump tables with size less or "
24 "equal than JumpTableSizeThreshold."),
25 cl::init(Val: 10));
26
27// TODO: Consider adding a cost model for profitability analysis of this
28// transformation. Currently we replace a jump table with a switch if all the
29// functions in the jump table are smaller than the provided threshold.
30static cl::opt<unsigned> FunctionSizeThreshold(
31 "jump-table-to-switch-function-size-threshold", cl::Hidden,
32 cl::desc("Only split jump tables containing functions whose sizes are less "
33 "or equal than this threshold."),
34 cl::init(Val: 50));
35
36#define DEBUG_TYPE "jump-table-to-switch"
37
38namespace {
39struct JumpTableTy {
40 Value *Index;
41 SmallVector<Function *, 10> Funcs;
42};
43} // anonymous namespace
44
45static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
46 PointerType *PtrTy) {
47 Constant *Ptr = dyn_cast<Constant>(Val: GEP->getPointerOperand());
48 if (!Ptr)
49 return std::nullopt;
50
51 GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: Ptr);
52 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
53 return std::nullopt;
54
55 Function &F = *GEP->getParent()->getParent();
56 const DataLayout &DL = F.getDataLayout();
57 const unsigned BitWidth =
58 DL.getIndexSizeInBits(AS: GEP->getPointerAddressSpace());
59 MapVector<Value *, APInt> VariableOffsets;
60 APInt ConstantOffset(BitWidth, 0);
61 if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
62 return std::nullopt;
63 if (VariableOffsets.size() != 1)
64 return std::nullopt;
65 // TODO: consider supporting more general patterns
66 if (!ConstantOffset.isZero())
67 return std::nullopt;
68 APInt StrideBytes = VariableOffsets.front().second;
69 const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(Ty: GV->getValueType());
70 if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
71 return std::nullopt;
72 const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
73 if (N > JumpTableSizeThreshold)
74 return std::nullopt;
75
76 JumpTableTy JumpTable;
77 JumpTable.Index = VariableOffsets.front().first;
78 JumpTable.Funcs.reserve(N);
79 for (uint64_t Index = 0; Index < N; ++Index) {
80 // ConstantOffset is zero.
81 APInt Offset = Index * StrideBytes;
82 Constant *C =
83 ConstantFoldLoadFromConst(C: GV->getInitializer(), Ty: PtrTy, Offset, DL);
84 auto *Func = dyn_cast_or_null<Function>(Val: C);
85 if (!Func || Func->isDeclaration() ||
86 Func->getInstructionCount() > FunctionSizeThreshold)
87 return std::nullopt;
88 JumpTable.Funcs.push_back(Elt: Func);
89 }
90 return JumpTable;
91}
92
93static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
94 DomTreeUpdater &DTU,
95 OptimizationRemarkEmitter &ORE) {
96 const bool IsVoid = CB->getType() == Type::getVoidTy(C&: CB->getContext());
97
98 SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
99 BasicBlock *BB = CB->getParent();
100 BasicBlock *Tail = SplitBlock(Old: BB, SplitPt: CB, DTU: &DTU, LI: nullptr, MSSAU: nullptr,
101 BBName: BB->getName() + Twine(".tail"));
102 DTUpdates.push_back(Elt: {DominatorTree::Delete, BB, Tail});
103 BB->getTerminator()->eraseFromParent();
104
105 Function &F = *BB->getParent();
106 BasicBlock *BBUnreachable = BasicBlock::Create(
107 Context&: F.getContext(), Name: "default.switch.case.unreachable", Parent: &F, InsertBefore: Tail);
108 IRBuilder<> BuilderUnreachable(BBUnreachable);
109 BuilderUnreachable.CreateUnreachable();
110
111 IRBuilder<> Builder(BB);
112 SwitchInst *Switch = Builder.CreateSwitch(V: JT.Index, Dest: BBUnreachable);
113 DTUpdates.push_back(Elt: {DominatorTree::Insert, BB, BBUnreachable});
114
115 IRBuilder<> BuilderTail(CB);
116 PHINode *PHI =
117 IsVoid ? nullptr : BuilderTail.CreatePHI(Ty: CB->getType(), NumReservedValues: JT.Funcs.size());
118
119 for (auto [Index, Func] : llvm::enumerate(First: JT.Funcs)) {
120 BasicBlock *B = BasicBlock::Create(Context&: Func->getContext(),
121 Name: "call." + Twine(Index), Parent: &F, InsertBefore: Tail);
122 DTUpdates.push_back(Elt: {DominatorTree::Insert, BB, B});
123 DTUpdates.push_back(Elt: {DominatorTree::Insert, B, Tail});
124
125 CallBase *Call = cast<CallBase>(Val: CB->clone());
126 Call->setCalledFunction(Func);
127 Call->insertInto(ParentBB: B, It: B->end());
128 Switch->addCase(
129 OnVal: cast<ConstantInt>(Val: ConstantInt::get(Ty: JT.Index->getType(), V: Index)), Dest: B);
130 BranchInst::Create(IfTrue: Tail, InsertBefore: B);
131 if (PHI)
132 PHI->addIncoming(V: Call, BB: B);
133 }
134 DTU.applyUpdates(Updates: DTUpdates);
135 ORE.emit(RemarkBuilder: [&]() {
136 return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
137 << "expanded indirect call into switch";
138 });
139 if (PHI)
140 CB->replaceAllUsesWith(V: PHI);
141 CB->eraseFromParent();
142 return Tail;
143}
144
145PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
146 FunctionAnalysisManager &AM) {
147 OptimizationRemarkEmitter &ORE =
148 AM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: F);
149 DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(IR&: F);
150 PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(IR&: F);
151 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152 bool Changed = false;
153 for (BasicBlock &BB : make_early_inc_range(Range&: F)) {
154 BasicBlock *CurrentBB = &BB;
155 while (CurrentBB) {
156 BasicBlock *SplittedOutTail = nullptr;
157 for (Instruction &I : make_early_inc_range(Range&: *CurrentBB)) {
158 auto *Call = dyn_cast<CallInst>(Val: &I);
159 if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
160 continue;
161 auto *L = dyn_cast<LoadInst>(Val: Call->getCalledOperand());
162 // Skip atomic or volatile loads.
163 if (!L || !L->isSimple())
164 continue;
165 auto *GEP = dyn_cast<GetElementPtrInst>(Val: L->getPointerOperand());
166 if (!GEP)
167 continue;
168 auto *PtrTy = dyn_cast<PointerType>(Val: L->getType());
169 assert(PtrTy && "call operand must be a pointer");
170 std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
171 if (!JumpTable)
172 continue;
173 SplittedOutTail = expandToSwitch(CB: Call, JT: *JumpTable, DTU, ORE);
174 Changed = true;
175 break;
176 }
177 CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
178 }
179 }
180
181 if (!Changed)
182 return PreservedAnalyses::all();
183
184 PreservedAnalyses PA;
185 if (DT)
186 PA.preserve<DominatorTreeAnalysis>();
187 if (PDT)
188 PA.preserve<PostDominatorTreeAnalysis>();
189 return PA;
190}
191