1//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Merge the multiple exit targets of a convergence region into a single block.
10// Each exit target will be assigned a constant value, and a phi node + switch
11// will allow the new exit target to re-route to the correct basic block.
12//
13//===----------------------------------------------------------------------===//
14
15#include "SPIRVMergeRegionExitTargets.h"
16#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
17#include "SPIRV.h"
18#include "SPIRVSubtarget.h"
19#include "SPIRVUtils.h"
20#include "llvm/ADT/DenseMap.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Analysis/LoopInfo.h"
23#include "llvm/IR/Dominators.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/Intrinsics.h"
26#include "llvm/InitializePasses.h"
27#include "llvm/Transforms/Utils/Cloning.h"
28#include "llvm/Transforms/Utils/LoopSimplify.h"
29#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
30
31using namespace llvm;
32
33namespace {
34
35// Run the pass on the given convergence region, ignoring the sub-regions.
36// Returns true if the CFG changed, false otherwise.
37static bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
38 SPIRV::ConvergenceRegion *CR) {
39 // Gather all the exit targets for this region.
40 SmallPtrSet<BasicBlock *, 4> ExitTargets;
41 for (BasicBlock *Exit : CR->Exits) {
42 for (BasicBlock *Target : successors(BB: Exit)) {
43 if (CR->Blocks.count(Ptr: Target) == 0)
44 ExitTargets.insert(Ptr: Target);
45 }
46 }
47
48 // If we have zero or one exit target, nothing do to.
49 if (ExitTargets.size() <= 1)
50 return false;
51
52 // Create the new single exit target.
53 auto F = CR->Entry->getParent();
54 auto NewExitTarget = BasicBlock::Create(Context&: F->getContext(), Name: "new.exit", Parent: F);
55 IRBuilder<> Builder(NewExitTarget);
56
57 AllocaInst *Variable = createVariable(F&: *F, Type: Builder.getInt32Ty());
58
59 // CodeGen output needs to be stable. Using the set as-is would order
60 // the targets differently depending on the allocation pattern.
61 // Sorting per basic-block ordering in the function.
62 std::vector<BasicBlock *> SortedExitTargets;
63 std::vector<BasicBlock *> SortedExits;
64 for (BasicBlock &BB : *F) {
65 if (ExitTargets.count(Ptr: &BB) != 0)
66 SortedExitTargets.push_back(x: &BB);
67 if (CR->Exits.count(Ptr: &BB) != 0)
68 SortedExits.push_back(x: &BB);
69 }
70
71 // Creating one constant per distinct exit target. This will be route to the
72 // correct target.
73 DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
74 for (BasicBlock *Target : SortedExitTargets)
75 TargetToValue.insert(
76 KV: std::make_pair(x&: Target, y: Builder.getInt32(C: TargetToValue.size())));
77
78 // Creating one variable per exit node, set to the constant matching the
79 // targeted external block.
80 std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
81 for (auto Exit : SortedExits) {
82 llvm::Value *Value = createExitVariable(BB: Exit, TargetToValue);
83 IRBuilder<> B2(Exit);
84 B2.SetInsertPoint(Exit->getFirstInsertionPt());
85 B2.CreateStore(Val: Value, Ptr: Variable);
86 ExitToVariable.emplace_back(args: std::make_pair(x&: Exit, y&: Value));
87 }
88
89 llvm::Value *Load = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: Variable);
90
91 // Creating the switch to jump to the correct exit target.
92 llvm::SwitchInst *Sw = Builder.CreateSwitch(V: Load, Dest: SortedExitTargets[0],
93 NumCases: SortedExitTargets.size() - 1);
94 for (size_t i = 1; i < SortedExitTargets.size(); i++) {
95 BasicBlock *BB = SortedExitTargets[i];
96 Sw->addCase(OnVal: TargetToValue[BB], Dest: BB);
97 }
98
99 // Fix exit branches to redirect to the new exit.
100 for (auto Exit : CR->Exits) {
101 Instruction *T = Exit->getTerminator();
102 for (auto I = succ_begin(I: T), E = succ_end(I: T); I != E; ++I)
103 if (ExitTargets.contains(Ptr: *I))
104 I.getUse()->set(NewExitTarget);
105 }
106
107 CR = CR->Parent;
108 while (CR) {
109 CR->Blocks.insert(Ptr: NewExitTarget);
110 CR = CR->Parent;
111 }
112
113 return true;
114}
115
116/// Run the pass on the given convergence region and sub-regions (DFS).
117/// Returns true if a region/sub-region was modified, false otherwise.
118/// This returns as soon as one region/sub-region has been modified.
119static bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) {
120 for (auto *Child : CR->Children)
121 if (runOnConvergenceRegion(LI, CR: Child))
122 return true;
123
124 return runOnConvergenceRegionNoRecurse(LI, CR);
125}
126
127#if !NDEBUG
128/// Validates each edge exiting the region has the same destination basic
129/// block.
130static void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
131 for (auto *Child : CR->Children)
132 validateRegionExits(Child);
133
134 SmallPtrSet<BasicBlock *, 0> ExitTargets;
135 for (auto *Exit : CR->Exits) {
136 for (auto *BB : successors(Exit)) {
137 if (CR->Blocks.count(BB) == 0)
138 ExitTargets.insert(BB);
139 }
140 }
141
142 assert(ExitTargets.size() <= 1);
143}
144#endif
145
146static bool runImpl(Function &F, LoopInfo &LI,
147 SPIRV::ConvergenceRegionInfo &RegionInfo) {
148 auto *TopLevelRegion = RegionInfo.getWritableTopLevelRegion();
149
150 // FIXME: very inefficient method: each time a region is modified, we bubble
151 // back up, and recompute the whole convergence region tree. Once the
152 // algorithm is completed and test coverage good enough, rewrite this pass
153 // to be efficient instead of simple.
154 bool Modified = false;
155 while (runOnConvergenceRegion(LI, CR: TopLevelRegion)) {
156 Modified = true;
157 }
158
159#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
160 validateRegionExits(TopLevelRegion);
161#endif
162 return Modified;
163}
164
165class SPIRVMergeRegionExitTargetsLegacy : public FunctionPass {
166public:
167 static char ID;
168
169 SPIRVMergeRegionExitTargetsLegacy() : FunctionPass(ID) {}
170
171 bool runOnFunction(Function &F) override {
172 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
173 auto &RegionInfo = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
174 .getRegionInfo();
175 return runImpl(F, LI, RegionInfo);
176 }
177
178 void getAnalysisUsage(AnalysisUsage &AU) const override {
179 AU.addRequired<DominatorTreeWrapperPass>();
180 AU.addRequired<LoopInfoWrapperPass>();
181 AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
182
183 AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
184 FunctionPass::getAnalysisUsage(AU);
185 }
186};
187} // namespace
188
189PreservedAnalyses
190SPIRVMergeRegionExitTargets::run(Function &F, FunctionAnalysisManager &AM) {
191 auto &LI = AM.getResult<LoopAnalysis>(IR&: F);
192 auto &RegionInfo = AM.getResult<SPIRVConvergenceRegionAnalysis>(IR&: F);
193 return runImpl(F, LI, RegionInfo) ? PreservedAnalyses::none()
194 : PreservedAnalyses::all();
195}
196
197char SPIRVMergeRegionExitTargetsLegacy::ID = 0;
198
199INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargetsLegacy,
200 "split-region-exit-blocks",
201 "SPIRV split region exit blocks", false, false)
202INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
203INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
204INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
205INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
206
207INITIALIZE_PASS_END(SPIRVMergeRegionExitTargetsLegacy,
208 "split-region-exit-blocks",
209 "SPIRV split region exit blocks", false, false)
210
211FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
212 return new SPIRVMergeRegionExitTargetsLegacy();
213}
214