1//===- CtxProfAnalysis.cpp - contextual profile analysis ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation of the contextual profile analysis, which maintains contextual
10// profiling info through IPO passes.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/CtxProfAnalysis.h"
15#include "llvm/ADT/APInt.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/Analysis/CFG.h"
18#include "llvm/IR/Analysis.h"
19#include "llvm/IR/Dominators.h"
20#include "llvm/IR/IntrinsicInst.h"
21#include "llvm/IR/Module.h"
22#include "llvm/IR/PassManager.h"
23#include "llvm/ProfileData/PGOCtxProfReader.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/MemoryBuffer.h"
26#include "llvm/Support/Path.h"
27#include <deque>
28#include <memory>
29
30#define DEBUG_TYPE "ctx_prof"
31
32using namespace llvm;
33
34namespace llvm {
35
36cl::opt<std::string>
37 UseCtxProfile("use-ctx-profile", cl::init(Val: ""), cl::Hidden,
38 cl::desc("Use the specified contextual profile file"));
39
40static cl::opt<CtxProfAnalysisPrinterPass::PrintMode> PrintLevel(
41 "ctx-profile-printer-level",
42 cl::init(Val: CtxProfAnalysisPrinterPass::PrintMode::YAML), cl::Hidden,
43 cl::values(clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::Everything,
44 "everything", "print everything - most verbose"),
45 clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::YAML, "yaml",
46 "just the yaml representation of the profile")),
47 cl::desc("Verbosity level of the contextual profile printer pass."));
48
49static cl::opt<bool> ForceIsInSpecializedModule(
50 "ctx-profile-force-is-specialized", cl::init(Val: false),
51 cl::desc("Treat the given module as-if it were containing the "
52 "post-thinlink module containing the root"));
53
54const char *AssignGUIDPass::GUIDMetadataName = "guid";
55
56class ProfileAnnotatorImpl final {
57 friend class ProfileAnnotator;
58 class BBInfo;
59 struct EdgeInfo {
60 BBInfo *const Src;
61 BBInfo *const Dest;
62 std::optional<uint64_t> Count;
63
64 explicit EdgeInfo(BBInfo &Src, BBInfo &Dest) : Src(&Src), Dest(&Dest) {}
65 };
66
67 class BBInfo {
68 std::optional<uint64_t> Count;
69 // OutEdges is dimensioned to match the number of terminator operands.
70 // Entries in the vector match the index in the terminator operand list. In
71 // some cases - see `shouldExcludeEdge` and its implementation - an entry
72 // will be nullptr.
73 // InEdges doesn't have the above constraint.
74 SmallVector<EdgeInfo *> OutEdges;
75 SmallVector<EdgeInfo *> InEdges;
76 size_t UnknownCountOutEdges = 0;
77 size_t UnknownCountInEdges = 0;
78
79 // Pass AssumeAllKnown when we try to propagate counts from edges to BBs -
80 // because all the edge counters must be known.
81 // Return std::nullopt if there were no edges to sum. The user can decide
82 // how to interpret that.
83 std::optional<uint64_t> getEdgeSum(const SmallVector<EdgeInfo *> &Edges,
84 bool AssumeAllKnown) const {
85 std::optional<uint64_t> Sum;
86 for (const auto *E : Edges) {
87 // `Edges` may be `OutEdges`, case in which `E` could be nullptr.
88 if (E) {
89 if (!Sum.has_value())
90 Sum = 0;
91 *Sum += (AssumeAllKnown ? *E->Count : E->Count.value_or(u: 0U));
92 }
93 }
94 return Sum;
95 }
96
97 bool computeCountFrom(const SmallVector<EdgeInfo *> &Edges) {
98 assert(!Count.has_value());
99 Count = getEdgeSum(Edges, AssumeAllKnown: true);
100 return Count.has_value();
101 }
102
103 void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges) {
104 uint64_t KnownSum = getEdgeSum(Edges, AssumeAllKnown: false).value_or(u: 0U);
105 uint64_t EdgeVal = *Count > KnownSum ? *Count - KnownSum : 0U;
106 EdgeInfo *E = nullptr;
107 for (auto *I : Edges)
108 if (I && !I->Count.has_value()) {
109 E = I;
110#ifdef NDEBUG
111 break;
112#else
113 assert((!E || E == I) &&
114 "Expected exactly one edge to have an unknown count, "
115 "found a second one");
116 continue;
117#endif
118 }
119 assert(E && "Expected exactly one edge to have an unknown count");
120 assert(!E->Count.has_value());
121 E->Count = EdgeVal;
122 assert(E->Src->UnknownCountOutEdges > 0);
123 assert(E->Dest->UnknownCountInEdges > 0);
124 --E->Src->UnknownCountOutEdges;
125 --E->Dest->UnknownCountInEdges;
126 }
127
128 public:
129 BBInfo(size_t NumInEdges, size_t NumOutEdges, std::optional<uint64_t> Count)
130 : Count(Count) {
131 // For in edges, we just want to pre-allocate enough space, since we know
132 // it at this stage. For out edges, we will insert edges at the indices
133 // corresponding to positions in this BB's terminator instruction, so we
134 // construct a default (nullptr values)-initialized vector. A nullptr edge
135 // corresponds to those that are excluded (see shouldExcludeEdge).
136 InEdges.reserve(N: NumInEdges);
137 OutEdges.resize(N: NumOutEdges);
138 }
139
140 bool tryTakeCountFromKnownOutEdges(const BasicBlock &BB) {
141 if (!UnknownCountOutEdges) {
142 return computeCountFrom(Edges: OutEdges);
143 }
144 return false;
145 }
146
147 bool tryTakeCountFromKnownInEdges(const BasicBlock &BB) {
148 if (!UnknownCountInEdges) {
149 return computeCountFrom(Edges: InEdges);
150 }
151 return false;
152 }
153
154 void addInEdge(EdgeInfo &Info) {
155 InEdges.push_back(Elt: &Info);
156 ++UnknownCountInEdges;
157 }
158
159 // For the out edges, we care about the position we place them in, which is
160 // the position in terminator instruction's list (at construction). Later,
161 // we build branch_weights metadata with edge frequency values matching
162 // these positions.
163 void addOutEdge(size_t Index, EdgeInfo &Info) {
164 OutEdges[Index] = &Info;
165 ++UnknownCountOutEdges;
166 }
167
168 bool hasCount() const { return Count.has_value(); }
169
170 uint64_t getCount() const { return *Count; }
171
172 bool trySetSingleUnknownInEdgeCount() {
173 if (UnknownCountInEdges == 1) {
174 setSingleUnknownEdgeCount(InEdges);
175 return true;
176 }
177 return false;
178 }
179
180 bool trySetSingleUnknownOutEdgeCount() {
181 if (UnknownCountOutEdges == 1) {
182 setSingleUnknownEdgeCount(OutEdges);
183 return true;
184 }
185 return false;
186 }
187 size_t getNumOutEdges() const { return OutEdges.size(); }
188
189 uint64_t getEdgeCount(size_t Index) const {
190 if (auto *E = OutEdges[Index])
191 return *E->Count;
192 return 0U;
193 }
194 };
195
196 const Function &F;
197 ArrayRef<uint64_t> Counters;
198 // To be accessed through getBBInfo() after construction.
199 std::map<const BasicBlock *, BBInfo> BBInfos;
200 std::vector<EdgeInfo> EdgeInfos;
201
202 // The only criteria for exclusion is faux suspend -> exit edges in presplit
203 // coroutines. The API serves for readability, currently.
204 bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const {
205 return llvm::isPresplitCoroSuspendExitEdge(Src, Dest);
206 }
207
208 BBInfo &getBBInfo(const BasicBlock &BB) { return BBInfos.find(x: &BB)->second; }
209
210 const BBInfo &getBBInfo(const BasicBlock &BB) const {
211 return BBInfos.find(x: &BB)->second;
212 }
213
214 // validation function after we propagate the counters: all BBs and edges'
215 // counters must have a value.
216 bool allCountersAreAssigned() const {
217 for (const auto &BBInfo : BBInfos)
218 if (!BBInfo.second.hasCount())
219 return false;
220 for (const auto &EdgeInfo : EdgeInfos)
221 if (!EdgeInfo.Count.has_value())
222 return false;
223 return true;
224 }
225
226 /// Check that all paths from the entry basic block that use edges with
227 /// non-zero counts arrive at a basic block with no successors (i.e. "exit")
228 bool allTakenPathsExit() const {
229 std::deque<const BasicBlock *> Worklist;
230 DenseSet<const BasicBlock *> Visited;
231 Worklist.push_back(x: &F.getEntryBlock());
232 bool HitExit = false;
233 while (!Worklist.empty()) {
234 const auto *BB = Worklist.front();
235 Worklist.pop_front();
236 if (!Visited.insert(V: BB).second)
237 continue;
238 if (succ_size(BB) == 0) {
239 if (isa<UnreachableInst>(Val: BB->getTerminator()))
240 return false;
241 HitExit = true;
242 continue;
243 }
244 if (succ_size(BB) == 1) {
245 Worklist.push_back(x: BB->getUniqueSuccessor());
246 continue;
247 }
248 const auto &BBInfo = getBBInfo(BB: *BB);
249 bool HasAWayOut = false;
250 for (auto I = 0U; I < BB->getTerminator()->getNumSuccessors(); ++I) {
251 const auto *Succ = BB->getTerminator()->getSuccessor(Idx: I);
252 if (!shouldExcludeEdge(Src: *BB, Dest: *Succ)) {
253 if (BBInfo.getEdgeCount(Index: I) > 0) {
254 HasAWayOut = true;
255 Worklist.push_back(x: Succ);
256 }
257 }
258 }
259 if (!HasAWayOut)
260 return false;
261 }
262 return HitExit;
263 }
264
265 bool allNonColdSelectsHaveProfile() const {
266 for (const auto &BB : F) {
267 if (getBBInfo(BB).getCount() > 0) {
268 for (const auto &I : BB) {
269 if (const auto *SI = dyn_cast<SelectInst>(Val: &I)) {
270 if (const auto *Inst = CtxProfAnalysis::getSelectInstrumentation(
271 SI&: *const_cast<SelectInst *>(SI))) {
272 auto Index = Inst->getIndex()->getZExtValue();
273 assert(Index < Counters.size());
274 if (Counters[Index] == 0)
275 return false;
276 }
277 }
278 }
279 }
280 }
281 return true;
282 }
283
284 // This is an adaptation of PGOUseFunc::populateCounters.
285 // FIXME(mtrofin): look into factoring the code to share one implementation.
286 void propagateCounterValues() {
287 bool KeepGoing = true;
288 while (KeepGoing) {
289 KeepGoing = false;
290 for (const auto &BB : F) {
291 auto &Info = getBBInfo(BB);
292 if (!Info.hasCount())
293 KeepGoing |= Info.tryTakeCountFromKnownOutEdges(BB) ||
294 Info.tryTakeCountFromKnownInEdges(BB);
295 if (Info.hasCount()) {
296 KeepGoing |= Info.trySetSingleUnknownOutEdgeCount();
297 KeepGoing |= Info.trySetSingleUnknownInEdgeCount();
298 }
299 }
300 }
301 assert(allCountersAreAssigned() &&
302 "[ctx-prof] Expected all counters have been assigned.");
303 assert(allTakenPathsExit() &&
304 "[ctx-prof] Encountered a BB with more than one successor, where "
305 "all outgoing edges have a 0 count. This occurs in non-exiting "
306 "functions (message pumps, usually) which are not supported in the "
307 "contextual profiling case");
308 assert(allNonColdSelectsHaveProfile() &&
309 "[ctx-prof] All non-cold select instructions were expected to have "
310 "a profile.");
311 }
312
313public:
314 ProfileAnnotatorImpl(const Function &F, ArrayRef<uint64_t> Counters)
315 : F(F), Counters(Counters) {
316 assert(!F.isDeclaration());
317 assert(!Counters.empty());
318 size_t NrEdges = 0;
319 for (const auto &BB : F) {
320 std::optional<uint64_t> Count;
321 if (auto *Ins = CtxProfAnalysis::getBBInstrumentation(
322 BB&: const_cast<BasicBlock &>(BB))) {
323 auto Index = Ins->getIndex()->getZExtValue();
324 assert(Index < Counters.size() &&
325 "The index must be inside the counters vector by construction - "
326 "tripping this assertion indicates a bug in how the contextual "
327 "profile is managed by IPO transforms");
328 (void)Index;
329 Count = Counters[Ins->getIndex()->getZExtValue()];
330 } else if (isa<UnreachableInst>(Val: BB.getTerminator())) {
331 // The program presumably didn't crash.
332 Count = 0;
333 }
334 auto [It, Ins] =
335 BBInfos.insert(x: {&BB, {pred_size(BB: &BB), succ_size(BB: &BB), Count}});
336 (void)Ins;
337 assert(Ins && "We iterate through the function's BBs, no reason to "
338 "insert one more than once");
339 NrEdges += llvm::count_if(Range: successors(BB: &BB), P: [&](const auto *Succ) {
340 return !shouldExcludeEdge(Src: BB, Dest: *Succ);
341 });
342 }
343 // Pre-allocate the vector, we want references to its contents to be stable.
344 EdgeInfos.reserve(n: NrEdges);
345 for (const auto &BB : F) {
346 auto &Info = getBBInfo(BB);
347 for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) {
348 const auto *Succ = BB.getTerminator()->getSuccessor(Idx: I);
349 if (!shouldExcludeEdge(Src: BB, Dest: *Succ)) {
350 auto &EI = EdgeInfos.emplace_back(args&: getBBInfo(BB), args&: getBBInfo(BB: *Succ));
351 Info.addOutEdge(Index: I, Info&: EI);
352 getBBInfo(BB: *Succ).addInEdge(Info&: EI);
353 }
354 }
355 }
356 assert(EdgeInfos.capacity() == NrEdges &&
357 "The capacity of EdgeInfos should have stayed unchanged it was "
358 "populated, because we need pointers to its contents to be stable");
359 propagateCounterValues();
360 }
361
362 uint64_t getBBCount(const BasicBlock &BB) { return getBBInfo(BB).getCount(); }
363};
364
365} // namespace llvm
366
367ProfileAnnotator::ProfileAnnotator(const Function &F,
368 ArrayRef<uint64_t> RawCounters)
369 : PImpl(std::make_unique<ProfileAnnotatorImpl>(args: F, args&: RawCounters)) {}
370
371ProfileAnnotator::~ProfileAnnotator() = default;
372
373uint64_t ProfileAnnotator::getBBCount(const BasicBlock &BB) const {
374 return PImpl->getBBCount(BB);
375}
376
377bool ProfileAnnotator::getSelectInstrProfile(SelectInst &SI,
378 uint64_t &TrueCount,
379 uint64_t &FalseCount) const {
380 const auto &BBInfo = PImpl->getBBInfo(BB: *SI.getParent());
381 TrueCount = FalseCount = 0;
382 if (BBInfo.getCount() == 0)
383 return false;
384
385 auto *Step = CtxProfAnalysis::getSelectInstrumentation(SI);
386 if (!Step)
387 return false;
388 auto Index = Step->getIndex()->getZExtValue();
389 assert(Index < PImpl->Counters.size() &&
390 "The index of the step instruction must be inside the "
391 "counters vector by "
392 "construction - tripping this assertion indicates a bug in "
393 "how the contextual profile is managed by IPO transforms");
394 auto TotalCount = BBInfo.getCount();
395 TrueCount = PImpl->Counters[Index];
396 FalseCount = (TotalCount > TrueCount ? TotalCount - TrueCount : 0U);
397 return true;
398}
399
400bool ProfileAnnotator::getOutgoingBranchWeights(
401 BasicBlock &BB, SmallVectorImpl<uint64_t> &Profile,
402 uint64_t &MaxCount) const {
403 Profile.clear();
404
405 if (succ_size(BB: &BB) < 2)
406 return false;
407
408 auto *Term = BB.getTerminator();
409 Profile.resize(N: Term->getNumSuccessors());
410
411 const auto &BBInfo = PImpl->getBBInfo(BB);
412 MaxCount = 0;
413 for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size;
414 ++SuccIdx) {
415 uint64_t EdgeCount = BBInfo.getEdgeCount(Index: SuccIdx);
416 if (EdgeCount > MaxCount)
417 MaxCount = EdgeCount;
418 Profile[SuccIdx] = EdgeCount;
419 }
420 return MaxCount > 0;
421}
422
423PreservedAnalyses AssignGUIDPass::run(Module &M, ModuleAnalysisManager &MAM) {
424 for (auto &F : M.functions()) {
425 if (F.isDeclaration())
426 continue;
427 if (F.getMetadata(Kind: GUIDMetadataName))
428 continue;
429 const GlobalValue::GUID GUID = F.getGUID();
430 F.setMetadata(Kind: GUIDMetadataName,
431 Node: MDNode::get(Context&: M.getContext(),
432 MDs: {ConstantAsMetadata::get(C: ConstantInt::get(
433 Ty: Type::getInt64Ty(C&: M.getContext()), V: GUID))}));
434 }
435 return PreservedAnalyses::none();
436}
437
438GlobalValue::GUID AssignGUIDPass::getGUID(const Function &F) {
439 if (F.isDeclaration()) {
440 assert(GlobalValue::isExternalLinkage(F.getLinkage()));
441 return F.getGUID();
442 }
443 auto *MD = F.getMetadata(Kind: GUIDMetadataName);
444 assert(MD && "guid not found for defined function");
445 return cast<ConstantInt>(Val: cast<ConstantAsMetadata>(Val: MD->getOperand(I: 0))
446 ->getValue()
447 ->stripPointerCasts())
448 ->getZExtValue();
449}
450AnalysisKey CtxProfAnalysis::Key;
451
452CtxProfAnalysis::CtxProfAnalysis(std::optional<StringRef> Profile)
453 : Profile([&]() -> std::optional<StringRef> {
454 if (Profile)
455 return *Profile;
456 if (UseCtxProfile.getNumOccurrences())
457 return UseCtxProfile;
458 return std::nullopt;
459 }()) {}
460
461PGOContextualProfile CtxProfAnalysis::run(Module &M,
462 ModuleAnalysisManager &MAM) {
463 if (!Profile)
464 return {};
465 ErrorOr<std::unique_ptr<MemoryBuffer>> MB = MemoryBuffer::getFile(Filename: *Profile);
466 if (auto EC = MB.getError()) {
467 M.getContext().emitError(ErrorStr: "could not open contextual profile file: " +
468 EC.message());
469 return {};
470 }
471 PGOCtxProfileReader Reader(MB.get()->getBuffer());
472 auto MaybeProfiles = Reader.loadProfiles();
473 if (!MaybeProfiles) {
474 M.getContext().emitError(ErrorStr: "contextual profile file is invalid: " +
475 toString(E: MaybeProfiles.takeError()));
476 return {};
477 }
478
479 // FIXME: We should drive this from ThinLTO, but for the time being, use the
480 // module name as indicator.
481 // We want to *only* keep the contextual profiles in modules that capture
482 // context trees. That allows us to compute specific PSIs, for example.
483 auto DetermineRootsInModule = [&M]() -> const DenseSet<GlobalValue::GUID> {
484 DenseSet<GlobalValue::GUID> ProfileRootsInModule;
485 auto ModName = M.getName();
486 auto Filename = sys::path::filename(path: ModName);
487 // Drop the file extension.
488 Filename = Filename.substr(Start: 0, N: Filename.find_last_of(C: '.'));
489 // See if it parses
490 APInt Guid;
491 // getAsInteger returns true if there are more chars to read other than the
492 // integer. So the "false" test is what we want.
493 if (!Filename.getAsInteger(Radix: 0, Result&: Guid))
494 ProfileRootsInModule.insert(V: Guid.getZExtValue());
495 return ProfileRootsInModule;
496 };
497 const auto ProfileRootsInModule = DetermineRootsInModule();
498 PGOContextualProfile Result;
499
500 // the logic from here on allows for modules that contain - by design - more
501 // than one root. We currently don't support that, because the determination
502 // happens based on the module name matching the root guid, but the logic can
503 // avoid assuming that.
504 if (!ProfileRootsInModule.empty()) {
505 Result.IsInSpecializedModule = true;
506 // Trim first the roots that aren't in this module.
507 for (auto &[RootGuid, _] :
508 llvm::make_early_inc_range(Range&: MaybeProfiles->Contexts))
509 if (!ProfileRootsInModule.contains(V: RootGuid))
510 MaybeProfiles->Contexts.erase(x: RootGuid);
511 // we can also drop the flat profiles
512 MaybeProfiles->FlatProfiles.clear();
513 }
514
515 for (const auto &F : M) {
516 if (F.isDeclaration())
517 continue;
518 auto GUID = AssignGUIDPass::getGUID(F);
519 assert(GUID && "guid not found for defined function");
520 const auto &Entry = F.begin();
521 uint32_t MaxCounters = 0; // we expect at least a counter.
522 for (const auto &I : *Entry)
523 if (auto *C = dyn_cast<InstrProfIncrementInst>(Val: &I)) {
524 MaxCounters =
525 static_cast<uint32_t>(C->getNumCounters()->getZExtValue());
526 break;
527 }
528 if (!MaxCounters)
529 continue;
530 uint32_t MaxCallsites = 0;
531 for (const auto &BB : F)
532 for (const auto &I : BB)
533 if (auto *C = dyn_cast<InstrProfCallsite>(Val: &I)) {
534 MaxCallsites =
535 static_cast<uint32_t>(C->getNumCounters()->getZExtValue());
536 break;
537 }
538 auto [It, Ins] = Result.FuncInfo.insert(
539 x: {GUID, PGOContextualProfile::FunctionInfo(F.getName())});
540 (void)Ins;
541 assert(Ins);
542 It->second.NextCallsiteIndex = MaxCallsites;
543 It->second.NextCounterIndex = MaxCounters;
544 }
545 // If we made it this far, the Result is valid - which we mark by setting
546 // .Profiles.
547 Result.Profiles = std::move(*MaybeProfiles);
548 Result.initIndex();
549 return Result;
550}
551
552GlobalValue::GUID
553PGOContextualProfile::getDefinedFunctionGUID(const Function &F) const {
554 if (auto It = FuncInfo.find(x: AssignGUIDPass::getGUID(F)); It != FuncInfo.end())
555 return It->first;
556 return 0;
557}
558
559CtxProfAnalysisPrinterPass::CtxProfAnalysisPrinterPass(raw_ostream &OS)
560 : OS(OS), Mode(PrintLevel) {}
561
562PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
563 ModuleAnalysisManager &MAM) {
564 CtxProfAnalysis::Result &C = MAM.getResult<CtxProfAnalysis>(IR&: M);
565 if (C.contexts().empty()) {
566 OS << "No contextual profile was provided.\n";
567 return PreservedAnalyses::all();
568 }
569
570 if (Mode == PrintMode::Everything) {
571 OS << "Function Info:\n";
572 for (const auto &[Guid, FuncInfo] : C.FuncInfo)
573 OS << Guid << " : " << FuncInfo.Name
574 << ". MaxCounterID: " << FuncInfo.NextCounterIndex
575 << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
576 }
577
578 if (Mode == PrintMode::Everything)
579 OS << "\nCurrent Profile:\n";
580 convertCtxProfToYaml(OS, Profile: C.profiles());
581 OS << "\n";
582 if (Mode == PrintMode::YAML)
583 return PreservedAnalyses::all();
584
585 OS << "\nFlat Profile:\n";
586 auto Flat = C.flatten();
587 for (const auto &[Guid, Counters] : Flat) {
588 OS << Guid << " : ";
589 for (auto V : Counters)
590 OS << V << " ";
591 OS << "\n";
592 }
593 return PreservedAnalyses::all();
594}
595
596InstrProfCallsite *CtxProfAnalysis::getCallsiteInstrumentation(CallBase &CB) {
597 if (!InstrProfCallsite::canInstrumentCallsite(CB))
598 return nullptr;
599 for (auto *Prev = CB.getPrevNode(); Prev; Prev = Prev->getPrevNode()) {
600 if (auto *IPC = dyn_cast<InstrProfCallsite>(Val: Prev))
601 return IPC;
602 assert(!isa<CallBase>(Prev) &&
603 "didn't expect to find another call, that's not the callsite "
604 "instrumentation, before an instrumentable callsite");
605 }
606 return nullptr;
607}
608
609InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
610 for (auto &I : BB)
611 if (auto *Incr = dyn_cast<InstrProfIncrementInst>(Val: &I))
612 if (!isa<InstrProfIncrementInstStep>(Val: &I))
613 return Incr;
614 return nullptr;
615}
616
617InstrProfIncrementInstStep *
618CtxProfAnalysis::getSelectInstrumentation(SelectInst &SI) {
619 Instruction *Prev = &SI;
620 while ((Prev = Prev->getPrevNode()))
621 if (auto *Step = dyn_cast<InstrProfIncrementInstStep>(Val: Prev))
622 return Step;
623 return nullptr;
624}
625
626template <class ProfTy>
627static void preorderVisitOneRoot(ProfTy &Profile,
628 function_ref<void(ProfTy &)> Visitor) {
629 std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) {
630 Visitor(Ctx);
631 for (auto &[_, SubCtxSet] : Ctx.callsites())
632 for (auto &[__, Subctx] : SubCtxSet)
633 Traverser(Subctx);
634 };
635 Traverser(Profile);
636}
637
638template <class ProfilesTy, class ProfTy>
639static void preorderVisit(ProfilesTy &Profiles,
640 function_ref<void(ProfTy &)> Visitor) {
641 for (auto &[_, P] : Profiles)
642 preorderVisitOneRoot<ProfTy>(P, Visitor);
643}
644
645void PGOContextualProfile::initIndex() {
646 // Initialize the head of the index list for each function. We don't need it
647 // after this point.
648 DenseMap<GlobalValue::GUID, PGOCtxProfContext *> InsertionPoints;
649 for (auto &[Guid, FI] : FuncInfo)
650 InsertionPoints[Guid] = &FI.Index;
651 preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>(
652 Profiles&: Profiles.Contexts, Visitor: [&](PGOCtxProfContext &Ctx) {
653 auto InsertIt = InsertionPoints.find(Val: Ctx.guid());
654 if (InsertIt == InsertionPoints.end())
655 return;
656 // Insert at the end of the list. Since we traverse in preorder, it
657 // means that when we iterate the list from the beginning, we'd
658 // encounter the contexts in the order we would have, should we have
659 // performed a full preorder traversal.
660 InsertIt->second->Next = &Ctx;
661 Ctx.Previous = InsertIt->second;
662 InsertIt->second = &Ctx;
663 });
664}
665
666bool PGOContextualProfile::isInSpecializedModule() const {
667 return ForceIsInSpecializedModule.getNumOccurrences() > 0
668 ? ForceIsInSpecializedModule
669 : IsInSpecializedModule;
670}
671
672void PGOContextualProfile::update(Visitor V, const Function &F) {
673 assert(isFunctionKnown(F));
674 GlobalValue::GUID G = getDefinedFunctionGUID(F);
675 for (auto *Node = FuncInfo.find(x: G)->second.Index.Next; Node;
676 Node = Node->Next)
677 V(*reinterpret_cast<PGOCtxProfContext *>(Node));
678}
679
680void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const {
681 if (!F)
682 return preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
683 const PGOCtxProfContext>(Profiles: Profiles.Contexts, Visitor: V);
684 assert(isFunctionKnown(*F));
685 GlobalValue::GUID G = getDefinedFunctionGUID(F: *F);
686 for (const auto *Node = FuncInfo.find(x: G)->second.Index.Next; Node;
687 Node = Node->Next)
688 V(*reinterpret_cast<const PGOCtxProfContext *>(Node));
689}
690
691const CtxProfFlatProfile PGOContextualProfile::flatten() const {
692 CtxProfFlatProfile Flat;
693 auto Accummulate = [](SmallVectorImpl<uint64_t> &Into,
694 const SmallVectorImpl<uint64_t> &From,
695 uint64_t SamplingRate) {
696 if (Into.empty())
697 Into.resize(N: From.size());
698 assert(Into.size() == From.size() &&
699 "All contexts corresponding to a function should have the exact "
700 "same number of counters.");
701 for (size_t I = 0, E = Into.size(); I < E; ++I)
702 Into[I] += From[I] * SamplingRate;
703 };
704
705 for (const auto &[_, CtxRoot] : Profiles.Contexts) {
706 const uint64_t SamplingFactor = CtxRoot.getTotalRootEntryCount();
707 preorderVisitOneRoot<const PGOCtxProfContext>(
708 Profile: CtxRoot, Visitor: [&](const PGOCtxProfContext &Ctx) {
709 Accummulate(Flat[Ctx.guid()], Ctx.counters(), SamplingFactor);
710 });
711
712 for (const auto &[G, Unh] : CtxRoot.getUnhandled())
713 Accummulate(Flat[G], Unh, SamplingFactor);
714 }
715 // We don't sample "Flat" currently, so sampling rate is 1.
716 for (const auto &[G, FC] : Profiles.FlatProfiles)
717 Accummulate(Flat[G], FC, /*SamplingRate=*/1);
718 return Flat;
719}
720
721const CtxProfFlatIndirectCallProfile
722PGOContextualProfile::flattenVirtCalls() const {
723 CtxProfFlatIndirectCallProfile Ret;
724 for (const auto &[_, CtxRoot] : Profiles.Contexts) {
725 const uint64_t TotalRootEntryCount = CtxRoot.getTotalRootEntryCount();
726 preorderVisitOneRoot<const PGOCtxProfContext>(
727 Profile: CtxRoot, Visitor: [&](const PGOCtxProfContext &Ctx) {
728 auto &Targets = Ret[Ctx.guid()];
729 for (const auto &[ID, SubctxSet] : Ctx.callsites())
730 for (const auto &Subctx : SubctxSet)
731 Targets[ID][Subctx.first] +=
732 Subctx.second.getEntrycount() * TotalRootEntryCount;
733 });
734 }
735 return Ret;
736}
737
738void CtxProfAnalysis::collectIndirectCallPromotionList(
739 CallBase &IC, Result &Profile,
740 SetVector<std::pair<CallBase *, Function *>> &Candidates) {
741 const auto *Instr = CtxProfAnalysis::getCallsiteInstrumentation(CB&: IC);
742 if (!Instr)
743 return;
744 Module &M = *IC.getParent()->getModule();
745 const uint32_t CallID = Instr->getIndex()->getZExtValue();
746 Profile.visit(
747 V: [&](const PGOCtxProfContext &Ctx) {
748 const auto &Targets = Ctx.callsites().find(x: CallID);
749 if (Targets == Ctx.callsites().end())
750 return;
751 for (const auto &[Guid, _] : Targets->second)
752 if (auto Name = Profile.getFunctionName(GUID: Guid); !Name.empty())
753 if (auto *Target = M.getFunction(Name))
754 if (Target->hasFnAttribute(Kind: Attribute::AlwaysInline))
755 Candidates.insert(X: {&IC, Target});
756 },
757 F: IC.getCaller());
758}
759