1//===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9
10#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
11#include "llvm/Analysis/OptimizationRemarkEmitter.h"
12#include "llvm/IR/Analysis.h"
13#include "llvm/IR/DiagnosticInfo.h"
14#include "llvm/IR/IRBuilder.h"
15#include "llvm/IR/Instructions.h"
16#include "llvm/IR/IntrinsicInst.h"
17#include "llvm/IR/Module.h"
18#include "llvm/IR/PassManager.h"
19#include "llvm/Support/CommandLine.h"
20#include <utility>
21
22using namespace llvm;
23
24#define DEBUG_TYPE "ctx-instr-lower"
25
26static cl::list<std::string> ContextRoots(
27 "profile-context-root", cl::Hidden,
28 cl::desc(
29 "A function name, assumed to be global, which will be treated as the "
30 "root of an interesting graph, which will be profiled independently "
31 "from other similar graphs."));
32
33bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() {
34 return !ContextRoots.empty();
35}
36
37// the names of symbols we expect in compiler-rt. Using a namespace for
38// readability.
39namespace CompilerRtAPINames {
40static auto StartCtx = "__llvm_ctx_profile_start_context";
41static auto ReleaseCtx = "__llvm_ctx_profile_release_context";
42static auto GetCtx = "__llvm_ctx_profile_get_context";
43static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";
44static auto CallsiteTLS = "__llvm_ctx_profile_callsite";
45} // namespace CompilerRtAPINames
46
47namespace {
48// The lowering logic and state.
49class CtxInstrumentationLowerer final {
50 Module &M;
51 ModuleAnalysisManager &MAM;
52 Type *ContextNodeTy = nullptr;
53 Type *ContextRootTy = nullptr;
54
55 DenseMap<const Function *, Constant *> ContextRootMap;
56 Function *StartCtx = nullptr;
57 Function *GetCtx = nullptr;
58 Function *ReleaseCtx = nullptr;
59 GlobalVariable *ExpectedCalleeTLS = nullptr;
60 GlobalVariable *CallsiteInfoTLS = nullptr;
61
62public:
63 CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
64 // return true if lowering happened (i.e. a change was made)
65 bool lowerFunction(Function &F);
66};
67
68// llvm.instrprof.increment[.step] captures the total number of counters as one
69// of its parameters, and llvm.instrprof.callsite captures the total number of
70// callsites. Those values are the same for instances of those intrinsics in
71// this function. Find the first instance of each and return them.
72std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) {
73 uint32_t NrCounters = 0;
74 uint32_t NrCallsites = 0;
75 for (const auto &BB : F) {
76 for (const auto &I : BB) {
77 if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(Val: &I)) {
78 uint32_t V =
79 static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
80 assert((!NrCounters || V == NrCounters) &&
81 "expected all llvm.instrprof.increment[.step] intrinsics to "
82 "have the same total nr of counters parameter");
83 NrCounters = V;
84 } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(Val: &I)) {
85 uint32_t V =
86 static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
87 assert((!NrCallsites || V == NrCallsites) &&
88 "expected all llvm.instrprof.callsite intrinsics to have the "
89 "same total nr of callsites parameter");
90 NrCallsites = V;
91 }
92#if NDEBUG
93 if (NrCounters && NrCallsites)
94 return std::make_pair(x&: NrCounters, y&: NrCallsites);
95#endif
96 }
97 }
98 return {NrCounters, NrCallsites};
99}
100} // namespace
101
102// set up tie-in with compiler-rt.
103// NOTE!!!
104// These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
105CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
106 ModuleAnalysisManager &MAM)
107 : M(M), MAM(MAM) {
108 auto *PointerTy = PointerType::get(C&: M.getContext(), AddressSpace: 0);
109 auto *SanitizerMutexType = Type::getInt8Ty(C&: M.getContext());
110 auto *I32Ty = Type::getInt32Ty(C&: M.getContext());
111 auto *I64Ty = Type::getInt64Ty(C&: M.getContext());
112
113 // The ContextRoot type
114 ContextRootTy =
115 StructType::get(Context&: M.getContext(), Elements: {
116 PointerTy, /*FirstNode*/
117 PointerTy, /*FirstMemBlock*/
118 PointerTy, /*CurrentMem*/
119 SanitizerMutexType, /*Taken*/
120 });
121 // The Context header.
122 ContextNodeTy = StructType::get(Context&: M.getContext(), Elements: {
123 I64Ty, /*Guid*/
124 PointerTy, /*Next*/
125 I32Ty, /*NrCounters*/
126 I32Ty, /*NrCallsites*/
127 });
128
129 // Define a global for each entrypoint. We'll reuse the entrypoint's name as
130 // prefix. We assume the entrypoint names to be unique.
131 for (const auto &Fname : ContextRoots) {
132 if (const auto *F = M.getFunction(Name: Fname)) {
133 if (F->isDeclaration())
134 continue;
135 auto *G = M.getOrInsertGlobal(Name: Fname + "_ctx_root", Ty: ContextRootTy);
136 cast<GlobalVariable>(Val: G)->setInitializer(
137 Constant::getNullValue(Ty: ContextRootTy));
138 ContextRootMap.insert(KV: std::make_pair(x&: F, y&: G));
139 for (const auto &BB : *F)
140 for (const auto &I : BB)
141 if (const auto *CB = dyn_cast<CallBase>(Val: &I))
142 if (CB->isMustTailCall()) {
143 M.getContext().emitError(
144 ErrorStr: "The function " + Fname +
145 " was indicated as a context root, but it features musttail "
146 "calls, which is not supported.");
147 }
148 }
149 }
150
151 // Declare the functions we will call.
152 StartCtx = cast<Function>(
153 Val: M.getOrInsertFunction(
154 Name: CompilerRtAPINames::StartCtx,
155 T: FunctionType::get(Result: ContextNodeTy->getPointerTo(),
156 Params: {ContextRootTy->getPointerTo(), /*ContextRoot*/
157 I64Ty, /*Guid*/ I32Ty,
158 /*NrCounters*/ I32Ty /*NrCallsites*/},
159 isVarArg: false))
160 .getCallee());
161 GetCtx = cast<Function>(
162 Val: M.getOrInsertFunction(Name: CompilerRtAPINames::GetCtx,
163 T: FunctionType::get(Result: ContextNodeTy->getPointerTo(),
164 Params: {PointerTy, /*Callee*/
165 I64Ty, /*Guid*/
166 I32Ty, /*NrCounters*/
167 I32Ty}, /*NrCallsites*/
168 isVarArg: false))
169 .getCallee());
170 ReleaseCtx = cast<Function>(
171 Val: M.getOrInsertFunction(
172 Name: CompilerRtAPINames::ReleaseCtx,
173 T: FunctionType::get(Result: Type::getVoidTy(C&: M.getContext()),
174 Params: {
175 ContextRootTy->getPointerTo(), /*ContextRoot*/
176 },
177 isVarArg: false))
178 .getCallee());
179
180 // Declare the TLSes we will need to use.
181 CallsiteInfoTLS =
182 new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
183 nullptr, CompilerRtAPINames::CallsiteTLS);
184 CallsiteInfoTLS->setThreadLocal(true);
185 CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
186 ExpectedCalleeTLS =
187 new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
188 nullptr, CompilerRtAPINames::ExpectedCalleeTLS);
189 ExpectedCalleeTLS->setThreadLocal(true);
190 ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
191}
192
193PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
194 ModuleAnalysisManager &MAM) {
195 CtxInstrumentationLowerer Lowerer(M, MAM);
196 bool Changed = false;
197 for (auto &F : M)
198 Changed |= Lowerer.lowerFunction(F);
199 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
200}
201
202bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
203 if (F.isDeclaration())
204 return false;
205 auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(IR&: M).getManager();
206 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: F);
207
208 Value *Guid = nullptr;
209 auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F);
210
211 Value *Context = nullptr;
212 Value *RealContext = nullptr;
213
214 StructType *ThisContextType = nullptr;
215 Value *TheRootContext = nullptr;
216 Value *ExpectedCalleeTLSAddr = nullptr;
217 Value *CallsiteInfoTLSAddr = nullptr;
218
219 auto &Head = F.getEntryBlock();
220 for (auto &I : Head) {
221 // Find the increment intrinsic in the entry basic block.
222 if (auto *Mark = dyn_cast<InstrProfIncrementInst>(Val: &I)) {
223 assert(Mark->getIndex()->isZero());
224
225 IRBuilder<> Builder(Mark);
226 // FIXME(mtrofin): use InstrProfSymtab::getCanonicalName
227 Guid = Builder.getInt64(C: F.getGUID());
228 // The type of the context of this function is now knowable since we have
229 // NrCallsites and NrCounters. We delcare it here because it's more
230 // convenient - we have the Builder.
231 ThisContextType = StructType::get(
232 Context&: F.getContext(),
233 Elements: {ContextNodeTy, ArrayType::get(ElementType: Builder.getInt64Ty(), NumElements: NrCounters),
234 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NrCallsites)});
235 // Figure out which way we obtain the context object for this function -
236 // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
237 // former case, we also set TheRootContext since we need to release it
238 // at the end (plus it can be used to know if we have an entrypoint or a
239 // regular function)
240 auto Iter = ContextRootMap.find(Val: &F);
241 if (Iter != ContextRootMap.end()) {
242 TheRootContext = Iter->second;
243 Context = Builder.CreateCall(Callee: StartCtx, Args: {TheRootContext, Guid,
244 Builder.getInt32(C: NrCounters),
245 Builder.getInt32(C: NrCallsites)});
246 ORE.emit(
247 RemarkBuilder: [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });
248 } else {
249 Context =
250 Builder.CreateCall(Callee: GetCtx, Args: {&F, Guid, Builder.getInt32(C: NrCounters),
251 Builder.getInt32(C: NrCallsites)});
252 ORE.emit(RemarkBuilder: [&] {
253 return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F);
254 });
255 }
256 // The context could be scratch.
257 auto *CtxAsInt = Builder.CreatePtrToInt(V: Context, DestTy: Builder.getInt64Ty());
258 if (NrCallsites > 0) {
259 // Figure out which index of the TLS 2-element buffers to use.
260 // Scratch context => we use index == 1. Real contexts => index == 0.
261 auto *Index = Builder.CreateAnd(LHS: CtxAsInt, RHS: Builder.getInt64(C: 1));
262 // The GEPs corresponding to that index, in the respective TLS.
263 ExpectedCalleeTLSAddr = Builder.CreateGEP(
264 Ty: Builder.getInt8Ty()->getPointerTo(),
265 Ptr: Builder.CreateThreadLocalAddress(Ptr: ExpectedCalleeTLS), IdxList: {Index});
266 CallsiteInfoTLSAddr = Builder.CreateGEP(
267 Ty: Builder.getInt32Ty(),
268 Ptr: Builder.CreateThreadLocalAddress(Ptr: CallsiteInfoTLS), IdxList: {Index});
269 }
270 // Because the context pointer may have LSB set (to indicate scratch),
271 // clear it for the value we use as base address for the counter vector.
272 // This way, if later we want to have "real" (not clobbered) buffers
273 // acting as scratch, the lowering (at least this part of it that deals
274 // with counters) stays the same.
275 RealContext = Builder.CreateIntToPtr(
276 V: Builder.CreateAnd(LHS: CtxAsInt, RHS: Builder.getInt64(C: -2)),
277 DestTy: ThisContextType->getPointerTo());
278 I.eraseFromParent();
279 break;
280 }
281 }
282 if (!Context) {
283 ORE.emit(RemarkBuilder: [&] {
284 return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)
285 << "Function doesn't have instrumentation, skipping";
286 });
287 return false;
288 }
289
290 bool ContextWasReleased = false;
291 for (auto &BB : F) {
292 for (auto &I : llvm::make_early_inc_range(Range&: BB)) {
293 if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(Val: &I)) {
294 IRBuilder<> Builder(Instr);
295 switch (Instr->getIntrinsicID()) {
296 case llvm::Intrinsic::instrprof_increment:
297 case llvm::Intrinsic::instrprof_increment_step: {
298 // Increments (or increment-steps) are just a typical load - increment
299 // - store in the RealContext.
300 auto *AsStep = cast<InstrProfIncrementInst>(Val: Instr);
301 auto *GEP = Builder.CreateGEP(
302 Ty: ThisContextType, Ptr: RealContext,
303 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1), AsStep->getIndex()});
304 Builder.CreateStore(
305 Val: Builder.CreateAdd(LHS: Builder.CreateLoad(Ty: Builder.getInt64Ty(), Ptr: GEP),
306 RHS: AsStep->getStep()),
307 Ptr: GEP);
308 } break;
309 case llvm::Intrinsic::instrprof_callsite:
310 // callsite lowering: write the called value in the expected callee
311 // TLS we treat the TLS as volatile because of signal handlers and to
312 // avoid these being moved away from the callsite they decorate.
313 auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Val: Instr);
314 Builder.CreateStore(Val: CSIntrinsic->getCallee(), Ptr: ExpectedCalleeTLSAddr,
315 isVolatile: true);
316 // write the GEP of the slot in the sub-contexts portion of the
317 // context in TLS. Now, here, we use the actual Context value - as
318 // returned from compiler-rt - which may have the LSB set if the
319 // Context was scratch. Since the header of the context object and
320 // then the values are all 8-aligned (or, really, insofar as we care,
321 // they are even) - if the context is scratch (meaning, an odd value),
322 // so will the GEP. This is important because this is then visible to
323 // compiler-rt which will produce scratch contexts for callers that
324 // have a scratch context.
325 Builder.CreateStore(
326 Val: Builder.CreateGEP(Ty: ThisContextType, Ptr: Context,
327 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 2),
328 CSIntrinsic->getIndex()}),
329 Ptr: CallsiteInfoTLSAddr, isVolatile: true);
330 break;
331 }
332 I.eraseFromParent();
333 } else if (TheRootContext && isa<ReturnInst>(Val: I)) {
334 // Remember to release the context if we are an entrypoint.
335 IRBuilder<> Builder(&I);
336 Builder.CreateCall(Callee: ReleaseCtx, Args: {TheRootContext});
337 ContextWasReleased = true;
338 }
339 }
340 }
341 // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
342 // to disallow this, (so this then stays as an error), another is to detect
343 // that and then do a wrapper or disallow the tail call. This only affects
344 // instrumentation, when we want to detect the call graph.
345 if (TheRootContext && !ContextWasReleased)
346 F.getContext().emitError(
347 ErrorStr: "[ctx_prof] An entrypoint was instrumented but it has no `ret` "
348 "instructions above which to release the context: " +
349 F.getName());
350 return true;
351}
352