| 1 | //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Merge the multiple exit targets of a convergence region into a single block. |
| 10 | // Each exit target will be assigned a constant value, and a phi node + switch |
| 11 | // will allow the new exit target to re-route to the correct basic block. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "Analysis/SPIRVConvergenceRegionAnalysis.h" |
| 16 | #include "SPIRV.h" |
| 17 | #include "SPIRVSubtarget.h" |
| 18 | #include "SPIRVUtils.h" |
| 19 | #include "llvm/ADT/DenseMap.h" |
| 20 | #include "llvm/ADT/SmallPtrSet.h" |
| 21 | #include "llvm/Analysis/LoopInfo.h" |
| 22 | #include "llvm/CodeGen/IntrinsicLowering.h" |
| 23 | #include "llvm/IR/Dominators.h" |
| 24 | #include "llvm/IR/IRBuilder.h" |
| 25 | #include "llvm/IR/Intrinsics.h" |
| 26 | #include "llvm/InitializePasses.h" |
| 27 | #include "llvm/Transforms/Utils/Cloning.h" |
| 28 | #include "llvm/Transforms/Utils/LoopSimplify.h" |
| 29 | #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" |
| 30 | |
| 31 | using namespace llvm; |
| 32 | |
| 33 | namespace { |
| 34 | |
| 35 | class SPIRVMergeRegionExitTargets : public FunctionPass { |
| 36 | public: |
| 37 | static char ID; |
| 38 | |
| 39 | SPIRVMergeRegionExitTargets() : FunctionPass(ID) {} |
| 40 | |
| 41 | // Gather all the successors of |BB|. |
| 42 | // This function asserts if the terminator neither a branch, switch or return. |
| 43 | std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) { |
| 44 | std::unordered_set<BasicBlock *> output; |
| 45 | auto *T = BB->getTerminator(); |
| 46 | |
| 47 | if (auto *BI = dyn_cast<BranchInst>(Val: T)) { |
| 48 | output.insert(x: BI->getSuccessor(i: 0)); |
| 49 | if (BI->isConditional()) |
| 50 | output.insert(x: BI->getSuccessor(i: 1)); |
| 51 | return output; |
| 52 | } |
| 53 | |
| 54 | if (auto *SI = dyn_cast<SwitchInst>(Val: T)) { |
| 55 | output.insert(x: SI->getDefaultDest()); |
| 56 | for (auto &Case : SI->cases()) |
| 57 | output.insert(x: Case.getCaseSuccessor()); |
| 58 | return output; |
| 59 | } |
| 60 | |
| 61 | assert(isa<ReturnInst>(T) && "Unhandled terminator type." ); |
| 62 | return output; |
| 63 | } |
| 64 | |
| 65 | /// Create a value in BB set to the value associated with the branch the block |
| 66 | /// terminator will take. |
| 67 | llvm::Value *createExitVariable( |
| 68 | BasicBlock *BB, |
| 69 | const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) { |
| 70 | auto *T = BB->getTerminator(); |
| 71 | if (isa<ReturnInst>(Val: T)) |
| 72 | return nullptr; |
| 73 | |
| 74 | IRBuilder<> Builder(BB); |
| 75 | Builder.SetInsertPoint(T); |
| 76 | |
| 77 | if (auto *BI = dyn_cast<BranchInst>(Val: T)) { |
| 78 | |
| 79 | BasicBlock *LHSTarget = BI->getSuccessor(i: 0); |
| 80 | BasicBlock *RHSTarget = |
| 81 | BI->isConditional() ? BI->getSuccessor(i: 1) : nullptr; |
| 82 | |
| 83 | Value *LHS = TargetToValue.lookup(Val: LHSTarget); |
| 84 | Value *RHS = TargetToValue.lookup(Val: RHSTarget); |
| 85 | |
| 86 | if (LHS == nullptr || RHS == nullptr) |
| 87 | return LHS == nullptr ? RHS : LHS; |
| 88 | return Builder.CreateSelect(C: BI->getCondition(), True: LHS, False: RHS); |
| 89 | } |
| 90 | |
| 91 | // TODO: add support for switch cases. |
| 92 | llvm_unreachable("Unhandled terminator type." ); |
| 93 | } |
| 94 | |
| 95 | /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|. |
| 96 | void replaceBranchTargets(BasicBlock *BB, |
| 97 | const SmallPtrSet<BasicBlock *, 4> &ToReplace, |
| 98 | BasicBlock *NewTarget) { |
| 99 | auto *T = BB->getTerminator(); |
| 100 | if (isa<ReturnInst>(Val: T)) |
| 101 | return; |
| 102 | |
| 103 | if (auto *BI = dyn_cast<BranchInst>(Val: T)) { |
| 104 | for (size_t i = 0; i < BI->getNumSuccessors(); i++) { |
| 105 | if (ToReplace.count(Ptr: BI->getSuccessor(i)) != 0) |
| 106 | BI->setSuccessor(idx: i, NewSucc: NewTarget); |
| 107 | } |
| 108 | return; |
| 109 | } |
| 110 | |
| 111 | if (auto *SI = dyn_cast<SwitchInst>(Val: T)) { |
| 112 | for (size_t i = 0; i < SI->getNumSuccessors(); i++) { |
| 113 | if (ToReplace.count(Ptr: SI->getSuccessor(idx: i)) != 0) |
| 114 | SI->setSuccessor(idx: i, NewSucc: NewTarget); |
| 115 | } |
| 116 | return; |
| 117 | } |
| 118 | |
| 119 | assert(false && "Unhandled terminator type." ); |
| 120 | } |
| 121 | |
| 122 | AllocaInst *CreateVariable(Function &F, Type *Type, |
| 123 | BasicBlock::iterator Position) { |
| 124 | const DataLayout &DL = F.getDataLayout(); |
| 125 | return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg" , |
| 126 | Position); |
| 127 | } |
| 128 | |
| 129 | // Run the pass on the given convergence region, ignoring the sub-regions. |
| 130 | // Returns true if the CFG changed, false otherwise. |
| 131 | bool runOnConvergenceRegionNoRecurse(LoopInfo &LI, |
| 132 | SPIRV::ConvergenceRegion *CR) { |
| 133 | // Gather all the exit targets for this region. |
| 134 | SmallPtrSet<BasicBlock *, 4> ExitTargets; |
| 135 | for (BasicBlock *Exit : CR->Exits) { |
| 136 | for (BasicBlock *Target : gatherSuccessors(BB: Exit)) { |
| 137 | if (CR->Blocks.count(Ptr: Target) == 0) |
| 138 | ExitTargets.insert(Ptr: Target); |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | // If we have zero or one exit target, nothing do to. |
| 143 | if (ExitTargets.size() <= 1) |
| 144 | return false; |
| 145 | |
| 146 | // Create the new single exit target. |
| 147 | auto F = CR->Entry->getParent(); |
| 148 | auto NewExitTarget = BasicBlock::Create(Context&: F->getContext(), Name: "new.exit" , Parent: F); |
| 149 | IRBuilder<> Builder(NewExitTarget); |
| 150 | |
| 151 | AllocaInst *Variable = CreateVariable(F&: *F, Type: Builder.getInt32Ty(), |
| 152 | Position: F->begin()->getFirstInsertionPt()); |
| 153 | |
| 154 | // CodeGen output needs to be stable. Using the set as-is would order |
| 155 | // the targets differently depending on the allocation pattern. |
| 156 | // Sorting per basic-block ordering in the function. |
| 157 | std::vector<BasicBlock *> SortedExitTargets; |
| 158 | std::vector<BasicBlock *> SortedExits; |
| 159 | for (BasicBlock &BB : *F) { |
| 160 | if (ExitTargets.count(Ptr: &BB) != 0) |
| 161 | SortedExitTargets.push_back(x: &BB); |
| 162 | if (CR->Exits.count(Ptr: &BB) != 0) |
| 163 | SortedExits.push_back(x: &BB); |
| 164 | } |
| 165 | |
| 166 | // Creating one constant per distinct exit target. This will be route to the |
| 167 | // correct target. |
| 168 | DenseMap<BasicBlock *, ConstantInt *> TargetToValue; |
| 169 | for (BasicBlock *Target : SortedExitTargets) |
| 170 | TargetToValue.insert( |
| 171 | KV: std::make_pair(x&: Target, y: Builder.getInt32(C: TargetToValue.size()))); |
| 172 | |
| 173 | // Creating one variable per exit node, set to the constant matching the |
| 174 | // targeted external block. |
| 175 | std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable; |
| 176 | for (auto Exit : SortedExits) { |
| 177 | llvm::Value *Value = createExitVariable(BB: Exit, TargetToValue); |
| 178 | IRBuilder<> B2(Exit); |
| 179 | B2.SetInsertPoint(Exit->getFirstInsertionPt()); |
| 180 | B2.CreateStore(Val: Value, Ptr: Variable); |
| 181 | ExitToVariable.emplace_back(args: std::make_pair(x&: Exit, y&: Value)); |
| 182 | } |
| 183 | |
| 184 | llvm::Value *Load = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: Variable); |
| 185 | |
| 186 | // Creating the switch to jump to the correct exit target. |
| 187 | llvm::SwitchInst *Sw = Builder.CreateSwitch(V: Load, Dest: SortedExitTargets[0], |
| 188 | NumCases: SortedExitTargets.size() - 1); |
| 189 | for (size_t i = 1; i < SortedExitTargets.size(); i++) { |
| 190 | BasicBlock *BB = SortedExitTargets[i]; |
| 191 | Sw->addCase(OnVal: TargetToValue[BB], Dest: BB); |
| 192 | } |
| 193 | |
| 194 | // Fix exit branches to redirect to the new exit. |
| 195 | for (auto Exit : CR->Exits) |
| 196 | replaceBranchTargets(BB: Exit, ToReplace: ExitTargets, NewTarget: NewExitTarget); |
| 197 | |
| 198 | CR = CR->Parent; |
| 199 | while (CR) { |
| 200 | CR->Blocks.insert(Ptr: NewExitTarget); |
| 201 | CR = CR->Parent; |
| 202 | } |
| 203 | |
| 204 | return true; |
| 205 | } |
| 206 | |
| 207 | /// Run the pass on the given convergence region and sub-regions (DFS). |
| 208 | /// Returns true if a region/sub-region was modified, false otherwise. |
| 209 | /// This returns as soon as one region/sub-region has been modified. |
| 210 | bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) { |
| 211 | for (auto *Child : CR->Children) |
| 212 | if (runOnConvergenceRegion(LI, CR: Child)) |
| 213 | return true; |
| 214 | |
| 215 | return runOnConvergenceRegionNoRecurse(LI, CR); |
| 216 | } |
| 217 | |
| 218 | #if !NDEBUG |
| 219 | /// Validates each edge exiting the region has the same destination basic |
| 220 | /// block. |
| 221 | void validateRegionExits(const SPIRV::ConvergenceRegion *CR) { |
| 222 | for (auto *Child : CR->Children) |
| 223 | validateRegionExits(Child); |
| 224 | |
| 225 | std::unordered_set<BasicBlock *> ExitTargets; |
| 226 | for (auto *Exit : CR->Exits) { |
| 227 | auto Set = gatherSuccessors(Exit); |
| 228 | for (auto *BB : Set) { |
| 229 | if (CR->Blocks.count(BB) == 0) |
| 230 | ExitTargets.insert(BB); |
| 231 | } |
| 232 | } |
| 233 | |
| 234 | assert(ExitTargets.size() <= 1); |
| 235 | } |
| 236 | #endif |
| 237 | |
| 238 | virtual bool runOnFunction(Function &F) override { |
| 239 | LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
| 240 | auto *TopLevelRegion = |
| 241 | getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>() |
| 242 | .getRegionInfo() |
| 243 | .getWritableTopLevelRegion(); |
| 244 | |
| 245 | // FIXME: very inefficient method: each time a region is modified, we bubble |
| 246 | // back up, and recompute the whole convergence region tree. Once the |
| 247 | // algorithm is completed and test coverage good enough, rewrite this pass |
| 248 | // to be efficient instead of simple. |
| 249 | bool modified = false; |
| 250 | while (runOnConvergenceRegion(LI, CR: TopLevelRegion)) { |
| 251 | modified = true; |
| 252 | } |
| 253 | |
| 254 | #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS) |
| 255 | validateRegionExits(TopLevelRegion); |
| 256 | #endif |
| 257 | return modified; |
| 258 | } |
| 259 | |
| 260 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 261 | AU.addRequired<DominatorTreeWrapperPass>(); |
| 262 | AU.addRequired<LoopInfoWrapperPass>(); |
| 263 | AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>(); |
| 264 | |
| 265 | AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>(); |
| 266 | FunctionPass::getAnalysisUsage(AU); |
| 267 | } |
| 268 | }; |
| 269 | } // namespace |
| 270 | |
| 271 | char SPIRVMergeRegionExitTargets::ID = 0; |
| 272 | |
| 273 | INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks" , |
| 274 | "SPIRV split region exit blocks" , false, false) |
| 275 | INITIALIZE_PASS_DEPENDENCY(LoopSimplify) |
| 276 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
| 277 | INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) |
| 278 | INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass) |
| 279 | |
| 280 | INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks" , |
| 281 | "SPIRV split region exit blocks" , false, false) |
| 282 | |
| 283 | FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() { |
| 284 | return new SPIRVMergeRegionExitTargets(); |
| 285 | } |
| 286 | |