1 | //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Merge the multiple exit targets of a convergence region into a single block. |
10 | // Each exit target will be assigned a constant value, and a phi node + switch |
11 | // will allow the new exit target to re-route to the correct basic block. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "Analysis/SPIRVConvergenceRegionAnalysis.h" |
16 | #include "SPIRV.h" |
17 | #include "SPIRVSubtarget.h" |
18 | #include "SPIRVUtils.h" |
19 | #include "llvm/ADT/DenseMap.h" |
20 | #include "llvm/ADT/SmallPtrSet.h" |
21 | #include "llvm/Analysis/LoopInfo.h" |
22 | #include "llvm/CodeGen/IntrinsicLowering.h" |
23 | #include "llvm/IR/Dominators.h" |
24 | #include "llvm/IR/IRBuilder.h" |
25 | #include "llvm/IR/Intrinsics.h" |
26 | #include "llvm/InitializePasses.h" |
27 | #include "llvm/Transforms/Utils/Cloning.h" |
28 | #include "llvm/Transforms/Utils/LoopSimplify.h" |
29 | #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" |
30 | |
31 | using namespace llvm; |
32 | |
33 | namespace { |
34 | |
35 | class SPIRVMergeRegionExitTargets : public FunctionPass { |
36 | public: |
37 | static char ID; |
38 | |
39 | SPIRVMergeRegionExitTargets() : FunctionPass(ID) {} |
40 | |
41 | // Gather all the successors of |BB|. |
42 | // This function asserts if the terminator neither a branch, switch or return. |
43 | std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) { |
44 | std::unordered_set<BasicBlock *> output; |
45 | auto *T = BB->getTerminator(); |
46 | |
47 | if (auto *BI = dyn_cast<BranchInst>(Val: T)) { |
48 | output.insert(x: BI->getSuccessor(i: 0)); |
49 | if (BI->isConditional()) |
50 | output.insert(x: BI->getSuccessor(i: 1)); |
51 | return output; |
52 | } |
53 | |
54 | if (auto *SI = dyn_cast<SwitchInst>(Val: T)) { |
55 | output.insert(x: SI->getDefaultDest()); |
56 | for (auto &Case : SI->cases()) |
57 | output.insert(x: Case.getCaseSuccessor()); |
58 | return output; |
59 | } |
60 | |
61 | assert(isa<ReturnInst>(T) && "Unhandled terminator type." ); |
62 | return output; |
63 | } |
64 | |
65 | /// Create a value in BB set to the value associated with the branch the block |
66 | /// terminator will take. |
67 | llvm::Value *createExitVariable( |
68 | BasicBlock *BB, |
69 | const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) { |
70 | auto *T = BB->getTerminator(); |
71 | if (isa<ReturnInst>(Val: T)) |
72 | return nullptr; |
73 | |
74 | IRBuilder<> Builder(BB); |
75 | Builder.SetInsertPoint(T); |
76 | |
77 | if (auto *BI = dyn_cast<BranchInst>(Val: T)) { |
78 | |
79 | BasicBlock *LHSTarget = BI->getSuccessor(i: 0); |
80 | BasicBlock *RHSTarget = |
81 | BI->isConditional() ? BI->getSuccessor(i: 1) : nullptr; |
82 | |
83 | Value *LHS = TargetToValue.lookup(Val: LHSTarget); |
84 | Value *RHS = TargetToValue.lookup(Val: RHSTarget); |
85 | |
86 | if (LHS == nullptr || RHS == nullptr) |
87 | return LHS == nullptr ? RHS : LHS; |
88 | return Builder.CreateSelect(C: BI->getCondition(), True: LHS, False: RHS); |
89 | } |
90 | |
91 | // TODO: add support for switch cases. |
92 | llvm_unreachable("Unhandled terminator type." ); |
93 | } |
94 | |
95 | /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|. |
96 | void replaceBranchTargets(BasicBlock *BB, |
97 | const SmallPtrSet<BasicBlock *, 4> &ToReplace, |
98 | BasicBlock *NewTarget) { |
99 | auto *T = BB->getTerminator(); |
100 | if (isa<ReturnInst>(Val: T)) |
101 | return; |
102 | |
103 | if (auto *BI = dyn_cast<BranchInst>(Val: T)) { |
104 | for (size_t i = 0; i < BI->getNumSuccessors(); i++) { |
105 | if (ToReplace.count(Ptr: BI->getSuccessor(i)) != 0) |
106 | BI->setSuccessor(idx: i, NewSucc: NewTarget); |
107 | } |
108 | return; |
109 | } |
110 | |
111 | if (auto *SI = dyn_cast<SwitchInst>(Val: T)) { |
112 | for (size_t i = 0; i < SI->getNumSuccessors(); i++) { |
113 | if (ToReplace.count(Ptr: SI->getSuccessor(idx: i)) != 0) |
114 | SI->setSuccessor(idx: i, NewSucc: NewTarget); |
115 | } |
116 | return; |
117 | } |
118 | |
119 | assert(false && "Unhandled terminator type." ); |
120 | } |
121 | |
122 | AllocaInst *CreateVariable(Function &F, Type *Type, |
123 | BasicBlock::iterator Position) { |
124 | const DataLayout &DL = F.getDataLayout(); |
125 | return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg" , |
126 | Position); |
127 | } |
128 | |
129 | // Run the pass on the given convergence region, ignoring the sub-regions. |
130 | // Returns true if the CFG changed, false otherwise. |
131 | bool runOnConvergenceRegionNoRecurse(LoopInfo &LI, |
132 | SPIRV::ConvergenceRegion *CR) { |
133 | // Gather all the exit targets for this region. |
134 | SmallPtrSet<BasicBlock *, 4> ExitTargets; |
135 | for (BasicBlock *Exit : CR->Exits) { |
136 | for (BasicBlock *Target : gatherSuccessors(BB: Exit)) { |
137 | if (CR->Blocks.count(Ptr: Target) == 0) |
138 | ExitTargets.insert(Ptr: Target); |
139 | } |
140 | } |
141 | |
142 | // If we have zero or one exit target, nothing do to. |
143 | if (ExitTargets.size() <= 1) |
144 | return false; |
145 | |
146 | // Create the new single exit target. |
147 | auto F = CR->Entry->getParent(); |
148 | auto NewExitTarget = BasicBlock::Create(Context&: F->getContext(), Name: "new.exit" , Parent: F); |
149 | IRBuilder<> Builder(NewExitTarget); |
150 | |
151 | AllocaInst *Variable = CreateVariable(F&: *F, Type: Builder.getInt32Ty(), |
152 | Position: F->begin()->getFirstInsertionPt()); |
153 | |
154 | // CodeGen output needs to be stable. Using the set as-is would order |
155 | // the targets differently depending on the allocation pattern. |
156 | // Sorting per basic-block ordering in the function. |
157 | std::vector<BasicBlock *> SortedExitTargets; |
158 | std::vector<BasicBlock *> SortedExits; |
159 | for (BasicBlock &BB : *F) { |
160 | if (ExitTargets.count(Ptr: &BB) != 0) |
161 | SortedExitTargets.push_back(x: &BB); |
162 | if (CR->Exits.count(Ptr: &BB) != 0) |
163 | SortedExits.push_back(x: &BB); |
164 | } |
165 | |
166 | // Creating one constant per distinct exit target. This will be route to the |
167 | // correct target. |
168 | DenseMap<BasicBlock *, ConstantInt *> TargetToValue; |
169 | for (BasicBlock *Target : SortedExitTargets) |
170 | TargetToValue.insert( |
171 | KV: std::make_pair(x&: Target, y: Builder.getInt32(C: TargetToValue.size()))); |
172 | |
173 | // Creating one variable per exit node, set to the constant matching the |
174 | // targeted external block. |
175 | std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable; |
176 | for (auto Exit : SortedExits) { |
177 | llvm::Value *Value = createExitVariable(BB: Exit, TargetToValue); |
178 | IRBuilder<> B2(Exit); |
179 | B2.SetInsertPoint(Exit->getFirstInsertionPt()); |
180 | B2.CreateStore(Val: Value, Ptr: Variable); |
181 | ExitToVariable.emplace_back(args: std::make_pair(x&: Exit, y&: Value)); |
182 | } |
183 | |
184 | llvm::Value *Load = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: Variable); |
185 | |
186 | // Creating the switch to jump to the correct exit target. |
187 | llvm::SwitchInst *Sw = Builder.CreateSwitch(V: Load, Dest: SortedExitTargets[0], |
188 | NumCases: SortedExitTargets.size() - 1); |
189 | for (size_t i = 1; i < SortedExitTargets.size(); i++) { |
190 | BasicBlock *BB = SortedExitTargets[i]; |
191 | Sw->addCase(OnVal: TargetToValue[BB], Dest: BB); |
192 | } |
193 | |
194 | // Fix exit branches to redirect to the new exit. |
195 | for (auto Exit : CR->Exits) |
196 | replaceBranchTargets(BB: Exit, ToReplace: ExitTargets, NewTarget: NewExitTarget); |
197 | |
198 | CR = CR->Parent; |
199 | while (CR) { |
200 | CR->Blocks.insert(Ptr: NewExitTarget); |
201 | CR = CR->Parent; |
202 | } |
203 | |
204 | return true; |
205 | } |
206 | |
207 | /// Run the pass on the given convergence region and sub-regions (DFS). |
208 | /// Returns true if a region/sub-region was modified, false otherwise. |
209 | /// This returns as soon as one region/sub-region has been modified. |
210 | bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) { |
211 | for (auto *Child : CR->Children) |
212 | if (runOnConvergenceRegion(LI, CR: Child)) |
213 | return true; |
214 | |
215 | return runOnConvergenceRegionNoRecurse(LI, CR); |
216 | } |
217 | |
218 | #if !NDEBUG |
219 | /// Validates each edge exiting the region has the same destination basic |
220 | /// block. |
221 | void validateRegionExits(const SPIRV::ConvergenceRegion *CR) { |
222 | for (auto *Child : CR->Children) |
223 | validateRegionExits(Child); |
224 | |
225 | std::unordered_set<BasicBlock *> ExitTargets; |
226 | for (auto *Exit : CR->Exits) { |
227 | auto Set = gatherSuccessors(Exit); |
228 | for (auto *BB : Set) { |
229 | if (CR->Blocks.count(BB) == 0) |
230 | ExitTargets.insert(BB); |
231 | } |
232 | } |
233 | |
234 | assert(ExitTargets.size() <= 1); |
235 | } |
236 | #endif |
237 | |
238 | virtual bool runOnFunction(Function &F) override { |
239 | LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
240 | auto *TopLevelRegion = |
241 | getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>() |
242 | .getRegionInfo() |
243 | .getWritableTopLevelRegion(); |
244 | |
245 | // FIXME: very inefficient method: each time a region is modified, we bubble |
246 | // back up, and recompute the whole convergence region tree. Once the |
247 | // algorithm is completed and test coverage good enough, rewrite this pass |
248 | // to be efficient instead of simple. |
249 | bool modified = false; |
250 | while (runOnConvergenceRegion(LI, CR: TopLevelRegion)) { |
251 | modified = true; |
252 | } |
253 | |
254 | #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS) |
255 | validateRegionExits(TopLevelRegion); |
256 | #endif |
257 | return modified; |
258 | } |
259 | |
260 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
261 | AU.addRequired<DominatorTreeWrapperPass>(); |
262 | AU.addRequired<LoopInfoWrapperPass>(); |
263 | AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>(); |
264 | |
265 | AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>(); |
266 | FunctionPass::getAnalysisUsage(AU); |
267 | } |
268 | }; |
269 | } // namespace |
270 | |
271 | char SPIRVMergeRegionExitTargets::ID = 0; |
272 | |
273 | INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks" , |
274 | "SPIRV split region exit blocks" , false, false) |
275 | INITIALIZE_PASS_DEPENDENCY(LoopSimplify) |
276 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
277 | INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) |
278 | INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass) |
279 | |
280 | INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks" , |
281 | "SPIRV split region exit blocks" , false, false) |
282 | |
283 | FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() { |
284 | return new SPIRVMergeRegionExitTargets(); |
285 | } |
286 | |