1//===- JumpTableToSwitch.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
10#include "llvm/ADT/DenseSet.h"
11#include "llvm/ADT/STLExtras.h"
12#include "llvm/ADT/SmallVector.h"
13#include "llvm/ADT/Statistic.h"
14#include "llvm/Analysis/ConstantFolding.h"
15#include "llvm/Analysis/CtxProfAnalysis.h"
16#include "llvm/Analysis/DomTreeUpdater.h"
17#include "llvm/Analysis/OptimizationRemarkEmitter.h"
18#include "llvm/Analysis/PostDominators.h"
19#include "llvm/IR/IRBuilder.h"
20#include "llvm/IR/LLVMContext.h"
21#include "llvm/IR/ProfDataUtils.h"
22#include "llvm/ProfileData/InstrProf.h"
23#include "llvm/Support/CommandLine.h"
24#include "llvm/Support/Error.h"
25#include "llvm/Transforms/Utils/BasicBlockUtils.h"
26#include <limits>
27
28using namespace llvm;
29
30static cl::opt<unsigned>
31 JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
32 cl::desc("Only split jump tables with size less or "
33 "equal than JumpTableSizeThreshold."),
34 cl::init(Val: 10));
35
36// TODO: Consider adding a cost model for profitability analysis of this
37// transformation. Currently we replace a jump table with a switch if all the
38// functions in the jump table are smaller than the provided threshold.
39static cl::opt<unsigned> FunctionSizeThreshold(
40 "jump-table-to-switch-function-size-threshold", cl::Hidden,
41 cl::desc("Only split jump tables containing functions whose sizes are less "
42 "or equal than this threshold."),
43 cl::init(Val: 50));
44
45namespace llvm {
46extern cl::opt<bool> ProfcheckDisableMetadataFixes;
47} // end namespace llvm
48
49#define DEBUG_TYPE "jump-table-to-switch"
50
51STATISTIC(NumEligibleJumpTables, "The number of jump tables seen by the pass "
52 "that can be converted if deemed profitable.");
53STATISTIC(NumJumpTablesConverted,
54 "The number of jump tables converted into switches.");
55
56namespace {
57struct JumpTableTy {
58 Value *Index;
59 SmallVector<Function *, 10> Funcs;
60};
61} // anonymous namespace
62
63static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
64 PointerType *PtrTy) {
65 Constant *Ptr = dyn_cast<Constant>(Val: GEP->getPointerOperand());
66 if (!Ptr)
67 return std::nullopt;
68
69 GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: Ptr);
70 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
71 return std::nullopt;
72
73 Function &F = *GEP->getParent()->getParent();
74 const DataLayout &DL = F.getDataLayout();
75 const unsigned BitWidth =
76 DL.getIndexSizeInBits(AS: GEP->getPointerAddressSpace());
77 SmallMapVector<Value *, APInt, 4> VariableOffsets;
78 APInt ConstantOffset(BitWidth, 0);
79 if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
80 return std::nullopt;
81 if (VariableOffsets.size() != 1)
82 return std::nullopt;
83 // TODO: consider supporting more general patterns
84 if (!ConstantOffset.isZero())
85 return std::nullopt;
86 APInt StrideBytes = VariableOffsets.front().second;
87 const uint64_t JumpTableSizeBytes = GV->getGlobalSize(DL);
88 if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
89 return std::nullopt;
90 ++NumEligibleJumpTables;
91 const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
92 if (N > JumpTableSizeThreshold)
93 return std::nullopt;
94
95 JumpTableTy JumpTable;
96 JumpTable.Index = VariableOffsets.front().first;
97 JumpTable.Funcs.reserve(N);
98 for (uint64_t Index = 0; Index < N; ++Index) {
99 // ConstantOffset is zero.
100 APInt Offset = Index * StrideBytes;
101 Constant *C =
102 ConstantFoldLoadFromConst(C: GV->getInitializer(), Ty: PtrTy, Offset, DL);
103 auto *Func = dyn_cast_or_null<Function>(Val: C);
104 if (!Func || Func->isDeclaration() ||
105 Func->getInstructionCount() > FunctionSizeThreshold)
106 return std::nullopt;
107 JumpTable.Funcs.push_back(Elt: Func);
108 }
109 return JumpTable;
110}
111
112static BasicBlock *
113expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
114 OptimizationRemarkEmitter &ORE,
115 llvm::function_ref<GlobalValue::GUID(const Function &)>
116 GetGuidForFunction) {
117 ++NumJumpTablesConverted;
118 const bool IsVoid = CB->getType() == Type::getVoidTy(C&: CB->getContext());
119
120 SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
121 BasicBlock *BB = CB->getParent();
122 BasicBlock *Tail = SplitBlock(Old: BB, SplitPt: CB, DTU: &DTU, LI: nullptr, MSSAU: nullptr,
123 BBName: BB->getName() + Twine(".tail"));
124 DTUpdates.push_back(Elt: {DominatorTree::Delete, BB, Tail});
125 BB->getTerminator()->eraseFromParent();
126
127 Function &F = *BB->getParent();
128 BasicBlock *BBUnreachable = BasicBlock::Create(
129 Context&: F.getContext(), Name: "default.switch.case.unreachable", Parent: &F, InsertBefore: Tail);
130 IRBuilder<> BuilderUnreachable(BBUnreachable);
131 BuilderUnreachable.CreateUnreachable();
132
133 IRBuilder<> Builder(BB);
134 SwitchInst *Switch = Builder.CreateSwitch(V: JT.Index, Dest: BBUnreachable);
135 DTUpdates.push_back(Elt: {DominatorTree::Insert, BB, BBUnreachable});
136
137 IRBuilder<> BuilderTail(CB);
138 PHINode *PHI =
139 IsVoid ? nullptr : BuilderTail.CreatePHI(Ty: CB->getType(), NumReservedValues: JT.Funcs.size());
140 const auto *ProfMD = CB->getMetadata(KindID: LLVMContext::MD_prof);
141
142 SmallVector<uint64_t> BranchWeights;
143 DenseMap<GlobalValue::GUID, uint64_t> GuidToCounter;
144 const bool HadProfile = isValueProfileMD(ProfileData: ProfMD);
145 if (HadProfile) {
146 // The assumptions, coming in, are that the functions in JT.Funcs are
147 // defined in this module (from parseJumpTable).
148 assert(llvm::all_of(
149 JT.Funcs, [](const Function *F) { return F && !F->isDeclaration(); }));
150 BranchWeights.reserve(N: JT.Funcs.size() + 1);
151 // The first is the default target, which is the unreachable block created
152 // above.
153 BranchWeights.push_back(Elt: 0U);
154 uint64_t TotalCount = 0;
155 auto Targets = getValueProfDataFromInst(
156 Inst: *CB, ValueKind: InstrProfValueKind::IPVK_IndirectCallTarget,
157 MaxNumValueData: std::numeric_limits<uint32_t>::max(), TotalC&: TotalCount);
158
159 for (const auto &[G, C] : Targets) {
160 [[maybe_unused]] auto It = GuidToCounter.insert(KV: {G, C});
161 assert(It.second);
162 }
163 }
164 for (auto [Index, Func] : llvm::enumerate(First: JT.Funcs)) {
165 BasicBlock *B = BasicBlock::Create(Context&: Func->getContext(),
166 Name: "call." + Twine(Index), Parent: &F, InsertBefore: Tail);
167 DTUpdates.push_back(Elt: {DominatorTree::Insert, BB, B});
168 DTUpdates.push_back(Elt: {DominatorTree::Insert, B, Tail});
169
170 CallBase *Call = cast<CallBase>(Val: CB->clone());
171 // The MD_prof metadata (VP kind), if it existed, can be dropped, it doesn't
172 // make sense on a direct call. Note that the values are used for the branch
173 // weights of the switch.
174 Call->setMetadata(KindID: LLVMContext::MD_prof, Node: nullptr);
175 Call->setCalledFunction(Func);
176 Call->insertInto(ParentBB: B, It: B->end());
177 Switch->addCase(
178 OnVal: cast<ConstantInt>(Val: ConstantInt::get(Ty: JT.Index->getType(), V: Index)), Dest: B);
179 GlobalValue::GUID FctID = GetGuidForFunction(*Func);
180 // It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose
181 // just some of the jump targets are taken (for the given profile).
182 BranchWeights.push_back(Elt: FctID == 0U ? 0U
183 : GuidToCounter.lookup_or(Val: FctID, Default: 0U));
184 UncondBrInst::Create(IfTrue: Tail, InsertBefore: B);
185 if (PHI)
186 PHI->addIncoming(V: Call, BB: B);
187 }
188 DTU.applyUpdates(Updates: DTUpdates);
189 ORE.emit(RemarkBuilder: [&]() {
190 return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
191 << "expanded indirect call into switch";
192 });
193 if (HadProfile && !ProfcheckDisableMetadataFixes) {
194 // At least one of the targets must've been taken.
195 assert(llvm::any_of(BranchWeights, not_equal_to(0)));
196 setBranchWeights(I&: *Switch, Weights: downscaleWeights(Weights: BranchWeights),
197 /*IsExpected=*/false);
198 } else
199 setExplicitlyUnknownBranchWeights(I&: *Switch, DEBUG_TYPE);
200 if (PHI)
201 CB->replaceAllUsesWith(V: PHI);
202 CB->eraseFromParent();
203 return Tail;
204}
205
206PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
207 FunctionAnalysisManager &AM) {
208 OptimizationRemarkEmitter &ORE =
209 AM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: F);
210 DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(IR&: F);
211 PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(IR&: F);
212 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
213 bool Changed = false;
214 auto FuncToGuid = [&](const Function &Fct) {
215 if (Fct.getMetadata(Kind: AssignGUIDPass::GUIDMetadataName))
216 return AssignGUIDPass::getGUID(F: Fct);
217
218 return Function::getGUIDAssumingExternalLinkage(GlobalName: getIRPGOFuncName(F, InLTO));
219 };
220
221 for (BasicBlock &BB : make_early_inc_range(Range&: F)) {
222 BasicBlock *CurrentBB = &BB;
223 while (CurrentBB) {
224 BasicBlock *SplittedOutTail = nullptr;
225 for (Instruction &I : make_early_inc_range(Range&: *CurrentBB)) {
226 auto *Call = dyn_cast<CallInst>(Val: &I);
227 if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
228 continue;
229 auto *L = dyn_cast<LoadInst>(Val: Call->getCalledOperand());
230 // Skip atomic or volatile loads.
231 if (!L || !L->isSimple())
232 continue;
233 auto *GEP = dyn_cast<GetElementPtrInst>(Val: L->getPointerOperand());
234 if (!GEP)
235 continue;
236 auto *PtrTy = dyn_cast<PointerType>(Val: L->getType());
237 assert(PtrTy && "call operand must be a pointer");
238 std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
239 if (!JumpTable)
240 continue;
241 SplittedOutTail =
242 expandToSwitch(CB: Call, JT: *JumpTable, DTU, ORE, GetGuidForFunction: FuncToGuid);
243 Changed = true;
244 break;
245 }
246 CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
247 }
248 }
249
250 if (!Changed)
251 return PreservedAnalyses::all();
252
253 PreservedAnalyses PA;
254 if (DT)
255 PA.preserve<DominatorTreeAnalysis>();
256 if (PDT)
257 PA.preserve<PostDominatorTreeAnalysis>();
258 return PA;
259}
260