1//===- CtxProfAnalysis.cpp - contextual profile analysis ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation of the contextual profile analysis, which maintains contextual
10// profiling info through IPO passes.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/CtxProfAnalysis.h"
15#include "llvm/ADT/APInt.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/Analysis/CFG.h"
18#include "llvm/IR/Analysis.h"
19#include "llvm/IR/Dominators.h"
20#include "llvm/IR/IntrinsicInst.h"
21#include "llvm/IR/Module.h"
22#include "llvm/IR/PassManager.h"
23#include "llvm/ProfileData/PGOCtxProfReader.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/MemoryBuffer.h"
26#include "llvm/Support/Path.h"
27#include <deque>
28#include <memory>
29
30#define DEBUG_TYPE "ctx_prof"
31
32using namespace llvm;
33cl::opt<std::string>
34 UseCtxProfile("use-ctx-profile", cl::init(Val: ""), cl::Hidden,
35 cl::desc("Use the specified contextual profile file"));
36
37static cl::opt<CtxProfAnalysisPrinterPass::PrintMode> PrintLevel(
38 "ctx-profile-printer-level",
39 cl::init(Val: CtxProfAnalysisPrinterPass::PrintMode::YAML), cl::Hidden,
40 cl::values(clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::Everything,
41 "everything", "print everything - most verbose"),
42 clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::YAML, "yaml",
43 "just the yaml representation of the profile")),
44 cl::desc("Verbosity level of the contextual profile printer pass."));
45
46static cl::opt<bool> ForceIsInSpecializedModule(
47 "ctx-profile-force-is-specialized", cl::init(Val: false),
48 cl::desc("Treat the given module as-if it were containing the "
49 "post-thinlink module containing the root"));
50
51const char *AssignGUIDPass::GUIDMetadataName = "guid";
52
53namespace llvm {
54class ProfileAnnotatorImpl final {
55 friend class ProfileAnnotator;
56 class BBInfo;
57 struct EdgeInfo {
58 BBInfo *const Src;
59 BBInfo *const Dest;
60 std::optional<uint64_t> Count;
61
62 explicit EdgeInfo(BBInfo &Src, BBInfo &Dest) : Src(&Src), Dest(&Dest) {}
63 };
64
65 class BBInfo {
66 std::optional<uint64_t> Count;
67 // OutEdges is dimensioned to match the number of terminator operands.
68 // Entries in the vector match the index in the terminator operand list. In
69 // some cases - see `shouldExcludeEdge` and its implementation - an entry
70 // will be nullptr.
71 // InEdges doesn't have the above constraint.
72 SmallVector<EdgeInfo *> OutEdges;
73 SmallVector<EdgeInfo *> InEdges;
74 size_t UnknownCountOutEdges = 0;
75 size_t UnknownCountInEdges = 0;
76
77 // Pass AssumeAllKnown when we try to propagate counts from edges to BBs -
78 // because all the edge counters must be known.
79 // Return std::nullopt if there were no edges to sum. The user can decide
80 // how to interpret that.
81 std::optional<uint64_t> getEdgeSum(const SmallVector<EdgeInfo *> &Edges,
82 bool AssumeAllKnown) const {
83 std::optional<uint64_t> Sum;
84 for (const auto *E : Edges) {
85 // `Edges` may be `OutEdges`, case in which `E` could be nullptr.
86 if (E) {
87 if (!Sum.has_value())
88 Sum = 0;
89 *Sum += (AssumeAllKnown ? *E->Count : E->Count.value_or(u: 0U));
90 }
91 }
92 return Sum;
93 }
94
95 bool computeCountFrom(const SmallVector<EdgeInfo *> &Edges) {
96 assert(!Count.has_value());
97 Count = getEdgeSum(Edges, AssumeAllKnown: true);
98 return Count.has_value();
99 }
100
101 void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges) {
102 uint64_t KnownSum = getEdgeSum(Edges, AssumeAllKnown: false).value_or(u: 0U);
103 uint64_t EdgeVal = *Count > KnownSum ? *Count - KnownSum : 0U;
104 EdgeInfo *E = nullptr;
105 for (auto *I : Edges)
106 if (I && !I->Count.has_value()) {
107 E = I;
108#ifdef NDEBUG
109 break;
110#else
111 assert((!E || E == I) &&
112 "Expected exactly one edge to have an unknown count, "
113 "found a second one");
114 continue;
115#endif
116 }
117 assert(E && "Expected exactly one edge to have an unknown count");
118 assert(!E->Count.has_value());
119 E->Count = EdgeVal;
120 assert(E->Src->UnknownCountOutEdges > 0);
121 assert(E->Dest->UnknownCountInEdges > 0);
122 --E->Src->UnknownCountOutEdges;
123 --E->Dest->UnknownCountInEdges;
124 }
125
126 public:
127 BBInfo(size_t NumInEdges, size_t NumOutEdges, std::optional<uint64_t> Count)
128 : Count(Count) {
129 // For in edges, we just want to pre-allocate enough space, since we know
130 // it at this stage. For out edges, we will insert edges at the indices
131 // corresponding to positions in this BB's terminator instruction, so we
132 // construct a default (nullptr values)-initialized vector. A nullptr edge
133 // corresponds to those that are excluded (see shouldExcludeEdge).
134 InEdges.reserve(N: NumInEdges);
135 OutEdges.resize(N: NumOutEdges);
136 }
137
138 bool tryTakeCountFromKnownOutEdges(const BasicBlock &BB) {
139 if (!UnknownCountOutEdges) {
140 return computeCountFrom(Edges: OutEdges);
141 }
142 return false;
143 }
144
145 bool tryTakeCountFromKnownInEdges(const BasicBlock &BB) {
146 if (!UnknownCountInEdges) {
147 return computeCountFrom(Edges: InEdges);
148 }
149 return false;
150 }
151
152 void addInEdge(EdgeInfo &Info) {
153 InEdges.push_back(Elt: &Info);
154 ++UnknownCountInEdges;
155 }
156
157 // For the out edges, we care about the position we place them in, which is
158 // the position in terminator instruction's list (at construction). Later,
159 // we build branch_weights metadata with edge frequency values matching
160 // these positions.
161 void addOutEdge(size_t Index, EdgeInfo &Info) {
162 OutEdges[Index] = &Info;
163 ++UnknownCountOutEdges;
164 }
165
166 bool hasCount() const { return Count.has_value(); }
167
168 uint64_t getCount() const { return *Count; }
169
170 bool trySetSingleUnknownInEdgeCount() {
171 if (UnknownCountInEdges == 1) {
172 setSingleUnknownEdgeCount(InEdges);
173 return true;
174 }
175 return false;
176 }
177
178 bool trySetSingleUnknownOutEdgeCount() {
179 if (UnknownCountOutEdges == 1) {
180 setSingleUnknownEdgeCount(OutEdges);
181 return true;
182 }
183 return false;
184 }
185 size_t getNumOutEdges() const { return OutEdges.size(); }
186
187 uint64_t getEdgeCount(size_t Index) const {
188 if (auto *E = OutEdges[Index])
189 return *E->Count;
190 return 0U;
191 }
192 };
193
194 const Function &F;
195 ArrayRef<uint64_t> Counters;
196 // To be accessed through getBBInfo() after construction.
197 std::map<const BasicBlock *, BBInfo> BBInfos;
198 std::vector<EdgeInfo> EdgeInfos;
199
200 // The only criteria for exclusion is faux suspend -> exit edges in presplit
201 // coroutines. The API serves for readability, currently.
202 bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const {
203 return llvm::isPresplitCoroSuspendExitEdge(Src, Dest);
204 }
205
206 BBInfo &getBBInfo(const BasicBlock &BB) { return BBInfos.find(x: &BB)->second; }
207
208 const BBInfo &getBBInfo(const BasicBlock &BB) const {
209 return BBInfos.find(x: &BB)->second;
210 }
211
212 // validation function after we propagate the counters: all BBs and edges'
213 // counters must have a value.
214 bool allCountersAreAssigned() const {
215 for (const auto &BBInfo : BBInfos)
216 if (!BBInfo.second.hasCount())
217 return false;
218 for (const auto &EdgeInfo : EdgeInfos)
219 if (!EdgeInfo.Count.has_value())
220 return false;
221 return true;
222 }
223
224 /// Check that all paths from the entry basic block that use edges with
225 /// non-zero counts arrive at a basic block with no successors (i.e. "exit")
226 bool allTakenPathsExit() const {
227 std::deque<const BasicBlock *> Worklist;
228 DenseSet<const BasicBlock *> Visited;
229 Worklist.push_back(x: &F.getEntryBlock());
230 bool HitExit = false;
231 while (!Worklist.empty()) {
232 const auto *BB = Worklist.front();
233 Worklist.pop_front();
234 if (!Visited.insert(V: BB).second)
235 continue;
236 if (succ_size(BB) == 0) {
237 if (isa<UnreachableInst>(Val: BB->getTerminator()))
238 return false;
239 HitExit = true;
240 continue;
241 }
242 if (succ_size(BB) == 1) {
243 Worklist.push_back(x: BB->getUniqueSuccessor());
244 continue;
245 }
246 const auto &BBInfo = getBBInfo(BB: *BB);
247 bool HasAWayOut = false;
248 for (auto I = 0U; I < BB->getTerminator()->getNumSuccessors(); ++I) {
249 const auto *Succ = BB->getTerminator()->getSuccessor(Idx: I);
250 if (!shouldExcludeEdge(Src: *BB, Dest: *Succ)) {
251 if (BBInfo.getEdgeCount(Index: I) > 0) {
252 HasAWayOut = true;
253 Worklist.push_back(x: Succ);
254 }
255 }
256 }
257 if (!HasAWayOut)
258 return false;
259 }
260 return HitExit;
261 }
262
263 bool allNonColdSelectsHaveProfile() const {
264 for (const auto &BB : F) {
265 if (getBBInfo(BB).getCount() > 0) {
266 for (const auto &I : BB) {
267 if (const auto *SI = dyn_cast<SelectInst>(Val: &I)) {
268 if (const auto *Inst = CtxProfAnalysis::getSelectInstrumentation(
269 SI&: *const_cast<SelectInst *>(SI))) {
270 auto Index = Inst->getIndex()->getZExtValue();
271 assert(Index < Counters.size());
272 if (Counters[Index] == 0)
273 return false;
274 }
275 }
276 }
277 }
278 }
279 return true;
280 }
281
282 // This is an adaptation of PGOUseFunc::populateCounters.
283 // FIXME(mtrofin): look into factoring the code to share one implementation.
284 void propagateCounterValues() {
285 bool KeepGoing = true;
286 while (KeepGoing) {
287 KeepGoing = false;
288 for (const auto &BB : F) {
289 auto &Info = getBBInfo(BB);
290 if (!Info.hasCount())
291 KeepGoing |= Info.tryTakeCountFromKnownOutEdges(BB) ||
292 Info.tryTakeCountFromKnownInEdges(BB);
293 if (Info.hasCount()) {
294 KeepGoing |= Info.trySetSingleUnknownOutEdgeCount();
295 KeepGoing |= Info.trySetSingleUnknownInEdgeCount();
296 }
297 }
298 }
299 assert(allCountersAreAssigned() &&
300 "[ctx-prof] Expected all counters have been assigned.");
301 assert(allTakenPathsExit() &&
302 "[ctx-prof] Encountered a BB with more than one successor, where "
303 "all outgoing edges have a 0 count. This occurs in non-exiting "
304 "functions (message pumps, usually) which are not supported in the "
305 "contextual profiling case");
306 assert(allNonColdSelectsHaveProfile() &&
307 "[ctx-prof] All non-cold select instructions were expected to have "
308 "a profile.");
309 }
310
311public:
312 ProfileAnnotatorImpl(const Function &F, ArrayRef<uint64_t> Counters)
313 : F(F), Counters(Counters) {
314 assert(!F.isDeclaration());
315 assert(!Counters.empty());
316 size_t NrEdges = 0;
317 for (const auto &BB : F) {
318 std::optional<uint64_t> Count;
319 if (auto *Ins = CtxProfAnalysis::getBBInstrumentation(
320 BB&: const_cast<BasicBlock &>(BB))) {
321 auto Index = Ins->getIndex()->getZExtValue();
322 assert(Index < Counters.size() &&
323 "The index must be inside the counters vector by construction - "
324 "tripping this assertion indicates a bug in how the contextual "
325 "profile is managed by IPO transforms");
326 (void)Index;
327 Count = Counters[Ins->getIndex()->getZExtValue()];
328 } else if (isa<UnreachableInst>(Val: BB.getTerminator())) {
329 // The program presumably didn't crash.
330 Count = 0;
331 }
332 auto [It, Ins] =
333 BBInfos.insert(x: {&BB, {pred_size(BB: &BB), succ_size(BB: &BB), Count}});
334 (void)Ins;
335 assert(Ins && "We iterate through the function's BBs, no reason to "
336 "insert one more than once");
337 NrEdges += llvm::count_if(Range: successors(BB: &BB), P: [&](const auto *Succ) {
338 return !shouldExcludeEdge(Src: BB, Dest: *Succ);
339 });
340 }
341 // Pre-allocate the vector, we want references to its contents to be stable.
342 EdgeInfos.reserve(n: NrEdges);
343 for (const auto &BB : F) {
344 auto &Info = getBBInfo(BB);
345 for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) {
346 const auto *Succ = BB.getTerminator()->getSuccessor(Idx: I);
347 if (!shouldExcludeEdge(Src: BB, Dest: *Succ)) {
348 auto &EI = EdgeInfos.emplace_back(args&: getBBInfo(BB), args&: getBBInfo(BB: *Succ));
349 Info.addOutEdge(Index: I, Info&: EI);
350 getBBInfo(BB: *Succ).addInEdge(Info&: EI);
351 }
352 }
353 }
354 assert(EdgeInfos.capacity() == NrEdges &&
355 "The capacity of EdgeInfos should have stayed unchanged it was "
356 "populated, because we need pointers to its contents to be stable");
357 propagateCounterValues();
358 }
359
360 uint64_t getBBCount(const BasicBlock &BB) { return getBBInfo(BB).getCount(); }
361};
362
363} // namespace llvm
364
365ProfileAnnotator::ProfileAnnotator(const Function &F,
366 ArrayRef<uint64_t> RawCounters)
367 : PImpl(std::make_unique<ProfileAnnotatorImpl>(args: F, args&: RawCounters)) {}
368
369ProfileAnnotator::~ProfileAnnotator() = default;
370
371uint64_t ProfileAnnotator::getBBCount(const BasicBlock &BB) const {
372 return PImpl->getBBCount(BB);
373}
374
375bool ProfileAnnotator::getSelectInstrProfile(SelectInst &SI,
376 uint64_t &TrueCount,
377 uint64_t &FalseCount) const {
378 const auto &BBInfo = PImpl->getBBInfo(BB: *SI.getParent());
379 TrueCount = FalseCount = 0;
380 if (BBInfo.getCount() == 0)
381 return false;
382
383 auto *Step = CtxProfAnalysis::getSelectInstrumentation(SI);
384 if (!Step)
385 return false;
386 auto Index = Step->getIndex()->getZExtValue();
387 assert(Index < PImpl->Counters.size() &&
388 "The index of the step instruction must be inside the "
389 "counters vector by "
390 "construction - tripping this assertion indicates a bug in "
391 "how the contextual profile is managed by IPO transforms");
392 auto TotalCount = BBInfo.getCount();
393 TrueCount = PImpl->Counters[Index];
394 FalseCount = (TotalCount > TrueCount ? TotalCount - TrueCount : 0U);
395 return true;
396}
397
398bool ProfileAnnotator::getOutgoingBranchWeights(
399 BasicBlock &BB, SmallVectorImpl<uint64_t> &Profile,
400 uint64_t &MaxCount) const {
401 Profile.clear();
402
403 if (succ_size(BB: &BB) < 2)
404 return false;
405
406 auto *Term = BB.getTerminator();
407 Profile.resize(N: Term->getNumSuccessors());
408
409 const auto &BBInfo = PImpl->getBBInfo(BB);
410 MaxCount = 0;
411 for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size;
412 ++SuccIdx) {
413 uint64_t EdgeCount = BBInfo.getEdgeCount(Index: SuccIdx);
414 if (EdgeCount > MaxCount)
415 MaxCount = EdgeCount;
416 Profile[SuccIdx] = EdgeCount;
417 }
418 return MaxCount > 0;
419}
420
421PreservedAnalyses AssignGUIDPass::run(Module &M, ModuleAnalysisManager &MAM) {
422 for (auto &F : M.functions()) {
423 if (F.isDeclaration())
424 continue;
425 if (F.getMetadata(Kind: GUIDMetadataName))
426 continue;
427 const GlobalValue::GUID GUID = F.getGUID();
428 F.setMetadata(Kind: GUIDMetadataName,
429 Node: MDNode::get(Context&: M.getContext(),
430 MDs: {ConstantAsMetadata::get(C: ConstantInt::get(
431 Ty: Type::getInt64Ty(C&: M.getContext()), V: GUID))}));
432 }
433 return PreservedAnalyses::none();
434}
435
436GlobalValue::GUID AssignGUIDPass::getGUID(const Function &F) {
437 if (F.isDeclaration()) {
438 assert(GlobalValue::isExternalLinkage(F.getLinkage()));
439 return F.getGUID();
440 }
441 auto *MD = F.getMetadata(Kind: GUIDMetadataName);
442 assert(MD && "guid not found for defined function");
443 return cast<ConstantInt>(Val: cast<ConstantAsMetadata>(Val: MD->getOperand(I: 0))
444 ->getValue()
445 ->stripPointerCasts())
446 ->getZExtValue();
447}
448AnalysisKey CtxProfAnalysis::Key;
449
450CtxProfAnalysis::CtxProfAnalysis(std::optional<StringRef> Profile)
451 : Profile([&]() -> std::optional<StringRef> {
452 if (Profile)
453 return *Profile;
454 if (UseCtxProfile.getNumOccurrences())
455 return UseCtxProfile;
456 return std::nullopt;
457 }()) {}
458
459PGOContextualProfile CtxProfAnalysis::run(Module &M,
460 ModuleAnalysisManager &MAM) {
461 if (!Profile)
462 return {};
463 ErrorOr<std::unique_ptr<MemoryBuffer>> MB = MemoryBuffer::getFile(Filename: *Profile);
464 if (auto EC = MB.getError()) {
465 M.getContext().emitError(ErrorStr: "could not open contextual profile file: " +
466 EC.message());
467 return {};
468 }
469 PGOCtxProfileReader Reader(MB.get()->getBuffer());
470 auto MaybeProfiles = Reader.loadProfiles();
471 if (!MaybeProfiles) {
472 M.getContext().emitError(ErrorStr: "contextual profile file is invalid: " +
473 toString(E: MaybeProfiles.takeError()));
474 return {};
475 }
476
477 // FIXME: We should drive this from ThinLTO, but for the time being, use the
478 // module name as indicator.
479 // We want to *only* keep the contextual profiles in modules that capture
480 // context trees. That allows us to compute specific PSIs, for example.
481 auto DetermineRootsInModule = [&M]() -> const DenseSet<GlobalValue::GUID> {
482 DenseSet<GlobalValue::GUID> ProfileRootsInModule;
483 auto ModName = M.getName();
484 auto Filename = sys::path::filename(path: ModName);
485 // Drop the file extension.
486 Filename = Filename.substr(Start: 0, N: Filename.find_last_of(C: '.'));
487 // See if it parses
488 APInt Guid;
489 // getAsInteger returns true if there are more chars to read other than the
490 // integer. So the "false" test is what we want.
491 if (!Filename.getAsInteger(Radix: 0, Result&: Guid))
492 ProfileRootsInModule.insert(V: Guid.getZExtValue());
493 return ProfileRootsInModule;
494 };
495 const auto ProfileRootsInModule = DetermineRootsInModule();
496 PGOContextualProfile Result;
497
498 // the logic from here on allows for modules that contain - by design - more
499 // than one root. We currently don't support that, because the determination
500 // happens based on the module name matching the root guid, but the logic can
501 // avoid assuming that.
502 if (!ProfileRootsInModule.empty()) {
503 Result.IsInSpecializedModule = true;
504 // Trim first the roots that aren't in this module.
505 for (auto &[RootGuid, _] :
506 llvm::make_early_inc_range(Range&: MaybeProfiles->Contexts))
507 if (!ProfileRootsInModule.contains(V: RootGuid))
508 MaybeProfiles->Contexts.erase(x: RootGuid);
509 // we can also drop the flat profiles
510 MaybeProfiles->FlatProfiles.clear();
511 }
512
513 for (const auto &F : M) {
514 if (F.isDeclaration())
515 continue;
516 auto GUID = AssignGUIDPass::getGUID(F);
517 assert(GUID && "guid not found for defined function");
518 const auto &Entry = F.begin();
519 uint32_t MaxCounters = 0; // we expect at least a counter.
520 for (const auto &I : *Entry)
521 if (auto *C = dyn_cast<InstrProfIncrementInst>(Val: &I)) {
522 MaxCounters =
523 static_cast<uint32_t>(C->getNumCounters()->getZExtValue());
524 break;
525 }
526 if (!MaxCounters)
527 continue;
528 uint32_t MaxCallsites = 0;
529 for (const auto &BB : F)
530 for (const auto &I : BB)
531 if (auto *C = dyn_cast<InstrProfCallsite>(Val: &I)) {
532 MaxCallsites =
533 static_cast<uint32_t>(C->getNumCounters()->getZExtValue());
534 break;
535 }
536 auto [It, Ins] = Result.FuncInfo.insert(
537 x: {GUID, PGOContextualProfile::FunctionInfo(F.getName())});
538 (void)Ins;
539 assert(Ins);
540 It->second.NextCallsiteIndex = MaxCallsites;
541 It->second.NextCounterIndex = MaxCounters;
542 }
543 // If we made it this far, the Result is valid - which we mark by setting
544 // .Profiles.
545 Result.Profiles = std::move(*MaybeProfiles);
546 Result.initIndex();
547 return Result;
548}
549
550GlobalValue::GUID
551PGOContextualProfile::getDefinedFunctionGUID(const Function &F) const {
552 if (auto It = FuncInfo.find(x: AssignGUIDPass::getGUID(F)); It != FuncInfo.end())
553 return It->first;
554 return 0;
555}
556
557CtxProfAnalysisPrinterPass::CtxProfAnalysisPrinterPass(raw_ostream &OS)
558 : OS(OS), Mode(PrintLevel) {}
559
560PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
561 ModuleAnalysisManager &MAM) {
562 CtxProfAnalysis::Result &C = MAM.getResult<CtxProfAnalysis>(IR&: M);
563 if (C.contexts().empty()) {
564 OS << "No contextual profile was provided.\n";
565 return PreservedAnalyses::all();
566 }
567
568 if (Mode == PrintMode::Everything) {
569 OS << "Function Info:\n";
570 for (const auto &[Guid, FuncInfo] : C.FuncInfo)
571 OS << Guid << " : " << FuncInfo.Name
572 << ". MaxCounterID: " << FuncInfo.NextCounterIndex
573 << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
574 }
575
576 if (Mode == PrintMode::Everything)
577 OS << "\nCurrent Profile:\n";
578 convertCtxProfToYaml(OS, Profile: C.profiles());
579 OS << "\n";
580 if (Mode == PrintMode::YAML)
581 return PreservedAnalyses::all();
582
583 OS << "\nFlat Profile:\n";
584 auto Flat = C.flatten();
585 for (const auto &[Guid, Counters] : Flat) {
586 OS << Guid << " : ";
587 for (auto V : Counters)
588 OS << V << " ";
589 OS << "\n";
590 }
591 return PreservedAnalyses::all();
592}
593
594InstrProfCallsite *CtxProfAnalysis::getCallsiteInstrumentation(CallBase &CB) {
595 if (!InstrProfCallsite::canInstrumentCallsite(CB))
596 return nullptr;
597 for (auto *Prev = CB.getPrevNode(); Prev; Prev = Prev->getPrevNode()) {
598 if (auto *IPC = dyn_cast<InstrProfCallsite>(Val: Prev))
599 return IPC;
600 assert(!isa<CallBase>(Prev) &&
601 "didn't expect to find another call, that's not the callsite "
602 "instrumentation, before an instrumentable callsite");
603 }
604 return nullptr;
605}
606
607InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
608 for (auto &I : BB)
609 if (auto *Incr = dyn_cast<InstrProfIncrementInst>(Val: &I))
610 if (!isa<InstrProfIncrementInstStep>(Val: &I))
611 return Incr;
612 return nullptr;
613}
614
615InstrProfIncrementInstStep *
616CtxProfAnalysis::getSelectInstrumentation(SelectInst &SI) {
617 Instruction *Prev = &SI;
618 while ((Prev = Prev->getPrevNode()))
619 if (auto *Step = dyn_cast<InstrProfIncrementInstStep>(Val: Prev))
620 return Step;
621 return nullptr;
622}
623
624template <class ProfTy>
625static void preorderVisitOneRoot(ProfTy &Profile,
626 function_ref<void(ProfTy &)> Visitor) {
627 std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) {
628 Visitor(Ctx);
629 for (auto &[_, SubCtxSet] : Ctx.callsites())
630 for (auto &[__, Subctx] : SubCtxSet)
631 Traverser(Subctx);
632 };
633 Traverser(Profile);
634}
635
636template <class ProfilesTy, class ProfTy>
637static void preorderVisit(ProfilesTy &Profiles,
638 function_ref<void(ProfTy &)> Visitor) {
639 for (auto &[_, P] : Profiles)
640 preorderVisitOneRoot<ProfTy>(P, Visitor);
641}
642
643void PGOContextualProfile::initIndex() {
644 // Initialize the head of the index list for each function. We don't need it
645 // after this point.
646 DenseMap<GlobalValue::GUID, PGOCtxProfContext *> InsertionPoints;
647 for (auto &[Guid, FI] : FuncInfo)
648 InsertionPoints[Guid] = &FI.Index;
649 preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>(
650 Profiles&: Profiles.Contexts, Visitor: [&](PGOCtxProfContext &Ctx) {
651 auto InsertIt = InsertionPoints.find(Val: Ctx.guid());
652 if (InsertIt == InsertionPoints.end())
653 return;
654 // Insert at the end of the list. Since we traverse in preorder, it
655 // means that when we iterate the list from the beginning, we'd
656 // encounter the contexts in the order we would have, should we have
657 // performed a full preorder traversal.
658 InsertIt->second->Next = &Ctx;
659 Ctx.Previous = InsertIt->second;
660 InsertIt->second = &Ctx;
661 });
662}
663
664bool PGOContextualProfile::isInSpecializedModule() const {
665 return ForceIsInSpecializedModule.getNumOccurrences() > 0
666 ? ForceIsInSpecializedModule
667 : IsInSpecializedModule;
668}
669
670void PGOContextualProfile::update(Visitor V, const Function &F) {
671 assert(isFunctionKnown(F));
672 GlobalValue::GUID G = getDefinedFunctionGUID(F);
673 for (auto *Node = FuncInfo.find(x: G)->second.Index.Next; Node;
674 Node = Node->Next)
675 V(*reinterpret_cast<PGOCtxProfContext *>(Node));
676}
677
678void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const {
679 if (!F)
680 return preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
681 const PGOCtxProfContext>(Profiles: Profiles.Contexts, Visitor: V);
682 assert(isFunctionKnown(*F));
683 GlobalValue::GUID G = getDefinedFunctionGUID(F: *F);
684 for (const auto *Node = FuncInfo.find(x: G)->second.Index.Next; Node;
685 Node = Node->Next)
686 V(*reinterpret_cast<const PGOCtxProfContext *>(Node));
687}
688
689const CtxProfFlatProfile PGOContextualProfile::flatten() const {
690 CtxProfFlatProfile Flat;
691 auto Accummulate = [](SmallVectorImpl<uint64_t> &Into,
692 const SmallVectorImpl<uint64_t> &From,
693 uint64_t SamplingRate) {
694 if (Into.empty())
695 Into.resize(N: From.size());
696 assert(Into.size() == From.size() &&
697 "All contexts corresponding to a function should have the exact "
698 "same number of counters.");
699 for (size_t I = 0, E = Into.size(); I < E; ++I)
700 Into[I] += From[I] * SamplingRate;
701 };
702
703 for (const auto &[_, CtxRoot] : Profiles.Contexts) {
704 const uint64_t SamplingFactor = CtxRoot.getTotalRootEntryCount();
705 preorderVisitOneRoot<const PGOCtxProfContext>(
706 Profile: CtxRoot, Visitor: [&](const PGOCtxProfContext &Ctx) {
707 Accummulate(Flat[Ctx.guid()], Ctx.counters(), SamplingFactor);
708 });
709
710 for (const auto &[G, Unh] : CtxRoot.getUnhandled())
711 Accummulate(Flat[G], Unh, SamplingFactor);
712 }
713 // We don't sample "Flat" currently, so sampling rate is 1.
714 for (const auto &[G, FC] : Profiles.FlatProfiles)
715 Accummulate(Flat[G], FC, /*SamplingRate=*/1);
716 return Flat;
717}
718
719const CtxProfFlatIndirectCallProfile
720PGOContextualProfile::flattenVirtCalls() const {
721 CtxProfFlatIndirectCallProfile Ret;
722 for (const auto &[_, CtxRoot] : Profiles.Contexts) {
723 const uint64_t TotalRootEntryCount = CtxRoot.getTotalRootEntryCount();
724 preorderVisitOneRoot<const PGOCtxProfContext>(
725 Profile: CtxRoot, Visitor: [&](const PGOCtxProfContext &Ctx) {
726 auto &Targets = Ret[Ctx.guid()];
727 for (const auto &[ID, SubctxSet] : Ctx.callsites())
728 for (const auto &Subctx : SubctxSet)
729 Targets[ID][Subctx.first] +=
730 Subctx.second.getEntrycount() * TotalRootEntryCount;
731 });
732 }
733 return Ret;
734}
735
736void CtxProfAnalysis::collectIndirectCallPromotionList(
737 CallBase &IC, Result &Profile,
738 SetVector<std::pair<CallBase *, Function *>> &Candidates) {
739 const auto *Instr = CtxProfAnalysis::getCallsiteInstrumentation(CB&: IC);
740 if (!Instr)
741 return;
742 Module &M = *IC.getParent()->getModule();
743 const uint32_t CallID = Instr->getIndex()->getZExtValue();
744 Profile.visit(
745 V: [&](const PGOCtxProfContext &Ctx) {
746 const auto &Targets = Ctx.callsites().find(x: CallID);
747 if (Targets == Ctx.callsites().end())
748 return;
749 for (const auto &[Guid, _] : Targets->second)
750 if (auto Name = Profile.getFunctionName(GUID: Guid); !Name.empty())
751 if (auto *Target = M.getFunction(Name))
752 if (Target->hasFnAttribute(Kind: Attribute::AlwaysInline))
753 Candidates.insert(X: {&IC, Target});
754 },
755 F: IC.getCaller());
756}
757