1//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Merge the multiple exit targets of a convergence region into a single block.
10// Each exit target will be assigned a constant value, and a phi node + switch
11// will allow the new exit target to re-route to the correct basic block.
12//
13//===----------------------------------------------------------------------===//
14
15#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
16#include "SPIRV.h"
17#include "SPIRVSubtarget.h"
18#include "SPIRVUtils.h"
19#include "llvm/ADT/DenseMap.h"
20#include "llvm/ADT/SmallPtrSet.h"
21#include "llvm/Analysis/LoopInfo.h"
22#include "llvm/IR/Dominators.h"
23#include "llvm/IR/IRBuilder.h"
24#include "llvm/IR/Intrinsics.h"
25#include "llvm/InitializePasses.h"
26#include "llvm/Transforms/Utils/Cloning.h"
27#include "llvm/Transforms/Utils/LoopSimplify.h"
28#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
29
30using namespace llvm;
31
32namespace {
33
34class SPIRVMergeRegionExitTargets : public FunctionPass {
35public:
36 static char ID;
37
38 SPIRVMergeRegionExitTargets() : FunctionPass(ID) {}
39
40 // Gather all the successors of |BB|.
41 // This function asserts if the terminator neither a branch, switch or return.
42 std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
43 std::unordered_set<BasicBlock *> output;
44 auto *T = BB->getTerminator();
45
46 if (auto *BI = dyn_cast<BranchInst>(Val: T)) {
47 output.insert(x: BI->getSuccessor(i: 0));
48 if (BI->isConditional())
49 output.insert(x: BI->getSuccessor(i: 1));
50 return output;
51 }
52
53 if (auto *SI = dyn_cast<SwitchInst>(Val: T)) {
54 output.insert(x: SI->getDefaultDest());
55 for (auto &Case : SI->cases())
56 output.insert(x: Case.getCaseSuccessor());
57 return output;
58 }
59
60 assert(isa<ReturnInst>(T) && "Unhandled terminator type.");
61 return output;
62 }
63
64 /// Create a value in BB set to the value associated with the branch the block
65 /// terminator will take.
66 llvm::Value *createExitVariable(
67 BasicBlock *BB,
68 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
69 auto *T = BB->getTerminator();
70 if (isa<ReturnInst>(Val: T))
71 return nullptr;
72
73 IRBuilder<> Builder(BB);
74 Builder.SetInsertPoint(T);
75
76 if (auto *BI = dyn_cast<BranchInst>(Val: T)) {
77
78 BasicBlock *LHSTarget = BI->getSuccessor(i: 0);
79 BasicBlock *RHSTarget =
80 BI->isConditional() ? BI->getSuccessor(i: 1) : nullptr;
81
82 Value *LHS = TargetToValue.lookup(Val: LHSTarget);
83 Value *RHS = TargetToValue.lookup(Val: RHSTarget);
84
85 if (LHS == nullptr || RHS == nullptr)
86 return LHS == nullptr ? RHS : LHS;
87 return Builder.CreateSelect(C: BI->getCondition(), True: LHS, False: RHS);
88 }
89
90 // TODO: add support for switch cases.
91 llvm_unreachable("Unhandled terminator type.");
92 }
93
94 /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
95 void replaceBranchTargets(BasicBlock *BB,
96 const SmallPtrSet<BasicBlock *, 4> &ToReplace,
97 BasicBlock *NewTarget) {
98 auto *T = BB->getTerminator();
99 if (isa<ReturnInst>(Val: T))
100 return;
101
102 if (auto *BI = dyn_cast<BranchInst>(Val: T)) {
103 for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
104 if (ToReplace.count(Ptr: BI->getSuccessor(i)) != 0)
105 BI->setSuccessor(idx: i, NewSucc: NewTarget);
106 }
107 return;
108 }
109
110 if (auto *SI = dyn_cast<SwitchInst>(Val: T)) {
111 for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
112 if (ToReplace.count(Ptr: SI->getSuccessor(idx: i)) != 0)
113 SI->setSuccessor(idx: i, NewSucc: NewTarget);
114 }
115 return;
116 }
117
118 assert(false && "Unhandled terminator type.");
119 }
120
121 AllocaInst *CreateVariable(Function &F, Type *Type,
122 BasicBlock::iterator Position) {
123 const DataLayout &DL = F.getDataLayout();
124 return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
125 Position);
126 }
127
128 // Run the pass on the given convergence region, ignoring the sub-regions.
129 // Returns true if the CFG changed, false otherwise.
130 bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
131 SPIRV::ConvergenceRegion *CR) {
132 // Gather all the exit targets for this region.
133 SmallPtrSet<BasicBlock *, 4> ExitTargets;
134 for (BasicBlock *Exit : CR->Exits) {
135 for (BasicBlock *Target : gatherSuccessors(BB: Exit)) {
136 if (CR->Blocks.count(Ptr: Target) == 0)
137 ExitTargets.insert(Ptr: Target);
138 }
139 }
140
141 // If we have zero or one exit target, nothing do to.
142 if (ExitTargets.size() <= 1)
143 return false;
144
145 // Create the new single exit target.
146 auto F = CR->Entry->getParent();
147 auto NewExitTarget = BasicBlock::Create(Context&: F->getContext(), Name: "new.exit", Parent: F);
148 IRBuilder<> Builder(NewExitTarget);
149
150 AllocaInst *Variable = CreateVariable(F&: *F, Type: Builder.getInt32Ty(),
151 Position: F->begin()->getFirstInsertionPt());
152
153 // CodeGen output needs to be stable. Using the set as-is would order
154 // the targets differently depending on the allocation pattern.
155 // Sorting per basic-block ordering in the function.
156 std::vector<BasicBlock *> SortedExitTargets;
157 std::vector<BasicBlock *> SortedExits;
158 for (BasicBlock &BB : *F) {
159 if (ExitTargets.count(Ptr: &BB) != 0)
160 SortedExitTargets.push_back(x: &BB);
161 if (CR->Exits.count(Ptr: &BB) != 0)
162 SortedExits.push_back(x: &BB);
163 }
164
165 // Creating one constant per distinct exit target. This will be route to the
166 // correct target.
167 DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
168 for (BasicBlock *Target : SortedExitTargets)
169 TargetToValue.insert(
170 KV: std::make_pair(x&: Target, y: Builder.getInt32(C: TargetToValue.size())));
171
172 // Creating one variable per exit node, set to the constant matching the
173 // targeted external block.
174 std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
175 for (auto Exit : SortedExits) {
176 llvm::Value *Value = createExitVariable(BB: Exit, TargetToValue);
177 IRBuilder<> B2(Exit);
178 B2.SetInsertPoint(Exit->getFirstInsertionPt());
179 B2.CreateStore(Val: Value, Ptr: Variable);
180 ExitToVariable.emplace_back(args: std::make_pair(x&: Exit, y&: Value));
181 }
182
183 llvm::Value *Load = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: Variable);
184
185 // Creating the switch to jump to the correct exit target.
186 llvm::SwitchInst *Sw = Builder.CreateSwitch(V: Load, Dest: SortedExitTargets[0],
187 NumCases: SortedExitTargets.size() - 1);
188 for (size_t i = 1; i < SortedExitTargets.size(); i++) {
189 BasicBlock *BB = SortedExitTargets[i];
190 Sw->addCase(OnVal: TargetToValue[BB], Dest: BB);
191 }
192
193 // Fix exit branches to redirect to the new exit.
194 for (auto Exit : CR->Exits)
195 replaceBranchTargets(BB: Exit, ToReplace: ExitTargets, NewTarget: NewExitTarget);
196
197 CR = CR->Parent;
198 while (CR) {
199 CR->Blocks.insert(Ptr: NewExitTarget);
200 CR = CR->Parent;
201 }
202
203 return true;
204 }
205
206 /// Run the pass on the given convergence region and sub-regions (DFS).
207 /// Returns true if a region/sub-region was modified, false otherwise.
208 /// This returns as soon as one region/sub-region has been modified.
209 bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) {
210 for (auto *Child : CR->Children)
211 if (runOnConvergenceRegion(LI, CR: Child))
212 return true;
213
214 return runOnConvergenceRegionNoRecurse(LI, CR);
215 }
216
217#if !NDEBUG
218 /// Validates each edge exiting the region has the same destination basic
219 /// block.
220 void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
221 for (auto *Child : CR->Children)
222 validateRegionExits(Child);
223
224 std::unordered_set<BasicBlock *> ExitTargets;
225 for (auto *Exit : CR->Exits) {
226 auto Set = gatherSuccessors(Exit);
227 for (auto *BB : Set) {
228 if (CR->Blocks.count(BB) == 0)
229 ExitTargets.insert(BB);
230 }
231 }
232
233 assert(ExitTargets.size() <= 1);
234 }
235#endif
236
237 bool runOnFunction(Function &F) override {
238 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
239 auto *TopLevelRegion =
240 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
241 .getRegionInfo()
242 .getWritableTopLevelRegion();
243
244 // FIXME: very inefficient method: each time a region is modified, we bubble
245 // back up, and recompute the whole convergence region tree. Once the
246 // algorithm is completed and test coverage good enough, rewrite this pass
247 // to be efficient instead of simple.
248 bool modified = false;
249 while (runOnConvergenceRegion(LI, CR: TopLevelRegion)) {
250 modified = true;
251 }
252
253#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
254 validateRegionExits(TopLevelRegion);
255#endif
256 return modified;
257 }
258
259 void getAnalysisUsage(AnalysisUsage &AU) const override {
260 AU.addRequired<DominatorTreeWrapperPass>();
261 AU.addRequired<LoopInfoWrapperPass>();
262 AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
263
264 AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
265 FunctionPass::getAnalysisUsage(AU);
266 }
267};
268} // namespace
269
270char SPIRVMergeRegionExitTargets::ID = 0;
271
272INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
273 "SPIRV split region exit blocks", false, false)
274INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
275INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
276INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
277INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
278
279INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
280 "SPIRV split region exit blocks", false, false)
281
282FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
283 return new SPIRVMergeRegionExitTargets();
284}
285