| 1 | //===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Flattens the contextual profile and lowers it to MD_prof. |
| 10 | // This should happen after all IPO (which is assumed to have maintained the |
| 11 | // contextual profile) happened. Flattening consists of summing the values at |
| 12 | // the same index of the counters belonging to all the contexts of a function. |
| 13 | // The lowering consists of materializing the counter values to function |
| 14 | // entrypoint counts and branch probabilities. |
| 15 | // |
| 16 | // This pass also removes contextual instrumentation, which has been kept around |
| 17 | // to facilitate its functionality. |
| 18 | // |
| 19 | //===----------------------------------------------------------------------===// |
| 20 | |
| 21 | #include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h" |
| 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | #include "llvm/ADT/ScopeExit.h" |
| 24 | #include "llvm/Analysis/CFG.h" |
| 25 | #include "llvm/Analysis/CtxProfAnalysis.h" |
| 26 | #include "llvm/Analysis/ProfileSummaryInfo.h" |
| 27 | #include "llvm/IR/Analysis.h" |
| 28 | #include "llvm/IR/CFG.h" |
| 29 | #include "llvm/IR/Dominators.h" |
| 30 | #include "llvm/IR/Instructions.h" |
| 31 | #include "llvm/IR/IntrinsicInst.h" |
| 32 | #include "llvm/IR/Module.h" |
| 33 | #include "llvm/IR/PassManager.h" |
| 34 | #include "llvm/IR/ProfileSummary.h" |
| 35 | #include "llvm/ProfileData/ProfileCommon.h" |
| 36 | #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" |
| 37 | #include "llvm/Transforms/Scalar/DCE.h" |
| 38 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| 39 | |
| 40 | using namespace llvm; |
| 41 | |
| 42 | #define DEBUG_TYPE "ctx_prof_flatten" |
| 43 | |
| 44 | namespace { |
| 45 | |
| 46 | /// Assign branch weights and function entry count. Also update the PSI |
| 47 | /// builder. |
| 48 | void assignProfileData(Function &F, ArrayRef<uint64_t> RawCounters) { |
| 49 | assert(!RawCounters.empty()); |
| 50 | ProfileAnnotator PA(F, RawCounters); |
| 51 | |
| 52 | F.setEntryCount(Count: RawCounters[0]); |
| 53 | SmallVector<uint64_t, 2> ProfileHolder; |
| 54 | |
| 55 | for (auto &BB : F) { |
| 56 | for (auto &I : BB) |
| 57 | if (auto *SI = dyn_cast<SelectInst>(Val: &I)) { |
| 58 | uint64_t TrueCount, FalseCount = 0; |
| 59 | if (!PA.getSelectInstrProfile(SI&: *SI, TrueCount, FalseCount)) |
| 60 | continue; |
| 61 | setProfMetadata(M: F.getParent(), TI: SI, EdgeCounts: {TrueCount, FalseCount}, |
| 62 | MaxCount: std::max(a: TrueCount, b: FalseCount)); |
| 63 | } |
| 64 | if (succ_size(BB: &BB) < 2) |
| 65 | continue; |
| 66 | uint64_t MaxCount = 0; |
| 67 | if (!PA.getOutgoingBranchWeights(BB, Profile&: ProfileHolder, MaxCount)) |
| 68 | continue; |
| 69 | assert(MaxCount > 0); |
| 70 | setProfMetadata(M: F.getParent(), TI: BB.getTerminator(), EdgeCounts: ProfileHolder, MaxCount); |
| 71 | } |
| 72 | } |
| 73 | |
| 74 | [[maybe_unused]] bool areAllBBsReachable(const Function &F, |
| 75 | FunctionAnalysisManager &FAM) { |
| 76 | auto &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: const_cast<Function &>(F)); |
| 77 | return llvm::all_of( |
| 78 | Range: F, P: [&](const BasicBlock &BB) { return DT.isReachableFromEntry(A: &BB); }); |
| 79 | } |
| 80 | |
| 81 | void clearColdFunctionProfile(Function &F) { |
| 82 | for (auto &BB : F) |
| 83 | BB.getTerminator()->setMetadata(KindID: LLVMContext::MD_prof, Node: nullptr); |
| 84 | F.setEntryCount(Count: 0U); |
| 85 | } |
| 86 | |
| 87 | void removeInstrumentation(Function &F) { |
| 88 | for (auto &BB : F) |
| 89 | for (auto &I : llvm::make_early_inc_range(Range&: BB)) |
| 90 | if (isa<InstrProfCntrInstBase>(Val: I)) |
| 91 | I.eraseFromParent(); |
| 92 | } |
| 93 | |
| 94 | void annotateIndirectCall( |
| 95 | Module &M, CallBase &CB, |
| 96 | const DenseMap<uint32_t, FlatIndirectTargets> &FlatProf, |
| 97 | const InstrProfCallsite &Ins) { |
| 98 | auto Idx = Ins.getIndex()->getZExtValue(); |
| 99 | auto FIt = FlatProf.find(Val: Idx); |
| 100 | if (FIt == FlatProf.end()) |
| 101 | return; |
| 102 | const auto &Targets = FIt->second; |
| 103 | SmallVector<InstrProfValueData, 2> Data; |
| 104 | uint64_t Sum = 0; |
| 105 | for (auto &[Guid, Count] : Targets) { |
| 106 | Data.push_back(Elt: {/*.Value=*/Guid, /*.Count=*/Count}); |
| 107 | Sum += Count; |
| 108 | } |
| 109 | |
| 110 | llvm::sort(C&: Data, |
| 111 | Comp: [](const InstrProfValueData &A, const InstrProfValueData &B) { |
| 112 | return A.Count > B.Count; |
| 113 | }); |
| 114 | llvm::annotateValueSite(M, Inst&: CB, VDs: Data, Sum, |
| 115 | ValueKind: InstrProfValueKind::IPVK_IndirectCallTarget, |
| 116 | MaxMDCount: Data.size()); |
| 117 | LLVM_DEBUG(dbgs() << "[ctxprof] flat indirect call prof: " << CB |
| 118 | << CB.getMetadata(LLVMContext::MD_prof) << "\n" ); |
| 119 | } |
| 120 | |
| 121 | // We normally return a "Changed" bool, but the calling pass' run assumes |
| 122 | // something will change - some profile will be added - so this won't add much |
| 123 | // by returning false when applicable. |
| 124 | void annotateIndirectCalls(Module &M, const CtxProfAnalysis::Result &CtxProf) { |
| 125 | const auto FlatIndCalls = CtxProf.flattenVirtCalls(); |
| 126 | for (auto &F : M) { |
| 127 | if (F.isDeclaration()) |
| 128 | continue; |
| 129 | auto FlatProfIter = FlatIndCalls.find(Val: AssignGUIDPass::getGUID(F)); |
| 130 | if (FlatProfIter == FlatIndCalls.end()) |
| 131 | continue; |
| 132 | const auto &FlatProf = FlatProfIter->second; |
| 133 | for (auto &BB : F) { |
| 134 | for (auto &I : BB) { |
| 135 | auto *CB = dyn_cast<CallBase>(Val: &I); |
| 136 | if (!CB || !CB->isIndirectCall()) |
| 137 | continue; |
| 138 | if (auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(CB&: *CB)) |
| 139 | annotateIndirectCall(M, CB&: *CB, FlatProf, Ins: *Ins); |
| 140 | } |
| 141 | } |
| 142 | } |
| 143 | } |
| 144 | |
| 145 | } // namespace |
| 146 | |
| 147 | PreservedAnalyses PGOCtxProfFlatteningPass::run(Module &M, |
| 148 | ModuleAnalysisManager &MAM) { |
| 149 | // Ensure in all cases the instrumentation is removed: if this module had no |
| 150 | // roots, the contextual profile would evaluate to false, but there would |
| 151 | // still be instrumentation. |
| 152 | // Note: in such cases we leave as-is any other profile info (if present - |
| 153 | // e.g. synthetic weights, etc) because it wouldn't interfere with the |
| 154 | // contextual - based one (which would be in other modules) |
| 155 | auto OnExit = llvm::make_scope_exit(F: [&]() { |
| 156 | if (IsPreThinlink) |
| 157 | return; |
| 158 | for (auto &F : M) |
| 159 | removeInstrumentation(F); |
| 160 | }); |
| 161 | auto &CtxProf = MAM.getResult<CtxProfAnalysis>(IR&: M); |
| 162 | // post-thinlink, we only reprocess for the module(s) containing the |
| 163 | // contextual tree. For everything else, OnExit will just clean the |
| 164 | // instrumentation. |
| 165 | if (!IsPreThinlink && !CtxProf.isInSpecializedModule()) |
| 166 | return PreservedAnalyses::none(); |
| 167 | |
| 168 | if (IsPreThinlink) |
| 169 | annotateIndirectCalls(M, CtxProf); |
| 170 | const auto FlattenedProfile = CtxProf.flatten(); |
| 171 | |
| 172 | for (auto &F : M) { |
| 173 | if (F.isDeclaration()) |
| 174 | continue; |
| 175 | |
| 176 | assert(areAllBBsReachable( |
| 177 | F, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M) |
| 178 | .getManager()) && |
| 179 | "Function has unreacheable basic blocks. The expectation was that " |
| 180 | "DCE was run before." ); |
| 181 | |
| 182 | auto It = FlattenedProfile.find(x: AssignGUIDPass::getGUID(F)); |
| 183 | // If this function didn't appear in the contextual profile, it's cold. |
| 184 | if (It == FlattenedProfile.end()) |
| 185 | clearColdFunctionProfile(F); |
| 186 | else |
| 187 | assignProfileData(F, RawCounters: It->second); |
| 188 | } |
| 189 | InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs); |
| 190 | // use here the flat profiles just so the importer doesn't complain about |
| 191 | // how different the PSIs are between the module with the roots and the |
| 192 | // various modules it imports. |
| 193 | for (auto &C : FlattenedProfile) { |
| 194 | PB.addEntryCount(Count: C.second[0]); |
| 195 | for (auto V : llvm::drop_begin(RangeOrContainer: C.second)) |
| 196 | PB.addInternalCount(Count: V); |
| 197 | } |
| 198 | |
| 199 | M.setProfileSummary(M: PB.getSummary()->getMD(Context&: M.getContext()), |
| 200 | Kind: ProfileSummary::Kind::PSK_Instr); |
| 201 | PreservedAnalyses PA; |
| 202 | PA.abandon<ProfileSummaryAnalysis>(); |
| 203 | MAM.invalidate(IR&: M, PA); |
| 204 | auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(IR&: M); |
| 205 | PSI.refresh(Other: PB.getSummary()); |
| 206 | return PreservedAnalyses::none(); |
| 207 | } |
| 208 | |