1//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Instrumentation-based profile-guided optimization
10//
11//===----------------------------------------------------------------------===//
12
13#include "CodeGenPGO.h"
14#include "CGDebugInfo.h"
15#include "CodeGenFunction.h"
16#include "CoverageMappingGen.h"
17#include "clang/AST/RecursiveASTVisitor.h"
18#include "clang/AST/StmtVisitor.h"
19#include "llvm/IR/Intrinsics.h"
20#include "llvm/IR/MDBuilder.h"
21#include "llvm/Support/CommandLine.h"
22#include "llvm/Support/Endian.h"
23#include "llvm/Support/MD5.h"
24#include <optional>
25
26namespace llvm {
27extern cl::opt<bool> EnableSingleByteCoverage;
28} // namespace llvm
29
30static llvm::cl::opt<bool>
31 EnableValueProfiling("enable-value-profiling",
32 llvm::cl::desc("Enable value profiling"),
33 llvm::cl::Hidden, llvm::cl::init(Val: false));
34
35using namespace clang;
36using namespace CodeGen;
37
38void CodeGenPGO::setFuncName(StringRef Name,
39 llvm::GlobalValue::LinkageTypes Linkage) {
40 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
41 FuncName = llvm::getPGOFuncName(
42 RawFuncName: Name, Linkage, FileName: CGM.getCodeGenOpts().MainFileName,
43 Version: PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
44
45 // If we're generating a profile, create a variable for the name.
46 if (CGM.getCodeGenOpts().hasProfileClangInstr())
47 FuncNameVar = llvm::createPGOFuncNameVar(M&: CGM.getModule(), Linkage, PGOFuncName: FuncName);
48}
49
50void CodeGenPGO::setFuncName(llvm::Function *Fn) {
51 setFuncName(Name: Fn->getName(), Linkage: Fn->getLinkage());
52 // Create PGOFuncName meta data.
53 llvm::createPGOFuncNameMetadata(F&: *Fn, PGOFuncName: FuncName);
54}
55
56/// The version of the PGO hash algorithm.
57enum PGOHashVersion : unsigned {
58 PGO_HASH_V1,
59 PGO_HASH_V2,
60 PGO_HASH_V3,
61
62 // Keep this set to the latest hash version.
63 PGO_HASH_LATEST = PGO_HASH_V3
64};
65
66namespace {
67/// Stable hasher for PGO region counters.
68///
69/// PGOHash produces a stable hash of a given function's control flow.
70///
71/// Changing the output of this hash will invalidate all previously generated
72/// profiles -- i.e., don't do it.
73///
74/// \note When this hash does eventually change (years?), we still need to
75/// support old hashes. We'll need to pull in the version number from the
76/// profile data format and use the matching hash function.
77class PGOHash {
78 uint64_t Working;
79 unsigned Count;
80 PGOHashVersion HashVersion;
81 llvm::MD5 MD5;
82
83 static const int NumBitsPerType = 6;
84 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
85 static const unsigned TooBig = 1u << NumBitsPerType;
86
87public:
88 /// Hash values for AST nodes.
89 ///
90 /// Distinct values for AST nodes that have region counters attached.
91 ///
92 /// These values must be stable. All new members must be added at the end,
93 /// and no members should be removed. Changing the enumeration value for an
94 /// AST node will affect the hash of every function that contains that node.
95 enum HashType : unsigned char {
96 None = 0,
97 LabelStmt = 1,
98 WhileStmt,
99 DoStmt,
100 ForStmt,
101 CXXForRangeStmt,
102 ObjCForCollectionStmt,
103 SwitchStmt,
104 CaseStmt,
105 DefaultStmt,
106 IfStmt,
107 CXXTryStmt,
108 CXXCatchStmt,
109 ConditionalOperator,
110 BinaryOperatorLAnd,
111 BinaryOperatorLOr,
112 BinaryConditionalOperator,
113 // The preceding values are available with PGO_HASH_V1.
114
115 EndOfScope,
116 IfThenBranch,
117 IfElseBranch,
118 GotoStmt,
119 IndirectGotoStmt,
120 BreakStmt,
121 ContinueStmt,
122 ReturnStmt,
123 ThrowExpr,
124 UnaryOperatorLNot,
125 BinaryOperatorLT,
126 BinaryOperatorGT,
127 BinaryOperatorLE,
128 BinaryOperatorGE,
129 BinaryOperatorEQ,
130 BinaryOperatorNE,
131 // The preceding values are available since PGO_HASH_V2.
132
133 // Keep this last. It's for the static assert that follows.
134 LastHashType
135 };
136 static_assert(LastHashType <= TooBig, "Too many types in HashType");
137
138 PGOHash(PGOHashVersion HashVersion)
139 : Working(0), Count(0), HashVersion(HashVersion) {}
140 void combine(HashType Type);
141 uint64_t finalize();
142 PGOHashVersion getHashVersion() const { return HashVersion; }
143};
144const int PGOHash::NumBitsPerType;
145const unsigned PGOHash::NumTypesPerWord;
146const unsigned PGOHash::TooBig;
147
148/// Get the PGO hash version used in the given indexed profile.
149static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
150 CodeGenModule &CGM) {
151 if (PGOReader->getVersion() <= 4)
152 return PGO_HASH_V1;
153 if (PGOReader->getVersion() <= 5)
154 return PGO_HASH_V2;
155 return PGO_HASH_V3;
156}
157
158/// A RecursiveASTVisitor that fills a map of statements to PGO counters.
159struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
160 using Base = RecursiveASTVisitor<MapRegionCounters>;
161
162 /// The next counter value to assign.
163 unsigned NextCounter;
164 /// The function hash.
165 PGOHash Hash;
166 /// The map of statements to counters.
167 llvm::DenseMap<const Stmt *, CounterPair> &CounterMap;
168 /// The state of MC/DC Coverage in this function.
169 MCDC::State &MCDCState;
170 /// Maximum number of supported MC/DC conditions in a boolean expression.
171 unsigned MCDCMaxCond;
172 /// The profile version.
173 uint64_t ProfileVersion;
174 /// Diagnostics Engine used to report warnings.
175 DiagnosticsEngine &Diag;
176
177 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
178 llvm::DenseMap<const Stmt *, CounterPair> &CounterMap,
179 MCDC::State &MCDCState, unsigned MCDCMaxCond,
180 DiagnosticsEngine &Diag)
181 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
182 MCDCState(MCDCState), MCDCMaxCond(MCDCMaxCond),
183 ProfileVersion(ProfileVersion), Diag(Diag) {}
184
185 // Blocks and lambdas are handled as separate functions, so we need not
186 // traverse them in the parent context.
187 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
188 bool TraverseLambdaExpr(LambdaExpr *LE) {
189 // Traverse the captures, but not the body.
190 for (auto C : zip(t: LE->captures(), u: LE->capture_inits()))
191 TraverseLambdaCapture(LE, C: &std::get<0>(t&: C), Init: std::get<1>(t&: C));
192 return true;
193 }
194 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
195
196 bool VisitDecl(const Decl *D) {
197 switch (D->getKind()) {
198 default:
199 break;
200 case Decl::Function:
201 case Decl::CXXMethod:
202 case Decl::CXXConstructor:
203 case Decl::CXXDestructor:
204 case Decl::CXXConversion:
205 case Decl::ObjCMethod:
206 case Decl::Block:
207 case Decl::Captured:
208 CounterMap[D->getBody()] = NextCounter++;
209 break;
210 }
211 return true;
212 }
213
214 /// If \p S gets a fresh counter, update the counter mappings. Return the
215 /// V1 hash of \p S.
216 PGOHash::HashType updateCounterMappings(Stmt *S) {
217 auto Type = getHashType(HashVersion: PGO_HASH_V1, S);
218 if (Type != PGOHash::None)
219 CounterMap[S] = NextCounter++;
220 return Type;
221 }
222
223 /// The following stacks are used with dataTraverseStmtPre() and
224 /// dataTraverseStmtPost() to track the depth of nested logical operators in a
225 /// boolean expression in a function. The ultimate purpose is to keep track
226 /// of the number of leaf-level conditions in the boolean expression so that a
227 /// profile bitmap can be allocated based on that number.
228 ///
229 /// The stacks are also used to find error cases and notify the user. A
230 /// standard logical operator nest for a boolean expression could be in a form
231 /// similar to this: "x = a && b && c && (d || f)"
232 unsigned NumCond = 0;
233 bool SplitNestedLogicalOp = false;
234 SmallVector<const Stmt *, 16> NonLogOpStack;
235 SmallVector<const BinaryOperator *, 16> LogOpStack;
236
237 // Hook: dataTraverseStmtPre() is invoked prior to visiting an AST Stmt node.
238 bool dataTraverseStmtPre(Stmt *S) {
239 /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
240 if (MCDCMaxCond == 0)
241 return true;
242
243 /// At the top of the logical operator nest, reset the number of conditions,
244 /// also forget previously seen split nesting cases.
245 if (LogOpStack.empty()) {
246 NumCond = 0;
247 SplitNestedLogicalOp = false;
248 }
249
250 if (const Expr *E = dyn_cast<Expr>(Val: S)) {
251 const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val: E->IgnoreParens());
252 if (BinOp && BinOp->isLogicalOp()) {
253 /// Check for "split-nested" logical operators. This happens when a new
254 /// boolean expression logical-op nest is encountered within an existing
255 /// boolean expression, separated by a non-logical operator. For
256 /// example, in "x = (a && b && c && foo(d && f))", the "d && f" case
257 /// starts a new boolean expression that is separated from the other
258 /// conditions by the operator foo(). Split-nested cases are not
259 /// supported by MC/DC.
260 SplitNestedLogicalOp = SplitNestedLogicalOp || !NonLogOpStack.empty();
261
262 LogOpStack.push_back(Elt: BinOp);
263 return true;
264 }
265 }
266
267 /// Keep track of non-logical operators. These are OK as long as we don't
268 /// encounter a new logical operator after seeing one.
269 if (!LogOpStack.empty())
270 NonLogOpStack.push_back(Elt: S);
271
272 return true;
273 }
274
275 // Hook: dataTraverseStmtPost() is invoked by the AST visitor after visiting
276 // an AST Stmt node. MC/DC will use it to to signal when the top of a
277 // logical operation (boolean expression) nest is encountered.
278 bool dataTraverseStmtPost(Stmt *S) {
279 /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
280 if (MCDCMaxCond == 0)
281 return true;
282
283 if (const Expr *E = dyn_cast<Expr>(Val: S)) {
284 const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val: E->IgnoreParens());
285 if (BinOp && BinOp->isLogicalOp()) {
286 assert(LogOpStack.back() == BinOp);
287 LogOpStack.pop_back();
288
289 /// At the top of logical operator nest:
290 if (LogOpStack.empty()) {
291 /// Was the "split-nested" logical operator case encountered?
292 if (SplitNestedLogicalOp) {
293 unsigned DiagID = Diag.getCustomDiagID(
294 L: DiagnosticsEngine::Warning,
295 FormatString: "unsupported MC/DC boolean expression; "
296 "contains an operation with a nested boolean expression. "
297 "Expression will not be covered");
298 Diag.Report(Loc: S->getBeginLoc(), DiagID);
299 return true;
300 }
301
302 /// Was the maximum number of conditions encountered?
303 if (NumCond > MCDCMaxCond) {
304 unsigned DiagID = Diag.getCustomDiagID(
305 L: DiagnosticsEngine::Warning,
306 FormatString: "unsupported MC/DC boolean expression; "
307 "number of conditions (%0) exceeds max (%1). "
308 "Expression will not be covered");
309 Diag.Report(Loc: S->getBeginLoc(), DiagID) << NumCond << MCDCMaxCond;
310 return true;
311 }
312
313 // Otherwise, allocate the Decision.
314 MCDCState.DecisionByStmt[BinOp].BitmapIdx = 0;
315 }
316 return true;
317 }
318 }
319
320 if (!LogOpStack.empty())
321 NonLogOpStack.pop_back();
322
323 return true;
324 }
325
326 /// The RHS of all logical operators gets a fresh counter in order to count
327 /// how many times the RHS evaluates to true or false, depending on the
328 /// semantics of the operator. This is only valid for ">= v7" of the profile
329 /// version so that we facilitate backward compatibility. In addition, in
330 /// order to use MC/DC, count the number of total LHS and RHS conditions.
331 bool VisitBinaryOperator(BinaryOperator *S) {
332 if (S->isLogicalOp()) {
333 if (CodeGenFunction::isInstrumentedCondition(C: S->getLHS()))
334 NumCond++;
335
336 if (CodeGenFunction::isInstrumentedCondition(C: S->getRHS())) {
337 if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
338 CounterMap[S->getRHS()] = NextCounter++;
339
340 NumCond++;
341 }
342 }
343 return Base::VisitBinaryOperator(S);
344 }
345
346 bool VisitConditionalOperator(ConditionalOperator *S) {
347 if (llvm::EnableSingleByteCoverage && S->getTrueExpr())
348 CounterMap[S->getTrueExpr()] = NextCounter++;
349 if (llvm::EnableSingleByteCoverage && S->getFalseExpr())
350 CounterMap[S->getFalseExpr()] = NextCounter++;
351 return Base::VisitConditionalOperator(S);
352 }
353
354 /// Include \p S in the function hash.
355 bool VisitStmt(Stmt *S) {
356 auto Type = updateCounterMappings(S);
357 if (Hash.getHashVersion() != PGO_HASH_V1)
358 Type = getHashType(HashVersion: Hash.getHashVersion(), S);
359 if (Type != PGOHash::None)
360 Hash.combine(Type);
361 return true;
362 }
363
364 bool TraverseIfStmt(IfStmt *If) {
365 // If we used the V1 hash, use the default traversal.
366 if (Hash.getHashVersion() == PGO_HASH_V1)
367 return Base::TraverseIfStmt(S: If);
368
369 // When single byte coverage mode is enabled, add a counter to then and
370 // else.
371 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
372 for (Stmt *CS : If->children()) {
373 if (!CS || NoSingleByteCoverage)
374 continue;
375 if (CS == If->getThen())
376 CounterMap[If->getThen()] = NextCounter++;
377 else if (CS == If->getElse())
378 CounterMap[If->getElse()] = NextCounter++;
379 }
380
381 // Otherwise, keep track of which branch we're in while traversing.
382 VisitStmt(S: If);
383
384 for (Stmt *CS : If->children()) {
385 if (!CS)
386 continue;
387 if (CS == If->getThen())
388 Hash.combine(Type: PGOHash::IfThenBranch);
389 else if (CS == If->getElse())
390 Hash.combine(Type: PGOHash::IfElseBranch);
391 TraverseStmt(S: CS);
392 }
393 Hash.combine(Type: PGOHash::EndOfScope);
394 return true;
395 }
396
397 bool TraverseWhileStmt(WhileStmt *While) {
398 // When single byte coverage mode is enabled, add a counter to condition and
399 // body.
400 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
401 for (Stmt *CS : While->children()) {
402 if (!CS || NoSingleByteCoverage)
403 continue;
404 if (CS == While->getCond())
405 CounterMap[While->getCond()] = NextCounter++;
406 else if (CS == While->getBody())
407 CounterMap[While->getBody()] = NextCounter++;
408 }
409
410 Base::TraverseWhileStmt(S: While);
411 if (Hash.getHashVersion() != PGO_HASH_V1)
412 Hash.combine(Type: PGOHash::EndOfScope);
413 return true;
414 }
415
416 bool TraverseDoStmt(DoStmt *Do) {
417 // When single byte coverage mode is enabled, add a counter to condition and
418 // body.
419 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
420 for (Stmt *CS : Do->children()) {
421 if (!CS || NoSingleByteCoverage)
422 continue;
423 if (CS == Do->getCond())
424 CounterMap[Do->getCond()] = NextCounter++;
425 else if (CS == Do->getBody())
426 CounterMap[Do->getBody()] = NextCounter++;
427 }
428
429 Base::TraverseDoStmt(S: Do);
430 if (Hash.getHashVersion() != PGO_HASH_V1)
431 Hash.combine(Type: PGOHash::EndOfScope);
432 return true;
433 }
434
435 bool TraverseForStmt(ForStmt *For) {
436 // When single byte coverage mode is enabled, add a counter to condition,
437 // increment and body.
438 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
439 for (Stmt *CS : For->children()) {
440 if (!CS || NoSingleByteCoverage)
441 continue;
442 if (CS == For->getCond())
443 CounterMap[For->getCond()] = NextCounter++;
444 else if (CS == For->getInc())
445 CounterMap[For->getInc()] = NextCounter++;
446 else if (CS == For->getBody())
447 CounterMap[For->getBody()] = NextCounter++;
448 }
449
450 Base::TraverseForStmt(S: For);
451 if (Hash.getHashVersion() != PGO_HASH_V1)
452 Hash.combine(Type: PGOHash::EndOfScope);
453 return true;
454 }
455
456 bool TraverseCXXForRangeStmt(CXXForRangeStmt *ForRange) {
457 // When single byte coverage mode is enabled, add a counter to body.
458 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
459 for (Stmt *CS : ForRange->children()) {
460 if (!CS || NoSingleByteCoverage)
461 continue;
462 if (CS == ForRange->getBody())
463 CounterMap[ForRange->getBody()] = NextCounter++;
464 }
465
466 Base::TraverseCXXForRangeStmt(S: ForRange);
467 if (Hash.getHashVersion() != PGO_HASH_V1)
468 Hash.combine(Type: PGOHash::EndOfScope);
469 return true;
470 }
471
472// If the statement type \p N is nestable, and its nesting impacts profile
473// stability, define a custom traversal which tracks the end of the statement
474// in the hash (provided we're not using the V1 hash).
475#define DEFINE_NESTABLE_TRAVERSAL(N) \
476 bool Traverse##N(N *S) { \
477 Base::Traverse##N(S); \
478 if (Hash.getHashVersion() != PGO_HASH_V1) \
479 Hash.combine(PGOHash::EndOfScope); \
480 return true; \
481 }
482
483 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
484 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
485 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
486
487 /// Get version \p HashVersion of the PGO hash for \p S.
488 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
489 switch (S->getStmtClass()) {
490 default:
491 break;
492 case Stmt::LabelStmtClass:
493 return PGOHash::LabelStmt;
494 case Stmt::WhileStmtClass:
495 return PGOHash::WhileStmt;
496 case Stmt::DoStmtClass:
497 return PGOHash::DoStmt;
498 case Stmt::ForStmtClass:
499 return PGOHash::ForStmt;
500 case Stmt::CXXForRangeStmtClass:
501 return PGOHash::CXXForRangeStmt;
502 case Stmt::ObjCForCollectionStmtClass:
503 return PGOHash::ObjCForCollectionStmt;
504 case Stmt::SwitchStmtClass:
505 return PGOHash::SwitchStmt;
506 case Stmt::CaseStmtClass:
507 return PGOHash::CaseStmt;
508 case Stmt::DefaultStmtClass:
509 return PGOHash::DefaultStmt;
510 case Stmt::IfStmtClass:
511 return PGOHash::IfStmt;
512 case Stmt::CXXTryStmtClass:
513 return PGOHash::CXXTryStmt;
514 case Stmt::CXXCatchStmtClass:
515 return PGOHash::CXXCatchStmt;
516 case Stmt::ConditionalOperatorClass:
517 return PGOHash::ConditionalOperator;
518 case Stmt::BinaryConditionalOperatorClass:
519 return PGOHash::BinaryConditionalOperator;
520 case Stmt::BinaryOperatorClass: {
521 const BinaryOperator *BO = cast<BinaryOperator>(Val: S);
522 if (BO->getOpcode() == BO_LAnd)
523 return PGOHash::BinaryOperatorLAnd;
524 if (BO->getOpcode() == BO_LOr)
525 return PGOHash::BinaryOperatorLOr;
526 if (HashVersion >= PGO_HASH_V2) {
527 switch (BO->getOpcode()) {
528 default:
529 break;
530 case BO_LT:
531 return PGOHash::BinaryOperatorLT;
532 case BO_GT:
533 return PGOHash::BinaryOperatorGT;
534 case BO_LE:
535 return PGOHash::BinaryOperatorLE;
536 case BO_GE:
537 return PGOHash::BinaryOperatorGE;
538 case BO_EQ:
539 return PGOHash::BinaryOperatorEQ;
540 case BO_NE:
541 return PGOHash::BinaryOperatorNE;
542 }
543 }
544 break;
545 }
546 }
547
548 if (HashVersion >= PGO_HASH_V2) {
549 switch (S->getStmtClass()) {
550 default:
551 break;
552 case Stmt::GotoStmtClass:
553 return PGOHash::GotoStmt;
554 case Stmt::IndirectGotoStmtClass:
555 return PGOHash::IndirectGotoStmt;
556 case Stmt::BreakStmtClass:
557 return PGOHash::BreakStmt;
558 case Stmt::ContinueStmtClass:
559 return PGOHash::ContinueStmt;
560 case Stmt::ReturnStmtClass:
561 return PGOHash::ReturnStmt;
562 case Stmt::CXXThrowExprClass:
563 return PGOHash::ThrowExpr;
564 case Stmt::UnaryOperatorClass: {
565 const UnaryOperator *UO = cast<UnaryOperator>(Val: S);
566 if (UO->getOpcode() == UO_LNot)
567 return PGOHash::UnaryOperatorLNot;
568 break;
569 }
570 }
571 }
572
573 return PGOHash::None;
574 }
575};
576
577/// A StmtVisitor that propagates the raw counts through the AST and
578/// records the count at statements where the value may change.
579struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
580 /// PGO state.
581 CodeGenPGO &PGO;
582
583 /// A flag that is set when the current count should be recorded on the
584 /// next statement, such as at the exit of a loop.
585 bool RecordNextStmtCount;
586
587 /// The count at the current location in the traversal.
588 uint64_t CurrentCount;
589
590 /// The map of statements to count values.
591 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
592
593 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
594 struct BreakContinue {
595 uint64_t BreakCount = 0;
596 uint64_t ContinueCount = 0;
597 BreakContinue() = default;
598 };
599 SmallVector<BreakContinue, 8> BreakContinueStack;
600
601 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
602 CodeGenPGO &PGO)
603 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
604
605 void RecordStmtCount(const Stmt *S) {
606 if (RecordNextStmtCount) {
607 CountMap[S] = CurrentCount;
608 RecordNextStmtCount = false;
609 }
610 }
611
612 /// Set and return the current count.
613 uint64_t setCount(uint64_t Count) {
614 CurrentCount = Count;
615 return Count;
616 }
617
618 void VisitStmt(const Stmt *S) {
619 RecordStmtCount(S);
620 for (const Stmt *Child : S->children())
621 if (Child)
622 this->Visit(S: Child);
623 }
624
625 void VisitFunctionDecl(const FunctionDecl *D) {
626 // Counter tracks entry to the function body.
627 uint64_t BodyCount = setCount(PGO.getRegionCount(S: D->getBody()));
628 CountMap[D->getBody()] = BodyCount;
629 Visit(S: D->getBody());
630 }
631
632 // Skip lambda expressions. We visit these as FunctionDecls when we're
633 // generating them and aren't interested in the body when generating a
634 // parent context.
635 void VisitLambdaExpr(const LambdaExpr *LE) {}
636
637 void VisitCapturedDecl(const CapturedDecl *D) {
638 // Counter tracks entry to the capture body.
639 uint64_t BodyCount = setCount(PGO.getRegionCount(S: D->getBody()));
640 CountMap[D->getBody()] = BodyCount;
641 Visit(S: D->getBody());
642 }
643
644 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
645 // Counter tracks entry to the method body.
646 uint64_t BodyCount = setCount(PGO.getRegionCount(S: D->getBody()));
647 CountMap[D->getBody()] = BodyCount;
648 Visit(S: D->getBody());
649 }
650
651 void VisitBlockDecl(const BlockDecl *D) {
652 // Counter tracks entry to the block body.
653 uint64_t BodyCount = setCount(PGO.getRegionCount(S: D->getBody()));
654 CountMap[D->getBody()] = BodyCount;
655 Visit(S: D->getBody());
656 }
657
658 void VisitReturnStmt(const ReturnStmt *S) {
659 RecordStmtCount(S);
660 if (S->getRetValue())
661 Visit(S: S->getRetValue());
662 CurrentCount = 0;
663 RecordNextStmtCount = true;
664 }
665
666 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
667 RecordStmtCount(S: E);
668 if (E->getSubExpr())
669 Visit(S: E->getSubExpr());
670 CurrentCount = 0;
671 RecordNextStmtCount = true;
672 }
673
674 void VisitGotoStmt(const GotoStmt *S) {
675 RecordStmtCount(S);
676 CurrentCount = 0;
677 RecordNextStmtCount = true;
678 }
679
680 void VisitLabelStmt(const LabelStmt *S) {
681 RecordNextStmtCount = false;
682 // Counter tracks the block following the label.
683 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
684 CountMap[S] = BlockCount;
685 Visit(S: S->getSubStmt());
686 }
687
688 void VisitBreakStmt(const BreakStmt *S) {
689 RecordStmtCount(S);
690 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
691 BreakContinueStack.back().BreakCount += CurrentCount;
692 CurrentCount = 0;
693 RecordNextStmtCount = true;
694 }
695
696 void VisitContinueStmt(const ContinueStmt *S) {
697 RecordStmtCount(S);
698 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
699 BreakContinueStack.back().ContinueCount += CurrentCount;
700 CurrentCount = 0;
701 RecordNextStmtCount = true;
702 }
703
704 void VisitWhileStmt(const WhileStmt *S) {
705 RecordStmtCount(S);
706 uint64_t ParentCount = CurrentCount;
707
708 BreakContinueStack.push_back(Elt: BreakContinue());
709 // Visit the body region first so the break/continue adjustments can be
710 // included when visiting the condition.
711 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
712 CountMap[S->getBody()] = CurrentCount;
713 Visit(S: S->getBody());
714 uint64_t BackedgeCount = CurrentCount;
715
716 // ...then go back and propagate counts through the condition. The count
717 // at the start of the condition is the sum of the incoming edges,
718 // the backedge from the end of the loop body, and the edges from
719 // continue statements.
720 BreakContinue BC = BreakContinueStack.pop_back_val();
721 uint64_t CondCount =
722 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
723 CountMap[S->getCond()] = CondCount;
724 Visit(S: S->getCond());
725 setCount(BC.BreakCount + CondCount - BodyCount);
726 RecordNextStmtCount = true;
727 }
728
729 void VisitDoStmt(const DoStmt *S) {
730 RecordStmtCount(S);
731 uint64_t LoopCount = PGO.getRegionCount(S);
732
733 BreakContinueStack.push_back(Elt: BreakContinue());
734 // The count doesn't include the fallthrough from the parent scope. Add it.
735 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
736 CountMap[S->getBody()] = BodyCount;
737 Visit(S: S->getBody());
738 uint64_t BackedgeCount = CurrentCount;
739
740 BreakContinue BC = BreakContinueStack.pop_back_val();
741 // The count at the start of the condition is equal to the count at the
742 // end of the body, plus any continues.
743 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
744 CountMap[S->getCond()] = CondCount;
745 Visit(S: S->getCond());
746 setCount(BC.BreakCount + CondCount - LoopCount);
747 RecordNextStmtCount = true;
748 }
749
750 void VisitForStmt(const ForStmt *S) {
751 RecordStmtCount(S);
752 if (S->getInit())
753 Visit(S: S->getInit());
754
755 uint64_t ParentCount = CurrentCount;
756
757 BreakContinueStack.push_back(Elt: BreakContinue());
758 // Visit the body region first. (This is basically the same as a while
759 // loop; see further comments in VisitWhileStmt.)
760 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
761 CountMap[S->getBody()] = BodyCount;
762 Visit(S: S->getBody());
763 uint64_t BackedgeCount = CurrentCount;
764 BreakContinue BC = BreakContinueStack.pop_back_val();
765
766 // The increment is essentially part of the body but it needs to include
767 // the count for all the continue statements.
768 if (S->getInc()) {
769 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
770 CountMap[S->getInc()] = IncCount;
771 Visit(S: S->getInc());
772 }
773
774 // ...then go back and propagate counts through the condition.
775 uint64_t CondCount =
776 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
777 if (S->getCond()) {
778 CountMap[S->getCond()] = CondCount;
779 Visit(S: S->getCond());
780 }
781 setCount(BC.BreakCount + CondCount - BodyCount);
782 RecordNextStmtCount = true;
783 }
784
785 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
786 RecordStmtCount(S);
787 if (S->getInit())
788 Visit(S: S->getInit());
789 Visit(S: S->getLoopVarStmt());
790 Visit(S: S->getRangeStmt());
791 Visit(S: S->getBeginStmt());
792 Visit(S: S->getEndStmt());
793
794 uint64_t ParentCount = CurrentCount;
795 BreakContinueStack.push_back(Elt: BreakContinue());
796 // Visit the body region first. (This is basically the same as a while
797 // loop; see further comments in VisitWhileStmt.)
798 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
799 CountMap[S->getBody()] = BodyCount;
800 Visit(S: S->getBody());
801 uint64_t BackedgeCount = CurrentCount;
802 BreakContinue BC = BreakContinueStack.pop_back_val();
803
804 // The increment is essentially part of the body but it needs to include
805 // the count for all the continue statements.
806 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
807 CountMap[S->getInc()] = IncCount;
808 Visit(S: S->getInc());
809
810 // ...then go back and propagate counts through the condition.
811 uint64_t CondCount =
812 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
813 CountMap[S->getCond()] = CondCount;
814 Visit(S: S->getCond());
815 setCount(BC.BreakCount + CondCount - BodyCount);
816 RecordNextStmtCount = true;
817 }
818
819 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
820 RecordStmtCount(S);
821 Visit(S: S->getElement());
822 uint64_t ParentCount = CurrentCount;
823 BreakContinueStack.push_back(Elt: BreakContinue());
824 // Counter tracks the body of the loop.
825 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
826 CountMap[S->getBody()] = BodyCount;
827 Visit(S: S->getBody());
828 uint64_t BackedgeCount = CurrentCount;
829 BreakContinue BC = BreakContinueStack.pop_back_val();
830
831 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
832 BodyCount);
833 RecordNextStmtCount = true;
834 }
835
836 void VisitSwitchStmt(const SwitchStmt *S) {
837 RecordStmtCount(S);
838 if (S->getInit())
839 Visit(S: S->getInit());
840 Visit(S: S->getCond());
841 CurrentCount = 0;
842 BreakContinueStack.push_back(Elt: BreakContinue());
843 Visit(S: S->getBody());
844 // If the switch is inside a loop, add the continue counts.
845 BreakContinue BC = BreakContinueStack.pop_back_val();
846 if (!BreakContinueStack.empty())
847 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
848 // Counter tracks the exit block of the switch.
849 setCount(PGO.getRegionCount(S));
850 RecordNextStmtCount = true;
851 }
852
853 void VisitSwitchCase(const SwitchCase *S) {
854 RecordNextStmtCount = false;
855 // Counter for this particular case. This counts only jumps from the
856 // switch header and does not include fallthrough from the case before
857 // this one.
858 uint64_t CaseCount = PGO.getRegionCount(S);
859 setCount(CurrentCount + CaseCount);
860 // We need the count without fallthrough in the mapping, so it's more useful
861 // for branch probabilities.
862 CountMap[S] = CaseCount;
863 RecordNextStmtCount = true;
864 Visit(S: S->getSubStmt());
865 }
866
867 void VisitIfStmt(const IfStmt *S) {
868 RecordStmtCount(S);
869
870 if (S->isConsteval()) {
871 const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
872 if (Stm)
873 Visit(S: Stm);
874 return;
875 }
876
877 uint64_t ParentCount = CurrentCount;
878 if (S->getInit())
879 Visit(S: S->getInit());
880 Visit(S: S->getCond());
881
882 // Counter tracks the "then" part of an if statement. The count for
883 // the "else" part, if it exists, will be calculated from this counter.
884 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
885 CountMap[S->getThen()] = ThenCount;
886 Visit(S: S->getThen());
887 uint64_t OutCount = CurrentCount;
888
889 uint64_t ElseCount = ParentCount - ThenCount;
890 if (S->getElse()) {
891 setCount(ElseCount);
892 CountMap[S->getElse()] = ElseCount;
893 Visit(S: S->getElse());
894 OutCount += CurrentCount;
895 } else
896 OutCount += ElseCount;
897 setCount(OutCount);
898 RecordNextStmtCount = true;
899 }
900
901 void VisitCXXTryStmt(const CXXTryStmt *S) {
902 RecordStmtCount(S);
903 Visit(S: S->getTryBlock());
904 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
905 Visit(S: S->getHandler(i: I));
906 // Counter tracks the continuation block of the try statement.
907 setCount(PGO.getRegionCount(S));
908 RecordNextStmtCount = true;
909 }
910
911 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
912 RecordNextStmtCount = false;
913 // Counter tracks the catch statement's handler block.
914 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
915 CountMap[S] = CatchCount;
916 Visit(S: S->getHandlerBlock());
917 }
918
919 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
920 RecordStmtCount(S: E);
921 uint64_t ParentCount = CurrentCount;
922 Visit(S: E->getCond());
923
924 // Counter tracks the "true" part of a conditional operator. The
925 // count in the "false" part will be calculated from this counter.
926 uint64_t TrueCount = setCount(PGO.getRegionCount(S: E));
927 CountMap[E->getTrueExpr()] = TrueCount;
928 Visit(S: E->getTrueExpr());
929 uint64_t OutCount = CurrentCount;
930
931 uint64_t FalseCount = setCount(ParentCount - TrueCount);
932 CountMap[E->getFalseExpr()] = FalseCount;
933 Visit(S: E->getFalseExpr());
934 OutCount += CurrentCount;
935
936 setCount(OutCount);
937 RecordNextStmtCount = true;
938 }
939
940 void VisitBinLAnd(const BinaryOperator *E) {
941 RecordStmtCount(S: E);
942 uint64_t ParentCount = CurrentCount;
943 Visit(S: E->getLHS());
944 // Counter tracks the right hand side of a logical and operator.
945 uint64_t RHSCount = setCount(PGO.getRegionCount(S: E));
946 CountMap[E->getRHS()] = RHSCount;
947 Visit(S: E->getRHS());
948 setCount(ParentCount + RHSCount - CurrentCount);
949 RecordNextStmtCount = true;
950 }
951
952 void VisitBinLOr(const BinaryOperator *E) {
953 RecordStmtCount(S: E);
954 uint64_t ParentCount = CurrentCount;
955 Visit(S: E->getLHS());
956 // Counter tracks the right hand side of a logical or operator.
957 uint64_t RHSCount = setCount(PGO.getRegionCount(S: E));
958 CountMap[E->getRHS()] = RHSCount;
959 Visit(S: E->getRHS());
960 setCount(ParentCount + RHSCount - CurrentCount);
961 RecordNextStmtCount = true;
962 }
963};
964} // end anonymous namespace
965
966void PGOHash::combine(HashType Type) {
967 // Check that we never combine 0 and only have six bits.
968 assert(Type && "Hash is invalid: unexpected type 0");
969 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
970
971 // Pass through MD5 if enough work has built up.
972 if (Count && Count % NumTypesPerWord == 0) {
973 using namespace llvm::support;
974 uint64_t Swapped =
975 endian::byte_swap<uint64_t, llvm::endianness::little>(value: Working);
976 MD5.update(Data: llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
977 Working = 0;
978 }
979
980 // Accumulate the current type.
981 ++Count;
982 Working = Working << NumBitsPerType | Type;
983}
984
985uint64_t PGOHash::finalize() {
986 // Use Working as the hash directly if we never used MD5.
987 if (Count <= NumTypesPerWord)
988 // No need to byte swap here, since none of the math was endian-dependent.
989 // This number will be byte-swapped as required on endianness transitions,
990 // so we will see the same value on the other side.
991 return Working;
992
993 // Check for remaining work in Working.
994 if (Working) {
995 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
996 // is buggy because it converts a uint64_t into an array of uint8_t.
997 if (HashVersion < PGO_HASH_V3) {
998 MD5.update(Data: {(uint8_t)Working});
999 } else {
1000 using namespace llvm::support;
1001 uint64_t Swapped =
1002 endian::byte_swap<uint64_t, llvm::endianness::little>(value: Working);
1003 MD5.update(Data: llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
1004 }
1005 }
1006
1007 // Finalize the MD5 and return the hash.
1008 llvm::MD5::MD5Result Result;
1009 MD5.final(Result);
1010 return Result.low();
1011}
1012
1013void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
1014 const Decl *D = GD.getDecl();
1015 if (!D->hasBody())
1016 return;
1017
1018 // Skip CUDA/HIP kernel launch stub functions.
1019 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
1020 D->hasAttr<CUDAGlobalAttr>())
1021 return;
1022
1023 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
1024 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1025 if (!InstrumentRegions && !PGOReader)
1026 return;
1027 if (D->isImplicit())
1028 return;
1029 // Constructors and destructors may be represented by several functions in IR.
1030 // If so, instrument only base variant, others are implemented by delegation
1031 // to the base one, it would be counted twice otherwise.
1032 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
1033 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(Val: D))
1034 if (GD.getCtorType() != Ctor_Base &&
1035 CodeGenFunction::IsConstructorDelegationValid(Ctor: CCD))
1036 return;
1037 }
1038 if (isa<CXXDestructorDecl>(Val: D) && GD.getDtorType() != Dtor_Base)
1039 return;
1040
1041 CGM.ClearUnusedCoverageMapping(D);
1042 if (Fn->hasFnAttribute(Kind: llvm::Attribute::NoProfile))
1043 return;
1044 if (Fn->hasFnAttribute(Kind: llvm::Attribute::SkipProfile))
1045 return;
1046
1047 SourceManager &SM = CGM.getContext().getSourceManager();
1048 if (!llvm::coverage::SystemHeadersCoverage &&
1049 SM.isInSystemHeader(Loc: D->getLocation()))
1050 return;
1051
1052 setFuncName(Fn);
1053
1054 mapRegionCounters(D);
1055 if (CGM.getCodeGenOpts().CoverageMapping)
1056 emitCounterRegionMapping(D);
1057 if (PGOReader) {
1058 loadRegionCounts(PGOReader, IsInMainFile: SM.isInMainFile(Loc: D->getLocation()));
1059 computeRegionCounts(D);
1060 applyFunctionAttributes(PGOReader, Fn);
1061 }
1062}
1063
1064void CodeGenPGO::mapRegionCounters(const Decl *D) {
1065 // Use the latest hash version when inserting instrumentation, but use the
1066 // version in the indexed profile if we're reading PGO data.
1067 PGOHashVersion HashVersion = PGO_HASH_LATEST;
1068 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
1069 if (auto *PGOReader = CGM.getPGOReader()) {
1070 HashVersion = getPGOHashVersion(PGOReader, CGM);
1071 ProfileVersion = PGOReader->getVersion();
1072 }
1073
1074 // If MC/DC is enabled, set the MaxConditions to a preset value. Otherwise,
1075 // set it to zero. This value impacts the number of conditions accepted in a
1076 // given boolean expression, which impacts the size of the bitmap used to
1077 // track test vector execution for that boolean expression. Because the
1078 // bitmap scales exponentially (2^n) based on the number of conditions seen,
1079 // the maximum value is hard-coded at 6 conditions, which is more than enough
1080 // for most embedded applications. Setting a maximum value prevents the
1081 // bitmap footprint from growing too large without the user's knowledge. In
1082 // the future, this value could be adjusted with a command-line option.
1083 unsigned MCDCMaxConditions =
1084 (CGM.getCodeGenOpts().MCDCCoverage ? CGM.getCodeGenOpts().MCDCMaxConds
1085 : 0);
1086
1087 RegionCounterMap.reset(p: new llvm::DenseMap<const Stmt *, CounterPair>);
1088 RegionMCDCState.reset(p: new MCDC::State);
1089 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap,
1090 *RegionMCDCState, MCDCMaxConditions, CGM.getDiags());
1091 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(Val: D))
1092 Walker.TraverseDecl(D: const_cast<FunctionDecl *>(FD));
1093 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(Val: D))
1094 Walker.TraverseDecl(D: const_cast<ObjCMethodDecl *>(MD));
1095 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(Val: D))
1096 Walker.TraverseDecl(D: const_cast<BlockDecl *>(BD));
1097 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(Val: D))
1098 Walker.TraverseDecl(D: const_cast<CapturedDecl *>(CD));
1099 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
1100 NumRegionCounters = Walker.NextCounter;
1101 FunctionHash = Walker.Hash.finalize();
1102}
1103
1104bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
1105 if (!D->getBody())
1106 return true;
1107
1108 // Skip host-only functions in the CUDA device compilation and device-only
1109 // functions in the host compilation. Just roughly filter them out based on
1110 // the function attributes. If there are effectively host-only or device-only
1111 // ones, their coverage mapping may still be generated.
1112 if (CGM.getLangOpts().CUDA &&
1113 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
1114 !D->hasAttr<CUDAGlobalAttr>()) ||
1115 (!CGM.getLangOpts().CUDAIsDevice &&
1116 (D->hasAttr<CUDAGlobalAttr>() ||
1117 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
1118 return true;
1119
1120 // Don't map the functions in system headers.
1121 const auto &SM = CGM.getContext().getSourceManager();
1122 auto Loc = D->getBody()->getBeginLoc();
1123 return !llvm::coverage::SystemHeadersCoverage && SM.isInSystemHeader(Loc);
1124}
1125
1126void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
1127 if (skipRegionMappingForDecl(D))
1128 return;
1129
1130 std::string CoverageMapping;
1131 llvm::raw_string_ostream OS(CoverageMapping);
1132 RegionMCDCState->BranchByStmt.clear();
1133 CoverageMappingGen MappingGen(
1134 *CGM.getCoverageMapping(), CGM.getContext().getSourceManager(),
1135 CGM.getLangOpts(), RegionCounterMap.get(), RegionMCDCState.get());
1136 MappingGen.emitCounterMapping(D, OS);
1137
1138 if (CoverageMapping.empty())
1139 return;
1140
1141 CGM.getCoverageMapping()->addFunctionMappingRecord(
1142 FunctionName: FuncNameVar, FunctionNameValue: FuncName, FunctionHash, CoverageMapping);
1143}
1144
1145void
1146CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
1147 llvm::GlobalValue::LinkageTypes Linkage) {
1148 if (skipRegionMappingForDecl(D))
1149 return;
1150
1151 std::string CoverageMapping;
1152 llvm::raw_string_ostream OS(CoverageMapping);
1153 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
1154 CGM.getContext().getSourceManager(),
1155 CGM.getLangOpts());
1156 MappingGen.emitEmptyMapping(D, OS);
1157
1158 if (CoverageMapping.empty())
1159 return;
1160
1161 setFuncName(Name, Linkage);
1162 CGM.getCoverageMapping()->addFunctionMappingRecord(
1163 FunctionName: FuncNameVar, FunctionNameValue: FuncName, FunctionHash, CoverageMapping, IsUsed: false);
1164}
1165
1166void CodeGenPGO::computeRegionCounts(const Decl *D) {
1167 StmtCountMap.reset(p: new llvm::DenseMap<const Stmt *, uint64_t>);
1168 ComputeRegionCounts Walker(*StmtCountMap, *this);
1169 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(Val: D))
1170 Walker.VisitFunctionDecl(D: FD);
1171 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(Val: D))
1172 Walker.VisitObjCMethodDecl(D: MD);
1173 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(Val: D))
1174 Walker.VisitBlockDecl(D: BD);
1175 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(Val: D))
1176 Walker.VisitCapturedDecl(D: const_cast<CapturedDecl *>(CD));
1177}
1178
1179void
1180CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
1181 llvm::Function *Fn) {
1182 if (!haveRegionCounts())
1183 return;
1184
1185 uint64_t FunctionCount = getRegionCount(S: nullptr);
1186 Fn->setEntryCount(Count: FunctionCount);
1187}
1188
1189std::pair<bool, bool> CodeGenPGO::getIsCounterPair(const Stmt *S) const {
1190 if (!RegionCounterMap)
1191 return {false, false};
1192
1193 auto I = RegionCounterMap->find(Val: S);
1194 if (I == RegionCounterMap->end())
1195 return {false, false};
1196
1197 return {I->second.Executed.hasValue(), I->second.Skipped.hasValue()};
1198}
1199
1200void CodeGenPGO::emitCounterSetOrIncrement(CGBuilderTy &Builder, const Stmt *S,
1201 llvm::Value *StepV) {
1202 if (!RegionCounterMap || !Builder.GetInsertBlock())
1203 return;
1204
1205 unsigned Counter = (*RegionCounterMap)[S].Executed;
1206
1207 // Make sure that pointer to global is passed in with zero addrspace
1208 // This is relevant during GPU profiling
1209 auto *NormalizedFuncNameVarPtr =
1210 llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
1211 C: FuncNameVar, Ty: llvm::PointerType::get(C&: CGM.getLLVMContext(), AddressSpace: 0));
1212
1213 llvm::Value *Args[] = {
1214 NormalizedFuncNameVarPtr, Builder.getInt64(C: FunctionHash),
1215 Builder.getInt32(C: NumRegionCounters), Builder.getInt32(C: Counter), StepV};
1216
1217 if (llvm::EnableSingleByteCoverage)
1218 Builder.CreateCall(Callee: CGM.getIntrinsic(IID: llvm::Intrinsic::instrprof_cover),
1219 Args: ArrayRef(Args, 4));
1220 else if (!StepV)
1221 Builder.CreateCall(Callee: CGM.getIntrinsic(IID: llvm::Intrinsic::instrprof_increment),
1222 Args: ArrayRef(Args, 4));
1223 else
1224 Builder.CreateCall(
1225 Callee: CGM.getIntrinsic(IID: llvm::Intrinsic::instrprof_increment_step), Args);
1226}
1227
1228bool CodeGenPGO::canEmitMCDCCoverage(const CGBuilderTy &Builder) {
1229 return (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1230 CGM.getCodeGenOpts().MCDCCoverage && Builder.GetInsertBlock());
1231}
1232
1233void CodeGenPGO::emitMCDCParameters(CGBuilderTy &Builder) {
1234 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1235 return;
1236
1237 auto *I8PtrTy = llvm::PointerType::getUnqual(C&: CGM.getLLVMContext());
1238
1239 // Emit intrinsic representing MCDC bitmap parameters at function entry.
1240 // This is used by the instrumentation pass, but it isn't actually lowered to
1241 // anything.
1242 llvm::Value *Args[3] = {llvm::ConstantExpr::getBitCast(C: FuncNameVar, Ty: I8PtrTy),
1243 Builder.getInt64(C: FunctionHash),
1244 Builder.getInt32(C: RegionMCDCState->BitmapBits)};
1245 Builder.CreateCall(
1246 Callee: CGM.getIntrinsic(IID: llvm::Intrinsic::instrprof_mcdc_parameters), Args);
1247}
1248
1249void CodeGenPGO::emitMCDCTestVectorBitmapUpdate(CGBuilderTy &Builder,
1250 const Expr *S,
1251 Address MCDCCondBitmapAddr,
1252 CodeGenFunction &CGF) {
1253 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1254 return;
1255
1256 S = S->IgnoreParens();
1257
1258 auto DecisionStateIter = RegionMCDCState->DecisionByStmt.find(Val: S);
1259 if (DecisionStateIter == RegionMCDCState->DecisionByStmt.end())
1260 return;
1261
1262 // Don't create tvbitmap_update if the record is allocated but excluded.
1263 // Or `bitmap |= (1 << 0)` would be wrongly executed to the next bitmap.
1264 if (DecisionStateIter->second.Indices.size() == 0)
1265 return;
1266
1267 // Extract the offset of the global bitmap associated with this expression.
1268 unsigned MCDCTestVectorBitmapOffset = DecisionStateIter->second.BitmapIdx;
1269 auto *I8PtrTy = llvm::PointerType::getUnqual(C&: CGM.getLLVMContext());
1270
1271 // Emit intrinsic responsible for updating the global bitmap corresponding to
1272 // a boolean expression. The index being set is based on the value loaded
1273 // from a pointer to a dedicated temporary value on the stack that is itself
1274 // updated via emitMCDCCondBitmapReset() and emitMCDCCondBitmapUpdate(). The
1275 // index represents an executed test vector.
1276 llvm::Value *Args[4] = {llvm::ConstantExpr::getBitCast(C: FuncNameVar, Ty: I8PtrTy),
1277 Builder.getInt64(C: FunctionHash),
1278 Builder.getInt32(C: MCDCTestVectorBitmapOffset),
1279 MCDCCondBitmapAddr.emitRawPointer(CGF)};
1280 Builder.CreateCall(
1281 Callee: CGM.getIntrinsic(IID: llvm::Intrinsic::instrprof_mcdc_tvbitmap_update), Args);
1282}
1283
1284void CodeGenPGO::emitMCDCCondBitmapReset(CGBuilderTy &Builder, const Expr *S,
1285 Address MCDCCondBitmapAddr) {
1286 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1287 return;
1288
1289 S = S->IgnoreParens();
1290
1291 if (!RegionMCDCState->DecisionByStmt.contains(Val: S))
1292 return;
1293
1294 // Emit intrinsic that resets a dedicated temporary value on the stack to 0.
1295 Builder.CreateStore(Val: Builder.getInt32(C: 0), Addr: MCDCCondBitmapAddr);
1296}
1297
1298void CodeGenPGO::emitMCDCCondBitmapUpdate(CGBuilderTy &Builder, const Expr *S,
1299 Address MCDCCondBitmapAddr,
1300 llvm::Value *Val,
1301 CodeGenFunction &CGF) {
1302 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1303 return;
1304
1305 // Even though, for simplicity, parentheses and unary logical-NOT operators
1306 // are considered part of their underlying condition for both MC/DC and
1307 // branch coverage, the condition IDs themselves are assigned and tracked
1308 // using the underlying condition itself. This is done solely for
1309 // consistency since parentheses and logical-NOTs are ignored when checking
1310 // whether the condition is actually an instrumentable condition. This can
1311 // also make debugging a bit easier.
1312 S = CodeGenFunction::stripCond(C: S);
1313
1314 auto BranchStateIter = RegionMCDCState->BranchByStmt.find(Val: S);
1315 if (BranchStateIter == RegionMCDCState->BranchByStmt.end())
1316 return;
1317
1318 // Extract the ID of the condition we are setting in the bitmap.
1319 const auto &Branch = BranchStateIter->second;
1320 assert(Branch.ID >= 0 && "Condition has no ID!");
1321 assert(Branch.DecisionStmt);
1322
1323 // Cancel the emission if the Decision is erased after the allocation.
1324 const auto DecisionIter =
1325 RegionMCDCState->DecisionByStmt.find(Val: Branch.DecisionStmt);
1326 if (DecisionIter == RegionMCDCState->DecisionByStmt.end())
1327 return;
1328
1329 const auto &TVIdxs = DecisionIter->second.Indices[Branch.ID];
1330
1331 auto *CurTV = Builder.CreateLoad(Addr: MCDCCondBitmapAddr,
1332 Name: "mcdc." + Twine(Branch.ID + 1) + ".cur");
1333 auto *NewTV = Builder.CreateAdd(LHS: CurTV, RHS: Builder.getInt32(C: TVIdxs[true]));
1334 NewTV = Builder.CreateSelect(
1335 C: Val, True: NewTV, False: Builder.CreateAdd(LHS: CurTV, RHS: Builder.getInt32(C: TVIdxs[false])));
1336 Builder.CreateStore(Val: NewTV, Addr: MCDCCondBitmapAddr);
1337}
1338
1339void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
1340 if (CGM.getCodeGenOpts().hasProfileClangInstr())
1341 M.addModuleFlag(Behavior: llvm::Module::Warning, Key: "EnableValueProfiling",
1342 Val: uint32_t(EnableValueProfiling));
1343}
1344
1345void CodeGenPGO::setProfileVersion(llvm::Module &M) {
1346 if (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1347 llvm::EnableSingleByteCoverage) {
1348 const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
1349 llvm::Type *IntTy64 = llvm::Type::getInt64Ty(C&: M.getContext());
1350 uint64_t ProfileVersion =
1351 (INSTR_PROF_RAW_VERSION | VARIANT_MASK_BYTE_COVERAGE);
1352
1353 auto IRLevelVersionVariable = new llvm::GlobalVariable(
1354 M, IntTy64, true, llvm::GlobalValue::WeakAnyLinkage,
1355 llvm::Constant::getIntegerValue(Ty: IntTy64,
1356 V: llvm::APInt(64, ProfileVersion)),
1357 VarName);
1358
1359 IRLevelVersionVariable->setVisibility(llvm::GlobalValue::HiddenVisibility);
1360 llvm::Triple TT(M.getTargetTriple());
1361 if (TT.isGPU())
1362 IRLevelVersionVariable->setVisibility(
1363 llvm::GlobalValue::ProtectedVisibility);
1364 if (TT.supportsCOMDAT()) {
1365 IRLevelVersionVariable->setLinkage(llvm::GlobalValue::ExternalLinkage);
1366 IRLevelVersionVariable->setComdat(M.getOrInsertComdat(Name: VarName));
1367 }
1368 IRLevelVersionVariable->setDSOLocal(true);
1369 }
1370}
1371
1372// This method either inserts a call to the profile run-time during
1373// instrumentation or puts profile data into metadata for PGO use.
1374void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
1375 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
1376
1377 if (!EnableValueProfiling)
1378 return;
1379
1380 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
1381 return;
1382
1383 if (isa<llvm::Constant>(Val: ValuePtr))
1384 return;
1385
1386 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
1387 if (InstrumentValueSites && RegionCounterMap) {
1388 auto BuilderInsertPoint = Builder.saveIP();
1389 Builder.SetInsertPoint(ValueSite);
1390 llvm::Value *Args[5] = {
1391 FuncNameVar,
1392 Builder.getInt64(C: FunctionHash),
1393 Builder.CreatePtrToInt(V: ValuePtr, DestTy: Builder.getInt64Ty()),
1394 Builder.getInt32(C: ValueKind),
1395 Builder.getInt32(C: NumValueSites[ValueKind]++)
1396 };
1397 Builder.CreateCall(
1398 Callee: CGM.getIntrinsic(IID: llvm::Intrinsic::instrprof_value_profile), Args);
1399 Builder.restoreIP(IP: BuilderInsertPoint);
1400 return;
1401 }
1402
1403 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1404 if (PGOReader && haveRegionCounts()) {
1405 // We record the top most called three functions at each call site.
1406 // Profile metadata contains "VP" string identifying this metadata
1407 // as value profiling data, then a uint32_t value for the value profiling
1408 // kind, a uint64_t value for the total number of times the call is
1409 // executed, followed by the function hash and execution count (uint64_t)
1410 // pairs for each function.
1411 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1412 return;
1413
1414 llvm::annotateValueSite(M&: CGM.getModule(), Inst&: *ValueSite, InstrProfR: *ProfRecord,
1415 ValueKind: (llvm::InstrProfValueKind)ValueKind,
1416 SiteIndx: NumValueSites[ValueKind]);
1417
1418 NumValueSites[ValueKind]++;
1419 }
1420}
1421
1422void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1423 bool IsInMainFile) {
1424 CGM.getPGOStats().addVisited(MainFile: IsInMainFile);
1425 RegionCounts.clear();
1426 auto RecordExpected = PGOReader->getInstrProfRecord(FuncName, FuncHash: FunctionHash);
1427 if (auto E = RecordExpected.takeError()) {
1428 auto IPE = std::get<0>(in: llvm::InstrProfError::take(E: std::move(E)));
1429 if (IPE == llvm::instrprof_error::unknown_function)
1430 CGM.getPGOStats().addMissing(MainFile: IsInMainFile);
1431 else if (IPE == llvm::instrprof_error::hash_mismatch)
1432 CGM.getPGOStats().addMismatched(MainFile: IsInMainFile);
1433 else if (IPE == llvm::instrprof_error::malformed)
1434 // TODO: Consider a more specific warning for this case.
1435 CGM.getPGOStats().addMismatched(MainFile: IsInMainFile);
1436 return;
1437 }
1438 ProfRecord =
1439 std::make_unique<llvm::InstrProfRecord>(args: std::move(RecordExpected.get()));
1440 RegionCounts = ProfRecord->Counts;
1441}
1442
1443/// Calculate what to divide by to scale weights.
1444///
1445/// Given the maximum weight, calculate a divisor that will scale all the
1446/// weights to strictly less than UINT32_MAX.
1447static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1448 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1449}
1450
1451/// Scale an individual branch weight (and add 1).
1452///
1453/// Scale a 64-bit weight down to 32-bits using \c Scale.
1454///
1455/// According to Laplace's Rule of Succession, it is better to compute the
1456/// weight based on the count plus 1, so universally add 1 to the value.
1457///
1458/// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1459/// greater than \c Weight.
1460static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1461 assert(Scale && "scale by 0?");
1462 uint64_t Scaled = Weight / Scale + 1;
1463 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1464 return Scaled;
1465}
1466
1467llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1468 uint64_t FalseCount) const {
1469 // Check for empty weights.
1470 if (!TrueCount && !FalseCount)
1471 return nullptr;
1472
1473 // Calculate how to scale down to 32-bits.
1474 uint64_t Scale = calculateWeightScale(MaxWeight: std::max(a: TrueCount, b: FalseCount));
1475
1476 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1477 return MDHelper.createBranchWeights(TrueWeight: scaleBranchWeight(Weight: TrueCount, Scale),
1478 FalseWeight: scaleBranchWeight(Weight: FalseCount, Scale));
1479}
1480
1481llvm::MDNode *
1482CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1483 // We need at least two elements to create meaningful weights.
1484 if (Weights.size() < 2)
1485 return nullptr;
1486
1487 // Check for empty weights.
1488 uint64_t MaxWeight = *llvm::max_element(Range&: Weights);
1489 if (MaxWeight == 0)
1490 return nullptr;
1491
1492 // Calculate how to scale down to 32-bits.
1493 uint64_t Scale = calculateWeightScale(MaxWeight);
1494
1495 SmallVector<uint32_t, 16> ScaledWeights;
1496 ScaledWeights.reserve(N: Weights.size());
1497 for (uint64_t W : Weights)
1498 ScaledWeights.push_back(Elt: scaleBranchWeight(Weight: W, Scale));
1499
1500 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1501 return MDHelper.createBranchWeights(Weights: ScaledWeights);
1502}
1503
1504llvm::MDNode *
1505CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1506 uint64_t LoopCount) const {
1507 if (!PGO->haveRegionCounts())
1508 return nullptr;
1509 std::optional<uint64_t> CondCount = PGO->getStmtCount(S: Cond);
1510 if (!CondCount || *CondCount == 0)
1511 return nullptr;
1512 return createProfileWeights(TrueCount: LoopCount,
1513 FalseCount: std::max(a: *CondCount, b: LoopCount) - LoopCount);
1514}
1515
1516void CodeGenFunction::incrementProfileCounter(const Stmt *S,
1517 llvm::Value *StepV) {
1518 if (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1519 !CurFn->hasFnAttribute(Kind: llvm::Attribute::NoProfile) &&
1520 !CurFn->hasFnAttribute(Kind: llvm::Attribute::SkipProfile)) {
1521 auto AL = ApplyDebugLocation::CreateArtificial(CGF&: *this);
1522 PGO->emitCounterSetOrIncrement(Builder, S, StepV);
1523 }
1524 PGO->setCurrentStmt(S);
1525}
1526
1527std::pair<bool, bool> CodeGenFunction::getIsCounterPair(const Stmt *S) const {
1528 return PGO->getIsCounterPair(S);
1529}
1530void CodeGenFunction::markStmtAsUsed(bool Skipped, const Stmt *S) {
1531 PGO->markStmtAsUsed(Skipped, S);
1532}
1533void CodeGenFunction::markStmtMaybeUsed(const Stmt *S) {
1534 PGO->markStmtMaybeUsed(S);
1535}
1536
1537void CodeGenFunction::maybeCreateMCDCCondBitmap() {
1538 if (isMCDCCoverageEnabled()) {
1539 PGO->emitMCDCParameters(Builder);
1540 MCDCCondBitmapAddr = CreateIRTemp(T: getContext().UnsignedIntTy, Name: "mcdc.addr");
1541 }
1542}
1543void CodeGenFunction::maybeResetMCDCCondBitmap(const Expr *E) {
1544 if (isMCDCCoverageEnabled() && isBinaryLogicalOp(E)) {
1545 PGO->emitMCDCCondBitmapReset(Builder, S: E, MCDCCondBitmapAddr);
1546 PGO->setCurrentStmt(E);
1547 }
1548}
1549void CodeGenFunction::maybeUpdateMCDCTestVectorBitmap(const Expr *E) {
1550 if (isMCDCCoverageEnabled() && isBinaryLogicalOp(E)) {
1551 PGO->emitMCDCTestVectorBitmapUpdate(Builder, S: E, MCDCCondBitmapAddr, CGF&: *this);
1552 PGO->setCurrentStmt(E);
1553 }
1554}
1555
1556void CodeGenFunction::maybeUpdateMCDCCondBitmap(const Expr *E,
1557 llvm::Value *Val) {
1558 if (isMCDCCoverageEnabled()) {
1559 PGO->emitMCDCCondBitmapUpdate(Builder, S: E, MCDCCondBitmapAddr, Val, CGF&: *this);
1560 PGO->setCurrentStmt(E);
1561 }
1562}
1563
1564uint64_t CodeGenFunction::getProfileCount(const Stmt *S) {
1565 return PGO->getStmtCount(S).value_or(u: 0);
1566}
1567
1568/// Set the profiler's current count.
1569void CodeGenFunction::setCurrentProfileCount(uint64_t Count) {
1570 PGO->setCurrentRegionCount(Count);
1571}
1572
1573/// Get the profiler's current count. This is generally the count for the most
1574/// recently incremented counter.
1575uint64_t CodeGenFunction::getCurrentProfileCount() {
1576 return PGO->getCurrentRegionCount();
1577}
1578