1 | //===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | |
10 | #include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h" |
11 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
12 | #include "llvm/IR/Analysis.h" |
13 | #include "llvm/IR/DiagnosticInfo.h" |
14 | #include "llvm/IR/IRBuilder.h" |
15 | #include "llvm/IR/Instructions.h" |
16 | #include "llvm/IR/IntrinsicInst.h" |
17 | #include "llvm/IR/Module.h" |
18 | #include "llvm/IR/PassManager.h" |
19 | #include "llvm/Support/CommandLine.h" |
20 | #include <utility> |
21 | |
22 | using namespace llvm; |
23 | |
24 | #define DEBUG_TYPE "ctx-instr-lower" |
25 | |
26 | static cl::list<std::string> ContextRoots( |
27 | "profile-context-root" , cl::Hidden, |
28 | cl::desc( |
29 | "A function name, assumed to be global, which will be treated as the " |
30 | "root of an interesting graph, which will be profiled independently " |
31 | "from other similar graphs." )); |
32 | |
33 | bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() { |
34 | return !ContextRoots.empty(); |
35 | } |
36 | |
37 | // the names of symbols we expect in compiler-rt. Using a namespace for |
38 | // readability. |
39 | namespace CompilerRtAPINames { |
40 | static auto StartCtx = "__llvm_ctx_profile_start_context" ; |
41 | static auto ReleaseCtx = "__llvm_ctx_profile_release_context" ; |
42 | static auto GetCtx = "__llvm_ctx_profile_get_context" ; |
43 | static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee" ; |
44 | static auto CallsiteTLS = "__llvm_ctx_profile_callsite" ; |
45 | } // namespace CompilerRtAPINames |
46 | |
47 | namespace { |
48 | // The lowering logic and state. |
49 | class CtxInstrumentationLowerer final { |
50 | Module &M; |
51 | ModuleAnalysisManager &MAM; |
52 | Type *ContextNodeTy = nullptr; |
53 | Type *ContextRootTy = nullptr; |
54 | |
55 | DenseMap<const Function *, Constant *> ContextRootMap; |
56 | Function *StartCtx = nullptr; |
57 | Function *GetCtx = nullptr; |
58 | Function *ReleaseCtx = nullptr; |
59 | GlobalVariable *ExpectedCalleeTLS = nullptr; |
60 | GlobalVariable *CallsiteInfoTLS = nullptr; |
61 | |
62 | public: |
63 | CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM); |
64 | // return true if lowering happened (i.e. a change was made) |
65 | bool lowerFunction(Function &F); |
66 | }; |
67 | |
68 | // llvm.instrprof.increment[.step] captures the total number of counters as one |
69 | // of its parameters, and llvm.instrprof.callsite captures the total number of |
70 | // callsites. Those values are the same for instances of those intrinsics in |
71 | // this function. Find the first instance of each and return them. |
72 | std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) { |
73 | uint32_t NrCounters = 0; |
74 | uint32_t NrCallsites = 0; |
75 | for (const auto &BB : F) { |
76 | for (const auto &I : BB) { |
77 | if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(Val: &I)) { |
78 | uint32_t V = |
79 | static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue()); |
80 | assert((!NrCounters || V == NrCounters) && |
81 | "expected all llvm.instrprof.increment[.step] intrinsics to " |
82 | "have the same total nr of counters parameter" ); |
83 | NrCounters = V; |
84 | } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(Val: &I)) { |
85 | uint32_t V = |
86 | static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue()); |
87 | assert((!NrCallsites || V == NrCallsites) && |
88 | "expected all llvm.instrprof.callsite intrinsics to have the " |
89 | "same total nr of callsites parameter" ); |
90 | NrCallsites = V; |
91 | } |
92 | #if NDEBUG |
93 | if (NrCounters && NrCallsites) |
94 | return std::make_pair(x&: NrCounters, y&: NrCallsites); |
95 | #endif |
96 | } |
97 | } |
98 | return {NrCounters, NrCallsites}; |
99 | } |
100 | } // namespace |
101 | |
102 | // set up tie-in with compiler-rt. |
103 | // NOTE!!! |
104 | // These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h |
105 | CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M, |
106 | ModuleAnalysisManager &MAM) |
107 | : M(M), MAM(MAM) { |
108 | auto *PointerTy = PointerType::get(C&: M.getContext(), AddressSpace: 0); |
109 | auto *SanitizerMutexType = Type::getInt8Ty(C&: M.getContext()); |
110 | auto *I32Ty = Type::getInt32Ty(C&: M.getContext()); |
111 | auto *I64Ty = Type::getInt64Ty(C&: M.getContext()); |
112 | |
113 | // The ContextRoot type |
114 | ContextRootTy = |
115 | StructType::get(Context&: M.getContext(), Elements: { |
116 | PointerTy, /*FirstNode*/ |
117 | PointerTy, /*FirstMemBlock*/ |
118 | PointerTy, /*CurrentMem*/ |
119 | SanitizerMutexType, /*Taken*/ |
120 | }); |
121 | // The Context header. |
122 | ContextNodeTy = StructType::get(Context&: M.getContext(), Elements: { |
123 | I64Ty, /*Guid*/ |
124 | PointerTy, /*Next*/ |
125 | I32Ty, /*NrCounters*/ |
126 | I32Ty, /*NrCallsites*/ |
127 | }); |
128 | |
129 | // Define a global for each entrypoint. We'll reuse the entrypoint's name as |
130 | // prefix. We assume the entrypoint names to be unique. |
131 | for (const auto &Fname : ContextRoots) { |
132 | if (const auto *F = M.getFunction(Name: Fname)) { |
133 | if (F->isDeclaration()) |
134 | continue; |
135 | auto *G = M.getOrInsertGlobal(Name: Fname + "_ctx_root" , Ty: ContextRootTy); |
136 | cast<GlobalVariable>(Val: G)->setInitializer( |
137 | Constant::getNullValue(Ty: ContextRootTy)); |
138 | ContextRootMap.insert(KV: std::make_pair(x&: F, y&: G)); |
139 | for (const auto &BB : *F) |
140 | for (const auto &I : BB) |
141 | if (const auto *CB = dyn_cast<CallBase>(Val: &I)) |
142 | if (CB->isMustTailCall()) { |
143 | M.getContext().emitError( |
144 | ErrorStr: "The function " + Fname + |
145 | " was indicated as a context root, but it features musttail " |
146 | "calls, which is not supported." ); |
147 | } |
148 | } |
149 | } |
150 | |
151 | // Declare the functions we will call. |
152 | StartCtx = cast<Function>( |
153 | Val: M.getOrInsertFunction( |
154 | Name: CompilerRtAPINames::StartCtx, |
155 | T: FunctionType::get(Result: ContextNodeTy->getPointerTo(), |
156 | Params: {ContextRootTy->getPointerTo(), /*ContextRoot*/ |
157 | I64Ty, /*Guid*/ I32Ty, |
158 | /*NrCounters*/ I32Ty /*NrCallsites*/}, |
159 | isVarArg: false)) |
160 | .getCallee()); |
161 | GetCtx = cast<Function>( |
162 | Val: M.getOrInsertFunction(Name: CompilerRtAPINames::GetCtx, |
163 | T: FunctionType::get(Result: ContextNodeTy->getPointerTo(), |
164 | Params: {PointerTy, /*Callee*/ |
165 | I64Ty, /*Guid*/ |
166 | I32Ty, /*NrCounters*/ |
167 | I32Ty}, /*NrCallsites*/ |
168 | isVarArg: false)) |
169 | .getCallee()); |
170 | ReleaseCtx = cast<Function>( |
171 | Val: M.getOrInsertFunction( |
172 | Name: CompilerRtAPINames::ReleaseCtx, |
173 | T: FunctionType::get(Result: Type::getVoidTy(C&: M.getContext()), |
174 | Params: { |
175 | ContextRootTy->getPointerTo(), /*ContextRoot*/ |
176 | }, |
177 | isVarArg: false)) |
178 | .getCallee()); |
179 | |
180 | // Declare the TLSes we will need to use. |
181 | CallsiteInfoTLS = |
182 | new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage, |
183 | nullptr, CompilerRtAPINames::CallsiteTLS); |
184 | CallsiteInfoTLS->setThreadLocal(true); |
185 | CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility); |
186 | ExpectedCalleeTLS = |
187 | new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage, |
188 | nullptr, CompilerRtAPINames::ExpectedCalleeTLS); |
189 | ExpectedCalleeTLS->setThreadLocal(true); |
190 | ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility); |
191 | } |
192 | |
193 | PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M, |
194 | ModuleAnalysisManager &MAM) { |
195 | CtxInstrumentationLowerer Lowerer(M, MAM); |
196 | bool Changed = false; |
197 | for (auto &F : M) |
198 | Changed |= Lowerer.lowerFunction(F); |
199 | return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); |
200 | } |
201 | |
202 | bool CtxInstrumentationLowerer::lowerFunction(Function &F) { |
203 | if (F.isDeclaration()) |
204 | return false; |
205 | auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(IR&: M).getManager(); |
206 | auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: F); |
207 | |
208 | Value *Guid = nullptr; |
209 | auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F); |
210 | |
211 | Value *Context = nullptr; |
212 | Value *RealContext = nullptr; |
213 | |
214 | StructType *ThisContextType = nullptr; |
215 | Value *TheRootContext = nullptr; |
216 | Value *ExpectedCalleeTLSAddr = nullptr; |
217 | Value *CallsiteInfoTLSAddr = nullptr; |
218 | |
219 | auto &Head = F.getEntryBlock(); |
220 | for (auto &I : Head) { |
221 | // Find the increment intrinsic in the entry basic block. |
222 | if (auto *Mark = dyn_cast<InstrProfIncrementInst>(Val: &I)) { |
223 | assert(Mark->getIndex()->isZero()); |
224 | |
225 | IRBuilder<> Builder(Mark); |
226 | // FIXME(mtrofin): use InstrProfSymtab::getCanonicalName |
227 | Guid = Builder.getInt64(C: F.getGUID()); |
228 | // The type of the context of this function is now knowable since we have |
229 | // NrCallsites and NrCounters. We delcare it here because it's more |
230 | // convenient - we have the Builder. |
231 | ThisContextType = StructType::get( |
232 | Context&: F.getContext(), |
233 | Elements: {ContextNodeTy, ArrayType::get(ElementType: Builder.getInt64Ty(), NumElements: NrCounters), |
234 | ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NrCallsites)}); |
235 | // Figure out which way we obtain the context object for this function - |
236 | // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the |
237 | // former case, we also set TheRootContext since we need to release it |
238 | // at the end (plus it can be used to know if we have an entrypoint or a |
239 | // regular function) |
240 | auto Iter = ContextRootMap.find(Val: &F); |
241 | if (Iter != ContextRootMap.end()) { |
242 | TheRootContext = Iter->second; |
243 | Context = Builder.CreateCall(Callee: StartCtx, Args: {TheRootContext, Guid, |
244 | Builder.getInt32(C: NrCounters), |
245 | Builder.getInt32(C: NrCallsites)}); |
246 | ORE.emit( |
247 | RemarkBuilder: [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint" , &F); }); |
248 | } else { |
249 | Context = |
250 | Builder.CreateCall(Callee: GetCtx, Args: {&F, Guid, Builder.getInt32(C: NrCounters), |
251 | Builder.getInt32(C: NrCallsites)}); |
252 | ORE.emit(RemarkBuilder: [&] { |
253 | return OptimizationRemark(DEBUG_TYPE, "RegularFunction" , &F); |
254 | }); |
255 | } |
256 | // The context could be scratch. |
257 | auto *CtxAsInt = Builder.CreatePtrToInt(V: Context, DestTy: Builder.getInt64Ty()); |
258 | if (NrCallsites > 0) { |
259 | // Figure out which index of the TLS 2-element buffers to use. |
260 | // Scratch context => we use index == 1. Real contexts => index == 0. |
261 | auto *Index = Builder.CreateAnd(LHS: CtxAsInt, RHS: Builder.getInt64(C: 1)); |
262 | // The GEPs corresponding to that index, in the respective TLS. |
263 | ExpectedCalleeTLSAddr = Builder.CreateGEP( |
264 | Ty: Builder.getInt8Ty()->getPointerTo(), |
265 | Ptr: Builder.CreateThreadLocalAddress(Ptr: ExpectedCalleeTLS), IdxList: {Index}); |
266 | CallsiteInfoTLSAddr = Builder.CreateGEP( |
267 | Ty: Builder.getInt32Ty(), |
268 | Ptr: Builder.CreateThreadLocalAddress(Ptr: CallsiteInfoTLS), IdxList: {Index}); |
269 | } |
270 | // Because the context pointer may have LSB set (to indicate scratch), |
271 | // clear it for the value we use as base address for the counter vector. |
272 | // This way, if later we want to have "real" (not clobbered) buffers |
273 | // acting as scratch, the lowering (at least this part of it that deals |
274 | // with counters) stays the same. |
275 | RealContext = Builder.CreateIntToPtr( |
276 | V: Builder.CreateAnd(LHS: CtxAsInt, RHS: Builder.getInt64(C: -2)), |
277 | DestTy: ThisContextType->getPointerTo()); |
278 | I.eraseFromParent(); |
279 | break; |
280 | } |
281 | } |
282 | if (!Context) { |
283 | ORE.emit(RemarkBuilder: [&] { |
284 | return OptimizationRemarkMissed(DEBUG_TYPE, "Skip" , &F) |
285 | << "Function doesn't have instrumentation, skipping" ; |
286 | }); |
287 | return false; |
288 | } |
289 | |
290 | bool ContextWasReleased = false; |
291 | for (auto &BB : F) { |
292 | for (auto &I : llvm::make_early_inc_range(Range&: BB)) { |
293 | if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(Val: &I)) { |
294 | IRBuilder<> Builder(Instr); |
295 | switch (Instr->getIntrinsicID()) { |
296 | case llvm::Intrinsic::instrprof_increment: |
297 | case llvm::Intrinsic::instrprof_increment_step: { |
298 | // Increments (or increment-steps) are just a typical load - increment |
299 | // - store in the RealContext. |
300 | auto *AsStep = cast<InstrProfIncrementInst>(Val: Instr); |
301 | auto *GEP = Builder.CreateGEP( |
302 | Ty: ThisContextType, Ptr: RealContext, |
303 | IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1), AsStep->getIndex()}); |
304 | Builder.CreateStore( |
305 | Val: Builder.CreateAdd(LHS: Builder.CreateLoad(Ty: Builder.getInt64Ty(), Ptr: GEP), |
306 | RHS: AsStep->getStep()), |
307 | Ptr: GEP); |
308 | } break; |
309 | case llvm::Intrinsic::instrprof_callsite: |
310 | // callsite lowering: write the called value in the expected callee |
311 | // TLS we treat the TLS as volatile because of signal handlers and to |
312 | // avoid these being moved away from the callsite they decorate. |
313 | auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Val: Instr); |
314 | Builder.CreateStore(Val: CSIntrinsic->getCallee(), Ptr: ExpectedCalleeTLSAddr, |
315 | isVolatile: true); |
316 | // write the GEP of the slot in the sub-contexts portion of the |
317 | // context in TLS. Now, here, we use the actual Context value - as |
318 | // returned from compiler-rt - which may have the LSB set if the |
319 | // Context was scratch. Since the header of the context object and |
320 | // then the values are all 8-aligned (or, really, insofar as we care, |
321 | // they are even) - if the context is scratch (meaning, an odd value), |
322 | // so will the GEP. This is important because this is then visible to |
323 | // compiler-rt which will produce scratch contexts for callers that |
324 | // have a scratch context. |
325 | Builder.CreateStore( |
326 | Val: Builder.CreateGEP(Ty: ThisContextType, Ptr: Context, |
327 | IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 2), |
328 | CSIntrinsic->getIndex()}), |
329 | Ptr: CallsiteInfoTLSAddr, isVolatile: true); |
330 | break; |
331 | } |
332 | I.eraseFromParent(); |
333 | } else if (TheRootContext && isa<ReturnInst>(Val: I)) { |
334 | // Remember to release the context if we are an entrypoint. |
335 | IRBuilder<> Builder(&I); |
336 | Builder.CreateCall(Callee: ReleaseCtx, Args: {TheRootContext}); |
337 | ContextWasReleased = true; |
338 | } |
339 | } |
340 | } |
341 | // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be |
342 | // to disallow this, (so this then stays as an error), another is to detect |
343 | // that and then do a wrapper or disallow the tail call. This only affects |
344 | // instrumentation, when we want to detect the call graph. |
345 | if (TheRootContext && !ContextWasReleased) |
346 | F.getContext().emitError( |
347 | ErrorStr: "[ctx_prof] An entrypoint was instrumented but it has no `ret` " |
348 | "instructions above which to release the context: " + |
349 | F.getName()); |
350 | return true; |
351 | } |
352 | |