1//===- SPIRVConvergenceRegionAnalysis.h ------------------------*- C++ -*--===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The analysis determines the convergence region for each basic block of
10// the module, and provides a tree-like structure describing the region
11// hierarchy.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
16#define LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
17
18#include "llvm/ADT/SmallPtrSet.h"
19#include "llvm/Analysis/CFG.h"
20#include "llvm/Analysis/LoopInfo.h"
21#include "llvm/IR/Dominators.h"
22#include <optional>
23#include <unordered_set>
24
25namespace llvm {
26class IntrinsicInst;
27class SPIRVSubtarget;
28class MachineFunction;
29class MachineModuleInfo;
30
31namespace SPIRV {
32
33// Returns the first convergence intrinsic found in |BB|, |nullopt| otherwise.
34std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB);
35std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB);
36
37// Describes a hierarchy of convergence regions.
38// A convergence region defines a CFG for which the execution flow can diverge
39// starting from the entry block, but should reconverge back before the end of
40// the exit blocks.
41class ConvergenceRegion {
42 DominatorTree &DT;
43 LoopInfo &LI;
44
45public:
46 // The parent region of this region, if any.
47 ConvergenceRegion *Parent = nullptr;
48 // The sub-regions contained in this region, if any.
49 SmallVector<ConvergenceRegion *> Children = {};
50 // The convergence instruction linked to this region, if any.
51 std::optional<IntrinsicInst *> ConvergenceToken = std::nullopt;
52 // The only block with a predecessor outside of this region.
53 BasicBlock *Entry = nullptr;
54 // All the blocks with an edge leaving this convergence region.
55 SmallPtrSet<BasicBlock *, 2> Exits = {};
56 // All the blocks that belongs to this region, including its subregions'.
57 SmallPtrSet<BasicBlock *, 8> Blocks = {};
58
59 // Creates a single convergence region encapsulating the whole function |F|.
60 ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, Function &F);
61
62 // Creates a single convergence region defined by entry and exits nodes, a
63 // list of blocks, and possibly a convergence token.
64 ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
65 std::optional<IntrinsicInst *> ConvergenceToken,
66 BasicBlock *Entry, SmallPtrSet<BasicBlock *, 8> &&Blocks,
67 SmallPtrSet<BasicBlock *, 2> &&Exits);
68
69 ConvergenceRegion(ConvergenceRegion &&CR)
70 : DT(CR.DT), LI(CR.LI), Parent(std::move(CR.Parent)),
71 Children(std::move(CR.Children)),
72 ConvergenceToken(std::move(CR.ConvergenceToken)),
73 Entry(std::move(CR.Entry)), Exits(std::move(CR.Exits)),
74 Blocks(std::move(CR.Blocks)) {}
75
76 ConvergenceRegion(const ConvergenceRegion &other) = delete;
77
78 // Returns true if the given basic block belongs to this region, or to one of
79 // its subregion.
80 bool contains(const BasicBlock *BB) const { return Blocks.count(Ptr: BB) != 0; }
81
82 void releaseMemory();
83
84 // Write to the debug output this region's hierarchy.
85 // |IndentSize| defines the number of tabs to print before any new line.
86 void dump(const unsigned IndentSize = 0) const;
87};
88
89// Holds a ConvergenceRegion hierarchy.
90class ConvergenceRegionInfo {
91 // The convergence region this structure holds.
92 ConvergenceRegion *TopLevelRegion;
93
94public:
95 ConvergenceRegionInfo() : TopLevelRegion(nullptr) {}
96
97 // Creates a new ConvergenceRegionInfo. Ownership of the TopLevelRegion is
98 // passed to this object.
99 ConvergenceRegionInfo(ConvergenceRegion *TopLevelRegion)
100 : TopLevelRegion(TopLevelRegion) {}
101
102 ~ConvergenceRegionInfo() { releaseMemory(); }
103
104 ConvergenceRegionInfo(ConvergenceRegionInfo &&LHS)
105 : TopLevelRegion(LHS.TopLevelRegion) {
106 if (TopLevelRegion != LHS.TopLevelRegion) {
107 releaseMemory();
108 TopLevelRegion = LHS.TopLevelRegion;
109 }
110 LHS.TopLevelRegion = nullptr;
111 }
112
113 ConvergenceRegionInfo &operator=(ConvergenceRegionInfo &&LHS) {
114 if (TopLevelRegion != LHS.TopLevelRegion) {
115 releaseMemory();
116 TopLevelRegion = LHS.TopLevelRegion;
117 }
118 LHS.TopLevelRegion = nullptr;
119 return *this;
120 }
121
122 void releaseMemory() {
123 if (TopLevelRegion == nullptr)
124 return;
125
126 TopLevelRegion->releaseMemory();
127 delete TopLevelRegion;
128 TopLevelRegion = nullptr;
129 }
130
131 const ConvergenceRegion *getTopLevelRegion() const { return TopLevelRegion; }
132 ConvergenceRegion *getWritableTopLevelRegion() const {
133 return TopLevelRegion;
134 }
135};
136
137} // namespace SPIRV
138
139// Wrapper around the function above to use it with the legacy pass manager.
140class SPIRVConvergenceRegionAnalysisWrapperPass : public FunctionPass {
141 SPIRV::ConvergenceRegionInfo CRI;
142
143public:
144 static char ID;
145
146 SPIRVConvergenceRegionAnalysisWrapperPass();
147
148 void getAnalysisUsage(AnalysisUsage &AU) const override {
149 AU.setPreservesAll();
150 AU.addRequired<LoopInfoWrapperPass>();
151 AU.addRequired<DominatorTreeWrapperPass>();
152 };
153
154 bool runOnFunction(Function &F) override;
155
156 SPIRV::ConvergenceRegionInfo &getRegionInfo() { return CRI; }
157 const SPIRV::ConvergenceRegionInfo &getRegionInfo() const { return CRI; }
158};
159
160// Wrapper around the function above to use it with the new pass manager.
161class SPIRVConvergenceRegionAnalysis
162 : public AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis> {
163 friend AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis>;
164 static AnalysisKey Key;
165
166public:
167 using Result = SPIRV::ConvergenceRegionInfo;
168
169 Result run(Function &F, FunctionAnalysisManager &AM);
170};
171
172namespace SPIRV {
173ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
174 LoopInfo &LI);
175} // namespace SPIRV
176
177} // namespace llvm
178#endif // LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
179