| 1 | //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Merge the multiple exit targets of a convergence region into a single block. |
| 10 | // Each exit target will be assigned a constant value, and a phi node + switch |
| 11 | // will allow the new exit target to re-route to the correct basic block. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "SPIRVMergeRegionExitTargets.h" |
| 16 | #include "Analysis/SPIRVConvergenceRegionAnalysis.h" |
| 17 | #include "SPIRV.h" |
| 18 | #include "SPIRVSubtarget.h" |
| 19 | #include "SPIRVUtils.h" |
| 20 | #include "llvm/ADT/DenseMap.h" |
| 21 | #include "llvm/ADT/SmallPtrSet.h" |
| 22 | #include "llvm/Analysis/LoopInfo.h" |
| 23 | #include "llvm/IR/Dominators.h" |
| 24 | #include "llvm/IR/IRBuilder.h" |
| 25 | #include "llvm/IR/Intrinsics.h" |
| 26 | #include "llvm/InitializePasses.h" |
| 27 | #include "llvm/Transforms/Utils/Cloning.h" |
| 28 | #include "llvm/Transforms/Utils/LoopSimplify.h" |
| 29 | #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" |
| 30 | |
| 31 | using namespace llvm; |
| 32 | |
| 33 | namespace { |
| 34 | |
| 35 | // Run the pass on the given convergence region, ignoring the sub-regions. |
| 36 | // Returns true if the CFG changed, false otherwise. |
| 37 | static bool runOnConvergenceRegionNoRecurse(LoopInfo &LI, |
| 38 | SPIRV::ConvergenceRegion *CR) { |
| 39 | // Gather all the exit targets for this region. |
| 40 | SmallPtrSet<BasicBlock *, 4> ExitTargets; |
| 41 | for (BasicBlock *Exit : CR->Exits) { |
| 42 | for (BasicBlock *Target : successors(BB: Exit)) { |
| 43 | if (CR->Blocks.count(Ptr: Target) == 0) |
| 44 | ExitTargets.insert(Ptr: Target); |
| 45 | } |
| 46 | } |
| 47 | |
| 48 | // If we have zero or one exit target, nothing do to. |
| 49 | if (ExitTargets.size() <= 1) |
| 50 | return false; |
| 51 | |
| 52 | // Create the new single exit target. |
| 53 | auto F = CR->Entry->getParent(); |
| 54 | auto NewExitTarget = BasicBlock::Create(Context&: F->getContext(), Name: "new.exit" , Parent: F); |
| 55 | IRBuilder<> Builder(NewExitTarget); |
| 56 | |
| 57 | AllocaInst *Variable = createVariable(F&: *F, Type: Builder.getInt32Ty()); |
| 58 | |
| 59 | // CodeGen output needs to be stable. Using the set as-is would order |
| 60 | // the targets differently depending on the allocation pattern. |
| 61 | // Sorting per basic-block ordering in the function. |
| 62 | std::vector<BasicBlock *> SortedExitTargets; |
| 63 | std::vector<BasicBlock *> SortedExits; |
| 64 | for (BasicBlock &BB : *F) { |
| 65 | if (ExitTargets.count(Ptr: &BB) != 0) |
| 66 | SortedExitTargets.push_back(x: &BB); |
| 67 | if (CR->Exits.count(Ptr: &BB) != 0) |
| 68 | SortedExits.push_back(x: &BB); |
| 69 | } |
| 70 | |
| 71 | // Creating one constant per distinct exit target. This will be route to the |
| 72 | // correct target. |
| 73 | DenseMap<BasicBlock *, ConstantInt *> TargetToValue; |
| 74 | for (BasicBlock *Target : SortedExitTargets) |
| 75 | TargetToValue.insert( |
| 76 | KV: std::make_pair(x&: Target, y: Builder.getInt32(C: TargetToValue.size()))); |
| 77 | |
| 78 | // Creating one variable per exit node, set to the constant matching the |
| 79 | // targeted external block. |
| 80 | std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable; |
| 81 | for (auto Exit : SortedExits) { |
| 82 | llvm::Value *Value = createExitVariable(BB: Exit, TargetToValue); |
| 83 | IRBuilder<> B2(Exit); |
| 84 | B2.SetInsertPoint(Exit->getFirstInsertionPt()); |
| 85 | B2.CreateStore(Val: Value, Ptr: Variable); |
| 86 | ExitToVariable.emplace_back(args: std::make_pair(x&: Exit, y&: Value)); |
| 87 | } |
| 88 | |
| 89 | llvm::Value *Load = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: Variable); |
| 90 | |
| 91 | // Creating the switch to jump to the correct exit target. |
| 92 | llvm::SwitchInst *Sw = Builder.CreateSwitch(V: Load, Dest: SortedExitTargets[0], |
| 93 | NumCases: SortedExitTargets.size() - 1); |
| 94 | for (size_t i = 1; i < SortedExitTargets.size(); i++) { |
| 95 | BasicBlock *BB = SortedExitTargets[i]; |
| 96 | Sw->addCase(OnVal: TargetToValue[BB], Dest: BB); |
| 97 | } |
| 98 | |
| 99 | // Fix exit branches to redirect to the new exit. |
| 100 | for (auto Exit : CR->Exits) { |
| 101 | Instruction *T = Exit->getTerminator(); |
| 102 | for (auto I = succ_begin(I: T), E = succ_end(I: T); I != E; ++I) |
| 103 | if (ExitTargets.contains(Ptr: *I)) |
| 104 | I.getUse()->set(NewExitTarget); |
| 105 | } |
| 106 | |
| 107 | CR = CR->Parent; |
| 108 | while (CR) { |
| 109 | CR->Blocks.insert(Ptr: NewExitTarget); |
| 110 | CR = CR->Parent; |
| 111 | } |
| 112 | |
| 113 | return true; |
| 114 | } |
| 115 | |
| 116 | /// Run the pass on the given convergence region and sub-regions (DFS). |
| 117 | /// Returns true if a region/sub-region was modified, false otherwise. |
| 118 | /// This returns as soon as one region/sub-region has been modified. |
| 119 | static bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) { |
| 120 | for (auto *Child : CR->Children) |
| 121 | if (runOnConvergenceRegion(LI, CR: Child)) |
| 122 | return true; |
| 123 | |
| 124 | return runOnConvergenceRegionNoRecurse(LI, CR); |
| 125 | } |
| 126 | |
| 127 | #if !NDEBUG |
| 128 | /// Validates each edge exiting the region has the same destination basic |
| 129 | /// block. |
| 130 | static void validateRegionExits(const SPIRV::ConvergenceRegion *CR) { |
| 131 | for (auto *Child : CR->Children) |
| 132 | validateRegionExits(Child); |
| 133 | |
| 134 | SmallPtrSet<BasicBlock *, 0> ExitTargets; |
| 135 | for (auto *Exit : CR->Exits) { |
| 136 | for (auto *BB : successors(Exit)) { |
| 137 | if (CR->Blocks.count(BB) == 0) |
| 138 | ExitTargets.insert(BB); |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | assert(ExitTargets.size() <= 1); |
| 143 | } |
| 144 | #endif |
| 145 | |
| 146 | static bool runImpl(Function &F, LoopInfo &LI, |
| 147 | SPIRV::ConvergenceRegionInfo &RegionInfo) { |
| 148 | auto *TopLevelRegion = RegionInfo.getWritableTopLevelRegion(); |
| 149 | |
| 150 | // FIXME: very inefficient method: each time a region is modified, we bubble |
| 151 | // back up, and recompute the whole convergence region tree. Once the |
| 152 | // algorithm is completed and test coverage good enough, rewrite this pass |
| 153 | // to be efficient instead of simple. |
| 154 | bool Modified = false; |
| 155 | while (runOnConvergenceRegion(LI, CR: TopLevelRegion)) { |
| 156 | Modified = true; |
| 157 | } |
| 158 | |
| 159 | #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS) |
| 160 | validateRegionExits(TopLevelRegion); |
| 161 | #endif |
| 162 | return Modified; |
| 163 | } |
| 164 | |
| 165 | class SPIRVMergeRegionExitTargetsLegacy : public FunctionPass { |
| 166 | public: |
| 167 | static char ID; |
| 168 | |
| 169 | SPIRVMergeRegionExitTargetsLegacy() : FunctionPass(ID) {} |
| 170 | |
| 171 | bool runOnFunction(Function &F) override { |
| 172 | LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
| 173 | auto &RegionInfo = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>() |
| 174 | .getRegionInfo(); |
| 175 | return runImpl(F, LI, RegionInfo); |
| 176 | } |
| 177 | |
| 178 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 179 | AU.addRequired<DominatorTreeWrapperPass>(); |
| 180 | AU.addRequired<LoopInfoWrapperPass>(); |
| 181 | AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>(); |
| 182 | |
| 183 | AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>(); |
| 184 | FunctionPass::getAnalysisUsage(AU); |
| 185 | } |
| 186 | }; |
| 187 | } // namespace |
| 188 | |
| 189 | PreservedAnalyses |
| 190 | SPIRVMergeRegionExitTargets::run(Function &F, FunctionAnalysisManager &AM) { |
| 191 | auto &LI = AM.getResult<LoopAnalysis>(IR&: F); |
| 192 | auto &RegionInfo = AM.getResult<SPIRVConvergenceRegionAnalysis>(IR&: F); |
| 193 | return runImpl(F, LI, RegionInfo) ? PreservedAnalyses::none() |
| 194 | : PreservedAnalyses::all(); |
| 195 | } |
| 196 | |
| 197 | char SPIRVMergeRegionExitTargetsLegacy::ID = 0; |
| 198 | |
| 199 | INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargetsLegacy, |
| 200 | "split-region-exit-blocks" , |
| 201 | "SPIRV split region exit blocks" , false, false) |
| 202 | INITIALIZE_PASS_DEPENDENCY(LoopSimplify) |
| 203 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
| 204 | INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) |
| 205 | INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass) |
| 206 | |
| 207 | INITIALIZE_PASS_END(SPIRVMergeRegionExitTargetsLegacy, |
| 208 | "split-region-exit-blocks" , |
| 209 | "SPIRV split region exit blocks" , false, false) |
| 210 | |
| 211 | FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() { |
| 212 | return new SPIRVMergeRegionExitTargetsLegacy(); |
| 213 | } |
| 214 | |