1//===- llvm/CodeGen/BasicBlockMatchingAndInference.cpp ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// In Propeller's profile, we have already read the hash values of basic blocks,
10// as well as the weights of basic blocks and edges in the CFG. In this file,
11// we first match the basic blocks in the profile with those in the current
12// MachineFunction using the basic block hash, thereby obtaining the weights of
13// some basic blocks and edges. Subsequently, we infer the weights of all basic
14// blocks using an inference algorithm.
15//
16// TODO: Integrate part of the code in this file with BOLT's implementation into
17// the LLVM infrastructure, enabling both BOLT and Propeller to reuse it.
18//
19//===----------------------------------------------------------------------===//
20
21#include "llvm/CodeGen/BasicBlockMatchingAndInference.h"
22#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
23#include "llvm/CodeGen/MachineBlockHashInfo.h"
24#include "llvm/CodeGen/Passes.h"
25#include "llvm/InitializePasses.h"
26#include <llvm/Support/CommandLine.h>
27#include <unordered_map>
28
29using namespace llvm;
30
31static cl::opt<float>
32 PropellerInferThreshold("propeller-infer-threshold",
33 cl::desc("Threshold for infer stale profile"),
34 cl::init(Val: 0.6), cl::Optional);
35
36/// The object is used to identify and match basic blocks given their hashes.
37class StaleMatcher {
38public:
39 /// Initialize stale matcher.
40 void init(const std::vector<MachineBasicBlock *> &Blocks,
41 const std::vector<BlendedBlockHash> &Hashes) {
42 assert(Blocks.size() == Hashes.size() &&
43 "incorrect matcher initialization");
44 for (size_t I = 0; I < Blocks.size(); I++) {
45 MachineBasicBlock *Block = Blocks[I];
46 uint16_t OpHash = Hashes[I].getOpcodeHash();
47 OpHashToBlocks[OpHash].push_back(x: std::make_pair(x: Hashes[I], y&: Block));
48 }
49 }
50
51 /// Find the most similar block for a given hash.
52 MachineBasicBlock *matchBlock(BlendedBlockHash BlendedHash) const {
53 auto BlockIt = OpHashToBlocks.find(x: BlendedHash.getOpcodeHash());
54 if (BlockIt == OpHashToBlocks.end()) {
55 return nullptr;
56 }
57 MachineBasicBlock *BestBlock = nullptr;
58 uint64_t BestDist = std::numeric_limits<uint64_t>::max();
59 for (auto It : BlockIt->second) {
60 MachineBasicBlock *Block = It.second;
61 BlendedBlockHash Hash = It.first;
62 uint64_t Dist = Hash.distance(BBH: BlendedHash);
63 if (BestBlock == nullptr || Dist < BestDist) {
64 BestDist = Dist;
65 BestBlock = Block;
66 }
67 }
68 return BestBlock;
69 }
70
71private:
72 using HashBlockPairType = std::pair<BlendedBlockHash, MachineBasicBlock *>;
73 std::unordered_map<uint16_t, std::vector<HashBlockPairType>> OpHashToBlocks;
74};
75
76INITIALIZE_PASS_BEGIN(BasicBlockMatchingAndInference,
77 "machine-block-match-infer",
78 "Machine Block Matching and Inference Analysis", true,
79 true)
80INITIALIZE_PASS_DEPENDENCY(MachineBlockHashInfo)
81INITIALIZE_PASS_DEPENDENCY(BasicBlockSectionsProfileReaderWrapperPass)
82INITIALIZE_PASS_END(BasicBlockMatchingAndInference, "machine-block-match-infer",
83 "Machine Block Matching and Inference Analysis", true, true)
84
85char BasicBlockMatchingAndInference::ID = 0;
86
87BasicBlockMatchingAndInference::BasicBlockMatchingAndInference()
88 : MachineFunctionPass(ID) {}
89
90void BasicBlockMatchingAndInference::getAnalysisUsage(AnalysisUsage &AU) const {
91 AU.addRequired<MachineBlockHashInfo>();
92 AU.addRequired<BasicBlockSectionsProfileReaderWrapperPass>();
93 AU.setPreservesAll();
94 MachineFunctionPass::getAnalysisUsage(AU);
95}
96
97std::optional<BasicBlockMatchingAndInference::WeightInfo>
98BasicBlockMatchingAndInference::getWeightInfo(StringRef FuncName) const {
99 auto It = ProgramWeightInfo.find(Key: FuncName);
100 if (It == ProgramWeightInfo.end()) {
101 return std::nullopt;
102 }
103 return It->second;
104}
105
106BasicBlockMatchingAndInference::WeightInfo
107BasicBlockMatchingAndInference::initWeightInfoByMatching(MachineFunction &MF) {
108 std::vector<MachineBasicBlock *> Blocks;
109 std::vector<BlendedBlockHash> Hashes;
110 auto BSPR = &getAnalysis<BasicBlockSectionsProfileReaderWrapperPass>();
111 auto MBHI = &getAnalysis<MachineBlockHashInfo>();
112 for (auto &Block : MF) {
113 Blocks.push_back(x: &Block);
114 Hashes.push_back(x: BlendedBlockHash(MBHI->getMBBHash(MBB: Block)));
115 }
116 StaleMatcher Matcher;
117 Matcher.init(Blocks, Hashes);
118 BasicBlockMatchingAndInference::WeightInfo MatchWeight;
119 const CFGProfile *CFG = BSPR->getFunctionCFGProfile(FuncName: MF.getName());
120 if (CFG == nullptr)
121 return MatchWeight;
122 for (auto &BlockCount : CFG->NodeCounts) {
123 if (CFG->BBHashes.count(Val: BlockCount.first.BaseID)) {
124 auto Hash = CFG->BBHashes.lookup(Val: BlockCount.first.BaseID);
125 MachineBasicBlock *Block = Matcher.matchBlock(BlendedHash: BlendedBlockHash(Hash));
126 // When a basic block has clone copies, sum their counts.
127 if (Block != nullptr)
128 MatchWeight.BlockWeights[Block] += BlockCount.second;
129 }
130 }
131 for (auto &PredItem : CFG->EdgeCounts) {
132 auto PredID = PredItem.first.BaseID;
133 if (!CFG->BBHashes.count(Val: PredID))
134 continue;
135 auto PredHash = CFG->BBHashes.lookup(Val: PredID);
136 MachineBasicBlock *PredBlock =
137 Matcher.matchBlock(BlendedHash: BlendedBlockHash(PredHash));
138 if (PredBlock == nullptr)
139 continue;
140 for (auto &SuccItem : PredItem.second) {
141 auto SuccID = SuccItem.first.BaseID;
142 auto EdgeWeight = SuccItem.second;
143 if (CFG->BBHashes.count(Val: SuccID)) {
144 auto SuccHash = CFG->BBHashes.lookup(Val: SuccID);
145 MachineBasicBlock *SuccBlock =
146 Matcher.matchBlock(BlendedHash: BlendedBlockHash(SuccHash));
147 // When an edge has clone copies, sum their counts.
148 if (SuccBlock != nullptr)
149 MatchWeight.EdgeWeights[std::make_pair(x&: PredBlock, y&: SuccBlock)] +=
150 EdgeWeight;
151 }
152 }
153 }
154 return MatchWeight;
155}
156
157void BasicBlockMatchingAndInference::generateWeightInfoByInference(
158 MachineFunction &MF,
159 BasicBlockMatchingAndInference::WeightInfo &MatchWeight) {
160 BlockEdgeMap Successors;
161 for (auto &Block : MF) {
162 for (auto *Succ : Block.successors())
163 Successors[&Block].push_back(Elt: Succ);
164 }
165 SampleProfileInference<MachineFunction> SPI(
166 MF, Successors, MatchWeight.BlockWeights, MatchWeight.EdgeWeights);
167 BlockWeightMap BlockWeights;
168 EdgeWeightMap EdgeWeights;
169 SPI.apply(BlockWeights, EdgeWeights);
170 ProgramWeightInfo.try_emplace(
171 Key: MF.getName(), Args: BasicBlockMatchingAndInference::WeightInfo{
172 .BlockWeights: std::move(BlockWeights), .EdgeWeights: std::move(EdgeWeights)});
173}
174
175bool BasicBlockMatchingAndInference::runOnMachineFunction(MachineFunction &MF) {
176 if (MF.empty())
177 return false;
178 auto MatchWeight = initWeightInfoByMatching(MF);
179 // If the ratio of the number of MBBs in matching to the total number of MBBs
180 // in the function is less than the threshold value, the processing should be
181 // abandoned.
182 if (static_cast<float>(MatchWeight.BlockWeights.size()) / MF.size() <
183 PropellerInferThreshold) {
184 return false;
185 }
186 generateWeightInfoByInference(MF, MatchWeight);
187 return false;
188}
189
190MachineFunctionPass *llvm::createBasicBlockMatchingAndInferencePass() {
191 return new BasicBlockMatchingAndInference();
192}
193