1//===- MaterializationUtils.cpp - Builds and manipulates coroutine frame
2//-------------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9// This file contains classes used to materialize insts after suspends points.
10//===----------------------------------------------------------------------===//
11
12#include "llvm/Transforms/Coroutines/MaterializationUtils.h"
13#include "CoroInternal.h"
14#include "llvm/ADT/PostOrderIterator.h"
15#include "llvm/IR/Dominators.h"
16#include "llvm/IR/InstIterator.h"
17#include "llvm/IR/Instruction.h"
18#include "llvm/IR/ModuleSlotTracker.h"
19#include "llvm/Transforms/Coroutines/SpillUtils.h"
20#include <deque>
21
22using namespace llvm;
23
24using namespace coro;
25
26// The "coro-suspend-crossing" flag is very noisy. There is another debug type,
27// "coro-frame", which results in leaner debug spew.
28#define DEBUG_TYPE "coro-suspend-crossing"
29
30namespace {
31
32// RematGraph is used to construct a DAG for rematerializable instructions
33// When the constructor is invoked with a candidate instruction (which is
34// materializable) it builds a DAG of materializable instructions from that
35// point.
36// Typically, for each instruction identified as re-materializable across a
37// suspend point, a RematGraph will be created.
38struct RematGraph {
39 // Each RematNode in the graph contains the edges to instructions providing
40 // operands in the current node.
41 struct RematNode {
42 Instruction *Node;
43 SmallVector<RematNode *> Operands;
44 RematNode() = default;
45 RematNode(Instruction *V) : Node(V) {}
46 };
47
48 RematNode *EntryNode;
49 using RematNodeMap =
50 SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>;
51 RematNodeMap Remats;
52 const std::function<bool(Instruction &)> &MaterializableCallback;
53 SuspendCrossingInfo &Checker;
54
55 RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback,
56 Instruction *I, SuspendCrossingInfo &Checker)
57 : MaterializableCallback(MaterializableCallback), Checker(Checker) {
58 std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(args&: I);
59 EntryNode = FirstNode.get();
60 std::deque<std::unique_ptr<RematNode>> WorkList;
61 addNode(NUPtr: std::move(FirstNode), WorkList, FirstUse: cast<User>(Val: I));
62 while (WorkList.size()) {
63 std::unique_ptr<RematNode> N = std::move(WorkList.front());
64 WorkList.pop_front();
65 addNode(NUPtr: std::move(N), WorkList, FirstUse: cast<User>(Val: I));
66 }
67 }
68
69 void addNode(std::unique_ptr<RematNode> NUPtr,
70 std::deque<std::unique_ptr<RematNode>> &WorkList,
71 User *FirstUse) {
72 RematNode *N = NUPtr.get();
73 auto [It, Inserted] = Remats.try_emplace(Key: N->Node);
74 if (!Inserted)
75 return;
76
77 // We haven't see this node yet - add to the list
78 It->second = std::move(NUPtr);
79 for (auto &Def : N->Node->operands()) {
80 Instruction *D = dyn_cast<Instruction>(Val: Def.get());
81 if (!D || !MaterializableCallback(*D) ||
82 !Checker.isDefinitionAcrossSuspend(I&: *D, U: FirstUse))
83 continue;
84
85 if (auto It = Remats.find(Key: D); It != Remats.end()) {
86 // Already have this in the graph
87 N->Operands.push_back(Elt: It->second.get());
88 continue;
89 }
90
91 bool NoMatch = true;
92 for (auto &I : WorkList) {
93 if (I->Node == D) {
94 NoMatch = false;
95 N->Operands.push_back(Elt: I.get());
96 break;
97 }
98 }
99 if (NoMatch) {
100 // Create a new node
101 std::unique_ptr<RematNode> ChildNode = std::make_unique<RematNode>(args&: D);
102 N->Operands.push_back(Elt: ChildNode.get());
103 WorkList.push_back(x: std::move(ChildNode));
104 }
105 }
106 }
107
108#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
109 static void dumpBasicBlockLabel(const BasicBlock *BB,
110 ModuleSlotTracker &MST) {
111 if (BB->hasName()) {
112 dbgs() << BB->getName();
113 return;
114 }
115
116 dbgs() << MST.getLocalSlot(BB);
117 }
118
119 void dump() const {
120 BasicBlock *BB = EntryNode->Node->getParent();
121 Function *F = BB->getParent();
122
123 ModuleSlotTracker MST(F->getParent());
124 MST.incorporateFunction(*F);
125
126 dbgs() << "Entry (";
127 dumpBasicBlockLabel(BB, MST);
128 dbgs() << ") : " << *EntryNode->Node << "\n";
129 for (auto &E : Remats) {
130 dbgs() << *(E.first) << "\n";
131 for (RematNode *U : E.second->Operands)
132 dbgs() << " " << *U->Node << "\n";
133 }
134 }
135#endif
136};
137
138} // namespace
139
140namespace llvm {
141template <> struct GraphTraits<RematGraph *> {
142 using NodeRef = RematGraph::RematNode *;
143 using ChildIteratorType = RematGraph::RematNode **;
144
145 static NodeRef getEntryNode(RematGraph *G) { return G->EntryNode; }
146 static ChildIteratorType child_begin(NodeRef N) {
147 return N->Operands.begin();
148 }
149 static ChildIteratorType child_end(NodeRef N) { return N->Operands.end(); }
150};
151
152} // end namespace llvm
153
154// For each instruction identified as materializable across the suspend point,
155// and its associated DAG of other rematerializable instructions,
156// recreate the DAG of instructions after the suspend point.
157static void rewriteMaterializableInstructions(
158 const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8>
159 &AllRemats) {
160 // This has to be done in 2 phases
161 // Do the remats and record the required defs to be replaced in the
162 // original use instructions
163 // Once all the remats are complete, replace the uses in the final
164 // instructions with the new defs
165 typedef struct {
166 Instruction *Use;
167 Instruction *Def;
168 Instruction *Remat;
169 } ProcessNode;
170
171 SmallVector<ProcessNode> FinalInstructionsToProcess;
172
173 for (const auto &E : AllRemats) {
174 Instruction *Use = E.first;
175 Instruction *CurrentMaterialization = nullptr;
176 RematGraph *RG = E.second.get();
177 ReversePostOrderTraversal<RematGraph *> RPOT(RG);
178 SmallVector<Instruction *> InstructionsToProcess;
179
180 // If the target use is actually a suspend instruction then we have to
181 // insert the remats into the end of the predecessor (there should only be
182 // one). This is so that suspend blocks always have the suspend instruction
183 // as the first instruction.
184 BasicBlock::iterator InsertPoint = Use->getParent()->getFirstInsertionPt();
185 if (isa<AnyCoroSuspendInst>(Val: Use)) {
186 BasicBlock *SuspendPredecessorBlock =
187 Use->getParent()->getSinglePredecessor();
188 assert(SuspendPredecessorBlock && "malformed coro suspend instruction");
189 InsertPoint = SuspendPredecessorBlock->getTerminator()->getIterator();
190 }
191
192 // Note: skip the first instruction as this is the actual use that we're
193 // rematerializing everything for.
194 auto I = RPOT.begin();
195 ++I;
196 for (; I != RPOT.end(); ++I) {
197 Instruction *D = (*I)->Node;
198 CurrentMaterialization = D->clone();
199 CurrentMaterialization->setName(D->getName());
200 CurrentMaterialization->insertBefore(InsertPos: InsertPoint);
201 InsertPoint = CurrentMaterialization->getIterator();
202
203 // Replace all uses of Def in the instructions being added as part of this
204 // rematerialization group
205 for (auto &I : InstructionsToProcess)
206 I->replaceUsesOfWith(From: D, To: CurrentMaterialization);
207
208 // Don't replace the final use at this point as this can cause problems
209 // for other materializations. Instead, for any final use that uses a
210 // define that's being rematerialized, record the replace values
211 for (unsigned i = 0, E = Use->getNumOperands(); i != E; ++i)
212 if (Use->getOperand(i) == D) // Is this operand pointing to oldval?
213 FinalInstructionsToProcess.push_back(
214 Elt: {.Use: Use, .Def: D, .Remat: CurrentMaterialization});
215
216 InstructionsToProcess.push_back(Elt: CurrentMaterialization);
217 }
218 }
219
220 // Finally, replace the uses with the defines that we've just rematerialized
221 for (auto &R : FinalInstructionsToProcess) {
222 if (auto *PN = dyn_cast<PHINode>(Val: R.Use)) {
223 assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
224 "values in the PHINode");
225 PN->replaceAllUsesWith(V: R.Remat);
226 PN->eraseFromParent();
227 continue;
228 }
229 R.Use->replaceUsesOfWith(From: R.Def, To: R.Remat);
230 }
231}
232
233/// Default materializable callback
234// Check for instructions that we can recreate on resume as opposed to spill
235// the result into a coroutine frame.
236bool llvm::coro::defaultMaterializable(Instruction &V) {
237 return (isa<CastInst>(Val: &V) || isa<GetElementPtrInst>(Val: &V) ||
238 isa<BinaryOperator>(Val: &V) || isa<CmpInst>(Val: &V) || isa<SelectInst>(Val: &V));
239}
240
241bool llvm::coro::isTriviallyMaterializable(Instruction &V) {
242 return defaultMaterializable(V);
243}
244
245#ifndef NDEBUG
246static void dumpRemats(
247 StringRef Title,
248 const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> &RM) {
249 dbgs() << "------------- " << Title << "--------------\n";
250 for (const auto &E : RM) {
251 E.second->dump();
252 dbgs() << "--\n";
253 }
254}
255#endif
256
257void coro::doRematerializations(
258 Function &F, SuspendCrossingInfo &Checker,
259 std::function<bool(Instruction &)> IsMaterializable) {
260 if (F.hasOptNone())
261 return;
262
263 coro::SpillInfo Spills;
264
265 // See if there are materializable instructions across suspend points
266 // We record these as the starting point to also identify materializable
267 // defs of uses in these operations
268 for (Instruction &I : instructions(F)) {
269 if (!IsMaterializable(I))
270 continue;
271 for (User *U : I.users())
272 if (Checker.isDefinitionAcrossSuspend(I, U))
273 Spills[&I].push_back(Elt: cast<Instruction>(Val: U));
274 }
275
276 // Process each of the identified rematerializable instructions
277 // and add predecessor instructions that can also be rematerialized.
278 // This is actually a graph of instructions since we could potentially
279 // have multiple uses of a def in the set of predecessor instructions.
280 // The approach here is to maintain a graph of instructions for each bottom
281 // level instruction - where we have a unique set of instructions (nodes)
282 // and edges between them. We then walk the graph in reverse post-dominator
283 // order to insert them past the suspend point, but ensure that ordering is
284 // correct. We also rely on CSE removing duplicate defs for remats of
285 // different instructions with a def in common (rather than maintaining more
286 // complex graphs for each suspend point)
287
288 // We can do this by adding new nodes to the list for each suspend
289 // point. Then using standard GraphTraits to give a reverse post-order
290 // traversal when we insert the nodes after the suspend
291 SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> AllRemats;
292 for (auto &E : Spills) {
293 for (Instruction *U : E.second) {
294 // Don't process a user twice (this can happen if the instruction uses
295 // more than one rematerializable def)
296 auto [It, Inserted] = AllRemats.try_emplace(Key: U);
297 if (!Inserted)
298 continue;
299
300 // Constructor creates the whole RematGraph for the given Use
301 auto RematUPtr =
302 std::make_unique<RematGraph>(args&: IsMaterializable, args&: U, args&: Checker);
303
304 LLVM_DEBUG(dbgs() << "***** Next remat group *****\n";
305 ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get());
306 for (auto I = RPOT.begin(); I != RPOT.end();
307 ++I) { (*I)->Node->dump(); } dbgs()
308 << "\n";);
309
310 It->second = std::move(RematUPtr);
311 }
312 }
313
314 // Rewrite materializable instructions to be materialized at the use
315 // point.
316 LLVM_DEBUG(dumpRemats("Materializations", AllRemats));
317 rewriteMaterializableInstructions(AllRemats);
318}
319