1 | //===- SuspendCrossingInfo.cpp - Utility for suspend crossing values ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // The SuspendCrossingInfo maintains data that allows to answer a question |
9 | // whether given two BasicBlocks A and B there is a path from A to B that |
10 | // passes through a suspend point. Note, SuspendCrossingInfo is invalidated |
11 | // by changes to the CFG including adding/removing BBs due to its use of BB |
12 | // ptrs in the BlockToIndexMapping. |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "llvm/Transforms/Coroutines/SuspendCrossingInfo.h" |
16 | #include "llvm/IR/ModuleSlotTracker.h" |
17 | |
18 | // The "coro-suspend-crossing" flag is very noisy. There is another debug type, |
19 | // "coro-frame", which results in leaner debug spew. |
20 | #define DEBUG_TYPE "coro-suspend-crossing" |
21 | |
22 | namespace llvm { |
23 | #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
24 | static void dumpBasicBlockLabel(const BasicBlock *BB, ModuleSlotTracker &MST) { |
25 | if (BB->hasName()) { |
26 | dbgs() << BB->getName(); |
27 | return; |
28 | } |
29 | |
30 | dbgs() << MST.getLocalSlot(BB); |
31 | } |
32 | |
33 | LLVM_DUMP_METHOD void |
34 | SuspendCrossingInfo::dump(StringRef Label, BitVector const &BV, |
35 | const ReversePostOrderTraversal<Function *> &RPOT, |
36 | ModuleSlotTracker &MST) const { |
37 | dbgs() << Label << ":" ; |
38 | for (const BasicBlock *BB : RPOT) { |
39 | auto BBNo = Mapping.blockToIndex(BB); |
40 | if (BV[BBNo]) { |
41 | dbgs() << " " ; |
42 | dumpBasicBlockLabel(BB, MST); |
43 | } |
44 | } |
45 | dbgs() << "\n" ; |
46 | } |
47 | |
48 | LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const { |
49 | if (Block.empty()) |
50 | return; |
51 | |
52 | BasicBlock *const B = Mapping.indexToBlock(0); |
53 | Function *F = B->getParent(); |
54 | |
55 | ModuleSlotTracker MST(F->getParent()); |
56 | MST.incorporateFunction(*F); |
57 | |
58 | ReversePostOrderTraversal<Function *> RPOT(F); |
59 | for (const BasicBlock *BB : RPOT) { |
60 | auto BBNo = Mapping.blockToIndex(BB); |
61 | dumpBasicBlockLabel(BB, MST); |
62 | dbgs() << ":\n" ; |
63 | dump(" Consumes" , Block[BBNo].Consumes, RPOT, MST); |
64 | dump(" Kills" , Block[BBNo].Kills, RPOT, MST); |
65 | } |
66 | dbgs() << "\n" ; |
67 | } |
68 | #endif |
69 | |
70 | bool SuspendCrossingInfo::hasPathCrossingSuspendPoint(BasicBlock *From, |
71 | BasicBlock *To) const { |
72 | size_t const FromIndex = Mapping.blockToIndex(BB: From); |
73 | size_t const ToIndex = Mapping.blockToIndex(BB: To); |
74 | bool const Result = Block[ToIndex].Kills[FromIndex]; |
75 | LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName() |
76 | << " crosses suspend point\n" ); |
77 | return Result; |
78 | } |
79 | |
80 | bool SuspendCrossingInfo::hasPathOrLoopCrossingSuspendPoint( |
81 | BasicBlock *From, BasicBlock *To) const { |
82 | size_t const FromIndex = Mapping.blockToIndex(BB: From); |
83 | size_t const ToIndex = Mapping.blockToIndex(BB: To); |
84 | bool Result = Block[ToIndex].Kills[FromIndex] || |
85 | (From == To && Block[ToIndex].KillLoop); |
86 | LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName() |
87 | << " crosses suspend point (path or loop)\n" ); |
88 | return Result; |
89 | } |
90 | |
91 | template <bool Initialize> |
92 | bool SuspendCrossingInfo::computeBlockData( |
93 | const ReversePostOrderTraversal<Function *> &RPOT) { |
94 | bool Changed = false; |
95 | |
96 | for (const BasicBlock *BB : RPOT) { |
97 | auto BBNo = Mapping.blockToIndex(BB); |
98 | auto &B = Block[BBNo]; |
99 | |
100 | // We don't need to count the predecessors when initialization. |
101 | if constexpr (!Initialize) |
102 | // If all the predecessors of the current Block don't change, |
103 | // the BlockData for the current block must not change too. |
104 | if (all_of(predecessors(BD: B), [this](BasicBlock *BB) { |
105 | return !Block[Mapping.blockToIndex(BB)].Changed; |
106 | })) { |
107 | B.Changed = false; |
108 | continue; |
109 | } |
110 | |
111 | // Saved Consumes and Kills bitsets so that it is easy to see |
112 | // if anything changed after propagation. |
113 | auto SavedConsumes = B.Consumes; |
114 | auto SavedKills = B.Kills; |
115 | |
116 | for (BasicBlock *PI : predecessors(BD: B)) { |
117 | auto PrevNo = Mapping.blockToIndex(BB: PI); |
118 | auto &P = Block[PrevNo]; |
119 | |
120 | // Propagate Kills and Consumes from predecessors into B. |
121 | B.Consumes |= P.Consumes; |
122 | B.Kills |= P.Kills; |
123 | |
124 | // If block P is a suspend block, it should propagate kills into block |
125 | // B for every block P consumes. |
126 | if (P.Suspend) |
127 | B.Kills |= P.Consumes; |
128 | } |
129 | |
130 | if (B.Suspend) { |
131 | // If block B is a suspend block, it should kill all of the blocks it |
132 | // consumes. |
133 | B.Kills |= B.Consumes; |
134 | } else if (B.End) { |
135 | // If block B is an end block, it should not propagate kills as the |
136 | // blocks following coro.end() are reached during initial invocation |
137 | // of the coroutine while all the data are still available on the |
138 | // stack or in the registers. |
139 | B.Kills.reset(); |
140 | } else { |
141 | // This is reached when B block it not Suspend nor coro.end and it |
142 | // need to make sure that it is not in the kill set. |
143 | B.KillLoop |= B.Kills[BBNo]; |
144 | B.Kills.reset(Idx: BBNo); |
145 | } |
146 | |
147 | if constexpr (!Initialize) { |
148 | B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes); |
149 | Changed |= B.Changed; |
150 | } |
151 | } |
152 | |
153 | return Changed; |
154 | } |
155 | |
156 | SuspendCrossingInfo::SuspendCrossingInfo( |
157 | Function &F, const SmallVectorImpl<AnyCoroSuspendInst *> &CoroSuspends, |
158 | const SmallVectorImpl<AnyCoroEndInst *> &CoroEnds) |
159 | : Mapping(F) { |
160 | const size_t N = Mapping.size(); |
161 | Block.resize(N); |
162 | |
163 | // Initialize every block so that it consumes itself |
164 | for (size_t I = 0; I < N; ++I) { |
165 | auto &B = Block[I]; |
166 | B.Consumes.resize(N); |
167 | B.Kills.resize(N); |
168 | B.Consumes.set(I); |
169 | B.Changed = true; |
170 | } |
171 | |
172 | // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as |
173 | // the code beyond coro.end is reachable during initial invocation of the |
174 | // coroutine. |
175 | for (auto *CE : CoroEnds) { |
176 | // Verify CoroEnd was normalized |
177 | assert(CE->getParent()->getFirstInsertionPt() == CE->getIterator() && |
178 | CE->getParent()->size() <= 2 && "CoroEnd must be in its own BB" ); |
179 | |
180 | getBlockData(BB: CE->getParent()).End = true; |
181 | } |
182 | |
183 | // Mark all suspend blocks and indicate that they kill everything they |
184 | // consume. Note, that crossing coro.save also requires a spill, as any code |
185 | // between coro.save and coro.suspend may resume the coroutine and all of the |
186 | // state needs to be saved by that time. |
187 | auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) { |
188 | BasicBlock *SuspendBlock = BarrierInst->getParent(); |
189 | auto &B = getBlockData(BB: SuspendBlock); |
190 | B.Suspend = true; |
191 | B.Kills |= B.Consumes; |
192 | }; |
193 | for (auto *CSI : CoroSuspends) { |
194 | // Verify CoroSuspend was normalized |
195 | assert(CSI->getParent()->getFirstInsertionPt() == CSI->getIterator() && |
196 | CSI->getParent()->size() <= 2 && |
197 | "CoroSuspend must be in its own BB" ); |
198 | |
199 | markSuspendBlock(CSI); |
200 | if (auto *Save = CSI->getCoroSave()) |
201 | markSuspendBlock(Save); |
202 | } |
203 | |
204 | // It is considered to be faster to use RPO traversal for forward-edges |
205 | // dataflow analysis. |
206 | ReversePostOrderTraversal<Function *> RPOT(&F); |
207 | computeBlockData</*Initialize=*/true>(RPOT); |
208 | while (computeBlockData</*Initialize*/ false>(RPOT)) |
209 | ; |
210 | |
211 | LLVM_DEBUG(dump()); |
212 | } |
213 | |
214 | } // namespace llvm |
215 | |