1//===- SuspendCrossingInfo.cpp - Utility for suspend crossing values ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// The SuspendCrossingInfo maintains data that allows to answer a question
9// whether given two BasicBlocks A and B there is a path from A to B that
10// passes through a suspend point. Note, SuspendCrossingInfo is invalidated
11// by changes to the CFG including adding/removing BBs due to its use of BB
12// ptrs in the BlockToIndexMapping.
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Transforms/Coroutines/SuspendCrossingInfo.h"
16#include "llvm/IR/ModuleSlotTracker.h"
17
18// The "coro-suspend-crossing" flag is very noisy. There is another debug type,
19// "coro-frame", which results in leaner debug spew.
20#define DEBUG_TYPE "coro-suspend-crossing"
21
22namespace llvm {
23#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
24static void dumpBasicBlockLabel(const BasicBlock *BB, ModuleSlotTracker &MST) {
25 if (BB->hasName()) {
26 dbgs() << BB->getName();
27 return;
28 }
29
30 dbgs() << MST.getLocalSlot(BB);
31}
32
33LLVM_DUMP_METHOD void
34SuspendCrossingInfo::dump(StringRef Label, BitVector const &BV,
35 const ReversePostOrderTraversal<Function *> &RPOT,
36 ModuleSlotTracker &MST) const {
37 dbgs() << Label << ":";
38 for (const BasicBlock *BB : RPOT) {
39 auto BBNo = Mapping.blockToIndex(BB);
40 if (BV[BBNo]) {
41 dbgs() << " ";
42 dumpBasicBlockLabel(BB, MST);
43 }
44 }
45 dbgs() << "\n";
46}
47
48LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
49 if (Block.empty())
50 return;
51
52 BasicBlock *const B = Mapping.indexToBlock(0);
53 Function *F = B->getParent();
54
55 ModuleSlotTracker MST(F->getParent());
56 MST.incorporateFunction(*F);
57
58 ReversePostOrderTraversal<Function *> RPOT(F);
59 for (const BasicBlock *BB : RPOT) {
60 auto BBNo = Mapping.blockToIndex(BB);
61 dumpBasicBlockLabel(BB, MST);
62 dbgs() << ":\n";
63 dump(" Consumes", Block[BBNo].Consumes, RPOT, MST);
64 dump(" Kills", Block[BBNo].Kills, RPOT, MST);
65 }
66 dbgs() << "\n";
67}
68#endif
69
70bool SuspendCrossingInfo::hasPathCrossingSuspendPoint(BasicBlock *From,
71 BasicBlock *To) const {
72 size_t const FromIndex = Mapping.blockToIndex(BB: From);
73 size_t const ToIndex = Mapping.blockToIndex(BB: To);
74 bool const Result = Block[ToIndex].Kills[FromIndex];
75 LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
76 << " crosses suspend point\n");
77 return Result;
78}
79
80bool SuspendCrossingInfo::hasPathOrLoopCrossingSuspendPoint(
81 BasicBlock *From, BasicBlock *To) const {
82 size_t const FromIndex = Mapping.blockToIndex(BB: From);
83 size_t const ToIndex = Mapping.blockToIndex(BB: To);
84 bool Result = Block[ToIndex].Kills[FromIndex] ||
85 (From == To && Block[ToIndex].KillLoop);
86 LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
87 << " crosses suspend point (path or loop)\n");
88 return Result;
89}
90
91template <bool Initialize>
92bool SuspendCrossingInfo::computeBlockData(
93 const ReversePostOrderTraversal<Function *> &RPOT) {
94 bool Changed = false;
95
96 for (const BasicBlock *BB : RPOT) {
97 auto BBNo = Mapping.blockToIndex(BB);
98 auto &B = Block[BBNo];
99
100 // We don't need to count the predecessors when initialization.
101 if constexpr (!Initialize)
102 // If all the predecessors of the current Block don't change,
103 // the BlockData for the current block must not change too.
104 if (all_of(predecessors(BD: B), [this](BasicBlock *BB) {
105 return !Block[Mapping.blockToIndex(BB)].Changed;
106 })) {
107 B.Changed = false;
108 continue;
109 }
110
111 // Saved Consumes and Kills bitsets so that it is easy to see
112 // if anything changed after propagation.
113 auto SavedConsumes = B.Consumes;
114 auto SavedKills = B.Kills;
115
116 for (BasicBlock *PI : predecessors(BD: B)) {
117 auto PrevNo = Mapping.blockToIndex(BB: PI);
118 auto &P = Block[PrevNo];
119
120 // Propagate Kills and Consumes from predecessors into B.
121 B.Consumes |= P.Consumes;
122 B.Kills |= P.Kills;
123
124 // If block P is a suspend block, it should propagate kills into block
125 // B for every block P consumes.
126 if (P.Suspend)
127 B.Kills |= P.Consumes;
128 }
129
130 if (B.Suspend) {
131 // If block B is a suspend block, it should kill all of the blocks it
132 // consumes.
133 B.Kills |= B.Consumes;
134 } else if (B.End) {
135 // If block B is an end block, it should not propagate kills as the
136 // blocks following coro.end() are reached during initial invocation
137 // of the coroutine while all the data are still available on the
138 // stack or in the registers.
139 B.Kills.reset();
140 } else {
141 // This is reached when B block it not Suspend nor coro.end and it
142 // need to make sure that it is not in the kill set.
143 B.KillLoop |= B.Kills[BBNo];
144 B.Kills.reset(Idx: BBNo);
145 }
146
147 if constexpr (!Initialize) {
148 B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
149 Changed |= B.Changed;
150 }
151 }
152
153 return Changed;
154}
155
156SuspendCrossingInfo::SuspendCrossingInfo(
157 Function &F, const SmallVectorImpl<AnyCoroSuspendInst *> &CoroSuspends,
158 const SmallVectorImpl<AnyCoroEndInst *> &CoroEnds)
159 : Mapping(F) {
160 const size_t N = Mapping.size();
161 Block.resize(N);
162
163 // Initialize every block so that it consumes itself
164 for (size_t I = 0; I < N; ++I) {
165 auto &B = Block[I];
166 B.Consumes.resize(N);
167 B.Kills.resize(N);
168 B.Consumes.set(I);
169 B.Changed = true;
170 }
171
172 // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
173 // the code beyond coro.end is reachable during initial invocation of the
174 // coroutine.
175 for (auto *CE : CoroEnds) {
176 // Verify CoroEnd was normalized
177 assert(CE->getParent()->getFirstInsertionPt() == CE->getIterator() &&
178 CE->getParent()->size() <= 2 && "CoroEnd must be in its own BB");
179
180 getBlockData(BB: CE->getParent()).End = true;
181 }
182
183 // Mark all suspend blocks and indicate that they kill everything they
184 // consume. Note, that crossing coro.save also requires a spill, as any code
185 // between coro.save and coro.suspend may resume the coroutine and all of the
186 // state needs to be saved by that time.
187 auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
188 BasicBlock *SuspendBlock = BarrierInst->getParent();
189 auto &B = getBlockData(BB: SuspendBlock);
190 B.Suspend = true;
191 B.Kills |= B.Consumes;
192 };
193 for (auto *CSI : CoroSuspends) {
194 // Verify CoroSuspend was normalized
195 assert(CSI->getParent()->getFirstInsertionPt() == CSI->getIterator() &&
196 CSI->getParent()->size() <= 2 &&
197 "CoroSuspend must be in its own BB");
198
199 markSuspendBlock(CSI);
200 if (auto *Save = CSI->getCoroSave())
201 markSuspendBlock(Save);
202 }
203
204 // It is considered to be faster to use RPO traversal for forward-edges
205 // dataflow analysis.
206 ReversePostOrderTraversal<Function *> RPOT(&F);
207 computeBlockData</*Initialize=*/true>(RPOT);
208 while (computeBlockData</*Initialize*/ false>(RPOT))
209 ;
210
211 LLVM_DEBUG(dump());
212}
213
214} // namespace llvm
215