1 | //===- DependencyGraph.cpp ------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" |
10 | #include "llvm/ADT/ArrayRef.h" |
11 | #include "llvm/SandboxIR/Instruction.h" |
12 | #include "llvm/SandboxIR/Utils.h" |
13 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h" |
14 | |
15 | namespace llvm::sandboxir { |
16 | |
17 | User::op_iterator PredIterator::skipBadIt(User::op_iterator OpIt, |
18 | User::op_iterator OpItE, |
19 | const DependencyGraph &DAG) { |
20 | auto Skip = [&DAG](auto OpIt) { |
21 | auto *I = dyn_cast<Instruction>((*OpIt).get()); |
22 | return I == nullptr || DAG.getNode(I) == nullptr; |
23 | }; |
24 | while (OpIt != OpItE && Skip(OpIt)) |
25 | ++OpIt; |
26 | return OpIt; |
27 | } |
28 | |
29 | PredIterator::value_type PredIterator::operator*() { |
30 | // If it's a DGNode then we dereference the operand iterator. |
31 | if (!isa<MemDGNode>(Val: N)) { |
32 | assert(OpIt != OpItE && "Can't dereference end iterator!" ); |
33 | return DAG->getNode(I: cast<Instruction>(Val: (Value *)*OpIt)); |
34 | } |
35 | // It's a MemDGNode, so we check if we return either the use-def operand, |
36 | // or a mem predecessor. |
37 | if (OpIt != OpItE) |
38 | return DAG->getNode(I: cast<Instruction>(Val: (Value *)*OpIt)); |
39 | // It's a MemDGNode with OpIt == end, so we need to use MemIt. |
40 | assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && |
41 | "Cant' dereference end iterator!" ); |
42 | return *MemIt; |
43 | } |
44 | |
45 | PredIterator &PredIterator::operator++() { |
46 | // If it's a DGNode then we increment the use-def iterator. |
47 | if (!isa<MemDGNode>(Val: N)) { |
48 | assert(OpIt != OpItE && "Already at end!" ); |
49 | ++OpIt; |
50 | // Skip operands that are not instructions or are outside the DAG. |
51 | OpIt = PredIterator::skipBadIt(OpIt, OpItE, DAG: *DAG); |
52 | return *this; |
53 | } |
54 | // It's a MemDGNode, so if we are not at the end of the use-def iterator we |
55 | // need to first increment that. |
56 | if (OpIt != OpItE) { |
57 | ++OpIt; |
58 | // Skip operands that are not instructions or are outside the DAG. |
59 | OpIt = PredIterator::skipBadIt(OpIt, OpItE, DAG: *DAG); |
60 | return *this; |
61 | } |
62 | // It's a MemDGNode with OpIt == end, so we need to increment MemIt. |
63 | assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!" ); |
64 | ++MemIt; |
65 | return *this; |
66 | } |
67 | |
68 | bool PredIterator::operator==(const PredIterator &Other) const { |
69 | assert(DAG == Other.DAG && "Iterators of different DAGs!" ); |
70 | assert(N == Other.N && "Iterators of different nodes!" ); |
71 | return OpIt == Other.OpIt && MemIt == Other.MemIt; |
72 | } |
73 | |
74 | void DGNode::setSchedBundle(SchedBundle &SB) { |
75 | if (this->SB != nullptr) |
76 | this->SB->eraseFromBundle(N: this); |
77 | this->SB = &SB; |
78 | } |
79 | |
80 | DGNode::~DGNode() { |
81 | if (SB == nullptr) |
82 | return; |
83 | SB->eraseFromBundle(N: this); |
84 | } |
85 | |
86 | #ifndef NDEBUG |
87 | void DGNode::print(raw_ostream &OS, bool PrintDeps) const { |
88 | OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n" ; |
89 | } |
90 | void DGNode::dump() const { print(dbgs()); } |
91 | void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const { |
92 | DGNode::print(OS, false); |
93 | if (PrintDeps) { |
94 | // Print memory preds. |
95 | static constexpr const unsigned Indent = 4; |
96 | for (auto *Pred : MemPreds) |
97 | OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n" ; |
98 | } |
99 | } |
100 | #endif // NDEBUG |
101 | |
102 | MemDGNode * |
103 | MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl, |
104 | const DependencyGraph &DAG) { |
105 | Instruction *I = Intvl.top(); |
106 | Instruction *BeforeI = Intvl.bottom(); |
107 | // Walk down the chain looking for a mem-dep candidate instruction. |
108 | while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI) |
109 | I = I->getNextNode(); |
110 | if (!DGNode::isMemDepNodeCandidate(I)) |
111 | return nullptr; |
112 | return cast<MemDGNode>(Val: DAG.getNode(I)); |
113 | } |
114 | |
115 | MemDGNode * |
116 | MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl, |
117 | const DependencyGraph &DAG) { |
118 | Instruction *I = Intvl.bottom(); |
119 | Instruction *AfterI = Intvl.top(); |
120 | // Walk up the chain looking for a mem-dep candidate instruction. |
121 | while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI) |
122 | I = I->getPrevNode(); |
123 | if (!DGNode::isMemDepNodeCandidate(I)) |
124 | return nullptr; |
125 | return cast<MemDGNode>(Val: DAG.getNode(I)); |
126 | } |
127 | |
128 | Interval<MemDGNode> |
129 | MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs, |
130 | DependencyGraph &DAG) { |
131 | if (Instrs.empty()) |
132 | return {}; |
133 | auto *TopMemN = getTopMemDGNode(Intvl: Instrs, DAG); |
134 | // If we couldn't find a mem node in range TopN - BotN then it's empty. |
135 | if (TopMemN == nullptr) |
136 | return {}; |
137 | auto *BotMemN = getBotMemDGNode(Intvl: Instrs, DAG); |
138 | assert(BotMemN != nullptr && "TopMemN should be null too!" ); |
139 | // Now that we have the mem-dep nodes, create and return the range. |
140 | return Interval<MemDGNode>(TopMemN, BotMemN); |
141 | } |
142 | |
143 | DependencyGraph::DependencyType |
144 | DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) { |
145 | // TODO: Perhaps compile-time improvement by skipping if neither is mem? |
146 | if (FromI->mayWriteToMemory()) { |
147 | if (ToI->mayReadFromMemory()) |
148 | return DependencyType::ReadAfterWrite; |
149 | if (ToI->mayWriteToMemory()) |
150 | return DependencyType::WriteAfterWrite; |
151 | } else if (FromI->mayReadFromMemory()) { |
152 | if (ToI->mayWriteToMemory()) |
153 | return DependencyType::WriteAfterRead; |
154 | } |
155 | if (isa<sandboxir::PHINode>(Val: FromI) || isa<sandboxir::PHINode>(Val: ToI)) |
156 | return DependencyType::Control; |
157 | if (ToI->isTerminator()) |
158 | return DependencyType::Control; |
159 | if (DGNode::isStackSaveOrRestoreIntrinsic(I: FromI) || |
160 | DGNode::isStackSaveOrRestoreIntrinsic(I: ToI)) |
161 | return DependencyType::Other; |
162 | return DependencyType::None; |
163 | } |
164 | |
165 | static bool isOrdered(Instruction *I) { |
166 | auto IsOrdered = [](Instruction *I) { |
167 | if (auto *LI = dyn_cast<LoadInst>(Val: I)) |
168 | return !LI->isUnordered(); |
169 | if (auto *SI = dyn_cast<StoreInst>(Val: I)) |
170 | return !SI->isUnordered(); |
171 | if (DGNode::isFenceLike(I)) |
172 | return true; |
173 | return false; |
174 | }; |
175 | bool Is = IsOrdered(I); |
176 | assert((!Is || DGNode::isMemDepCandidate(I)) && |
177 | "An ordered instruction must be a MemDepCandidate!" ); |
178 | return Is; |
179 | } |
180 | |
181 | bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI, |
182 | DependencyType DepType) { |
183 | std::optional<MemoryLocation> DstLocOpt = |
184 | Utils::memoryLocationGetOrNone(I: DstI); |
185 | if (!DstLocOpt) |
186 | return true; |
187 | // Check aliasing. |
188 | assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) && |
189 | "Expected a mem instr" ); |
190 | // TODO: Check AABudget |
191 | ModRefInfo SrcModRef = |
192 | isOrdered(I: SrcI) |
193 | ? ModRefInfo::ModRef |
194 | : Utils::aliasAnalysisGetModRefInfo(BatchAA&: *BatchAA, I: SrcI, OptLoc: *DstLocOpt); |
195 | switch (DepType) { |
196 | case DependencyType::ReadAfterWrite: |
197 | case DependencyType::WriteAfterWrite: |
198 | return isModSet(MRI: SrcModRef); |
199 | case DependencyType::WriteAfterRead: |
200 | return isRefSet(MRI: SrcModRef); |
201 | default: |
202 | llvm_unreachable("Expected only RAW, WAW and WAR!" ); |
203 | } |
204 | } |
205 | |
206 | bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) { |
207 | DependencyType RoughDepType = getRoughDepType(FromI: SrcI, ToI: DstI); |
208 | switch (RoughDepType) { |
209 | case DependencyType::ReadAfterWrite: |
210 | case DependencyType::WriteAfterWrite: |
211 | case DependencyType::WriteAfterRead: |
212 | return alias(SrcI, DstI, DepType: RoughDepType); |
213 | case DependencyType::Control: |
214 | // Adding actual dep edges from PHIs/to terminator would just create too |
215 | // many edges, which would be bad for compile-time. |
216 | // So we ignore them in the DAG formation but handle them in the |
217 | // scheduler, while sorting the ready list. |
218 | return false; |
219 | case DependencyType::Other: |
220 | return true; |
221 | case DependencyType::None: |
222 | return false; |
223 | } |
224 | llvm_unreachable("Unknown DependencyType enum" ); |
225 | } |
226 | |
227 | void DependencyGraph::scanAndAddDeps(MemDGNode &DstN, |
228 | const Interval<MemDGNode> &SrcScanRange) { |
229 | assert(isa<MemDGNode>(DstN) && |
230 | "DstN is the mem dep destination, so it must be mem" ); |
231 | Instruction *DstI = DstN.getInstruction(); |
232 | // Walk up the instruction chain from ScanRange bottom to top, looking for |
233 | // memory instrs that may alias. |
234 | for (MemDGNode &SrcN : reverse(C: SrcScanRange)) { |
235 | Instruction *SrcI = SrcN.getInstruction(); |
236 | if (hasDep(SrcI, DstI)) |
237 | DstN.addMemPred(PredN: &SrcN); |
238 | } |
239 | } |
240 | |
241 | void DependencyGraph::setDefUseUnscheduledSuccs( |
242 | const Interval<Instruction> &NewInterval) { |
243 | // +---+ |
244 | // | | Def |
245 | // | | | |
246 | // | | v |
247 | // | | Use |
248 | // +---+ |
249 | // Set the intra-interval counters in NewInterval. |
250 | for (Instruction &I : NewInterval) { |
251 | for (Value *Op : I.operands()) { |
252 | auto *OpI = dyn_cast<Instruction>(Val: Op); |
253 | if (OpI == nullptr) |
254 | continue; |
255 | // TODO: For now don't cross BBs. |
256 | if (OpI->getParent() != I.getParent()) |
257 | continue; |
258 | if (!NewInterval.contains(I: OpI)) |
259 | continue; |
260 | auto *OpN = getNode(I: OpI); |
261 | if (OpN == nullptr) |
262 | continue; |
263 | ++OpN->UnscheduledSuccs; |
264 | } |
265 | } |
266 | |
267 | // Now handle the cross-interval edges. |
268 | bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(Other: DAGInterval); |
269 | const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; |
270 | const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; |
271 | // +---+ |
272 | // |Top| |
273 | // | | Def |
274 | // +---+ | |
275 | // | | v |
276 | // |Bot| Use |
277 | // | | |
278 | // +---+ |
279 | // Walk over all instructions in "BotInterval" and update the counter |
280 | // of operands that are in "TopInterval". |
281 | for (Instruction &BotI : BotInterval) { |
282 | auto *BotN = getNode(I: &BotI); |
283 | // Skip scheduled nodes. |
284 | if (BotN->scheduled()) |
285 | continue; |
286 | for (Value *Op : BotI.operands()) { |
287 | auto *OpI = dyn_cast<Instruction>(Val: Op); |
288 | if (OpI == nullptr) |
289 | continue; |
290 | auto *OpN = getNode(I: OpI); |
291 | if (OpN == nullptr) |
292 | continue; |
293 | if (!TopInterval.contains(I: OpI)) |
294 | continue; |
295 | ++OpN->UnscheduledSuccs; |
296 | } |
297 | } |
298 | } |
299 | |
300 | void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) { |
301 | // Create Nodes only for the new sections of the DAG. |
302 | DGNode *LastN = getOrCreateNode(I: NewInterval.top()); |
303 | MemDGNode *LastMemN = dyn_cast<MemDGNode>(Val: LastN); |
304 | for (Instruction &I : drop_begin(RangeOrContainer: NewInterval)) { |
305 | auto *N = getOrCreateNode(I: &I); |
306 | // Build the Mem node chain. |
307 | if (auto *MemN = dyn_cast<MemDGNode>(Val: N)) { |
308 | MemN->setPrevNode(LastMemN); |
309 | LastMemN = MemN; |
310 | } |
311 | } |
312 | // Link new MemDGNode chain with the old one, if any. |
313 | if (!DAGInterval.empty()) { |
314 | bool NewIsAbove = NewInterval.comesBefore(Other: DAGInterval); |
315 | const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; |
316 | const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; |
317 | MemDGNode *LinkTopN = |
318 | MemDGNodeIntervalBuilder::getBotMemDGNode(Intvl: TopInterval, DAG: *this); |
319 | MemDGNode *LinkBotN = |
320 | MemDGNodeIntervalBuilder::getTopMemDGNode(Intvl: BotInterval, DAG: *this); |
321 | assert((LinkTopN == nullptr || LinkBotN == nullptr || |
322 | LinkTopN->comesBefore(LinkBotN)) && |
323 | "Wrong order!" ); |
324 | if (LinkTopN != nullptr && LinkBotN != nullptr) { |
325 | LinkTopN->setNextNode(LinkBotN); |
326 | } |
327 | #ifndef NDEBUG |
328 | // TODO: Remove this once we've done enough testing. |
329 | // Check that the chain is well formed. |
330 | auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval); |
331 | MemDGNode *ChainTopN = |
332 | MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this); |
333 | MemDGNode *ChainBotN = |
334 | MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this); |
335 | if (ChainTopN != nullptr && ChainBotN != nullptr) { |
336 | for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr; |
337 | LastN = N, N = N->getNextNode()) { |
338 | assert(N == LastN->getNextNode() && "Bad chain!" ); |
339 | assert(N->getPrevNode() == LastN && "Bad chain!" ); |
340 | } |
341 | } |
342 | #endif // NDEBUG |
343 | } |
344 | |
345 | setDefUseUnscheduledSuccs(NewInterval); |
346 | } |
347 | |
348 | MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN, |
349 | MemDGNode *SkipN) const { |
350 | auto *I = N->getInstruction(); |
351 | for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr; |
352 | PrevI = PrevI->getPrevNode()) { |
353 | auto *PrevN = getNodeOrNull(I: PrevI); |
354 | if (PrevN == nullptr) |
355 | return nullptr; |
356 | auto *PrevMemN = dyn_cast<MemDGNode>(Val: PrevN); |
357 | if (PrevMemN != nullptr && PrevMemN != SkipN) |
358 | return PrevMemN; |
359 | } |
360 | return nullptr; |
361 | } |
362 | |
363 | MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN, |
364 | MemDGNode *SkipN) const { |
365 | auto *I = N->getInstruction(); |
366 | for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr; |
367 | NextI = NextI->getNextNode()) { |
368 | auto *NextN = getNodeOrNull(I: NextI); |
369 | if (NextN == nullptr) |
370 | return nullptr; |
371 | auto *NextMemN = dyn_cast<MemDGNode>(Val: NextN); |
372 | if (NextMemN != nullptr && NextMemN != SkipN) |
373 | return NextMemN; |
374 | } |
375 | return nullptr; |
376 | } |
377 | |
378 | void DependencyGraph::notifyCreateInstr(Instruction *I) { |
379 | if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting) |
380 | // We don't maintain the DAG while reverting. |
381 | return; |
382 | // Nothing to do if the node is not in the focus range of the DAG. |
383 | if (!(DAGInterval.contains(I) || DAGInterval.touches(Elm: I))) |
384 | return; |
385 | // Include `I` into the interval. |
386 | DAGInterval = DAGInterval.getUnionInterval(Other: {I, I}); |
387 | auto *N = getOrCreateNode(I); |
388 | auto *MemN = dyn_cast<MemDGNode>(Val: N); |
389 | |
390 | // Update the MemDGNode chain if this is a memory node. |
391 | if (MemN != nullptr) { |
392 | if (auto *PrevMemN = getMemDGNodeBefore(N: MemN, /*IncludingN=*/false)) { |
393 | PrevMemN->NextMemN = MemN; |
394 | MemN->PrevMemN = PrevMemN; |
395 | } |
396 | if (auto *NextMemN = getMemDGNodeAfter(N: MemN, /*IncludingN=*/false)) { |
397 | NextMemN->PrevMemN = MemN; |
398 | MemN->NextMemN = NextMemN; |
399 | } |
400 | |
401 | // Add Mem dependencies. |
402 | // 1. Scan for deps above `I` for deps to `I`: AboveN->MemN. |
403 | if (DAGInterval.top()->comesBefore(Other: I)) { |
404 | Interval<Instruction> AboveIntvl(DAGInterval.top(), I->getPrevNode()); |
405 | auto SrcInterval = MemDGNodeIntervalBuilder::make(Instrs: AboveIntvl, DAG&: *this); |
406 | scanAndAddDeps(DstN&: *MemN, SrcScanRange: SrcInterval); |
407 | } |
408 | // 2. Scan for deps below `I` for deps from `I`: MemN->BelowN. |
409 | if (I->comesBefore(Other: DAGInterval.bottom())) { |
410 | Interval<Instruction> BelowIntvl(I->getNextNode(), DAGInterval.bottom()); |
411 | for (MemDGNode &BelowN : |
412 | MemDGNodeIntervalBuilder::make(Instrs: BelowIntvl, DAG&: *this)) |
413 | scanAndAddDeps(DstN&: BelowN, SrcScanRange: Interval<MemDGNode>(MemN, MemN)); |
414 | } |
415 | } |
416 | } |
417 | |
418 | void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) { |
419 | if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting) |
420 | // We don't maintain the DAG while reverting. |
421 | return; |
422 | // NOTE: This function runs before `I` moves to its new destination. |
423 | BasicBlock *BB = To.getNodeParent(); |
424 | assert(!(To != BB->end() && &*To == I->getNextNode()) && |
425 | !(To == BB->end() && std::next(I->getIterator()) == BB->end()) && |
426 | "Should not have been called if destination is same as origin." ); |
427 | |
428 | // TODO: We can only handle fully internal movements within DAGInterval or at |
429 | // the borders, i.e., right before the top or right after the bottom. |
430 | assert(To.getNodeParent() == I->getParent() && |
431 | "TODO: We don't support movement across BBs!" ); |
432 | assert( |
433 | (To == std::next(DAGInterval.bottom()->getIterator()) || |
434 | (To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) || |
435 | (To != BB->end() && DAGInterval.contains(&*To))) && |
436 | "TODO: To should be either within the DAGInterval or right " |
437 | "before/after it." ); |
438 | |
439 | // Make a copy of the DAGInterval before we update it. |
440 | auto OrigDAGInterval = DAGInterval; |
441 | |
442 | // Maintain the DAGInterval. |
443 | DAGInterval.notifyMoveInstr(I, BeforeIt: To); |
444 | |
445 | // TODO: Perhaps check if this is legal by checking the dependencies? |
446 | |
447 | // Update the MemDGNode chain to reflect the instr movement if necessary. |
448 | DGNode *N = getNodeOrNull(I); |
449 | if (N == nullptr) |
450 | return; |
451 | MemDGNode *MemN = dyn_cast<MemDGNode>(Val: N); |
452 | if (MemN == nullptr) |
453 | return; |
454 | |
455 | // First safely detach it from the existing chain. |
456 | MemN->detachFromChain(); |
457 | |
458 | // Now insert it back into the chain at the new location. |
459 | // |
460 | // We won't always have a DGNode to insert before it. If `To` is BB->end() or |
461 | // if it points to an instr after DAGInterval.bottom() then we will have to |
462 | // find a node to insert *after*. |
463 | // |
464 | // BB: BB: |
465 | // I1 I1 ^ |
466 | // I2 I2 | DAGInteval [I1 to I3] |
467 | // I3 I3 V |
468 | // I4 I4 <- `To` == right after DAGInterval |
469 | // <- `To` == BB->end() |
470 | // |
471 | if (To == BB->end() || |
472 | To == std::next(x: OrigDAGInterval.bottom()->getIterator())) { |
473 | // If we don't have a node to insert before, find a node to insert after and |
474 | // update the chain. |
475 | DGNode *InsertAfterN = getNode(I: &*std::prev(x: To)); |
476 | MemN->setPrevNode( |
477 | getMemDGNodeBefore(N: InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN)); |
478 | } else { |
479 | // We have a node to insert before, so update the chain. |
480 | DGNode *BeforeToN = getNode(I: &*To); |
481 | MemN->setPrevNode( |
482 | getMemDGNodeBefore(N: BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN)); |
483 | MemN->setNextNode( |
484 | getMemDGNodeAfter(N: BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN)); |
485 | } |
486 | } |
487 | |
488 | void DependencyGraph::notifyEraseInstr(Instruction *I) { |
489 | if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting) |
490 | // We don't maintain the DAG while reverting. |
491 | return; |
492 | auto *N = getNode(I); |
493 | if (N == nullptr) |
494 | // Early return if there is no DAG node for `I`. |
495 | return; |
496 | if (auto *MemN = dyn_cast<MemDGNode>(Val: getNode(I))) { |
497 | // Update the MemDGNode chain if this is a memory node. |
498 | auto *PrevMemN = getMemDGNodeBefore(N: MemN, /*IncludingN=*/false); |
499 | auto *NextMemN = getMemDGNodeAfter(N: MemN, /*IncludingN=*/false); |
500 | if (PrevMemN != nullptr) |
501 | PrevMemN->NextMemN = NextMemN; |
502 | if (NextMemN != nullptr) |
503 | NextMemN->PrevMemN = PrevMemN; |
504 | |
505 | // Drop the memory dependencies from both predecessors and successors. |
506 | while (!MemN->memPreds().empty()) { |
507 | auto *PredN = *MemN->memPreds().begin(); |
508 | MemN->removeMemPred(PredN); |
509 | } |
510 | while (!MemN->memSuccs().empty()) { |
511 | auto *SuccN = *MemN->memSuccs().begin(); |
512 | SuccN->removeMemPred(PredN: MemN); |
513 | } |
514 | // NOTE: The unscheduled succs for MemNodes get updated be setMemPred(). |
515 | } else { |
516 | // If this is a non-mem node we only need to update UnscheduledSuccs. |
517 | if (!N->scheduled()) |
518 | for (auto *PredN : N->preds(DAG&: *this)) |
519 | PredN->decrUnscheduledSuccs(); |
520 | } |
521 | // Finally erase the Node. |
522 | InstrToNodeMap.erase(Val: I); |
523 | } |
524 | |
525 | void DependencyGraph::notifySetUse(const Use &U, Value *NewSrc) { |
526 | // Update the UnscheduledSuccs counter for both the current source and NewSrc |
527 | // if needed. |
528 | if (auto *CurrSrcI = dyn_cast<Instruction>(Val: U.get())) { |
529 | if (auto *CurrSrcN = getNode(I: CurrSrcI)) { |
530 | CurrSrcN->decrUnscheduledSuccs(); |
531 | } |
532 | } |
533 | if (auto *NewSrcI = dyn_cast<Instruction>(Val: NewSrc)) { |
534 | if (auto *NewSrcN = getNode(I: NewSrcI)) { |
535 | ++NewSrcN->UnscheduledSuccs; |
536 | } |
537 | } |
538 | } |
539 | |
540 | Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) { |
541 | if (Instrs.empty()) |
542 | return {}; |
543 | |
544 | Interval<Instruction> InstrsInterval(Instrs); |
545 | Interval<Instruction> Union = DAGInterval.getUnionInterval(Other: InstrsInterval); |
546 | auto NewInterval = Union.getSingleDiff(Other: DAGInterval); |
547 | if (NewInterval.empty()) |
548 | return {}; |
549 | |
550 | createNewNodes(NewInterval); |
551 | |
552 | // Create the dependencies. |
553 | // |
554 | // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval. |
555 | // +---+ - - |
556 | // | | SrcN | | |
557 | // | | | | SrcRange | |
558 | // |New| v | | DstRange |
559 | // | | DstN - | |
560 | // | | | |
561 | // +---+ - |
562 | // We are scanning for deps with destination in NewInterval and sources in |
563 | // NewInterval until DstN, for each DstN. |
564 | auto FullScan = [this](const Interval<Instruction> Intvl) { |
565 | auto DstRange = MemDGNodeIntervalBuilder::make(Instrs: Intvl, DAG&: *this); |
566 | if (!DstRange.empty()) { |
567 | for (MemDGNode &DstN : drop_begin(RangeOrContainer&: DstRange)) { |
568 | auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode()); |
569 | scanAndAddDeps(DstN, SrcScanRange: SrcRange); |
570 | } |
571 | } |
572 | }; |
573 | auto MemDAGInterval = MemDGNodeIntervalBuilder::make(Instrs: DAGInterval, DAG&: *this); |
574 | if (MemDAGInterval.empty()) { |
575 | FullScan(NewInterval); |
576 | } |
577 | // 2. The new section is below the old section. |
578 | // +---+ - |
579 | // | | | |
580 | // |Old| SrcN | |
581 | // | | | | |
582 | // +---+ | | SrcRange |
583 | // +---+ | | - |
584 | // | | | | | |
585 | // |New| v | | DstRange |
586 | // | | DstN - | |
587 | // | | | |
588 | // +---+ - |
589 | // We are scanning for deps with destination in NewInterval because the deps |
590 | // in DAGInterval have already been computed. We consider sources in the whole |
591 | // range including both NewInterval and DAGInterval until DstN, for each DstN. |
592 | else if (DAGInterval.bottom()->comesBefore(Other: NewInterval.top())) { |
593 | auto DstRange = MemDGNodeIntervalBuilder::make(Instrs: NewInterval, DAG&: *this); |
594 | auto SrcRangeFull = MemDAGInterval.getUnionInterval(Other: DstRange); |
595 | for (MemDGNode &DstN : DstRange) { |
596 | auto SrcRange = |
597 | Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode()); |
598 | scanAndAddDeps(DstN, SrcScanRange: SrcRange); |
599 | } |
600 | } |
601 | // 3. The new section is above the old section. |
602 | else if (NewInterval.bottom()->comesBefore(Other: DAGInterval.top())) { |
603 | // +---+ - - |
604 | // | | SrcN | | |
605 | // |New| | | SrcRange | DstRange |
606 | // | | v | | |
607 | // | | DstN - | |
608 | // | | | |
609 | // +---+ - |
610 | // +---+ |
611 | // |Old| |
612 | // | | |
613 | // +---+ |
614 | // When scanning for deps with destination in NewInterval we need to fully |
615 | // scan the interval. This is the same as the scanning for a new DAG. |
616 | FullScan(NewInterval); |
617 | |
618 | // +---+ - |
619 | // | | | |
620 | // |New| SrcN | SrcRange |
621 | // | | | | |
622 | // | | | | |
623 | // | | | | |
624 | // +---+ | - |
625 | // +---+ | - |
626 | // |Old| v | DstRange |
627 | // | | DstN | |
628 | // +---+ - |
629 | // When scanning for deps with destination in DAGInterval we need to |
630 | // consider sources from the NewInterval only, because all intra-DAGInterval |
631 | // dependencies have already been created. |
632 | auto DstRangeOld = MemDAGInterval; |
633 | auto SrcRange = MemDGNodeIntervalBuilder::make(Instrs: NewInterval, DAG&: *this); |
634 | for (MemDGNode &DstN : DstRangeOld) |
635 | scanAndAddDeps(DstN, SrcScanRange: SrcRange); |
636 | } else { |
637 | llvm_unreachable("We don't expect extending in both directions!" ); |
638 | } |
639 | |
640 | DAGInterval = Union; |
641 | return NewInterval; |
642 | } |
643 | |
644 | #ifndef NDEBUG |
645 | void DependencyGraph::print(raw_ostream &OS) const { |
646 | // InstrToNodeMap is unordered so we need to create an ordered vector. |
647 | SmallVector<DGNode *> Nodes; |
648 | Nodes.reserve(InstrToNodeMap.size()); |
649 | for (const auto &Pair : InstrToNodeMap) |
650 | Nodes.push_back(Pair.second.get()); |
651 | // Sort them based on which one comes first in the BB. |
652 | sort(Nodes, [](DGNode *N1, DGNode *N2) { |
653 | return N1->getInstruction()->comesBefore(N2->getInstruction()); |
654 | }); |
655 | for (auto *N : Nodes) |
656 | N->print(OS, /*PrintDeps=*/true); |
657 | } |
658 | |
659 | void DependencyGraph::dump() const { |
660 | print(dbgs()); |
661 | dbgs() << "\n" ; |
662 | } |
663 | #endif // NDEBUG |
664 | |
665 | } // namespace llvm::sandboxir |
666 | |