1//===-- InsertCodePrefetch.cpp ---=========--------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// Code Prefetch Insertion Pass.
11//===----------------------------------------------------------------------===//
12/// This pass inserts code prefetch instructions according to the prefetch
13/// directives in the basic block section profile. The target of a prefetch can
14/// be the beginning of any dynamic basic block, that is the beginning of a
15/// machine basic block, or immediately after a callsite. A global symbol is
16/// emitted at the position of the target so it can be addressed from the
17/// prefetch instruction from any module. In order to insert prefetch hints,
18/// `TargetInstrInfo::insertCodePrefetchInstr` must be implemented by the
19/// target.
20//===----------------------------------------------------------------------===//
21
22#include "llvm/CodeGen/InsertCodePrefetch.h"
23
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/StringRef.h"
28#include "llvm/CodeGen/BasicBlockSectionUtils.h"
29#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
30#include "llvm/CodeGen/MachineBasicBlock.h"
31#include "llvm/CodeGen/MachineFunction.h"
32#include "llvm/CodeGen/MachineFunctionPass.h"
33#include "llvm/CodeGen/Passes.h"
34#include "llvm/CodeGen/TargetInstrInfo.h"
35#include "llvm/InitializePasses.h"
36#include "llvm/MC/MCContext.h"
37#include "llvm/MC/MCSymbolELF.h"
38#include "llvm/Object/ELFTypes.h"
39
40using namespace llvm;
41#define DEBUG_TYPE "insert-code-prefetch"
42
43SmallString<128> llvm::getPrefetchTargetSymbolName(StringRef FunctionName,
44 const UniqueBBID &BBID,
45 unsigned CallsiteIndex) {
46 SmallString<128> R("__llvm_prefetch_target_");
47 R += FunctionName;
48 R += "_";
49 R += utostr(X: BBID.BaseID);
50 R += "_";
51 R += utostr(X: CallsiteIndex);
52 return R;
53}
54
55namespace {
56class InsertCodePrefetch : public MachineFunctionPass {
57public:
58 static char ID;
59
60 InsertCodePrefetch() : MachineFunctionPass(ID) {}
61
62 StringRef getPassName() const override {
63 return "Code Prefetch Inserter Pass";
64 }
65
66 void getAnalysisUsage(AnalysisUsage &AU) const override;
67
68 // Sets prefetch targets based on the bb section profile.
69 bool runOnMachineFunction(MachineFunction &MF) override;
70};
71
72} // end anonymous namespace
73
74//===----------------------------------------------------------------------===//
75// Implementation
76//===----------------------------------------------------------------------===//
77
78char InsertCodePrefetch::ID = 0;
79INITIALIZE_PASS_BEGIN(InsertCodePrefetch, DEBUG_TYPE, "Code prefetch insertion",
80 true, false)
81INITIALIZE_PASS_DEPENDENCY(BasicBlockSectionsProfileReaderWrapperPass)
82INITIALIZE_PASS_END(InsertCodePrefetch, DEBUG_TYPE, "Code prefetch insertion",
83 true, false)
84
85static bool setPrefetchTargets(MachineFunction &MF,
86 const SmallVector<CallsiteID> &PrefetchTargets) {
87 if (PrefetchTargets.empty())
88 return false;
89 // Set each block's prefetch targets so AsmPrinter can emit a special symbol
90 // there.
91 DenseMap<UniqueBBID, SmallVector<unsigned>> PrefetchTargetsByBBID;
92 for (const auto &Target : PrefetchTargets)
93 PrefetchTargetsByBBID[Target.BBID].push_back(Elt: Target.CallsiteIndex);
94 // Sort and uniquify the callsite indices for every block.
95 for (auto &[K, V] : PrefetchTargetsByBBID) {
96 llvm::sort(C&: V);
97 V.erase(CS: llvm::unique(R&: V), CE: V.end());
98 }
99 MF.setPrefetchTargets(PrefetchTargetsByBBID);
100 return true;
101}
102
103static bool
104insertPrefetchHints(MachineFunction &MF,
105 const SmallVector<PrefetchHint> &PrefetchHints) {
106 bool PrefetchInserted = false;
107 bool IsELF = MF.getTarget().getTargetTriple().isOSBinFormatELF();
108 const Module *M = MF.getFunction().getParent();
109 DenseMap<UniqueBBID, SmallVector<PrefetchHint>> PrefetchHintsBySiteBBID;
110 for (const auto &H : PrefetchHints)
111 PrefetchHintsBySiteBBID[H.SiteID.BBID].push_back(Elt: H);
112 // Sort prefetch hints by their callsite index so we can insert them by one
113 // pass over the block's instructions.
114 for (auto &[SiteBBID, Hints] : PrefetchHintsBySiteBBID) {
115 llvm::stable_sort(
116 Range&: Hints, C: [](const PrefetchHint &H1, const PrefetchHint &H2) {
117 return H1.SiteID.CallsiteIndex < H2.SiteID.CallsiteIndex;
118 });
119 }
120 auto PtrTy =
121 PointerType::getUnqual(C&: MF.getFunction().getParent()->getContext());
122 const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
123 for (auto &BB : MF) {
124 auto It = PrefetchHintsBySiteBBID.find(Val: *BB.getBBID());
125 if (It == PrefetchHintsBySiteBBID.end())
126 continue;
127 const auto &BBHints = It->second;
128 unsigned NumCallsInBB = 0;
129 auto InstrIt = BB.begin();
130 for (auto HintIt = BBHints.begin(); HintIt != BBHints.end();) {
131 auto NextInstrIt = InstrIt == BB.end() ? BB.end() : std::next(x: InstrIt);
132 // Insert all the prefetch hints which must be placed after this call (or
133 // at the beginning of the block if `NumCallsInBB` is zero.
134 while (HintIt != BBHints.end() &&
135 HintIt->SiteID.CallsiteIndex == NumCallsInBB) {
136 bool TargetFunctionDefined = false;
137 if (Function *TargetFunction = M->getFunction(Name: HintIt->TargetFunction))
138 TargetFunctionDefined = !TargetFunction->isDeclaration();
139
140 auto TargetSymbolName = getPrefetchTargetSymbolName(
141 FunctionName: HintIt->TargetFunction, BBID: HintIt->TargetID.BBID,
142 CallsiteIndex: HintIt->TargetID.CallsiteIndex);
143 auto *GV = MF.getFunction().getParent()->getOrInsertGlobal(
144 Name: TargetSymbolName, Ty: PtrTy);
145 MachineInstr *PrefetchInstr =
146 TII->insertCodePrefetchInstr(MBB&: BB, InsertBefore: InstrIt, GV);
147 if (!TargetFunctionDefined && IsELF) {
148 // If the target function is not defined in this module, we guard
149 // against undefined prefetch target symbol by emitting a fallback
150 // symbol with weak linkage right after the prefetch instruction. If
151 // there is no strong symbol, the fallback will be used and we
152 // prefetch the next address:
153 //
154 // prefetchit1 __llvm_prefetch_target_foo_x_y(%rip)
155 // .weak __llvm_prefetch_target_foo_x_y
156 // __llvm_prefetch_target_foo_x_y:
157 MCSymbolELF *WeakFallbackSym = static_cast<MCSymbolELF *>(
158 MF.getContext().getOrCreateSymbol(Name: TargetSymbolName));
159 // The fallback symbol may have been defined via another prefetch
160 // instruction in the same module, in which case we should not emit it
161 // here. Ideally, getOrCreateSymbol should tell us if the symbol
162 // existed, but we use `isBindingSet()` since that API is not
163 // available.
164 if (!WeakFallbackSym->isBindingSet()) {
165 WeakFallbackSym->setBinding(ELF::STB_WEAK);
166 PrefetchInstr->setPostInstrSymbol(MF, Symbol: WeakFallbackSym);
167 }
168 }
169 PrefetchInserted = true;
170 ++HintIt;
171 }
172 if (InstrIt == BB.end())
173 break;
174 if (InstrIt->isCall())
175 ++NumCallsInBB;
176 InstrIt = NextInstrIt;
177 }
178 }
179 return PrefetchInserted;
180}
181
182bool InsertCodePrefetch::runOnMachineFunction(MachineFunction &MF) {
183 assert(MF.getTarget().getBBSectionsType() == BasicBlockSection::List &&
184 "BB Sections list not enabled!");
185 if (hasInstrProfHashMismatch(MF))
186 return false;
187
188 auto &ProfileReader =
189 getAnalysis<BasicBlockSectionsProfileReaderWrapperPass>();
190 bool R = setPrefetchTargets(
191 MF, PrefetchTargets: ProfileReader.getPrefetchTargetsForFunction(FuncName: MF.getName()));
192 bool S = insertPrefetchHints(
193 MF, PrefetchHints: ProfileReader.getPrefetchHintsForFunction(FuncName: MF.getName()));
194 return R || S;
195}
196
197void InsertCodePrefetch::getAnalysisUsage(AnalysisUsage &AU) const {
198 AU.setPreservesAll();
199 AU.addRequired<BasicBlockSectionsProfileReaderWrapperPass>();
200 MachineFunctionPass::getAnalysisUsage(AU);
201}
202
203MachineFunctionPass *llvm::createInsertCodePrefetchPass() {
204 return new InsertCodePrefetch();
205}
206