| 1 | //===- CtxProfAnalysis.cpp - contextual profile analysis ------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Implementation of the contextual profile analysis, which maintains contextual |
| 10 | // profiling info through IPO passes. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "llvm/Analysis/CtxProfAnalysis.h" |
| 15 | #include "llvm/ADT/APInt.h" |
| 16 | #include "llvm/ADT/STLExtras.h" |
| 17 | #include "llvm/Analysis/CFG.h" |
| 18 | #include "llvm/IR/Analysis.h" |
| 19 | #include "llvm/IR/Dominators.h" |
| 20 | #include "llvm/IR/IntrinsicInst.h" |
| 21 | #include "llvm/IR/Module.h" |
| 22 | #include "llvm/IR/PassManager.h" |
| 23 | #include "llvm/ProfileData/PGOCtxProfReader.h" |
| 24 | #include "llvm/Support/CommandLine.h" |
| 25 | #include "llvm/Support/MemoryBuffer.h" |
| 26 | #include "llvm/Support/Path.h" |
| 27 | #include <deque> |
| 28 | #include <memory> |
| 29 | |
| 30 | #define DEBUG_TYPE "ctx_prof" |
| 31 | |
| 32 | using namespace llvm; |
| 33 | cl::opt<std::string> |
| 34 | UseCtxProfile("use-ctx-profile" , cl::init(Val: "" ), cl::Hidden, |
| 35 | cl::desc("Use the specified contextual profile file" )); |
| 36 | |
| 37 | static cl::opt<CtxProfAnalysisPrinterPass::PrintMode> PrintLevel( |
| 38 | "ctx-profile-printer-level" , |
| 39 | cl::init(Val: CtxProfAnalysisPrinterPass::PrintMode::YAML), cl::Hidden, |
| 40 | cl::values(clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::Everything, |
| 41 | "everything" , "print everything - most verbose" ), |
| 42 | clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::YAML, "yaml" , |
| 43 | "just the yaml representation of the profile" )), |
| 44 | cl::desc("Verbosity level of the contextual profile printer pass." )); |
| 45 | |
| 46 | static cl::opt<bool> ForceIsInSpecializedModule( |
| 47 | "ctx-profile-force-is-specialized" , cl::init(Val: false), |
| 48 | cl::desc("Treat the given module as-if it were containing the " |
| 49 | "post-thinlink module containing the root" )); |
| 50 | |
| 51 | const char *AssignGUIDPass::GUIDMetadataName = "guid" ; |
| 52 | |
| 53 | namespace llvm { |
| 54 | class ProfileAnnotatorImpl final { |
| 55 | friend class ProfileAnnotator; |
| 56 | class BBInfo; |
| 57 | struct EdgeInfo { |
| 58 | BBInfo *const Src; |
| 59 | BBInfo *const Dest; |
| 60 | std::optional<uint64_t> Count; |
| 61 | |
| 62 | explicit EdgeInfo(BBInfo &Src, BBInfo &Dest) : Src(&Src), Dest(&Dest) {} |
| 63 | }; |
| 64 | |
| 65 | class BBInfo { |
| 66 | std::optional<uint64_t> Count; |
| 67 | // OutEdges is dimensioned to match the number of terminator operands. |
| 68 | // Entries in the vector match the index in the terminator operand list. In |
| 69 | // some cases - see `shouldExcludeEdge` and its implementation - an entry |
| 70 | // will be nullptr. |
| 71 | // InEdges doesn't have the above constraint. |
| 72 | SmallVector<EdgeInfo *> OutEdges; |
| 73 | SmallVector<EdgeInfo *> InEdges; |
| 74 | size_t UnknownCountOutEdges = 0; |
| 75 | size_t UnknownCountInEdges = 0; |
| 76 | |
| 77 | // Pass AssumeAllKnown when we try to propagate counts from edges to BBs - |
| 78 | // because all the edge counters must be known. |
| 79 | // Return std::nullopt if there were no edges to sum. The user can decide |
| 80 | // how to interpret that. |
| 81 | std::optional<uint64_t> getEdgeSum(const SmallVector<EdgeInfo *> &Edges, |
| 82 | bool AssumeAllKnown) const { |
| 83 | std::optional<uint64_t> Sum; |
| 84 | for (const auto *E : Edges) { |
| 85 | // `Edges` may be `OutEdges`, case in which `E` could be nullptr. |
| 86 | if (E) { |
| 87 | if (!Sum.has_value()) |
| 88 | Sum = 0; |
| 89 | *Sum += (AssumeAllKnown ? *E->Count : E->Count.value_or(u: 0U)); |
| 90 | } |
| 91 | } |
| 92 | return Sum; |
| 93 | } |
| 94 | |
| 95 | bool computeCountFrom(const SmallVector<EdgeInfo *> &Edges) { |
| 96 | assert(!Count.has_value()); |
| 97 | Count = getEdgeSum(Edges, AssumeAllKnown: true); |
| 98 | return Count.has_value(); |
| 99 | } |
| 100 | |
| 101 | void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges) { |
| 102 | uint64_t KnownSum = getEdgeSum(Edges, AssumeAllKnown: false).value_or(u: 0U); |
| 103 | uint64_t EdgeVal = *Count > KnownSum ? *Count - KnownSum : 0U; |
| 104 | EdgeInfo *E = nullptr; |
| 105 | for (auto *I : Edges) |
| 106 | if (I && !I->Count.has_value()) { |
| 107 | E = I; |
| 108 | #ifdef NDEBUG |
| 109 | break; |
| 110 | #else |
| 111 | assert((!E || E == I) && |
| 112 | "Expected exactly one edge to have an unknown count, " |
| 113 | "found a second one" ); |
| 114 | continue; |
| 115 | #endif |
| 116 | } |
| 117 | assert(E && "Expected exactly one edge to have an unknown count" ); |
| 118 | assert(!E->Count.has_value()); |
| 119 | E->Count = EdgeVal; |
| 120 | assert(E->Src->UnknownCountOutEdges > 0); |
| 121 | assert(E->Dest->UnknownCountInEdges > 0); |
| 122 | --E->Src->UnknownCountOutEdges; |
| 123 | --E->Dest->UnknownCountInEdges; |
| 124 | } |
| 125 | |
| 126 | public: |
| 127 | BBInfo(size_t NumInEdges, size_t NumOutEdges, std::optional<uint64_t> Count) |
| 128 | : Count(Count) { |
| 129 | // For in edges, we just want to pre-allocate enough space, since we know |
| 130 | // it at this stage. For out edges, we will insert edges at the indices |
| 131 | // corresponding to positions in this BB's terminator instruction, so we |
| 132 | // construct a default (nullptr values)-initialized vector. A nullptr edge |
| 133 | // corresponds to those that are excluded (see shouldExcludeEdge). |
| 134 | InEdges.reserve(N: NumInEdges); |
| 135 | OutEdges.resize(N: NumOutEdges); |
| 136 | } |
| 137 | |
| 138 | bool tryTakeCountFromKnownOutEdges(const BasicBlock &BB) { |
| 139 | if (!UnknownCountOutEdges) { |
| 140 | return computeCountFrom(Edges: OutEdges); |
| 141 | } |
| 142 | return false; |
| 143 | } |
| 144 | |
| 145 | bool tryTakeCountFromKnownInEdges(const BasicBlock &BB) { |
| 146 | if (!UnknownCountInEdges) { |
| 147 | return computeCountFrom(Edges: InEdges); |
| 148 | } |
| 149 | return false; |
| 150 | } |
| 151 | |
| 152 | void addInEdge(EdgeInfo &Info) { |
| 153 | InEdges.push_back(Elt: &Info); |
| 154 | ++UnknownCountInEdges; |
| 155 | } |
| 156 | |
| 157 | // For the out edges, we care about the position we place them in, which is |
| 158 | // the position in terminator instruction's list (at construction). Later, |
| 159 | // we build branch_weights metadata with edge frequency values matching |
| 160 | // these positions. |
| 161 | void addOutEdge(size_t Index, EdgeInfo &Info) { |
| 162 | OutEdges[Index] = &Info; |
| 163 | ++UnknownCountOutEdges; |
| 164 | } |
| 165 | |
| 166 | bool hasCount() const { return Count.has_value(); } |
| 167 | |
| 168 | uint64_t getCount() const { return *Count; } |
| 169 | |
| 170 | bool trySetSingleUnknownInEdgeCount() { |
| 171 | if (UnknownCountInEdges == 1) { |
| 172 | setSingleUnknownEdgeCount(InEdges); |
| 173 | return true; |
| 174 | } |
| 175 | return false; |
| 176 | } |
| 177 | |
| 178 | bool trySetSingleUnknownOutEdgeCount() { |
| 179 | if (UnknownCountOutEdges == 1) { |
| 180 | setSingleUnknownEdgeCount(OutEdges); |
| 181 | return true; |
| 182 | } |
| 183 | return false; |
| 184 | } |
| 185 | size_t getNumOutEdges() const { return OutEdges.size(); } |
| 186 | |
| 187 | uint64_t getEdgeCount(size_t Index) const { |
| 188 | if (auto *E = OutEdges[Index]) |
| 189 | return *E->Count; |
| 190 | return 0U; |
| 191 | } |
| 192 | }; |
| 193 | |
| 194 | const Function &F; |
| 195 | ArrayRef<uint64_t> Counters; |
| 196 | // To be accessed through getBBInfo() after construction. |
| 197 | std::map<const BasicBlock *, BBInfo> BBInfos; |
| 198 | std::vector<EdgeInfo> EdgeInfos; |
| 199 | |
| 200 | // The only criteria for exclusion is faux suspend -> exit edges in presplit |
| 201 | // coroutines. The API serves for readability, currently. |
| 202 | bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const { |
| 203 | return llvm::isPresplitCoroSuspendExitEdge(Src, Dest); |
| 204 | } |
| 205 | |
| 206 | BBInfo &getBBInfo(const BasicBlock &BB) { return BBInfos.find(x: &BB)->second; } |
| 207 | |
| 208 | const BBInfo &getBBInfo(const BasicBlock &BB) const { |
| 209 | return BBInfos.find(x: &BB)->second; |
| 210 | } |
| 211 | |
| 212 | // validation function after we propagate the counters: all BBs and edges' |
| 213 | // counters must have a value. |
| 214 | bool allCountersAreAssigned() const { |
| 215 | for (const auto &BBInfo : BBInfos) |
| 216 | if (!BBInfo.second.hasCount()) |
| 217 | return false; |
| 218 | for (const auto &EdgeInfo : EdgeInfos) |
| 219 | if (!EdgeInfo.Count.has_value()) |
| 220 | return false; |
| 221 | return true; |
| 222 | } |
| 223 | |
| 224 | /// Check that all paths from the entry basic block that use edges with |
| 225 | /// non-zero counts arrive at a basic block with no successors (i.e. "exit") |
| 226 | bool allTakenPathsExit() const { |
| 227 | std::deque<const BasicBlock *> Worklist; |
| 228 | DenseSet<const BasicBlock *> Visited; |
| 229 | Worklist.push_back(x: &F.getEntryBlock()); |
| 230 | bool HitExit = false; |
| 231 | while (!Worklist.empty()) { |
| 232 | const auto *BB = Worklist.front(); |
| 233 | Worklist.pop_front(); |
| 234 | if (!Visited.insert(V: BB).second) |
| 235 | continue; |
| 236 | if (succ_size(BB) == 0) { |
| 237 | if (isa<UnreachableInst>(Val: BB->getTerminator())) |
| 238 | return false; |
| 239 | HitExit = true; |
| 240 | continue; |
| 241 | } |
| 242 | if (succ_size(BB) == 1) { |
| 243 | Worklist.push_back(x: BB->getUniqueSuccessor()); |
| 244 | continue; |
| 245 | } |
| 246 | const auto &BBInfo = getBBInfo(BB: *BB); |
| 247 | bool HasAWayOut = false; |
| 248 | for (auto I = 0U; I < BB->getTerminator()->getNumSuccessors(); ++I) { |
| 249 | const auto *Succ = BB->getTerminator()->getSuccessor(Idx: I); |
| 250 | if (!shouldExcludeEdge(Src: *BB, Dest: *Succ)) { |
| 251 | if (BBInfo.getEdgeCount(Index: I) > 0) { |
| 252 | HasAWayOut = true; |
| 253 | Worklist.push_back(x: Succ); |
| 254 | } |
| 255 | } |
| 256 | } |
| 257 | if (!HasAWayOut) |
| 258 | return false; |
| 259 | } |
| 260 | return HitExit; |
| 261 | } |
| 262 | |
| 263 | bool allNonColdSelectsHaveProfile() const { |
| 264 | for (const auto &BB : F) { |
| 265 | if (getBBInfo(BB).getCount() > 0) { |
| 266 | for (const auto &I : BB) { |
| 267 | if (const auto *SI = dyn_cast<SelectInst>(Val: &I)) { |
| 268 | if (const auto *Inst = CtxProfAnalysis::getSelectInstrumentation( |
| 269 | SI&: *const_cast<SelectInst *>(SI))) { |
| 270 | auto Index = Inst->getIndex()->getZExtValue(); |
| 271 | assert(Index < Counters.size()); |
| 272 | if (Counters[Index] == 0) |
| 273 | return false; |
| 274 | } |
| 275 | } |
| 276 | } |
| 277 | } |
| 278 | } |
| 279 | return true; |
| 280 | } |
| 281 | |
| 282 | // This is an adaptation of PGOUseFunc::populateCounters. |
| 283 | // FIXME(mtrofin): look into factoring the code to share one implementation. |
| 284 | void propagateCounterValues() { |
| 285 | bool KeepGoing = true; |
| 286 | while (KeepGoing) { |
| 287 | KeepGoing = false; |
| 288 | for (const auto &BB : F) { |
| 289 | auto &Info = getBBInfo(BB); |
| 290 | if (!Info.hasCount()) |
| 291 | KeepGoing |= Info.tryTakeCountFromKnownOutEdges(BB) || |
| 292 | Info.tryTakeCountFromKnownInEdges(BB); |
| 293 | if (Info.hasCount()) { |
| 294 | KeepGoing |= Info.trySetSingleUnknownOutEdgeCount(); |
| 295 | KeepGoing |= Info.trySetSingleUnknownInEdgeCount(); |
| 296 | } |
| 297 | } |
| 298 | } |
| 299 | assert(allCountersAreAssigned() && |
| 300 | "[ctx-prof] Expected all counters have been assigned." ); |
| 301 | assert(allTakenPathsExit() && |
| 302 | "[ctx-prof] Encountered a BB with more than one successor, where " |
| 303 | "all outgoing edges have a 0 count. This occurs in non-exiting " |
| 304 | "functions (message pumps, usually) which are not supported in the " |
| 305 | "contextual profiling case" ); |
| 306 | assert(allNonColdSelectsHaveProfile() && |
| 307 | "[ctx-prof] All non-cold select instructions were expected to have " |
| 308 | "a profile." ); |
| 309 | } |
| 310 | |
| 311 | public: |
| 312 | ProfileAnnotatorImpl(const Function &F, ArrayRef<uint64_t> Counters) |
| 313 | : F(F), Counters(Counters) { |
| 314 | assert(!F.isDeclaration()); |
| 315 | assert(!Counters.empty()); |
| 316 | size_t NrEdges = 0; |
| 317 | for (const auto &BB : F) { |
| 318 | std::optional<uint64_t> Count; |
| 319 | if (auto *Ins = CtxProfAnalysis::getBBInstrumentation( |
| 320 | BB&: const_cast<BasicBlock &>(BB))) { |
| 321 | auto Index = Ins->getIndex()->getZExtValue(); |
| 322 | assert(Index < Counters.size() && |
| 323 | "The index must be inside the counters vector by construction - " |
| 324 | "tripping this assertion indicates a bug in how the contextual " |
| 325 | "profile is managed by IPO transforms" ); |
| 326 | (void)Index; |
| 327 | Count = Counters[Ins->getIndex()->getZExtValue()]; |
| 328 | } else if (isa<UnreachableInst>(Val: BB.getTerminator())) { |
| 329 | // The program presumably didn't crash. |
| 330 | Count = 0; |
| 331 | } |
| 332 | auto [It, Ins] = |
| 333 | BBInfos.insert(x: {&BB, {pred_size(BB: &BB), succ_size(BB: &BB), Count}}); |
| 334 | (void)Ins; |
| 335 | assert(Ins && "We iterate through the function's BBs, no reason to " |
| 336 | "insert one more than once" ); |
| 337 | NrEdges += llvm::count_if(Range: successors(BB: &BB), P: [&](const auto *Succ) { |
| 338 | return !shouldExcludeEdge(Src: BB, Dest: *Succ); |
| 339 | }); |
| 340 | } |
| 341 | // Pre-allocate the vector, we want references to its contents to be stable. |
| 342 | EdgeInfos.reserve(n: NrEdges); |
| 343 | for (const auto &BB : F) { |
| 344 | auto &Info = getBBInfo(BB); |
| 345 | for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) { |
| 346 | const auto *Succ = BB.getTerminator()->getSuccessor(Idx: I); |
| 347 | if (!shouldExcludeEdge(Src: BB, Dest: *Succ)) { |
| 348 | auto &EI = EdgeInfos.emplace_back(args&: getBBInfo(BB), args&: getBBInfo(BB: *Succ)); |
| 349 | Info.addOutEdge(Index: I, Info&: EI); |
| 350 | getBBInfo(BB: *Succ).addInEdge(Info&: EI); |
| 351 | } |
| 352 | } |
| 353 | } |
| 354 | assert(EdgeInfos.capacity() == NrEdges && |
| 355 | "The capacity of EdgeInfos should have stayed unchanged it was " |
| 356 | "populated, because we need pointers to its contents to be stable" ); |
| 357 | propagateCounterValues(); |
| 358 | } |
| 359 | |
| 360 | uint64_t getBBCount(const BasicBlock &BB) { return getBBInfo(BB).getCount(); } |
| 361 | }; |
| 362 | |
| 363 | } // namespace llvm |
| 364 | |
| 365 | ProfileAnnotator::ProfileAnnotator(const Function &F, |
| 366 | ArrayRef<uint64_t> RawCounters) |
| 367 | : PImpl(std::make_unique<ProfileAnnotatorImpl>(args: F, args&: RawCounters)) {} |
| 368 | |
| 369 | ProfileAnnotator::~ProfileAnnotator() = default; |
| 370 | |
| 371 | uint64_t ProfileAnnotator::getBBCount(const BasicBlock &BB) const { |
| 372 | return PImpl->getBBCount(BB); |
| 373 | } |
| 374 | |
| 375 | bool ProfileAnnotator::getSelectInstrProfile(SelectInst &SI, |
| 376 | uint64_t &TrueCount, |
| 377 | uint64_t &FalseCount) const { |
| 378 | const auto &BBInfo = PImpl->getBBInfo(BB: *SI.getParent()); |
| 379 | TrueCount = FalseCount = 0; |
| 380 | if (BBInfo.getCount() == 0) |
| 381 | return false; |
| 382 | |
| 383 | auto *Step = CtxProfAnalysis::getSelectInstrumentation(SI); |
| 384 | if (!Step) |
| 385 | return false; |
| 386 | auto Index = Step->getIndex()->getZExtValue(); |
| 387 | assert(Index < PImpl->Counters.size() && |
| 388 | "The index of the step instruction must be inside the " |
| 389 | "counters vector by " |
| 390 | "construction - tripping this assertion indicates a bug in " |
| 391 | "how the contextual profile is managed by IPO transforms" ); |
| 392 | auto TotalCount = BBInfo.getCount(); |
| 393 | TrueCount = PImpl->Counters[Index]; |
| 394 | FalseCount = (TotalCount > TrueCount ? TotalCount - TrueCount : 0U); |
| 395 | return true; |
| 396 | } |
| 397 | |
| 398 | bool ProfileAnnotator::getOutgoingBranchWeights( |
| 399 | BasicBlock &BB, SmallVectorImpl<uint64_t> &Profile, |
| 400 | uint64_t &MaxCount) const { |
| 401 | Profile.clear(); |
| 402 | |
| 403 | if (succ_size(BB: &BB) < 2) |
| 404 | return false; |
| 405 | |
| 406 | auto *Term = BB.getTerminator(); |
| 407 | Profile.resize(N: Term->getNumSuccessors()); |
| 408 | |
| 409 | const auto &BBInfo = PImpl->getBBInfo(BB); |
| 410 | MaxCount = 0; |
| 411 | for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size; |
| 412 | ++SuccIdx) { |
| 413 | uint64_t EdgeCount = BBInfo.getEdgeCount(Index: SuccIdx); |
| 414 | if (EdgeCount > MaxCount) |
| 415 | MaxCount = EdgeCount; |
| 416 | Profile[SuccIdx] = EdgeCount; |
| 417 | } |
| 418 | return MaxCount > 0; |
| 419 | } |
| 420 | |
| 421 | PreservedAnalyses AssignGUIDPass::run(Module &M, ModuleAnalysisManager &MAM) { |
| 422 | for (auto &F : M.functions()) { |
| 423 | if (F.isDeclaration()) |
| 424 | continue; |
| 425 | if (F.getMetadata(Kind: GUIDMetadataName)) |
| 426 | continue; |
| 427 | const GlobalValue::GUID GUID = F.getGUID(); |
| 428 | F.setMetadata(Kind: GUIDMetadataName, |
| 429 | Node: MDNode::get(Context&: M.getContext(), |
| 430 | MDs: {ConstantAsMetadata::get(C: ConstantInt::get( |
| 431 | Ty: Type::getInt64Ty(C&: M.getContext()), V: GUID))})); |
| 432 | } |
| 433 | return PreservedAnalyses::none(); |
| 434 | } |
| 435 | |
| 436 | GlobalValue::GUID AssignGUIDPass::getGUID(const Function &F) { |
| 437 | if (F.isDeclaration()) { |
| 438 | assert(GlobalValue::isExternalLinkage(F.getLinkage())); |
| 439 | return F.getGUID(); |
| 440 | } |
| 441 | auto *MD = F.getMetadata(Kind: GUIDMetadataName); |
| 442 | assert(MD && "guid not found for defined function" ); |
| 443 | return cast<ConstantInt>(Val: cast<ConstantAsMetadata>(Val: MD->getOperand(I: 0)) |
| 444 | ->getValue() |
| 445 | ->stripPointerCasts()) |
| 446 | ->getZExtValue(); |
| 447 | } |
| 448 | AnalysisKey CtxProfAnalysis::Key; |
| 449 | |
| 450 | CtxProfAnalysis::CtxProfAnalysis(std::optional<StringRef> Profile) |
| 451 | : Profile([&]() -> std::optional<StringRef> { |
| 452 | if (Profile) |
| 453 | return *Profile; |
| 454 | if (UseCtxProfile.getNumOccurrences()) |
| 455 | return UseCtxProfile; |
| 456 | return std::nullopt; |
| 457 | }()) {} |
| 458 | |
| 459 | PGOContextualProfile CtxProfAnalysis::run(Module &M, |
| 460 | ModuleAnalysisManager &MAM) { |
| 461 | if (!Profile) |
| 462 | return {}; |
| 463 | ErrorOr<std::unique_ptr<MemoryBuffer>> MB = MemoryBuffer::getFile(Filename: *Profile); |
| 464 | if (auto EC = MB.getError()) { |
| 465 | M.getContext().emitError(ErrorStr: "could not open contextual profile file: " + |
| 466 | EC.message()); |
| 467 | return {}; |
| 468 | } |
| 469 | PGOCtxProfileReader Reader(MB.get()->getBuffer()); |
| 470 | auto MaybeProfiles = Reader.loadProfiles(); |
| 471 | if (!MaybeProfiles) { |
| 472 | M.getContext().emitError(ErrorStr: "contextual profile file is invalid: " + |
| 473 | toString(E: MaybeProfiles.takeError())); |
| 474 | return {}; |
| 475 | } |
| 476 | |
| 477 | // FIXME: We should drive this from ThinLTO, but for the time being, use the |
| 478 | // module name as indicator. |
| 479 | // We want to *only* keep the contextual profiles in modules that capture |
| 480 | // context trees. That allows us to compute specific PSIs, for example. |
| 481 | auto DetermineRootsInModule = [&M]() -> const DenseSet<GlobalValue::GUID> { |
| 482 | DenseSet<GlobalValue::GUID> ProfileRootsInModule; |
| 483 | auto ModName = M.getName(); |
| 484 | auto Filename = sys::path::filename(path: ModName); |
| 485 | // Drop the file extension. |
| 486 | Filename = Filename.substr(Start: 0, N: Filename.find_last_of(C: '.')); |
| 487 | // See if it parses |
| 488 | APInt Guid; |
| 489 | // getAsInteger returns true if there are more chars to read other than the |
| 490 | // integer. So the "false" test is what we want. |
| 491 | if (!Filename.getAsInteger(Radix: 0, Result&: Guid)) |
| 492 | ProfileRootsInModule.insert(V: Guid.getZExtValue()); |
| 493 | return ProfileRootsInModule; |
| 494 | }; |
| 495 | const auto ProfileRootsInModule = DetermineRootsInModule(); |
| 496 | PGOContextualProfile Result; |
| 497 | |
| 498 | // the logic from here on allows for modules that contain - by design - more |
| 499 | // than one root. We currently don't support that, because the determination |
| 500 | // happens based on the module name matching the root guid, but the logic can |
| 501 | // avoid assuming that. |
| 502 | if (!ProfileRootsInModule.empty()) { |
| 503 | Result.IsInSpecializedModule = true; |
| 504 | // Trim first the roots that aren't in this module. |
| 505 | for (auto &[RootGuid, _] : |
| 506 | llvm::make_early_inc_range(Range&: MaybeProfiles->Contexts)) |
| 507 | if (!ProfileRootsInModule.contains(V: RootGuid)) |
| 508 | MaybeProfiles->Contexts.erase(x: RootGuid); |
| 509 | // we can also drop the flat profiles |
| 510 | MaybeProfiles->FlatProfiles.clear(); |
| 511 | } |
| 512 | |
| 513 | for (const auto &F : M) { |
| 514 | if (F.isDeclaration()) |
| 515 | continue; |
| 516 | auto GUID = AssignGUIDPass::getGUID(F); |
| 517 | assert(GUID && "guid not found for defined function" ); |
| 518 | const auto &Entry = F.begin(); |
| 519 | uint32_t MaxCounters = 0; // we expect at least a counter. |
| 520 | for (const auto &I : *Entry) |
| 521 | if (auto *C = dyn_cast<InstrProfIncrementInst>(Val: &I)) { |
| 522 | MaxCounters = |
| 523 | static_cast<uint32_t>(C->getNumCounters()->getZExtValue()); |
| 524 | break; |
| 525 | } |
| 526 | if (!MaxCounters) |
| 527 | continue; |
| 528 | uint32_t MaxCallsites = 0; |
| 529 | for (const auto &BB : F) |
| 530 | for (const auto &I : BB) |
| 531 | if (auto *C = dyn_cast<InstrProfCallsite>(Val: &I)) { |
| 532 | MaxCallsites = |
| 533 | static_cast<uint32_t>(C->getNumCounters()->getZExtValue()); |
| 534 | break; |
| 535 | } |
| 536 | auto [It, Ins] = Result.FuncInfo.insert( |
| 537 | x: {GUID, PGOContextualProfile::FunctionInfo(F.getName())}); |
| 538 | (void)Ins; |
| 539 | assert(Ins); |
| 540 | It->second.NextCallsiteIndex = MaxCallsites; |
| 541 | It->second.NextCounterIndex = MaxCounters; |
| 542 | } |
| 543 | // If we made it this far, the Result is valid - which we mark by setting |
| 544 | // .Profiles. |
| 545 | Result.Profiles = std::move(*MaybeProfiles); |
| 546 | Result.initIndex(); |
| 547 | return Result; |
| 548 | } |
| 549 | |
| 550 | GlobalValue::GUID |
| 551 | PGOContextualProfile::getDefinedFunctionGUID(const Function &F) const { |
| 552 | if (auto It = FuncInfo.find(x: AssignGUIDPass::getGUID(F)); It != FuncInfo.end()) |
| 553 | return It->first; |
| 554 | return 0; |
| 555 | } |
| 556 | |
| 557 | CtxProfAnalysisPrinterPass::CtxProfAnalysisPrinterPass(raw_ostream &OS) |
| 558 | : OS(OS), Mode(PrintLevel) {} |
| 559 | |
| 560 | PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M, |
| 561 | ModuleAnalysisManager &MAM) { |
| 562 | CtxProfAnalysis::Result &C = MAM.getResult<CtxProfAnalysis>(IR&: M); |
| 563 | if (C.contexts().empty()) { |
| 564 | OS << "No contextual profile was provided.\n" ; |
| 565 | return PreservedAnalyses::all(); |
| 566 | } |
| 567 | |
| 568 | if (Mode == PrintMode::Everything) { |
| 569 | OS << "Function Info:\n" ; |
| 570 | for (const auto &[Guid, FuncInfo] : C.FuncInfo) |
| 571 | OS << Guid << " : " << FuncInfo.Name |
| 572 | << ". MaxCounterID: " << FuncInfo.NextCounterIndex |
| 573 | << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n" ; |
| 574 | } |
| 575 | |
| 576 | if (Mode == PrintMode::Everything) |
| 577 | OS << "\nCurrent Profile:\n" ; |
| 578 | convertCtxProfToYaml(OS, Profile: C.profiles()); |
| 579 | OS << "\n" ; |
| 580 | if (Mode == PrintMode::YAML) |
| 581 | return PreservedAnalyses::all(); |
| 582 | |
| 583 | OS << "\nFlat Profile:\n" ; |
| 584 | auto Flat = C.flatten(); |
| 585 | for (const auto &[Guid, Counters] : Flat) { |
| 586 | OS << Guid << " : " ; |
| 587 | for (auto V : Counters) |
| 588 | OS << V << " " ; |
| 589 | OS << "\n" ; |
| 590 | } |
| 591 | return PreservedAnalyses::all(); |
| 592 | } |
| 593 | |
| 594 | InstrProfCallsite *CtxProfAnalysis::getCallsiteInstrumentation(CallBase &CB) { |
| 595 | if (!InstrProfCallsite::canInstrumentCallsite(CB)) |
| 596 | return nullptr; |
| 597 | for (auto *Prev = CB.getPrevNode(); Prev; Prev = Prev->getPrevNode()) { |
| 598 | if (auto *IPC = dyn_cast<InstrProfCallsite>(Val: Prev)) |
| 599 | return IPC; |
| 600 | assert(!isa<CallBase>(Prev) && |
| 601 | "didn't expect to find another call, that's not the callsite " |
| 602 | "instrumentation, before an instrumentable callsite" ); |
| 603 | } |
| 604 | return nullptr; |
| 605 | } |
| 606 | |
| 607 | InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) { |
| 608 | for (auto &I : BB) |
| 609 | if (auto *Incr = dyn_cast<InstrProfIncrementInst>(Val: &I)) |
| 610 | if (!isa<InstrProfIncrementInstStep>(Val: &I)) |
| 611 | return Incr; |
| 612 | return nullptr; |
| 613 | } |
| 614 | |
| 615 | InstrProfIncrementInstStep * |
| 616 | CtxProfAnalysis::getSelectInstrumentation(SelectInst &SI) { |
| 617 | Instruction *Prev = &SI; |
| 618 | while ((Prev = Prev->getPrevNode())) |
| 619 | if (auto *Step = dyn_cast<InstrProfIncrementInstStep>(Val: Prev)) |
| 620 | return Step; |
| 621 | return nullptr; |
| 622 | } |
| 623 | |
| 624 | template <class ProfTy> |
| 625 | static void preorderVisitOneRoot(ProfTy &Profile, |
| 626 | function_ref<void(ProfTy &)> Visitor) { |
| 627 | std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) { |
| 628 | Visitor(Ctx); |
| 629 | for (auto &[_, SubCtxSet] : Ctx.callsites()) |
| 630 | for (auto &[__, Subctx] : SubCtxSet) |
| 631 | Traverser(Subctx); |
| 632 | }; |
| 633 | Traverser(Profile); |
| 634 | } |
| 635 | |
| 636 | template <class ProfilesTy, class ProfTy> |
| 637 | static void preorderVisit(ProfilesTy &Profiles, |
| 638 | function_ref<void(ProfTy &)> Visitor) { |
| 639 | for (auto &[_, P] : Profiles) |
| 640 | preorderVisitOneRoot<ProfTy>(P, Visitor); |
| 641 | } |
| 642 | |
| 643 | void PGOContextualProfile::initIndex() { |
| 644 | // Initialize the head of the index list for each function. We don't need it |
| 645 | // after this point. |
| 646 | DenseMap<GlobalValue::GUID, PGOCtxProfContext *> InsertionPoints; |
| 647 | for (auto &[Guid, FI] : FuncInfo) |
| 648 | InsertionPoints[Guid] = &FI.Index; |
| 649 | preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>( |
| 650 | Profiles&: Profiles.Contexts, Visitor: [&](PGOCtxProfContext &Ctx) { |
| 651 | auto InsertIt = InsertionPoints.find(Val: Ctx.guid()); |
| 652 | if (InsertIt == InsertionPoints.end()) |
| 653 | return; |
| 654 | // Insert at the end of the list. Since we traverse in preorder, it |
| 655 | // means that when we iterate the list from the beginning, we'd |
| 656 | // encounter the contexts in the order we would have, should we have |
| 657 | // performed a full preorder traversal. |
| 658 | InsertIt->second->Next = &Ctx; |
| 659 | Ctx.Previous = InsertIt->second; |
| 660 | InsertIt->second = &Ctx; |
| 661 | }); |
| 662 | } |
| 663 | |
| 664 | bool PGOContextualProfile::isInSpecializedModule() const { |
| 665 | return ForceIsInSpecializedModule.getNumOccurrences() > 0 |
| 666 | ? ForceIsInSpecializedModule |
| 667 | : IsInSpecializedModule; |
| 668 | } |
| 669 | |
| 670 | void PGOContextualProfile::update(Visitor V, const Function &F) { |
| 671 | assert(isFunctionKnown(F)); |
| 672 | GlobalValue::GUID G = getDefinedFunctionGUID(F); |
| 673 | for (auto *Node = FuncInfo.find(x: G)->second.Index.Next; Node; |
| 674 | Node = Node->Next) |
| 675 | V(*reinterpret_cast<PGOCtxProfContext *>(Node)); |
| 676 | } |
| 677 | |
| 678 | void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const { |
| 679 | if (!F) |
| 680 | return preorderVisit<const PGOCtxProfContext::CallTargetMapTy, |
| 681 | const PGOCtxProfContext>(Profiles: Profiles.Contexts, Visitor: V); |
| 682 | assert(isFunctionKnown(*F)); |
| 683 | GlobalValue::GUID G = getDefinedFunctionGUID(F: *F); |
| 684 | for (const auto *Node = FuncInfo.find(x: G)->second.Index.Next; Node; |
| 685 | Node = Node->Next) |
| 686 | V(*reinterpret_cast<const PGOCtxProfContext *>(Node)); |
| 687 | } |
| 688 | |
| 689 | const CtxProfFlatProfile PGOContextualProfile::flatten() const { |
| 690 | CtxProfFlatProfile Flat; |
| 691 | auto Accummulate = [](SmallVectorImpl<uint64_t> &Into, |
| 692 | const SmallVectorImpl<uint64_t> &From, |
| 693 | uint64_t SamplingRate) { |
| 694 | if (Into.empty()) |
| 695 | Into.resize(N: From.size()); |
| 696 | assert(Into.size() == From.size() && |
| 697 | "All contexts corresponding to a function should have the exact " |
| 698 | "same number of counters." ); |
| 699 | for (size_t I = 0, E = Into.size(); I < E; ++I) |
| 700 | Into[I] += From[I] * SamplingRate; |
| 701 | }; |
| 702 | |
| 703 | for (const auto &[_, CtxRoot] : Profiles.Contexts) { |
| 704 | const uint64_t SamplingFactor = CtxRoot.getTotalRootEntryCount(); |
| 705 | preorderVisitOneRoot<const PGOCtxProfContext>( |
| 706 | Profile: CtxRoot, Visitor: [&](const PGOCtxProfContext &Ctx) { |
| 707 | Accummulate(Flat[Ctx.guid()], Ctx.counters(), SamplingFactor); |
| 708 | }); |
| 709 | |
| 710 | for (const auto &[G, Unh] : CtxRoot.getUnhandled()) |
| 711 | Accummulate(Flat[G], Unh, SamplingFactor); |
| 712 | } |
| 713 | // We don't sample "Flat" currently, so sampling rate is 1. |
| 714 | for (const auto &[G, FC] : Profiles.FlatProfiles) |
| 715 | Accummulate(Flat[G], FC, /*SamplingRate=*/1); |
| 716 | return Flat; |
| 717 | } |
| 718 | |
| 719 | const CtxProfFlatIndirectCallProfile |
| 720 | PGOContextualProfile::flattenVirtCalls() const { |
| 721 | CtxProfFlatIndirectCallProfile Ret; |
| 722 | for (const auto &[_, CtxRoot] : Profiles.Contexts) { |
| 723 | const uint64_t TotalRootEntryCount = CtxRoot.getTotalRootEntryCount(); |
| 724 | preorderVisitOneRoot<const PGOCtxProfContext>( |
| 725 | Profile: CtxRoot, Visitor: [&](const PGOCtxProfContext &Ctx) { |
| 726 | auto &Targets = Ret[Ctx.guid()]; |
| 727 | for (const auto &[ID, SubctxSet] : Ctx.callsites()) |
| 728 | for (const auto &Subctx : SubctxSet) |
| 729 | Targets[ID][Subctx.first] += |
| 730 | Subctx.second.getEntrycount() * TotalRootEntryCount; |
| 731 | }); |
| 732 | } |
| 733 | return Ret; |
| 734 | } |
| 735 | |
| 736 | void CtxProfAnalysis::collectIndirectCallPromotionList( |
| 737 | CallBase &IC, Result &Profile, |
| 738 | SetVector<std::pair<CallBase *, Function *>> &Candidates) { |
| 739 | const auto *Instr = CtxProfAnalysis::getCallsiteInstrumentation(CB&: IC); |
| 740 | if (!Instr) |
| 741 | return; |
| 742 | Module &M = *IC.getParent()->getModule(); |
| 743 | const uint32_t CallID = Instr->getIndex()->getZExtValue(); |
| 744 | Profile.visit( |
| 745 | V: [&](const PGOCtxProfContext &Ctx) { |
| 746 | const auto &Targets = Ctx.callsites().find(x: CallID); |
| 747 | if (Targets == Ctx.callsites().end()) |
| 748 | return; |
| 749 | for (const auto &[Guid, _] : Targets->second) |
| 750 | if (auto Name = Profile.getFunctionName(GUID: Guid); !Name.empty()) |
| 751 | if (auto *Target = M.getFunction(Name)) |
| 752 | if (Target->hasFnAttribute(Kind: Attribute::AlwaysInline)) |
| 753 | Candidates.insert(X: {&IC, Target}); |
| 754 | }, |
| 755 | F: IC.getCaller()); |
| 756 | } |
| 757 | |