1 | //===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Flattens the contextual profile and lowers it to MD_prof. |
10 | // This should happen after all IPO (which is assumed to have maintained the |
11 | // contextual profile) happened. Flattening consists of summing the values at |
12 | // the same index of the counters belonging to all the contexts of a function. |
13 | // The lowering consists of materializing the counter values to function |
14 | // entrypoint counts and branch probabilities. |
15 | // |
16 | // This pass also removes contextual instrumentation, which has been kept around |
17 | // to facilitate its functionality. |
18 | // |
19 | //===----------------------------------------------------------------------===// |
20 | |
21 | #include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h" |
22 | #include "llvm/ADT/STLExtras.h" |
23 | #include "llvm/ADT/ScopeExit.h" |
24 | #include "llvm/Analysis/CFG.h" |
25 | #include "llvm/Analysis/CtxProfAnalysis.h" |
26 | #include "llvm/Analysis/ProfileSummaryInfo.h" |
27 | #include "llvm/IR/Analysis.h" |
28 | #include "llvm/IR/CFG.h" |
29 | #include "llvm/IR/Dominators.h" |
30 | #include "llvm/IR/Instructions.h" |
31 | #include "llvm/IR/IntrinsicInst.h" |
32 | #include "llvm/IR/Module.h" |
33 | #include "llvm/IR/PassManager.h" |
34 | #include "llvm/IR/ProfileSummary.h" |
35 | #include "llvm/ProfileData/ProfileCommon.h" |
36 | #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" |
37 | #include "llvm/Transforms/Scalar/DCE.h" |
38 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
39 | |
40 | using namespace llvm; |
41 | |
42 | #define DEBUG_TYPE "ctx_prof_flatten" |
43 | |
44 | namespace { |
45 | |
46 | /// Assign branch weights and function entry count. Also update the PSI |
47 | /// builder. |
48 | void assignProfileData(Function &F, ArrayRef<uint64_t> RawCounters) { |
49 | assert(!RawCounters.empty()); |
50 | ProfileAnnotator PA(F, RawCounters); |
51 | |
52 | F.setEntryCount(Count: RawCounters[0]); |
53 | SmallVector<uint64_t, 2> ProfileHolder; |
54 | |
55 | for (auto &BB : F) { |
56 | for (auto &I : BB) |
57 | if (auto *SI = dyn_cast<SelectInst>(Val: &I)) { |
58 | uint64_t TrueCount, FalseCount = 0; |
59 | if (!PA.getSelectInstrProfile(SI&: *SI, TrueCount, FalseCount)) |
60 | continue; |
61 | setProfMetadata(M: F.getParent(), TI: SI, EdgeCounts: {TrueCount, FalseCount}, |
62 | MaxCount: std::max(a: TrueCount, b: FalseCount)); |
63 | } |
64 | if (succ_size(BB: &BB) < 2) |
65 | continue; |
66 | uint64_t MaxCount = 0; |
67 | if (!PA.getOutgoingBranchWeights(BB, Profile&: ProfileHolder, MaxCount)) |
68 | continue; |
69 | assert(MaxCount > 0); |
70 | setProfMetadata(M: F.getParent(), TI: BB.getTerminator(), EdgeCounts: ProfileHolder, MaxCount); |
71 | } |
72 | } |
73 | |
74 | [[maybe_unused]] bool areAllBBsReachable(const Function &F, |
75 | FunctionAnalysisManager &FAM) { |
76 | auto &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: const_cast<Function &>(F)); |
77 | return llvm::all_of( |
78 | Range: F, P: [&](const BasicBlock &BB) { return DT.isReachableFromEntry(A: &BB); }); |
79 | } |
80 | |
81 | void clearColdFunctionProfile(Function &F) { |
82 | for (auto &BB : F) |
83 | BB.getTerminator()->setMetadata(KindID: LLVMContext::MD_prof, Node: nullptr); |
84 | F.setEntryCount(Count: 0U); |
85 | } |
86 | |
87 | void removeInstrumentation(Function &F) { |
88 | for (auto &BB : F) |
89 | for (auto &I : llvm::make_early_inc_range(Range&: BB)) |
90 | if (isa<InstrProfCntrInstBase>(Val: I)) |
91 | I.eraseFromParent(); |
92 | } |
93 | |
94 | void annotateIndirectCall( |
95 | Module &M, CallBase &CB, |
96 | const DenseMap<uint32_t, FlatIndirectTargets> &FlatProf, |
97 | const InstrProfCallsite &Ins) { |
98 | auto Idx = Ins.getIndex()->getZExtValue(); |
99 | auto FIt = FlatProf.find(Val: Idx); |
100 | if (FIt == FlatProf.end()) |
101 | return; |
102 | const auto &Targets = FIt->second; |
103 | SmallVector<InstrProfValueData, 2> Data; |
104 | uint64_t Sum = 0; |
105 | for (auto &[Guid, Count] : Targets) { |
106 | Data.push_back(Elt: {/*.Value=*/Guid, /*.Count=*/Count}); |
107 | Sum += Count; |
108 | } |
109 | |
110 | llvm::sort(C&: Data, |
111 | Comp: [](const InstrProfValueData &A, const InstrProfValueData &B) { |
112 | return A.Count > B.Count; |
113 | }); |
114 | llvm::annotateValueSite(M, Inst&: CB, VDs: Data, Sum, |
115 | ValueKind: InstrProfValueKind::IPVK_IndirectCallTarget, |
116 | MaxMDCount: Data.size()); |
117 | LLVM_DEBUG(dbgs() << "[ctxprof] flat indirect call prof: " << CB |
118 | << CB.getMetadata(LLVMContext::MD_prof) << "\n" ); |
119 | } |
120 | |
121 | // We normally return a "Changed" bool, but the calling pass' run assumes |
122 | // something will change - some profile will be added - so this won't add much |
123 | // by returning false when applicable. |
124 | void annotateIndirectCalls(Module &M, const CtxProfAnalysis::Result &CtxProf) { |
125 | const auto FlatIndCalls = CtxProf.flattenVirtCalls(); |
126 | for (auto &F : M) { |
127 | if (F.isDeclaration()) |
128 | continue; |
129 | auto FlatProfIter = FlatIndCalls.find(Val: AssignGUIDPass::getGUID(F)); |
130 | if (FlatProfIter == FlatIndCalls.end()) |
131 | continue; |
132 | const auto &FlatProf = FlatProfIter->second; |
133 | for (auto &BB : F) { |
134 | for (auto &I : BB) { |
135 | auto *CB = dyn_cast<CallBase>(Val: &I); |
136 | if (!CB || !CB->isIndirectCall()) |
137 | continue; |
138 | if (auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(CB&: *CB)) |
139 | annotateIndirectCall(M, CB&: *CB, FlatProf, Ins: *Ins); |
140 | } |
141 | } |
142 | } |
143 | } |
144 | |
145 | } // namespace |
146 | |
147 | PreservedAnalyses PGOCtxProfFlatteningPass::run(Module &M, |
148 | ModuleAnalysisManager &MAM) { |
149 | // Ensure in all cases the instrumentation is removed: if this module had no |
150 | // roots, the contextual profile would evaluate to false, but there would |
151 | // still be instrumentation. |
152 | // Note: in such cases we leave as-is any other profile info (if present - |
153 | // e.g. synthetic weights, etc) because it wouldn't interfere with the |
154 | // contextual - based one (which would be in other modules) |
155 | auto OnExit = llvm::make_scope_exit(F: [&]() { |
156 | if (IsPreThinlink) |
157 | return; |
158 | for (auto &F : M) |
159 | removeInstrumentation(F); |
160 | }); |
161 | auto &CtxProf = MAM.getResult<CtxProfAnalysis>(IR&: M); |
162 | // post-thinlink, we only reprocess for the module(s) containing the |
163 | // contextual tree. For everything else, OnExit will just clean the |
164 | // instrumentation. |
165 | if (!IsPreThinlink && !CtxProf.isInSpecializedModule()) |
166 | return PreservedAnalyses::none(); |
167 | |
168 | if (IsPreThinlink) |
169 | annotateIndirectCalls(M, CtxProf); |
170 | const auto FlattenedProfile = CtxProf.flatten(); |
171 | |
172 | for (auto &F : M) { |
173 | if (F.isDeclaration()) |
174 | continue; |
175 | |
176 | assert(areAllBBsReachable( |
177 | F, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M) |
178 | .getManager()) && |
179 | "Function has unreacheable basic blocks. The expectation was that " |
180 | "DCE was run before." ); |
181 | |
182 | auto It = FlattenedProfile.find(x: AssignGUIDPass::getGUID(F)); |
183 | // If this function didn't appear in the contextual profile, it's cold. |
184 | if (It == FlattenedProfile.end()) |
185 | clearColdFunctionProfile(F); |
186 | else |
187 | assignProfileData(F, RawCounters: It->second); |
188 | } |
189 | InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs); |
190 | // use here the flat profiles just so the importer doesn't complain about |
191 | // how different the PSIs are between the module with the roots and the |
192 | // various modules it imports. |
193 | for (auto &C : FlattenedProfile) { |
194 | PB.addEntryCount(Count: C.second[0]); |
195 | for (auto V : llvm::drop_begin(RangeOrContainer: C.second)) |
196 | PB.addInternalCount(Count: V); |
197 | } |
198 | |
199 | M.setProfileSummary(M: PB.getSummary()->getMD(Context&: M.getContext()), |
200 | Kind: ProfileSummary::Kind::PSK_Instr); |
201 | PreservedAnalyses PA; |
202 | PA.abandon<ProfileSummaryAnalysis>(); |
203 | MAM.invalidate(IR&: M, PA); |
204 | auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(IR&: M); |
205 | PSI.refresh(Other: PB.getSummary()); |
206 | return PreservedAnalyses::none(); |
207 | } |
208 | |