1 | //===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | |
10 | #include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h" |
11 | #include "llvm/ADT/STLExtras.h" |
12 | #include "llvm/Analysis/CFG.h" |
13 | #include "llvm/Analysis/CtxProfAnalysis.h" |
14 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
15 | #include "llvm/IR/Analysis.h" |
16 | #include "llvm/IR/Constants.h" |
17 | #include "llvm/IR/DiagnosticInfo.h" |
18 | #include "llvm/IR/GlobalValue.h" |
19 | #include "llvm/IR/IRBuilder.h" |
20 | #include "llvm/IR/InstrTypes.h" |
21 | #include "llvm/IR/Instructions.h" |
22 | #include "llvm/IR/IntrinsicInst.h" |
23 | #include "llvm/IR/Module.h" |
24 | #include "llvm/IR/PassManager.h" |
25 | #include "llvm/ProfileData/CtxInstrContextNode.h" |
26 | #include "llvm/ProfileData/InstrProf.h" |
27 | #include "llvm/Support/CommandLine.h" |
28 | #include <utility> |
29 | |
30 | using namespace llvm; |
31 | |
32 | #define DEBUG_TYPE "ctx-instr-lower" |
33 | |
34 | static cl::list<std::string> ContextRoots( |
35 | "profile-context-root" , cl::Hidden, |
36 | cl::desc( |
37 | "A function name, assumed to be global, which will be treated as the " |
38 | "root of an interesting graph, which will be profiled independently " |
39 | "from other similar graphs." )); |
40 | |
41 | bool PGOCtxProfLoweringPass::isCtxIRPGOInstrEnabled() { |
42 | return !ContextRoots.empty(); |
43 | } |
44 | |
45 | // the names of symbols we expect in compiler-rt. Using a namespace for |
46 | // readability. |
47 | namespace CompilerRtAPINames { |
48 | static auto StartCtx = "__llvm_ctx_profile_start_context" ; |
49 | static auto ReleaseCtx = "__llvm_ctx_profile_release_context" ; |
50 | static auto GetCtx = "__llvm_ctx_profile_get_context" ; |
51 | static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee" ; |
52 | static auto CallsiteTLS = "__llvm_ctx_profile_callsite" ; |
53 | } // namespace CompilerRtAPINames |
54 | |
55 | namespace { |
56 | // The lowering logic and state. |
57 | class CtxInstrumentationLowerer final { |
58 | Module &M; |
59 | ModuleAnalysisManager &MAM; |
60 | Type *ContextNodeTy = nullptr; |
61 | StructType *FunctionDataTy = nullptr; |
62 | |
63 | DenseSet<const Function *> ContextRootSet; |
64 | Function *StartCtx = nullptr; |
65 | Function *GetCtx = nullptr; |
66 | Function *ReleaseCtx = nullptr; |
67 | GlobalVariable *ExpectedCalleeTLS = nullptr; |
68 | GlobalVariable *CallsiteInfoTLS = nullptr; |
69 | Constant *CannotBeRootInitializer = nullptr; |
70 | |
71 | public: |
72 | CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM); |
73 | // return true if lowering happened (i.e. a change was made) |
74 | bool lowerFunction(Function &F); |
75 | }; |
76 | |
77 | // llvm.instrprof.increment[.step] captures the total number of counters as one |
78 | // of its parameters, and llvm.instrprof.callsite captures the total number of |
79 | // callsites. Those values are the same for instances of those intrinsics in |
80 | // this function. Find the first instance of each and return them. |
81 | std::pair<uint32_t, uint32_t> getNumCountersAndCallsites(const Function &F) { |
82 | uint32_t NumCounters = 0; |
83 | uint32_t NumCallsites = 0; |
84 | for (const auto &BB : F) { |
85 | for (const auto &I : BB) { |
86 | if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(Val: &I)) { |
87 | uint32_t V = |
88 | static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue()); |
89 | assert((!NumCounters || V == NumCounters) && |
90 | "expected all llvm.instrprof.increment[.step] intrinsics to " |
91 | "have the same total nr of counters parameter" ); |
92 | NumCounters = V; |
93 | } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(Val: &I)) { |
94 | uint32_t V = |
95 | static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue()); |
96 | assert((!NumCallsites || V == NumCallsites) && |
97 | "expected all llvm.instrprof.callsite intrinsics to have the " |
98 | "same total nr of callsites parameter" ); |
99 | NumCallsites = V; |
100 | } |
101 | #if NDEBUG |
102 | if (NumCounters && NumCallsites) |
103 | return std::make_pair(x&: NumCounters, y&: NumCallsites); |
104 | #endif |
105 | } |
106 | } |
107 | return {NumCounters, NumCallsites}; |
108 | } |
109 | |
110 | void emitUnsupportedRootError(const Function &F, StringRef Reason) { |
111 | F.getContext().emitError(ErrorStr: "[ctxprof] The function " + F.getName() + |
112 | " was indicated as context root but " + Reason + |
113 | ", which is not supported." ); |
114 | } |
115 | } // namespace |
116 | |
117 | // set up tie-in with compiler-rt. |
118 | // NOTE!!! |
119 | // These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h |
120 | CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M, |
121 | ModuleAnalysisManager &MAM) |
122 | : M(M), MAM(MAM) { |
123 | auto *PointerTy = PointerType::get(C&: M.getContext(), AddressSpace: 0); |
124 | auto *SanitizerMutexType = Type::getInt8Ty(C&: M.getContext()); |
125 | auto *I32Ty = Type::getInt32Ty(C&: M.getContext()); |
126 | auto *I64Ty = Type::getInt64Ty(C&: M.getContext()); |
127 | |
128 | #define _PTRDECL(_, __) PointerTy, |
129 | #define _VOLATILE_PTRDECL(_, __) PointerTy, |
130 | #define _CONTEXT_ROOT PointerTy, |
131 | #define _MUTEXDECL(_) SanitizerMutexType, |
132 | |
133 | FunctionDataTy = StructType::get( |
134 | Context&: M.getContext(), Elements: {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT, |
135 | _VOLATILE_PTRDECL, _MUTEXDECL)}); |
136 | #undef _PTRDECL |
137 | #undef _CONTEXT_ROOT |
138 | #undef _VOLATILE_PTRDECL |
139 | #undef _MUTEXDECL |
140 | |
141 | #define _PTRDECL(_, __) Constant::getNullValue(PointerTy), |
142 | #define _VOLATILE_PTRDECL(_, __) _PTRDECL(_, __) |
143 | #define _MUTEXDECL(_) Constant::getNullValue(SanitizerMutexType), |
144 | #define _CONTEXT_ROOT \ |
145 | Constant::getIntegerValue( \ |
146 | PointerTy, \ |
147 | APInt(M.getDataLayout().getPointerTypeSizeInBits(PointerTy), 1U)), |
148 | CannotBeRootInitializer = ConstantStruct::get( |
149 | T: FunctionDataTy, V: {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT, |
150 | _VOLATILE_PTRDECL, _MUTEXDECL)}); |
151 | #undef _PTRDECL |
152 | #undef _CONTEXT_ROOT |
153 | #undef _VOLATILE_PTRDECL |
154 | #undef _MUTEXDECL |
155 | |
156 | // The Context header. |
157 | ContextNodeTy = StructType::get(Context&: M.getContext(), Elements: { |
158 | I64Ty, /*Guid*/ |
159 | PointerTy, /*Next*/ |
160 | I32Ty, /*NumCounters*/ |
161 | I32Ty, /*NumCallsites*/ |
162 | }); |
163 | |
164 | // Define a global for each entrypoint. We'll reuse the entrypoint's name |
165 | // as prefix. We assume the entrypoint names to be unique. |
166 | for (const auto &Fname : ContextRoots) { |
167 | if (const auto *F = M.getFunction(Name: Fname)) { |
168 | if (F->isDeclaration()) |
169 | continue; |
170 | ContextRootSet.insert(V: F); |
171 | for (const auto &BB : *F) |
172 | for (const auto &I : BB) |
173 | if (const auto *CB = dyn_cast<CallBase>(Val: &I)) |
174 | if (CB->isMustTailCall()) |
175 | emitUnsupportedRootError(F: *F, Reason: "it features musttail calls" ); |
176 | } |
177 | } |
178 | |
179 | // Declare the functions we will call. |
180 | StartCtx = cast<Function>( |
181 | Val: M.getOrInsertFunction( |
182 | Name: CompilerRtAPINames::StartCtx, |
183 | T: FunctionType::get(Result: PointerTy, |
184 | Params: {PointerTy, /*FunctionData*/ |
185 | I64Ty, /*Guid*/ I32Ty, |
186 | /*NumCounters*/ I32Ty /*NumCallsites*/}, |
187 | isVarArg: false)) |
188 | .getCallee()); |
189 | GetCtx = cast<Function>( |
190 | Val: M.getOrInsertFunction(Name: CompilerRtAPINames::GetCtx, |
191 | T: FunctionType::get(Result: PointerTy, |
192 | Params: {PointerTy, /*FunctionData*/ |
193 | PointerTy, /*Callee*/ |
194 | I64Ty, /*Guid*/ |
195 | I32Ty, /*NumCounters*/ |
196 | I32Ty}, /*NumCallsites*/ |
197 | isVarArg: false)) |
198 | .getCallee()); |
199 | ReleaseCtx = cast<Function>( |
200 | Val: M.getOrInsertFunction(Name: CompilerRtAPINames::ReleaseCtx, |
201 | T: FunctionType::get(Result: Type::getVoidTy(C&: M.getContext()), |
202 | Params: { |
203 | PointerTy, /*FunctionData*/ |
204 | }, |
205 | isVarArg: false)) |
206 | .getCallee()); |
207 | |
208 | // Declare the TLSes we will need to use. |
209 | CallsiteInfoTLS = |
210 | new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage, |
211 | nullptr, CompilerRtAPINames::CallsiteTLS); |
212 | CallsiteInfoTLS->setThreadLocal(true); |
213 | CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility); |
214 | ExpectedCalleeTLS = |
215 | new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage, |
216 | nullptr, CompilerRtAPINames::ExpectedCalleeTLS); |
217 | ExpectedCalleeTLS->setThreadLocal(true); |
218 | ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility); |
219 | } |
220 | |
221 | PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M, |
222 | ModuleAnalysisManager &MAM) { |
223 | CtxInstrumentationLowerer Lowerer(M, MAM); |
224 | bool Changed = false; |
225 | for (auto &F : M) |
226 | Changed |= Lowerer.lowerFunction(F); |
227 | return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); |
228 | } |
229 | |
230 | bool CtxInstrumentationLowerer::lowerFunction(Function &F) { |
231 | if (F.isDeclaration()) |
232 | return false; |
233 | |
234 | // Probably pointless to try to do anything here, unlikely to be |
235 | // performance-affecting. |
236 | if (!llvm::canReturn(F)) { |
237 | for (auto &BB : F) |
238 | for (auto &I : make_early_inc_range(Range&: BB)) |
239 | if (isa<InstrProfCntrInstBase>(Val: &I)) |
240 | I.eraseFromParent(); |
241 | if (ContextRootSet.contains(V: &F)) |
242 | emitUnsupportedRootError(F, Reason: "it does not return" ); |
243 | return true; |
244 | } |
245 | |
246 | auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(IR&: M).getManager(); |
247 | auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: F); |
248 | |
249 | Value *Guid = nullptr; |
250 | auto [NumCounters, NumCallsites] = getNumCountersAndCallsites(F); |
251 | |
252 | Value *Context = nullptr; |
253 | Value *RealContext = nullptr; |
254 | |
255 | StructType *ThisContextType = nullptr; |
256 | Value *TheRootFuctionData = nullptr; |
257 | Value *ExpectedCalleeTLSAddr = nullptr; |
258 | Value *CallsiteInfoTLSAddr = nullptr; |
259 | const bool HasMusttail = [&F]() { |
260 | for (auto &BB : F) |
261 | for (auto &I : BB) |
262 | if (auto *CB = dyn_cast<CallBase>(Val: &I)) |
263 | if (CB->isMustTailCall()) |
264 | return true; |
265 | return false; |
266 | }(); |
267 | |
268 | if (HasMusttail && ContextRootSet.contains(V: &F)) { |
269 | F.getContext().emitError( |
270 | ErrorStr: "[ctx_prof] A function with musttail calls was explicitly requested as " |
271 | "root. That is not supported because we cannot instrument a return " |
272 | "instruction to release the context: " + |
273 | F.getName()); |
274 | return false; |
275 | } |
276 | auto &Head = F.getEntryBlock(); |
277 | for (auto &I : Head) { |
278 | // Find the increment intrinsic in the entry basic block. |
279 | if (auto *Mark = dyn_cast<InstrProfIncrementInst>(Val: &I)) { |
280 | assert(Mark->getIndex()->isZero()); |
281 | |
282 | IRBuilder<> Builder(Mark); |
283 | Guid = Builder.getInt64( |
284 | C: AssignGUIDPass::getGUID(F: cast<Function>(Val&: *Mark->getNameValue()))); |
285 | // The type of the context of this function is now knowable since we have |
286 | // NumCallsites and NumCounters. We delcare it here because it's more |
287 | // convenient - we have the Builder. |
288 | ThisContextType = StructType::get( |
289 | Context&: F.getContext(), |
290 | Elements: {ContextNodeTy, ArrayType::get(ElementType: Builder.getInt64Ty(), NumElements: NumCounters), |
291 | ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumCallsites)}); |
292 | // Figure out which way we obtain the context object for this function - |
293 | // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the |
294 | // former case, we also set TheRootFuctionData since we need to release it |
295 | // at the end (plus it can be used to know if we have an entrypoint or a |
296 | // regular function) |
297 | // Don't set a name, they end up taking a lot of space and we don't need |
298 | // them. |
299 | |
300 | // Zero-initialize the FunctionData, except for functions that have |
301 | // musttail calls. There, we set the CtxRoot field to 1, which will be |
302 | // treated as a "can't be set as root". |
303 | TheRootFuctionData = new GlobalVariable( |
304 | M, FunctionDataTy, false, GlobalVariable::InternalLinkage, |
305 | HasMusttail ? CannotBeRootInitializer |
306 | : Constant::getNullValue(Ty: FunctionDataTy)); |
307 | |
308 | if (ContextRootSet.contains(V: &F)) { |
309 | Context = Builder.CreateCall( |
310 | Callee: StartCtx, Args: {TheRootFuctionData, Guid, Builder.getInt32(C: NumCounters), |
311 | Builder.getInt32(C: NumCallsites)}); |
312 | ORE.emit( |
313 | RemarkBuilder: [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint" , &F); }); |
314 | } else { |
315 | Context = Builder.CreateCall(Callee: GetCtx, Args: {TheRootFuctionData, &F, Guid, |
316 | Builder.getInt32(C: NumCounters), |
317 | Builder.getInt32(C: NumCallsites)}); |
318 | ORE.emit(RemarkBuilder: [&] { |
319 | return OptimizationRemark(DEBUG_TYPE, "RegularFunction" , &F); |
320 | }); |
321 | } |
322 | // The context could be scratch. |
323 | auto *CtxAsInt = Builder.CreatePtrToInt(V: Context, DestTy: Builder.getInt64Ty()); |
324 | if (NumCallsites > 0) { |
325 | // Figure out which index of the TLS 2-element buffers to use. |
326 | // Scratch context => we use index == 1. Real contexts => index == 0. |
327 | auto *Index = Builder.CreateAnd(LHS: CtxAsInt, RHS: Builder.getInt64(C: 1)); |
328 | // The GEPs corresponding to that index, in the respective TLS. |
329 | ExpectedCalleeTLSAddr = Builder.CreateGEP( |
330 | Ty: PointerType::getUnqual(C&: F.getContext()), |
331 | Ptr: Builder.CreateThreadLocalAddress(Ptr: ExpectedCalleeTLS), IdxList: {Index}); |
332 | CallsiteInfoTLSAddr = Builder.CreateGEP( |
333 | Ty: Builder.getInt32Ty(), |
334 | Ptr: Builder.CreateThreadLocalAddress(Ptr: CallsiteInfoTLS), IdxList: {Index}); |
335 | } |
336 | // Because the context pointer may have LSB set (to indicate scratch), |
337 | // clear it for the value we use as base address for the counter vector. |
338 | // This way, if later we want to have "real" (not clobbered) buffers |
339 | // acting as scratch, the lowering (at least this part of it that deals |
340 | // with counters) stays the same. |
341 | RealContext = Builder.CreateIntToPtr( |
342 | V: Builder.CreateAnd(LHS: CtxAsInt, RHS: Builder.getInt64(C: -2)), |
343 | DestTy: PointerType::getUnqual(C&: F.getContext())); |
344 | I.eraseFromParent(); |
345 | break; |
346 | } |
347 | } |
348 | if (!Context) { |
349 | ORE.emit(RemarkBuilder: [&] { |
350 | return OptimizationRemarkMissed(DEBUG_TYPE, "Skip" , &F) |
351 | << "Function doesn't have instrumentation, skipping" ; |
352 | }); |
353 | return false; |
354 | } |
355 | |
356 | bool ContextWasReleased = false; |
357 | for (auto &BB : F) { |
358 | for (auto &I : llvm::make_early_inc_range(Range&: BB)) { |
359 | if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(Val: &I)) { |
360 | IRBuilder<> Builder(Instr); |
361 | switch (Instr->getIntrinsicID()) { |
362 | case llvm::Intrinsic::instrprof_increment: |
363 | case llvm::Intrinsic::instrprof_increment_step: { |
364 | // Increments (or increment-steps) are just a typical load - increment |
365 | // - store in the RealContext. |
366 | auto *AsStep = cast<InstrProfIncrementInst>(Val: Instr); |
367 | auto *GEP = Builder.CreateGEP( |
368 | Ty: ThisContextType, Ptr: RealContext, |
369 | IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1), AsStep->getIndex()}); |
370 | Builder.CreateStore( |
371 | Val: Builder.CreateAdd(LHS: Builder.CreateLoad(Ty: Builder.getInt64Ty(), Ptr: GEP), |
372 | RHS: AsStep->getStep()), |
373 | Ptr: GEP); |
374 | } break; |
375 | case llvm::Intrinsic::instrprof_callsite: |
376 | // callsite lowering: write the called value in the expected callee |
377 | // TLS we treat the TLS as volatile because of signal handlers and to |
378 | // avoid these being moved away from the callsite they decorate. |
379 | auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Val: Instr); |
380 | Builder.CreateStore(Val: CSIntrinsic->getCallee(), Ptr: ExpectedCalleeTLSAddr, |
381 | isVolatile: true); |
382 | // write the GEP of the slot in the sub-contexts portion of the |
383 | // context in TLS. Now, here, we use the actual Context value - as |
384 | // returned from compiler-rt - which may have the LSB set if the |
385 | // Context was scratch. Since the header of the context object and |
386 | // then the values are all 8-aligned (or, really, insofar as we care, |
387 | // they are even) - if the context is scratch (meaning, an odd value), |
388 | // so will the GEP. This is important because this is then visible to |
389 | // compiler-rt which will produce scratch contexts for callers that |
390 | // have a scratch context. |
391 | Builder.CreateStore( |
392 | Val: Builder.CreateGEP(Ty: ThisContextType, Ptr: Context, |
393 | IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 2), |
394 | CSIntrinsic->getIndex()}), |
395 | Ptr: CallsiteInfoTLSAddr, isVolatile: true); |
396 | break; |
397 | } |
398 | I.eraseFromParent(); |
399 | } else if (!HasMusttail && isa<ReturnInst>(Val: I)) { |
400 | // Remember to release the context if we are an entrypoint. |
401 | IRBuilder<> Builder(&I); |
402 | Builder.CreateCall(Callee: ReleaseCtx, Args: {TheRootFuctionData}); |
403 | ContextWasReleased = true; |
404 | } |
405 | } |
406 | } |
407 | if (!HasMusttail && !ContextWasReleased) |
408 | F.getContext().emitError( |
409 | ErrorStr: "[ctx_prof] A function that doesn't have musttail calls was " |
410 | "instrumented but it has no `ret` " |
411 | "instructions above which to release the context: " + |
412 | F.getName()); |
413 | return true; |
414 | } |
415 | |
416 | PreservedAnalyses NoinlineNonPrevailing::run(Module &M, |
417 | ModuleAnalysisManager &MAM) { |
418 | bool Changed = false; |
419 | for (auto &F : M) { |
420 | if (F.isDeclaration()) |
421 | continue; |
422 | if (F.hasFnAttribute(Kind: Attribute::NoInline)) |
423 | continue; |
424 | if (!F.isWeakForLinker()) |
425 | continue; |
426 | |
427 | if (F.hasFnAttribute(Kind: Attribute::AlwaysInline)) |
428 | F.removeFnAttr(Kind: Attribute::AlwaysInline); |
429 | |
430 | F.addFnAttr(Kind: Attribute::NoInline); |
431 | Changed = true; |
432 | } |
433 | if (Changed) |
434 | return PreservedAnalyses::none(); |
435 | return PreservedAnalyses::all(); |
436 | } |
437 | |