| 1 | //===- SuspendCrossingInfo.cpp - Utility for suspend crossing values ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // The SuspendCrossingInfo maintains data that allows to answer a question |
| 9 | // whether given two BasicBlocks A and B there is a path from A to B that |
| 10 | // passes through a suspend point. Note, SuspendCrossingInfo is invalidated |
| 11 | // by changes to the CFG including adding/removing BBs due to its use of BB |
| 12 | // ptrs in the BlockToIndexMapping. |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "llvm/Transforms/Coroutines/SuspendCrossingInfo.h" |
| 16 | #include "llvm/IR/ModuleSlotTracker.h" |
| 17 | |
| 18 | // The "coro-suspend-crossing" flag is very noisy. There is another debug type, |
| 19 | // "coro-frame", which results in leaner debug spew. |
| 20 | #define DEBUG_TYPE "coro-suspend-crossing" |
| 21 | |
| 22 | namespace llvm { |
| 23 | #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
| 24 | static void dumpBasicBlockLabel(const BasicBlock *BB, ModuleSlotTracker &MST) { |
| 25 | if (BB->hasName()) { |
| 26 | dbgs() << BB->getName(); |
| 27 | return; |
| 28 | } |
| 29 | |
| 30 | dbgs() << MST.getLocalSlot(BB); |
| 31 | } |
| 32 | |
| 33 | LLVM_DUMP_METHOD void |
| 34 | SuspendCrossingInfo::dump(StringRef Label, BitVector const &BV, |
| 35 | const ReversePostOrderTraversal<Function *> &RPOT, |
| 36 | ModuleSlotTracker &MST) const { |
| 37 | dbgs() << Label << ":" ; |
| 38 | for (const BasicBlock *BB : RPOT) { |
| 39 | auto BBNo = Mapping.blockToIndex(BB); |
| 40 | if (BV[BBNo]) { |
| 41 | dbgs() << " " ; |
| 42 | dumpBasicBlockLabel(BB, MST); |
| 43 | } |
| 44 | } |
| 45 | dbgs() << "\n" ; |
| 46 | } |
| 47 | |
| 48 | LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const { |
| 49 | if (Block.empty()) |
| 50 | return; |
| 51 | |
| 52 | BasicBlock *const B = Mapping.indexToBlock(0); |
| 53 | Function *F = B->getParent(); |
| 54 | |
| 55 | ModuleSlotTracker MST(F->getParent()); |
| 56 | MST.incorporateFunction(*F); |
| 57 | |
| 58 | ReversePostOrderTraversal<Function *> RPOT(F); |
| 59 | for (const BasicBlock *BB : RPOT) { |
| 60 | auto BBNo = Mapping.blockToIndex(BB); |
| 61 | dumpBasicBlockLabel(BB, MST); |
| 62 | dbgs() << ":\n" ; |
| 63 | dump(" Consumes" , Block[BBNo].Consumes, RPOT, MST); |
| 64 | dump(" Kills" , Block[BBNo].Kills, RPOT, MST); |
| 65 | } |
| 66 | dbgs() << "\n" ; |
| 67 | } |
| 68 | #endif |
| 69 | |
| 70 | bool SuspendCrossingInfo::hasPathCrossingSuspendPoint(BasicBlock *From, |
| 71 | BasicBlock *To) const { |
| 72 | size_t const FromIndex = Mapping.blockToIndex(BB: From); |
| 73 | size_t const ToIndex = Mapping.blockToIndex(BB: To); |
| 74 | bool const Result = Block[ToIndex].Kills[FromIndex]; |
| 75 | LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName() |
| 76 | << " crosses suspend point\n" ); |
| 77 | return Result; |
| 78 | } |
| 79 | |
| 80 | bool SuspendCrossingInfo::hasPathOrLoopCrossingSuspendPoint( |
| 81 | BasicBlock *From, BasicBlock *To) const { |
| 82 | size_t const FromIndex = Mapping.blockToIndex(BB: From); |
| 83 | size_t const ToIndex = Mapping.blockToIndex(BB: To); |
| 84 | bool Result = Block[ToIndex].Kills[FromIndex] || |
| 85 | (From == To && Block[ToIndex].KillLoop); |
| 86 | LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName() |
| 87 | << " crosses suspend point (path or loop)\n" ); |
| 88 | return Result; |
| 89 | } |
| 90 | |
| 91 | template <bool Initialize> |
| 92 | bool SuspendCrossingInfo::computeBlockData( |
| 93 | const ReversePostOrderTraversal<Function *> &RPOT) { |
| 94 | bool Changed = false; |
| 95 | |
| 96 | for (const BasicBlock *BB : RPOT) { |
| 97 | auto BBNo = Mapping.blockToIndex(BB); |
| 98 | auto &B = Block[BBNo]; |
| 99 | |
| 100 | // We don't need to count the predecessors when initialization. |
| 101 | if constexpr (!Initialize) |
| 102 | // If all the predecessors of the current Block don't change, |
| 103 | // the BlockData for the current block must not change too. |
| 104 | if (all_of(predecessors(BD: B), [this](BasicBlock *BB) { |
| 105 | return !Block[Mapping.blockToIndex(BB)].Changed; |
| 106 | })) { |
| 107 | B.Changed = false; |
| 108 | continue; |
| 109 | } |
| 110 | |
| 111 | // Saved Consumes and Kills bitsets so that it is easy to see |
| 112 | // if anything changed after propagation. |
| 113 | auto SavedConsumes = B.Consumes; |
| 114 | auto SavedKills = B.Kills; |
| 115 | |
| 116 | for (BasicBlock *PI : predecessors(BD: B)) { |
| 117 | auto PrevNo = Mapping.blockToIndex(BB: PI); |
| 118 | auto &P = Block[PrevNo]; |
| 119 | |
| 120 | // Propagate Kills and Consumes from predecessors into B. |
| 121 | B.Consumes |= P.Consumes; |
| 122 | B.Kills |= P.Kills; |
| 123 | |
| 124 | // If block P is a suspend block, it should propagate kills into block |
| 125 | // B for every block P consumes. |
| 126 | if (P.Suspend) |
| 127 | B.Kills |= P.Consumes; |
| 128 | } |
| 129 | |
| 130 | if (B.Suspend) { |
| 131 | // If block B is a suspend block, it should kill all of the blocks it |
| 132 | // consumes. |
| 133 | B.Kills |= B.Consumes; |
| 134 | } else if (B.End) { |
| 135 | // If block B is an end block, it should not propagate kills as the |
| 136 | // blocks following coro.end() are reached during initial invocation |
| 137 | // of the coroutine while all the data are still available on the |
| 138 | // stack or in the registers. |
| 139 | B.Kills.reset(); |
| 140 | } else { |
| 141 | // This is reached when B block it not Suspend nor coro.end and it |
| 142 | // need to make sure that it is not in the kill set. |
| 143 | B.KillLoop |= B.Kills[BBNo]; |
| 144 | B.Kills.reset(Idx: BBNo); |
| 145 | } |
| 146 | |
| 147 | if constexpr (!Initialize) { |
| 148 | B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes); |
| 149 | Changed |= B.Changed; |
| 150 | } |
| 151 | } |
| 152 | |
| 153 | return Changed; |
| 154 | } |
| 155 | |
| 156 | SuspendCrossingInfo::SuspendCrossingInfo( |
| 157 | Function &F, const SmallVectorImpl<AnyCoroSuspendInst *> &CoroSuspends, |
| 158 | const SmallVectorImpl<AnyCoroEndInst *> &CoroEnds) |
| 159 | : Mapping(F) { |
| 160 | const size_t N = Mapping.size(); |
| 161 | Block.resize(N); |
| 162 | |
| 163 | // Initialize every block so that it consumes itself |
| 164 | for (size_t I = 0; I < N; ++I) { |
| 165 | auto &B = Block[I]; |
| 166 | B.Consumes.resize(N); |
| 167 | B.Kills.resize(N); |
| 168 | B.Consumes.set(I); |
| 169 | B.Changed = true; |
| 170 | } |
| 171 | |
| 172 | // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as |
| 173 | // the code beyond coro.end is reachable during initial invocation of the |
| 174 | // coroutine. |
| 175 | for (auto *CE : CoroEnds) { |
| 176 | // Verify CoroEnd was normalized |
| 177 | assert(CE->getParent()->getFirstInsertionPt() == CE->getIterator() && |
| 178 | CE->getParent()->size() <= 2 && "CoroEnd must be in its own BB" ); |
| 179 | |
| 180 | getBlockData(BB: CE->getParent()).End = true; |
| 181 | } |
| 182 | |
| 183 | // Mark all suspend blocks and indicate that they kill everything they |
| 184 | // consume. Note, that crossing coro.save also requires a spill, as any code |
| 185 | // between coro.save and coro.suspend may resume the coroutine and all of the |
| 186 | // state needs to be saved by that time. |
| 187 | auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) { |
| 188 | BasicBlock *SuspendBlock = BarrierInst->getParent(); |
| 189 | auto &B = getBlockData(BB: SuspendBlock); |
| 190 | B.Suspend = true; |
| 191 | B.Kills |= B.Consumes; |
| 192 | }; |
| 193 | for (auto *CSI : CoroSuspends) { |
| 194 | // Verify CoroSuspend was normalized |
| 195 | assert(CSI->getParent()->getFirstInsertionPt() == CSI->getIterator() && |
| 196 | CSI->getParent()->size() <= 2 && |
| 197 | "CoroSuspend must be in its own BB" ); |
| 198 | |
| 199 | markSuspendBlock(CSI); |
| 200 | if (auto *Save = CSI->getCoroSave()) |
| 201 | markSuspendBlock(Save); |
| 202 | } |
| 203 | |
| 204 | // It is considered to be faster to use RPO traversal for forward-edges |
| 205 | // dataflow analysis. |
| 206 | ReversePostOrderTraversal<Function *> RPOT(&F); |
| 207 | computeBlockData</*Initialize=*/true>(RPOT); |
| 208 | while (computeBlockData</*Initialize*/ false>(RPOT)) |
| 209 | ; |
| 210 | |
| 211 | LLVM_DEBUG(dump()); |
| 212 | } |
| 213 | |
| 214 | } // namespace llvm |
| 215 | |