1 | //===----------------------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // The analysis determines the convergence region for each basic block of |
10 | // the module, and provides a tree-like structure describing the region |
11 | // hierarchy. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "SPIRVConvergenceRegionAnalysis.h" |
16 | #include "SPIRV.h" |
17 | #include "llvm/Analysis/LoopInfo.h" |
18 | #include "llvm/IR/Dominators.h" |
19 | #include "llvm/IR/IntrinsicInst.h" |
20 | #include "llvm/InitializePasses.h" |
21 | #include "llvm/Transforms/Utils/LoopSimplify.h" |
22 | #include <optional> |
23 | #include <queue> |
24 | |
25 | #define DEBUG_TYPE "spirv-convergence-region-analysis" |
26 | |
27 | using namespace llvm; |
28 | using namespace SPIRV; |
29 | |
30 | INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass, |
31 | "convergence-region" , |
32 | "SPIRV convergence regions analysis" , true, true) |
33 | INITIALIZE_PASS_DEPENDENCY(LoopSimplify) |
34 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
35 | INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) |
36 | INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass, |
37 | "convergence-region" , "SPIRV convergence regions analysis" , |
38 | true, true) |
39 | |
40 | namespace { |
41 | |
42 | template <typename BasicBlockType, typename IntrinsicInstType> |
43 | std::optional<IntrinsicInstType *> |
44 | getConvergenceTokenInternal(BasicBlockType *BB) { |
45 | static_assert(std::is_const_v<IntrinsicInstType> == |
46 | std::is_const_v<BasicBlockType>, |
47 | "Constness must match between input and output." ); |
48 | static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>, |
49 | "Input must be a basic block." ); |
50 | static_assert( |
51 | std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>, |
52 | "Output type must be an intrinsic instruction." ); |
53 | |
54 | for (auto &I : *BB) { |
55 | if (auto *CI = dyn_cast<ConvergenceControlInst>(&I)) { |
56 | // Make sure that the anchor or entry intrinsics did not reach here with a |
57 | // parent token. This should have failed the verifier. |
58 | assert(CI->isLoop() || |
59 | !CI->getOperandBundle(LLVMContext::OB_convergencectrl)); |
60 | return CI; |
61 | } |
62 | |
63 | if (auto *CI = dyn_cast<CallInst>(&I)) { |
64 | auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl); |
65 | if (!OB.has_value()) |
66 | continue; |
67 | return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]); |
68 | } |
69 | } |
70 | |
71 | return std::nullopt; |
72 | } |
73 | } // anonymous namespace |
74 | |
75 | // Given a ConvergenceRegion tree with |Start| as its root, finds the smallest |
76 | // region |Entry| belongs to. If |Entry| does not belong to the region defined |
77 | // by |Start|, this function returns |nullptr|. |
78 | static ConvergenceRegion *findParentRegion(ConvergenceRegion *Start, |
79 | BasicBlock *Entry) { |
80 | ConvergenceRegion *Candidate = nullptr; |
81 | ConvergenceRegion *NextCandidate = Start; |
82 | |
83 | while (Candidate != NextCandidate && NextCandidate != nullptr) { |
84 | Candidate = NextCandidate; |
85 | NextCandidate = nullptr; |
86 | |
87 | // End of the search, we can return. |
88 | if (Candidate->Children.size() == 0) |
89 | return Candidate; |
90 | |
91 | for (auto *Child : Candidate->Children) { |
92 | if (Child->Blocks.count(Ptr: Entry) != 0) { |
93 | NextCandidate = Child; |
94 | break; |
95 | } |
96 | } |
97 | } |
98 | |
99 | return Candidate; |
100 | } |
101 | |
102 | std::optional<IntrinsicInst *> |
103 | llvm::SPIRV::getConvergenceToken(BasicBlock *BB) { |
104 | return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB); |
105 | } |
106 | |
107 | std::optional<const IntrinsicInst *> |
108 | llvm::SPIRV::getConvergenceToken(const BasicBlock *BB) { |
109 | return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB); |
110 | } |
111 | |
112 | ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, |
113 | Function &F) |
114 | : DT(DT), LI(LI), Parent(nullptr) { |
115 | Entry = &F.getEntryBlock(); |
116 | ConvergenceToken = getConvergenceToken(BB: Entry); |
117 | for (auto &B : F) { |
118 | Blocks.insert(Ptr: &B); |
119 | if (isa<ReturnInst>(Val: B.getTerminator())) |
120 | Exits.insert(Ptr: &B); |
121 | } |
122 | } |
123 | |
124 | ConvergenceRegion::ConvergenceRegion( |
125 | DominatorTree &DT, LoopInfo &LI, |
126 | std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry, |
127 | SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits) |
128 | : DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry), |
129 | Exits(std::move(Exits)), Blocks(std::move(Blocks)) { |
130 | for ([[maybe_unused]] auto *BB : this->Exits) |
131 | assert(this->Blocks.count(BB) != 0); |
132 | assert(this->Blocks.count(this->Entry) != 0); |
133 | } |
134 | |
135 | void ConvergenceRegion::releaseMemory() { |
136 | // Parent memory is owned by the parent. |
137 | Parent = nullptr; |
138 | for (auto *Child : Children) { |
139 | Child->releaseMemory(); |
140 | delete Child; |
141 | } |
142 | Children.resize(N: 0); |
143 | } |
144 | |
145 | void ConvergenceRegion::dump(const unsigned IndentSize) const { |
146 | const std::string Indent(IndentSize, '\t'); |
147 | dbgs() << Indent << this << ": {\n" ; |
148 | dbgs() << Indent << " Parent: " << Parent << "\n" ; |
149 | |
150 | if (ConvergenceToken.value_or(u: nullptr)) { |
151 | dbgs() << Indent |
152 | << " ConvergenceToken: " << ConvergenceToken.value()->getName() |
153 | << "\n" ; |
154 | } |
155 | |
156 | if (Entry->getName() != "" ) |
157 | dbgs() << Indent << " Entry: " << Entry->getName() << "\n" ; |
158 | else |
159 | dbgs() << Indent << " Entry: " << Entry << "\n" ; |
160 | |
161 | dbgs() << Indent << " Exits: { " ; |
162 | for (const auto &Exit : Exits) { |
163 | if (Exit->getName() != "" ) |
164 | dbgs() << Exit->getName() << ", " ; |
165 | else |
166 | dbgs() << Exit << ", " ; |
167 | } |
168 | dbgs() << " }\n" ; |
169 | |
170 | dbgs() << Indent << " Blocks: { " ; |
171 | for (const auto &Block : Blocks) { |
172 | if (Block->getName() != "" ) |
173 | dbgs() << Block->getName() << ", " ; |
174 | else |
175 | dbgs() << Block << ", " ; |
176 | } |
177 | dbgs() << " }\n" ; |
178 | |
179 | dbgs() << Indent << " Children: {\n" ; |
180 | for (const auto Child : Children) |
181 | Child->dump(IndentSize: IndentSize + 2); |
182 | dbgs() << Indent << " }\n" ; |
183 | |
184 | dbgs() << Indent << "}\n" ; |
185 | } |
186 | |
187 | namespace { |
188 | class ConvergenceRegionAnalyzer { |
189 | public: |
190 | ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI) |
191 | : DT(DT), LI(LI), F(F) {} |
192 | |
193 | private: |
194 | bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const { |
195 | if (From == To) |
196 | return true; |
197 | |
198 | // We only handle loop in the simplified form. This means: |
199 | // - a single back-edge, a single latch. |
200 | // - meaning the back-edge target can only be the loop header. |
201 | // - meaning the From can only be the loop latch. |
202 | if (!LI.isLoopHeader(BB: To)) |
203 | return false; |
204 | |
205 | auto *L = LI.getLoopFor(BB: To); |
206 | if (L->contains(BB: From) && L->isLoopLatch(BB: From)) |
207 | return true; |
208 | |
209 | return false; |
210 | } |
211 | |
212 | std::unordered_set<BasicBlock *> |
213 | findPathsToMatch(LoopInfo &LI, BasicBlock *From, |
214 | std::function<bool(const BasicBlock *)> isMatch) const { |
215 | std::unordered_set<BasicBlock *> Output; |
216 | |
217 | if (isMatch(From)) |
218 | Output.insert(x: From); |
219 | |
220 | auto *Terminator = From->getTerminator(); |
221 | for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) { |
222 | auto *To = Terminator->getSuccessor(Idx: i); |
223 | // Ignore back edges. |
224 | if (isBackEdge(From, To)) |
225 | continue; |
226 | |
227 | auto ChildSet = findPathsToMatch(LI, From: To, isMatch); |
228 | if (ChildSet.size() == 0) |
229 | continue; |
230 | |
231 | Output.insert(first: ChildSet.begin(), last: ChildSet.end()); |
232 | Output.insert(x: From); |
233 | if (LI.isLoopHeader(BB: From)) { |
234 | auto *L = LI.getLoopFor(BB: From); |
235 | for (auto *BB : L->getBlocks()) { |
236 | Output.insert(x: BB); |
237 | } |
238 | } |
239 | } |
240 | |
241 | return Output; |
242 | } |
243 | |
244 | SmallPtrSet<BasicBlock *, 2> |
245 | findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) { |
246 | SmallPtrSet<BasicBlock *, 2> Exits; |
247 | |
248 | for (auto *B : RegionBlocks) { |
249 | auto *Terminator = B->getTerminator(); |
250 | for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) { |
251 | auto *Child = Terminator->getSuccessor(Idx: i); |
252 | if (RegionBlocks.count(Ptr: Child) == 0) |
253 | Exits.insert(Ptr: B); |
254 | } |
255 | } |
256 | |
257 | return Exits; |
258 | } |
259 | |
260 | public: |
261 | ConvergenceRegionInfo analyze() { |
262 | ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F); |
263 | std::queue<Loop *> ToProcess; |
264 | for (auto *L : LI.getLoopsInPreorder()) |
265 | ToProcess.push(x: L); |
266 | |
267 | while (ToProcess.size() != 0) { |
268 | auto *L = ToProcess.front(); |
269 | ToProcess.pop(); |
270 | |
271 | auto CT = getConvergenceToken(BB: L->getHeader()); |
272 | SmallPtrSet<BasicBlock *, 8> RegionBlocks(llvm::from_range, L->blocks()); |
273 | SmallVector<BasicBlock *> LoopExits; |
274 | L->getExitingBlocks(ExitingBlocks&: LoopExits); |
275 | if (CT.has_value()) { |
276 | for (auto *Exit : LoopExits) { |
277 | auto N = findPathsToMatch(LI, From: Exit, isMatch: [&CT](const BasicBlock *block) { |
278 | auto Token = getConvergenceToken(BB: block); |
279 | if (Token == std::nullopt) |
280 | return false; |
281 | return Token.value() == CT.value(); |
282 | }); |
283 | RegionBlocks.insert_range(R&: N); |
284 | } |
285 | } |
286 | |
287 | auto RegionExits = findExitNodes(RegionBlocks); |
288 | ConvergenceRegion *Region = new ConvergenceRegion( |
289 | DT, LI, CT, L->getHeader(), std::move(RegionBlocks), |
290 | std::move(RegionExits)); |
291 | Region->Parent = findParentRegion(Start: TopLevelRegion, Entry: Region->Entry); |
292 | assert(Region->Parent != nullptr && "This is impossible." ); |
293 | Region->Parent->Children.push_back(Elt: Region); |
294 | } |
295 | |
296 | return ConvergenceRegionInfo(TopLevelRegion); |
297 | } |
298 | |
299 | private: |
300 | DominatorTree &DT; |
301 | LoopInfo &LI; |
302 | Function &F; |
303 | }; |
304 | } // anonymous namespace |
305 | |
306 | ConvergenceRegionInfo llvm::SPIRV::getConvergenceRegions(Function &F, |
307 | DominatorTree &DT, |
308 | LoopInfo &LI) { |
309 | ConvergenceRegionAnalyzer Analyzer(F, DT, LI); |
310 | return Analyzer.analyze(); |
311 | } |
312 | |
313 | char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0; |
314 | |
315 | SPIRVConvergenceRegionAnalysisWrapperPass:: |
316 | SPIRVConvergenceRegionAnalysisWrapperPass() |
317 | : FunctionPass(ID) {} |
318 | |
319 | bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) { |
320 | DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); |
321 | LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
322 | |
323 | CRI = SPIRV::getConvergenceRegions(F, DT, LI); |
324 | // Nothing was modified. |
325 | return false; |
326 | } |
327 | |
328 | SPIRVConvergenceRegionAnalysis::Result |
329 | SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) { |
330 | Result CRI; |
331 | auto &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F); |
332 | auto &LI = AM.getResult<LoopAnalysis>(IR&: F); |
333 | CRI = SPIRV::getConvergenceRegions(F, DT, LI); |
334 | return CRI; |
335 | } |
336 | |
337 | AnalysisKey SPIRVConvergenceRegionAnalysis::Key; |
338 | |