1//===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Flattens the contextual profile and lowers it to MD_prof.
10// This should happen after all IPO (which is assumed to have maintained the
11// contextual profile) happened. Flattening consists of summing the values at
12// the same index of the counters belonging to all the contexts of a function.
13// The lowering consists of materializing the counter values to function
14// entrypoint counts and branch probabilities.
15//
16// This pass also removes contextual instrumentation, which has been kept around
17// to facilitate its functionality.
18//
19//===----------------------------------------------------------------------===//
20
21#include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/ScopeExit.h"
24#include "llvm/Analysis/CFG.h"
25#include "llvm/Analysis/CtxProfAnalysis.h"
26#include "llvm/Analysis/ProfileSummaryInfo.h"
27#include "llvm/IR/Analysis.h"
28#include "llvm/IR/CFG.h"
29#include "llvm/IR/Dominators.h"
30#include "llvm/IR/Instructions.h"
31#include "llvm/IR/IntrinsicInst.h"
32#include "llvm/IR/Module.h"
33#include "llvm/IR/PassManager.h"
34#include "llvm/IR/ProfileSummary.h"
35#include "llvm/ProfileData/ProfileCommon.h"
36#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
37#include "llvm/Transforms/Scalar/DCE.h"
38#include "llvm/Transforms/Utils/BasicBlockUtils.h"
39
40using namespace llvm;
41
42#define DEBUG_TYPE "ctx_prof_flatten"
43
44namespace {
45
46/// Assign branch weights and function entry count. Also update the PSI
47/// builder.
48void assignProfileData(Function &F, ArrayRef<uint64_t> RawCounters) {
49 assert(!RawCounters.empty());
50 ProfileAnnotator PA(F, RawCounters);
51
52 F.setEntryCount(Count: RawCounters[0]);
53 SmallVector<uint64_t, 2> ProfileHolder;
54
55 for (auto &BB : F) {
56 for (auto &I : BB)
57 if (auto *SI = dyn_cast<SelectInst>(Val: &I)) {
58 uint64_t TrueCount, FalseCount = 0;
59 if (!PA.getSelectInstrProfile(SI&: *SI, TrueCount, FalseCount))
60 continue;
61 setProfMetadata(M: F.getParent(), TI: SI, EdgeCounts: {TrueCount, FalseCount},
62 MaxCount: std::max(a: TrueCount, b: FalseCount));
63 }
64 if (succ_size(BB: &BB) < 2)
65 continue;
66 uint64_t MaxCount = 0;
67 if (!PA.getOutgoingBranchWeights(BB, Profile&: ProfileHolder, MaxCount))
68 continue;
69 assert(MaxCount > 0);
70 setProfMetadata(M: F.getParent(), TI: BB.getTerminator(), EdgeCounts: ProfileHolder, MaxCount);
71 }
72}
73
74[[maybe_unused]] bool areAllBBsReachable(const Function &F,
75 FunctionAnalysisManager &FAM) {
76 auto &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: const_cast<Function &>(F));
77 return llvm::all_of(
78 Range: F, P: [&](const BasicBlock &BB) { return DT.isReachableFromEntry(A: &BB); });
79}
80
81void clearColdFunctionProfile(Function &F) {
82 for (auto &BB : F)
83 BB.getTerminator()->setMetadata(KindID: LLVMContext::MD_prof, Node: nullptr);
84 F.setEntryCount(Count: 0U);
85}
86
87void removeInstrumentation(Function &F) {
88 for (auto &BB : F)
89 for (auto &I : llvm::make_early_inc_range(Range&: BB))
90 if (isa<InstrProfCntrInstBase>(Val: I))
91 I.eraseFromParent();
92}
93
94void annotateIndirectCall(
95 Module &M, CallBase &CB,
96 const DenseMap<uint32_t, FlatIndirectTargets> &FlatProf,
97 const InstrProfCallsite &Ins) {
98 auto Idx = Ins.getIndex()->getZExtValue();
99 auto FIt = FlatProf.find(Val: Idx);
100 if (FIt == FlatProf.end())
101 return;
102 const auto &Targets = FIt->second;
103 SmallVector<InstrProfValueData, 2> Data;
104 uint64_t Sum = 0;
105 for (auto &[Guid, Count] : Targets) {
106 Data.push_back(Elt: {/*.Value=*/Guid, /*.Count=*/Count});
107 Sum += Count;
108 }
109
110 llvm::sort(C&: Data,
111 Comp: [](const InstrProfValueData &A, const InstrProfValueData &B) {
112 return A.Count > B.Count;
113 });
114 llvm::annotateValueSite(M, Inst&: CB, VDs: Data, Sum,
115 ValueKind: InstrProfValueKind::IPVK_IndirectCallTarget,
116 MaxMDCount: Data.size());
117 LLVM_DEBUG(dbgs() << "[ctxprof] flat indirect call prof: " << CB
118 << CB.getMetadata(LLVMContext::MD_prof) << "\n");
119}
120
121// We normally return a "Changed" bool, but the calling pass' run assumes
122// something will change - some profile will be added - so this won't add much
123// by returning false when applicable.
124void annotateIndirectCalls(Module &M, const CtxProfAnalysis::Result &CtxProf) {
125 const auto FlatIndCalls = CtxProf.flattenVirtCalls();
126 for (auto &F : M) {
127 if (F.isDeclaration())
128 continue;
129 auto FlatProfIter = FlatIndCalls.find(Val: AssignGUIDPass::getGUID(F));
130 if (FlatProfIter == FlatIndCalls.end())
131 continue;
132 const auto &FlatProf = FlatProfIter->second;
133 for (auto &BB : F) {
134 for (auto &I : BB) {
135 auto *CB = dyn_cast<CallBase>(Val: &I);
136 if (!CB || !CB->isIndirectCall())
137 continue;
138 if (auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(CB&: *CB))
139 annotateIndirectCall(M, CB&: *CB, FlatProf, Ins: *Ins);
140 }
141 }
142 }
143}
144
145} // namespace
146
147PreservedAnalyses PGOCtxProfFlatteningPass::run(Module &M,
148 ModuleAnalysisManager &MAM) {
149 // Ensure in all cases the instrumentation is removed: if this module had no
150 // roots, the contextual profile would evaluate to false, but there would
151 // still be instrumentation.
152 // Note: in such cases we leave as-is any other profile info (if present -
153 // e.g. synthetic weights, etc) because it wouldn't interfere with the
154 // contextual - based one (which would be in other modules)
155 auto OnExit = llvm::make_scope_exit(F: [&]() {
156 if (IsPreThinlink)
157 return;
158 for (auto &F : M)
159 removeInstrumentation(F);
160 });
161 auto &CtxProf = MAM.getResult<CtxProfAnalysis>(IR&: M);
162 // post-thinlink, we only reprocess for the module(s) containing the
163 // contextual tree. For everything else, OnExit will just clean the
164 // instrumentation.
165 if (!IsPreThinlink && !CtxProf.isInSpecializedModule())
166 return PreservedAnalyses::none();
167
168 if (IsPreThinlink)
169 annotateIndirectCalls(M, CtxProf);
170 const auto FlattenedProfile = CtxProf.flatten();
171
172 for (auto &F : M) {
173 if (F.isDeclaration())
174 continue;
175
176 assert(areAllBBsReachable(
177 F, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M)
178 .getManager()) &&
179 "Function has unreacheable basic blocks. The expectation was that "
180 "DCE was run before.");
181
182 auto It = FlattenedProfile.find(x: AssignGUIDPass::getGUID(F));
183 // If this function didn't appear in the contextual profile, it's cold.
184 if (It == FlattenedProfile.end())
185 clearColdFunctionProfile(F);
186 else
187 assignProfileData(F, RawCounters: It->second);
188 }
189 InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs);
190 // use here the flat profiles just so the importer doesn't complain about
191 // how different the PSIs are between the module with the roots and the
192 // various modules it imports.
193 for (auto &C : FlattenedProfile) {
194 PB.addEntryCount(Count: C.second[0]);
195 for (auto V : llvm::drop_begin(RangeOrContainer: C.second))
196 PB.addInternalCount(Count: V);
197 }
198
199 M.setProfileSummary(M: PB.getSummary()->getMD(Context&: M.getContext()),
200 Kind: ProfileSummary::Kind::PSK_Instr);
201 PreservedAnalyses PA;
202 PA.abandon<ProfileSummaryAnalysis>();
203 MAM.invalidate(IR&: M, PA);
204 auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(IR&: M);
205 PSI.refresh(Other: PB.getSummary());
206 return PreservedAnalyses::none();
207}
208