1//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Merge the multiple exit targets of a convergence region into a single block.
10// Each exit target will be assigned a constant value, and a phi node + switch
11// will allow the new exit target to re-route to the correct basic block.
12//
13//===----------------------------------------------------------------------===//
14
15#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
16#include "SPIRV.h"
17#include "SPIRVSubtarget.h"
18#include "SPIRVUtils.h"
19#include "llvm/ADT/DenseMap.h"
20#include "llvm/ADT/SmallPtrSet.h"
21#include "llvm/Analysis/LoopInfo.h"
22#include "llvm/IR/Dominators.h"
23#include "llvm/IR/IRBuilder.h"
24#include "llvm/IR/Intrinsics.h"
25#include "llvm/InitializePasses.h"
26#include "llvm/Transforms/Utils/Cloning.h"
27#include "llvm/Transforms/Utils/LoopSimplify.h"
28#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
29
30using namespace llvm;
31
32namespace {
33
34class SPIRVMergeRegionExitTargets : public FunctionPass {
35public:
36 static char ID;
37
38 SPIRVMergeRegionExitTargets() : FunctionPass(ID) {}
39
40 /// Create a value in BB set to the value associated with the branch the block
41 /// terminator will take.
42 llvm::Value *createExitVariable(
43 BasicBlock *BB,
44 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
45 auto *T = BB->getTerminator();
46 if (isa<ReturnInst>(Val: T))
47 return nullptr;
48 if (auto *BI = dyn_cast<UncondBrInst>(Val: T))
49 return TargetToValue.lookup(Val: BI->getSuccessor());
50
51 IRBuilder<> Builder(BB);
52 Builder.SetInsertPoint(T);
53
54 if (auto *BI = dyn_cast<CondBrInst>(Val: T)) {
55 Value *LHS = TargetToValue.lookup(Val: BI->getSuccessor(i: 0));
56 Value *RHS = TargetToValue.lookup(Val: BI->getSuccessor(i: 1));
57
58 if (LHS == nullptr || RHS == nullptr)
59 return LHS == nullptr ? RHS : LHS;
60 return Builder.CreateSelect(C: BI->getCondition(), True: LHS, False: RHS);
61 }
62
63 // TODO: add support for switch cases.
64 llvm_unreachable("Unhandled terminator type.");
65 }
66
67 AllocaInst *CreateVariable(Function &F, Type *Type,
68 BasicBlock::iterator Position) {
69 const DataLayout &DL = F.getDataLayout();
70 return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
71 Position);
72 }
73
74 // Run the pass on the given convergence region, ignoring the sub-regions.
75 // Returns true if the CFG changed, false otherwise.
76 bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
77 SPIRV::ConvergenceRegion *CR) {
78 // Gather all the exit targets for this region.
79 SmallPtrSet<BasicBlock *, 4> ExitTargets;
80 for (BasicBlock *Exit : CR->Exits) {
81 for (BasicBlock *Target : successors(BB: Exit)) {
82 if (CR->Blocks.count(Ptr: Target) == 0)
83 ExitTargets.insert(Ptr: Target);
84 }
85 }
86
87 // If we have zero or one exit target, nothing do to.
88 if (ExitTargets.size() <= 1)
89 return false;
90
91 // Create the new single exit target.
92 auto F = CR->Entry->getParent();
93 auto NewExitTarget = BasicBlock::Create(Context&: F->getContext(), Name: "new.exit", Parent: F);
94 IRBuilder<> Builder(NewExitTarget);
95
96 AllocaInst *Variable = CreateVariable(F&: *F, Type: Builder.getInt32Ty(),
97 Position: F->begin()->getFirstInsertionPt());
98
99 // CodeGen output needs to be stable. Using the set as-is would order
100 // the targets differently depending on the allocation pattern.
101 // Sorting per basic-block ordering in the function.
102 std::vector<BasicBlock *> SortedExitTargets;
103 std::vector<BasicBlock *> SortedExits;
104 for (BasicBlock &BB : *F) {
105 if (ExitTargets.count(Ptr: &BB) != 0)
106 SortedExitTargets.push_back(x: &BB);
107 if (CR->Exits.count(Ptr: &BB) != 0)
108 SortedExits.push_back(x: &BB);
109 }
110
111 // Creating one constant per distinct exit target. This will be route to the
112 // correct target.
113 DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
114 for (BasicBlock *Target : SortedExitTargets)
115 TargetToValue.insert(
116 KV: std::make_pair(x&: Target, y: Builder.getInt32(C: TargetToValue.size())));
117
118 // Creating one variable per exit node, set to the constant matching the
119 // targeted external block.
120 std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
121 for (auto Exit : SortedExits) {
122 llvm::Value *Value = createExitVariable(BB: Exit, TargetToValue);
123 IRBuilder<> B2(Exit);
124 B2.SetInsertPoint(Exit->getFirstInsertionPt());
125 B2.CreateStore(Val: Value, Ptr: Variable);
126 ExitToVariable.emplace_back(args: std::make_pair(x&: Exit, y&: Value));
127 }
128
129 llvm::Value *Load = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: Variable);
130
131 // Creating the switch to jump to the correct exit target.
132 llvm::SwitchInst *Sw = Builder.CreateSwitch(V: Load, Dest: SortedExitTargets[0],
133 NumCases: SortedExitTargets.size() - 1);
134 for (size_t i = 1; i < SortedExitTargets.size(); i++) {
135 BasicBlock *BB = SortedExitTargets[i];
136 Sw->addCase(OnVal: TargetToValue[BB], Dest: BB);
137 }
138
139 // Fix exit branches to redirect to the new exit.
140 for (auto Exit : CR->Exits) {
141 Instruction *T = Exit->getTerminator();
142 for (auto I = succ_begin(I: T), E = succ_end(I: T); I != E; ++I)
143 if (ExitTargets.contains(Ptr: *I))
144 I.getUse()->set(NewExitTarget);
145 }
146
147 CR = CR->Parent;
148 while (CR) {
149 CR->Blocks.insert(Ptr: NewExitTarget);
150 CR = CR->Parent;
151 }
152
153 return true;
154 }
155
156 /// Run the pass on the given convergence region and sub-regions (DFS).
157 /// Returns true if a region/sub-region was modified, false otherwise.
158 /// This returns as soon as one region/sub-region has been modified.
159 bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) {
160 for (auto *Child : CR->Children)
161 if (runOnConvergenceRegion(LI, CR: Child))
162 return true;
163
164 return runOnConvergenceRegionNoRecurse(LI, CR);
165 }
166
167#if !NDEBUG
168 /// Validates each edge exiting the region has the same destination basic
169 /// block.
170 void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
171 for (auto *Child : CR->Children)
172 validateRegionExits(Child);
173
174 std::unordered_set<BasicBlock *> ExitTargets;
175 for (auto *Exit : CR->Exits) {
176 for (auto *BB : successors(Exit)) {
177 if (CR->Blocks.count(BB) == 0)
178 ExitTargets.insert(BB);
179 }
180 }
181
182 assert(ExitTargets.size() <= 1);
183 }
184#endif
185
186 bool runOnFunction(Function &F) override {
187 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
188 auto *TopLevelRegion =
189 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
190 .getRegionInfo()
191 .getWritableTopLevelRegion();
192
193 // FIXME: very inefficient method: each time a region is modified, we bubble
194 // back up, and recompute the whole convergence region tree. Once the
195 // algorithm is completed and test coverage good enough, rewrite this pass
196 // to be efficient instead of simple.
197 bool modified = false;
198 while (runOnConvergenceRegion(LI, CR: TopLevelRegion)) {
199 modified = true;
200 }
201
202#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
203 validateRegionExits(TopLevelRegion);
204#endif
205 return modified;
206 }
207
208 void getAnalysisUsage(AnalysisUsage &AU) const override {
209 AU.addRequired<DominatorTreeWrapperPass>();
210 AU.addRequired<LoopInfoWrapperPass>();
211 AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
212
213 AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
214 FunctionPass::getAnalysisUsage(AU);
215 }
216};
217} // namespace
218
219char SPIRVMergeRegionExitTargets::ID = 0;
220
221INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
222 "SPIRV split region exit blocks", false, false)
223INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
224INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
225INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
226INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
227
228INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
229 "SPIRV split region exit blocks", false, false)
230
231FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
232 return new SPIRVMergeRegionExitTargets();
233}
234