| 1 | //===----------------------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // The analysis determines the convergence region for each basic block of |
| 10 | // the module, and provides a tree-like structure describing the region |
| 11 | // hierarchy. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "SPIRVConvergenceRegionAnalysis.h" |
| 16 | #include "SPIRV.h" |
| 17 | #include "llvm/Analysis/LoopInfo.h" |
| 18 | #include "llvm/IR/Dominators.h" |
| 19 | #include "llvm/IR/IntrinsicInst.h" |
| 20 | #include "llvm/InitializePasses.h" |
| 21 | #include "llvm/Transforms/Utils/LoopSimplify.h" |
| 22 | #include <optional> |
| 23 | #include <queue> |
| 24 | |
| 25 | #define DEBUG_TYPE "spirv-convergence-region-analysis" |
| 26 | |
| 27 | using namespace llvm; |
| 28 | using namespace SPIRV; |
| 29 | |
| 30 | INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass, |
| 31 | "convergence-region" , |
| 32 | "SPIRV convergence regions analysis" , true, true) |
| 33 | INITIALIZE_PASS_DEPENDENCY(LoopSimplify) |
| 34 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
| 35 | INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) |
| 36 | INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass, |
| 37 | "convergence-region" , "SPIRV convergence regions analysis" , |
| 38 | true, true) |
| 39 | |
| 40 | namespace { |
| 41 | |
| 42 | template <typename BasicBlockType, typename IntrinsicInstType> |
| 43 | std::optional<IntrinsicInstType *> |
| 44 | getConvergenceTokenInternal(BasicBlockType *BB) { |
| 45 | static_assert(std::is_const_v<IntrinsicInstType> == |
| 46 | std::is_const_v<BasicBlockType>, |
| 47 | "Constness must match between input and output." ); |
| 48 | static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>, |
| 49 | "Input must be a basic block." ); |
| 50 | static_assert( |
| 51 | std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>, |
| 52 | "Output type must be an intrinsic instruction." ); |
| 53 | |
| 54 | for (auto &I : *BB) { |
| 55 | if (auto *CI = dyn_cast<ConvergenceControlInst>(&I)) { |
| 56 | // Make sure that the anchor or entry intrinsics did not reach here with a |
| 57 | // parent token. This should have failed the verifier. |
| 58 | assert(CI->isLoop() || |
| 59 | !CI->getOperandBundle(LLVMContext::OB_convergencectrl)); |
| 60 | return CI; |
| 61 | } |
| 62 | |
| 63 | if (auto *CI = dyn_cast<CallInst>(&I)) { |
| 64 | auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl); |
| 65 | if (!OB.has_value()) |
| 66 | continue; |
| 67 | return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]); |
| 68 | } |
| 69 | } |
| 70 | |
| 71 | return std::nullopt; |
| 72 | } |
| 73 | } // anonymous namespace |
| 74 | |
| 75 | // Given a ConvergenceRegion tree with |Start| as its root, finds the smallest |
| 76 | // region |Entry| belongs to. If |Entry| does not belong to the region defined |
| 77 | // by |Start|, this function returns |nullptr|. |
| 78 | static ConvergenceRegion *findParentRegion(ConvergenceRegion *Start, |
| 79 | BasicBlock *Entry) { |
| 80 | ConvergenceRegion *Candidate = nullptr; |
| 81 | ConvergenceRegion *NextCandidate = Start; |
| 82 | |
| 83 | while (Candidate != NextCandidate && NextCandidate != nullptr) { |
| 84 | Candidate = NextCandidate; |
| 85 | NextCandidate = nullptr; |
| 86 | |
| 87 | // End of the search, we can return. |
| 88 | if (Candidate->Children.size() == 0) |
| 89 | return Candidate; |
| 90 | |
| 91 | for (auto *Child : Candidate->Children) { |
| 92 | if (Child->Blocks.count(Ptr: Entry) != 0) { |
| 93 | NextCandidate = Child; |
| 94 | break; |
| 95 | } |
| 96 | } |
| 97 | } |
| 98 | |
| 99 | return Candidate; |
| 100 | } |
| 101 | |
| 102 | std::optional<IntrinsicInst *> |
| 103 | llvm::SPIRV::getConvergenceToken(BasicBlock *BB) { |
| 104 | return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB); |
| 105 | } |
| 106 | |
| 107 | std::optional<const IntrinsicInst *> |
| 108 | llvm::SPIRV::getConvergenceToken(const BasicBlock *BB) { |
| 109 | return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB); |
| 110 | } |
| 111 | |
| 112 | ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, |
| 113 | Function &F) |
| 114 | : DT(DT), LI(LI), Parent(nullptr) { |
| 115 | Entry = &F.getEntryBlock(); |
| 116 | ConvergenceToken = getConvergenceToken(BB: Entry); |
| 117 | for (auto &B : F) { |
| 118 | Blocks.insert(Ptr: &B); |
| 119 | if (isa<ReturnInst>(Val: B.getTerminator())) |
| 120 | Exits.insert(Ptr: &B); |
| 121 | } |
| 122 | } |
| 123 | |
| 124 | ConvergenceRegion::ConvergenceRegion( |
| 125 | DominatorTree &DT, LoopInfo &LI, |
| 126 | std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry, |
| 127 | SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits) |
| 128 | : DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry), |
| 129 | Exits(std::move(Exits)), Blocks(std::move(Blocks)) { |
| 130 | for ([[maybe_unused]] auto *BB : this->Exits) |
| 131 | assert(this->Blocks.count(BB) != 0); |
| 132 | assert(this->Blocks.count(this->Entry) != 0); |
| 133 | } |
| 134 | |
| 135 | void ConvergenceRegion::releaseMemory() { |
| 136 | // Parent memory is owned by the parent. |
| 137 | Parent = nullptr; |
| 138 | for (auto *Child : Children) { |
| 139 | Child->releaseMemory(); |
| 140 | delete Child; |
| 141 | } |
| 142 | Children.resize(N: 0); |
| 143 | } |
| 144 | |
| 145 | void ConvergenceRegion::dump(const unsigned IndentSize) const { |
| 146 | const std::string Indent(IndentSize, '\t'); |
| 147 | dbgs() << Indent << this << ": {\n" ; |
| 148 | dbgs() << Indent << " Parent: " << Parent << "\n" ; |
| 149 | |
| 150 | if (ConvergenceToken.value_or(u: nullptr)) { |
| 151 | dbgs() << Indent |
| 152 | << " ConvergenceToken: " << ConvergenceToken.value()->getName() |
| 153 | << "\n" ; |
| 154 | } |
| 155 | |
| 156 | if (Entry->getName() != "" ) |
| 157 | dbgs() << Indent << " Entry: " << Entry->getName() << "\n" ; |
| 158 | else |
| 159 | dbgs() << Indent << " Entry: " << Entry << "\n" ; |
| 160 | |
| 161 | dbgs() << Indent << " Exits: { " ; |
| 162 | for (const auto &Exit : Exits) { |
| 163 | if (Exit->getName() != "" ) |
| 164 | dbgs() << Exit->getName() << ", " ; |
| 165 | else |
| 166 | dbgs() << Exit << ", " ; |
| 167 | } |
| 168 | dbgs() << " }\n" ; |
| 169 | |
| 170 | dbgs() << Indent << " Blocks: { " ; |
| 171 | for (const auto &Block : Blocks) { |
| 172 | if (Block->getName() != "" ) |
| 173 | dbgs() << Block->getName() << ", " ; |
| 174 | else |
| 175 | dbgs() << Block << ", " ; |
| 176 | } |
| 177 | dbgs() << " }\n" ; |
| 178 | |
| 179 | dbgs() << Indent << " Children: {\n" ; |
| 180 | for (const auto Child : Children) |
| 181 | Child->dump(IndentSize: IndentSize + 2); |
| 182 | dbgs() << Indent << " }\n" ; |
| 183 | |
| 184 | dbgs() << Indent << "}\n" ; |
| 185 | } |
| 186 | |
| 187 | namespace { |
| 188 | class ConvergenceRegionAnalyzer { |
| 189 | public: |
| 190 | ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI) |
| 191 | : DT(DT), LI(LI), F(F) {} |
| 192 | |
| 193 | private: |
| 194 | bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const { |
| 195 | if (From == To) |
| 196 | return true; |
| 197 | |
| 198 | // We only handle loop in the simplified form. This means: |
| 199 | // - a single back-edge, a single latch. |
| 200 | // - meaning the back-edge target can only be the loop header. |
| 201 | // - meaning the From can only be the loop latch. |
| 202 | if (!LI.isLoopHeader(BB: To)) |
| 203 | return false; |
| 204 | |
| 205 | auto *L = LI.getLoopFor(BB: To); |
| 206 | if (L->contains(BB: From) && L->isLoopLatch(BB: From)) |
| 207 | return true; |
| 208 | |
| 209 | return false; |
| 210 | } |
| 211 | |
| 212 | std::unordered_set<BasicBlock *> |
| 213 | findPathsToMatch(LoopInfo &LI, BasicBlock *From, |
| 214 | std::function<bool(const BasicBlock *)> isMatch) const { |
| 215 | std::unordered_set<BasicBlock *> Output; |
| 216 | |
| 217 | if (isMatch(From)) |
| 218 | Output.insert(x: From); |
| 219 | |
| 220 | auto *Terminator = From->getTerminator(); |
| 221 | for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) { |
| 222 | auto *To = Terminator->getSuccessor(Idx: i); |
| 223 | // Ignore back edges. |
| 224 | if (isBackEdge(From, To)) |
| 225 | continue; |
| 226 | |
| 227 | auto ChildSet = findPathsToMatch(LI, From: To, isMatch); |
| 228 | if (ChildSet.size() == 0) |
| 229 | continue; |
| 230 | |
| 231 | Output.insert(first: ChildSet.begin(), last: ChildSet.end()); |
| 232 | Output.insert(x: From); |
| 233 | if (LI.isLoopHeader(BB: From)) { |
| 234 | auto *L = LI.getLoopFor(BB: From); |
| 235 | for (auto *BB : L->getBlocks()) { |
| 236 | Output.insert(x: BB); |
| 237 | } |
| 238 | } |
| 239 | } |
| 240 | |
| 241 | return Output; |
| 242 | } |
| 243 | |
| 244 | SmallPtrSet<BasicBlock *, 2> |
| 245 | findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) { |
| 246 | SmallPtrSet<BasicBlock *, 2> Exits; |
| 247 | |
| 248 | for (auto *B : RegionBlocks) { |
| 249 | auto *Terminator = B->getTerminator(); |
| 250 | for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) { |
| 251 | auto *Child = Terminator->getSuccessor(Idx: i); |
| 252 | if (RegionBlocks.count(Ptr: Child) == 0) |
| 253 | Exits.insert(Ptr: B); |
| 254 | } |
| 255 | } |
| 256 | |
| 257 | return Exits; |
| 258 | } |
| 259 | |
| 260 | public: |
| 261 | ConvergenceRegionInfo analyze() { |
| 262 | ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F); |
| 263 | std::queue<Loop *> ToProcess; |
| 264 | for (auto *L : LI.getLoopsInPreorder()) |
| 265 | ToProcess.push(x: L); |
| 266 | |
| 267 | while (ToProcess.size() != 0) { |
| 268 | auto *L = ToProcess.front(); |
| 269 | ToProcess.pop(); |
| 270 | |
| 271 | auto CT = getConvergenceToken(BB: L->getHeader()); |
| 272 | SmallPtrSet<BasicBlock *, 8> RegionBlocks(llvm::from_range, L->blocks()); |
| 273 | SmallVector<BasicBlock *> LoopExits; |
| 274 | L->getExitingBlocks(ExitingBlocks&: LoopExits); |
| 275 | if (CT.has_value()) { |
| 276 | for (auto *Exit : LoopExits) { |
| 277 | auto N = findPathsToMatch(LI, From: Exit, isMatch: [&CT](const BasicBlock *block) { |
| 278 | auto Token = getConvergenceToken(BB: block); |
| 279 | if (Token == std::nullopt) |
| 280 | return false; |
| 281 | return Token.value() == CT.value(); |
| 282 | }); |
| 283 | RegionBlocks.insert_range(R&: N); |
| 284 | } |
| 285 | } |
| 286 | |
| 287 | auto RegionExits = findExitNodes(RegionBlocks); |
| 288 | ConvergenceRegion *Region = new ConvergenceRegion( |
| 289 | DT, LI, CT, L->getHeader(), std::move(RegionBlocks), |
| 290 | std::move(RegionExits)); |
| 291 | Region->Parent = findParentRegion(Start: TopLevelRegion, Entry: Region->Entry); |
| 292 | assert(Region->Parent != nullptr && "This is impossible." ); |
| 293 | Region->Parent->Children.push_back(Elt: Region); |
| 294 | } |
| 295 | |
| 296 | return ConvergenceRegionInfo(TopLevelRegion); |
| 297 | } |
| 298 | |
| 299 | private: |
| 300 | DominatorTree &DT; |
| 301 | LoopInfo &LI; |
| 302 | Function &F; |
| 303 | }; |
| 304 | } // anonymous namespace |
| 305 | |
| 306 | ConvergenceRegionInfo llvm::SPIRV::getConvergenceRegions(Function &F, |
| 307 | DominatorTree &DT, |
| 308 | LoopInfo &LI) { |
| 309 | ConvergenceRegionAnalyzer Analyzer(F, DT, LI); |
| 310 | return Analyzer.analyze(); |
| 311 | } |
| 312 | |
| 313 | char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0; |
| 314 | |
| 315 | SPIRVConvergenceRegionAnalysisWrapperPass:: |
| 316 | SPIRVConvergenceRegionAnalysisWrapperPass() |
| 317 | : FunctionPass(ID) {} |
| 318 | |
| 319 | bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) { |
| 320 | DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); |
| 321 | LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
| 322 | |
| 323 | CRI = SPIRV::getConvergenceRegions(F, DT, LI); |
| 324 | // Nothing was modified. |
| 325 | return false; |
| 326 | } |
| 327 | |
| 328 | SPIRVConvergenceRegionAnalysis::Result |
| 329 | SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) { |
| 330 | Result CRI; |
| 331 | auto &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F); |
| 332 | auto &LI = AM.getResult<LoopAnalysis>(IR&: F); |
| 333 | CRI = SPIRV::getConvergenceRegions(F, DT, LI); |
| 334 | return CRI; |
| 335 | } |
| 336 | |
| 337 | AnalysisKey SPIRVConvergenceRegionAnalysis::Key; |
| 338 | |