1//===- DependencyGraph.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
10#include "llvm/ADT/ArrayRef.h"
11#include "llvm/SandboxIR/Instruction.h"
12#include "llvm/SandboxIR/Utils.h"
13#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
14
15namespace llvm::sandboxir {
16
17User::op_iterator PredIterator::skipBadIt(User::op_iterator OpIt,
18 User::op_iterator OpItE,
19 const DependencyGraph &DAG) {
20 auto Skip = [&DAG](auto OpIt) {
21 auto *I = dyn_cast<Instruction>((*OpIt).get());
22 return I == nullptr || DAG.getNode(I) == nullptr;
23 };
24 while (OpIt != OpItE && Skip(OpIt))
25 ++OpIt;
26 return OpIt;
27}
28
29PredIterator::value_type PredIterator::operator*() {
30 // If it's a DGNode then we dereference the operand iterator.
31 if (!isa<MemDGNode>(Val: N)) {
32 assert(OpIt != OpItE && "Can't dereference end iterator!");
33 return DAG->getNode(I: cast<Instruction>(Val: (Value *)*OpIt));
34 }
35 // It's a MemDGNode, so we check if we return either the use-def operand,
36 // or a mem predecessor.
37 if (OpIt != OpItE)
38 return DAG->getNode(I: cast<Instruction>(Val: (Value *)*OpIt));
39 // It's a MemDGNode with OpIt == end, so we need to use MemIt.
40 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
41 "Cant' dereference end iterator!");
42 return *MemIt;
43}
44
45PredIterator &PredIterator::operator++() {
46 // If it's a DGNode then we increment the use-def iterator.
47 if (!isa<MemDGNode>(Val: N)) {
48 assert(OpIt != OpItE && "Already at end!");
49 ++OpIt;
50 // Skip operands that are not instructions or are outside the DAG.
51 OpIt = PredIterator::skipBadIt(OpIt, OpItE, DAG: *DAG);
52 return *this;
53 }
54 // It's a MemDGNode, so if we are not at the end of the use-def iterator we
55 // need to first increment that.
56 if (OpIt != OpItE) {
57 ++OpIt;
58 // Skip operands that are not instructions or are outside the DAG.
59 OpIt = PredIterator::skipBadIt(OpIt, OpItE, DAG: *DAG);
60 return *this;
61 }
62 // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
63 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
64 ++MemIt;
65 return *this;
66}
67
68bool PredIterator::operator==(const PredIterator &Other) const {
69 assert(DAG == Other.DAG && "Iterators of different DAGs!");
70 assert(N == Other.N && "Iterators of different nodes!");
71 return OpIt == Other.OpIt && MemIt == Other.MemIt;
72}
73
74User::user_iterator SuccIterator::skipOutOfScope(User::user_iterator UserIt,
75 User::user_iterator UserItE,
76 const DependencyGraph &DAG) {
77 auto Skip = [&DAG](User::user_iterator UserIt) {
78 auto *I = dyn_cast<Instruction>(Val: *UserIt);
79 return I == nullptr || DAG.getNode(I) == nullptr;
80 };
81 while (UserIt != UserItE && Skip(UserIt))
82 ++UserIt;
83 return UserIt;
84}
85
86SuccIterator::value_type SuccIterator::operator*() {
87 // If it's a DGNode then we dereference the user iterator.
88 if (!isa<MemDGNode>(Val: N)) {
89 assert(UserIt != UserItE && "Can't dereference end iterator!");
90 return DAG->getNode(I: cast<Instruction>(Val: (Value *)*UserIt));
91 }
92 // It's a MemDGNode, so we check if we return either the def-use operand,
93 // or a mem predecessor.
94 if (UserIt != UserItE)
95 return DAG->getNode(I: cast<Instruction>(Val: (Value *)*UserIt));
96 // It's a MemDGNode with UserIt == end, so we need to use MemIt.
97 assert(MemIt != cast<MemDGNode>(N)->MemSuccs.end() &&
98 "Cant' dereference end iterator!");
99 return *MemIt;
100}
101
102SuccIterator &SuccIterator::operator++() {
103 // If it's a DGNode then we increment the use-def iterator.
104 if (!isa<MemDGNode>(Val: N)) {
105 assert(UserIt != UserItE && "Already at end!");
106 ++UserIt;
107 // Skip users that are not instructions or are outside the DAG.
108 UserIt = SuccIterator::skipOutOfScope(UserIt, UserItE, DAG: *DAG);
109 return *this;
110 }
111 // It's a MemDGNode, so if we are not at the end of the def-use iterator we
112 // need to first increment that.
113 if (UserIt != UserItE) {
114 ++UserIt;
115 // Skip operands that are not instructions or are outside the DAG.
116 UserIt = SuccIterator::skipOutOfScope(UserIt, UserItE, DAG: *DAG);
117 return *this;
118 }
119 // It's a MemDGNode with UserIt == end, so we need to increment MemIt.
120 assert(MemIt != cast<MemDGNode>(N)->MemSuccs.end() && "Already at end!");
121 ++MemIt;
122 return *this;
123}
124
125bool SuccIterator::operator==(const SuccIterator &Other) const {
126 assert(DAG == Other.DAG && "Iterators of different DAGs!");
127 assert(N == Other.N && "Iterators of different nodes!");
128 return UserIt == Other.UserIt && MemIt == Other.MemIt;
129}
130
131void DGNode::setSchedBundle(SchedBundle &SB) {
132 if (this->SB != nullptr)
133 this->SB->eraseFromBundle(N: this);
134 this->SB = &SB;
135}
136
137DGNode::~DGNode() {
138 if (SB == nullptr)
139 return;
140 SB->eraseFromBundle(N: this);
141}
142
143#ifndef NDEBUG
144void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
145 OS << *I << " USuccs:" << UnscheduledSuccs << " UPreds:" << UnscheduledPreds
146 << " Sched:" << Scheduled << "\n";
147}
148void DGNode::dump() const { print(dbgs()); }
149void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
150 DGNode::print(OS, false);
151 if (PrintDeps) {
152 // Print memory preds.
153 static constexpr unsigned Indent = 4;
154 for (auto *Pred : MemPreds)
155 OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n";
156 }
157}
158#endif // NDEBUG
159
160MemDGNode *
161MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl,
162 const DependencyGraph &DAG) {
163 Instruction *I = Intvl.top();
164 Instruction *BeforeI = Intvl.bottom();
165 // Walk down the chain looking for a mem-dep candidate instruction.
166 while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
167 I = I->getNextNode();
168 if (!DGNode::isMemDepNodeCandidate(I))
169 return nullptr;
170 return cast<MemDGNode>(Val: DAG.getNode(I));
171}
172
173MemDGNode *
174MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl,
175 const DependencyGraph &DAG) {
176 Instruction *I = Intvl.bottom();
177 Instruction *AfterI = Intvl.top();
178 // Walk up the chain looking for a mem-dep candidate instruction.
179 while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
180 I = I->getPrevNode();
181 if (!DGNode::isMemDepNodeCandidate(I))
182 return nullptr;
183 return cast<MemDGNode>(Val: DAG.getNode(I));
184}
185
186Interval<MemDGNode>
187MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
188 DependencyGraph &DAG) {
189 if (Instrs.empty())
190 return {};
191 auto *TopMemN = getTopMemDGNode(Intvl: Instrs, DAG);
192 // If we couldn't find a mem node in range TopN - BotN then it's empty.
193 if (TopMemN == nullptr)
194 return {};
195 auto *BotMemN = getBotMemDGNode(Intvl: Instrs, DAG);
196 assert(BotMemN != nullptr && "TopMemN should be null too!");
197 // Now that we have the mem-dep nodes, create and return the range.
198 return Interval<MemDGNode>(TopMemN, BotMemN);
199}
200
201DependencyGraph::DependencyType
202DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) {
203 // TODO: Perhaps compile-time improvement by skipping if neither is mem?
204 if (FromI->mayWriteToMemory()) {
205 if (ToI->mayReadFromMemory())
206 return DependencyType::ReadAfterWrite;
207 if (ToI->mayWriteToMemory())
208 return DependencyType::WriteAfterWrite;
209 } else if (FromI->mayReadFromMemory()) {
210 if (ToI->mayWriteToMemory())
211 return DependencyType::WriteAfterRead;
212 }
213 if (isa<sandboxir::PHINode>(Val: FromI) || isa<sandboxir::PHINode>(Val: ToI))
214 return DependencyType::Control;
215 if (ToI->isTerminator())
216 return DependencyType::Control;
217 if (DGNode::isStackSaveOrRestoreIntrinsic(I: FromI) ||
218 DGNode::isStackSaveOrRestoreIntrinsic(I: ToI))
219 return DependencyType::Other;
220 return DependencyType::None;
221}
222
223static bool isOrdered(Instruction *I) {
224 auto IsOrdered = [](Instruction *I) {
225 if (auto *LI = dyn_cast<LoadInst>(Val: I))
226 return !LI->isUnordered();
227 if (auto *SI = dyn_cast<StoreInst>(Val: I))
228 return !SI->isUnordered();
229 if (DGNode::isFenceLike(I))
230 return true;
231 return false;
232 };
233 bool Is = IsOrdered(I);
234 assert((!Is || DGNode::isMemDepCandidate(I)) &&
235 "An ordered instruction must be a MemDepCandidate!");
236 return Is;
237}
238
239bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI,
240 DependencyType DepType) {
241 std::optional<MemoryLocation> DstLocOpt =
242 Utils::memoryLocationGetOrNone(I: DstI);
243 if (!DstLocOpt)
244 return true;
245 // Check aliasing.
246 assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) &&
247 "Expected a mem instr");
248 // TODO: Check AABudget
249 ModRefInfo SrcModRef =
250 isOrdered(I: SrcI)
251 ? ModRefInfo::ModRef
252 : Utils::aliasAnalysisGetModRefInfo(BatchAA&: *BatchAA, I: SrcI, OptLoc: *DstLocOpt);
253 switch (DepType) {
254 case DependencyType::ReadAfterWrite:
255 case DependencyType::WriteAfterWrite:
256 return isModSet(MRI: SrcModRef);
257 case DependencyType::WriteAfterRead:
258 return isRefSet(MRI: SrcModRef);
259 default:
260 llvm_unreachable("Expected only RAW, WAW and WAR!");
261 }
262}
263
264bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
265 DependencyType RoughDepType = getRoughDepType(FromI: SrcI, ToI: DstI);
266 switch (RoughDepType) {
267 case DependencyType::ReadAfterWrite:
268 case DependencyType::WriteAfterWrite:
269 case DependencyType::WriteAfterRead:
270 return alias(SrcI, DstI, DepType: RoughDepType);
271 case DependencyType::Control:
272 // Adding actual dep edges from PHIs/to terminator would just create too
273 // many edges, which would be bad for compile-time.
274 // So we ignore them in the DAG formation but handle them in the
275 // scheduler, while sorting the ready list.
276 return false;
277 case DependencyType::Other:
278 return true;
279 case DependencyType::None:
280 return false;
281 }
282 llvm_unreachable("Unknown DependencyType enum");
283}
284
285void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
286 const Interval<MemDGNode> &SrcScanRange) {
287 assert(isa<MemDGNode>(DstN) &&
288 "DstN is the mem dep destination, so it must be mem");
289 Instruction *DstI = DstN.getInstruction();
290 // Walk up the instruction chain from ScanRange bottom to top, looking for
291 // memory instrs that may alias.
292 for (MemDGNode &SrcN : reverse(C: SrcScanRange)) {
293 Instruction *SrcI = SrcN.getInstruction();
294 if (hasDep(SrcI, DstI))
295 DstN.addMemPred(PredN: &SrcN);
296 }
297}
298
299void DependencyGraph::setDefUseUnscheduledSuccs(
300 const Interval<Instruction> &NewInterval) {
301 // +---+
302 // | | Def
303 // | | |
304 // | | v
305 // | | Use
306 // +---+
307 // Set the intra-interval counters in NewInterval.
308 for (Instruction &I : NewInterval) {
309 unsigned CntUnschedPreds = 0;
310 for (Value *Op : I.operands()) {
311 auto *OpI = dyn_cast<Instruction>(Val: Op);
312 if (OpI == nullptr)
313 continue;
314 // TODO: For now don't cross BBs.
315 if (OpI->getParent() != I.getParent())
316 continue;
317 if (!NewInterval.contains(I: OpI))
318 continue;
319 auto *OpN = getNode(I: OpI);
320 if (OpN == nullptr)
321 continue;
322 OpN->incrUnscheduledSuccs();
323 if (!OpN->scheduled())
324 ++CntUnschedPreds;
325 }
326 getNode(I: &I)->UnscheduledPreds = CntUnschedPreds;
327 }
328
329 // Now handle the cross-interval edges.
330 bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(Other: DAGInterval);
331 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
332 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
333 // +---+
334 // |Top|
335 // | | Def
336 // +---+ |
337 // | | v
338 // |Bot| Use
339 // | |
340 // +---+
341 // Walk over all instructions in "BotInterval" and update the counter
342 // of operands that are in "TopInterval".
343 for (Instruction &BotI : BotInterval) {
344 auto *BotN = getNode(I: &BotI);
345 // Skip scheduled nodes.
346 if (BotN->scheduled())
347 continue;
348 unsigned CntUnscheduledPreds = 0;
349 for (Value *Op : BotI.operands()) {
350 auto *OpI = dyn_cast<Instruction>(Val: Op);
351 if (OpI == nullptr)
352 continue;
353 auto *OpN = getNode(I: OpI);
354 if (OpN == nullptr)
355 continue;
356 if (!TopInterval.contains(I: OpI))
357 continue;
358 if (!OpN->scheduled()) {
359 OpN->incrUnscheduledSuccs();
360 ++CntUnscheduledPreds;
361 }
362 }
363 *BotN->UnscheduledPreds += CntUnscheduledPreds;
364 }
365}
366
367void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
368 // Create Nodes only for the new sections of the DAG.
369 DGNode *LastN = getOrCreateNode(I: NewInterval.top());
370 MemDGNode *LastMemN = dyn_cast<MemDGNode>(Val: LastN);
371 for (Instruction &I : drop_begin(RangeOrContainer: NewInterval)) {
372 auto *N = getOrCreateNode(I: &I);
373 // Build the Mem node chain.
374 if (auto *MemN = dyn_cast<MemDGNode>(Val: N)) {
375 MemN->setPrevNode(LastMemN);
376 LastMemN = MemN;
377 }
378 }
379 // Link new MemDGNode chain with the old one, if any.
380 if (!DAGInterval.empty()) {
381 bool NewIsAbove = NewInterval.comesBefore(Other: DAGInterval);
382 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
383 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
384 MemDGNode *LinkTopN =
385 MemDGNodeIntervalBuilder::getBotMemDGNode(Intvl: TopInterval, DAG: *this);
386 MemDGNode *LinkBotN =
387 MemDGNodeIntervalBuilder::getTopMemDGNode(Intvl: BotInterval, DAG: *this);
388 assert((LinkTopN == nullptr || LinkBotN == nullptr ||
389 LinkTopN->comesBefore(LinkBotN)) &&
390 "Wrong order!");
391 if (LinkTopN != nullptr && LinkBotN != nullptr) {
392 LinkTopN->setNextNode(LinkBotN);
393 }
394#ifndef NDEBUG
395 // TODO: Remove this once we've done enough testing.
396 // Check that the chain is well formed.
397 auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
398 MemDGNode *ChainTopN =
399 MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
400 MemDGNode *ChainBotN =
401 MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
402 if (ChainTopN != nullptr && ChainBotN != nullptr) {
403 for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
404 LastN = N, N = N->getNextNode()) {
405 assert(N == LastN->getNextNode() && "Bad chain!");
406 assert(N->getPrevNode() == LastN && "Bad chain!");
407 }
408 }
409#endif // NDEBUG
410 }
411
412 setDefUseUnscheduledSuccs(NewInterval);
413}
414
415MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN,
416 MemDGNode *SkipN) const {
417 auto *I = N->getInstruction();
418 for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
419 PrevI = PrevI->getPrevNode()) {
420 auto *PrevN = getNodeOrNull(I: PrevI);
421 if (PrevN == nullptr)
422 return nullptr;
423 auto *PrevMemN = dyn_cast<MemDGNode>(Val: PrevN);
424 if (PrevMemN != nullptr && PrevMemN != SkipN)
425 return PrevMemN;
426 }
427 return nullptr;
428}
429
430MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
431 MemDGNode *SkipN) const {
432 auto *I = N->getInstruction();
433 for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
434 NextI = NextI->getNextNode()) {
435 auto *NextN = getNodeOrNull(I: NextI);
436 if (NextN == nullptr)
437 return nullptr;
438 auto *NextMemN = dyn_cast<MemDGNode>(Val: NextN);
439 if (NextMemN != nullptr && NextMemN != SkipN)
440 return NextMemN;
441 }
442 return nullptr;
443}
444
445void DependencyGraph::notifyCreateInstr(Instruction *I) {
446 if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting)
447 // We don't maintain the DAG while reverting.
448 return;
449 // Nothing to do if the node is not in the focus range of the DAG.
450 if (!(DAGInterval.contains(I) || DAGInterval.touches(Elm: I)))
451 return;
452 // Include `I` into the interval.
453 DAGInterval = DAGInterval.getUnionInterval(Other: {I, I});
454 auto *N = getOrCreateNode(I);
455 auto *MemN = dyn_cast<MemDGNode>(Val: N);
456
457 // Update the MemDGNode chain if this is a memory node.
458 if (MemN != nullptr) {
459 if (auto *PrevMemN = getMemDGNodeBefore(N: MemN, /*IncludingN=*/false)) {
460 PrevMemN->NextMemN = MemN;
461 MemN->PrevMemN = PrevMemN;
462 }
463 if (auto *NextMemN = getMemDGNodeAfter(N: MemN, /*IncludingN=*/false)) {
464 NextMemN->PrevMemN = MemN;
465 MemN->NextMemN = NextMemN;
466 }
467
468 // Add Mem dependencies.
469 // 1. Scan for deps above `I` for deps to `I`: AboveN->MemN.
470 if (DAGInterval.top()->comesBefore(Other: I)) {
471 Interval<Instruction> AboveIntvl(DAGInterval.top(), I->getPrevNode());
472 auto SrcInterval = MemDGNodeIntervalBuilder::make(Instrs: AboveIntvl, DAG&: *this);
473 scanAndAddDeps(DstN&: *MemN, SrcScanRange: SrcInterval);
474 }
475 // 2. Scan for deps below `I` for deps from `I`: MemN->BelowN.
476 if (I->comesBefore(Other: DAGInterval.bottom())) {
477 Interval<Instruction> BelowIntvl(I->getNextNode(), DAGInterval.bottom());
478 for (MemDGNode &BelowN :
479 MemDGNodeIntervalBuilder::make(Instrs: BelowIntvl, DAG&: *this))
480 scanAndAddDeps(DstN&: BelowN, SrcScanRange: Interval<MemDGNode>(MemN, MemN));
481 }
482 }
483}
484
485void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
486 if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting)
487 // We don't maintain the DAG while reverting.
488 return;
489 // NOTE: This function runs before `I` moves to its new destination.
490 BasicBlock *BB = To.getNodeParent();
491 assert(!(To != BB->end() && &*To == I->getNextNode()) &&
492 !(To == BB->end() && std::next(I->getIterator()) == BB->end()) &&
493 "Should not have been called if destination is same as origin.");
494
495 // TODO: We can only handle fully internal movements within DAGInterval or at
496 // the borders, i.e., right before the top or right after the bottom.
497 assert(To.getNodeParent() == I->getParent() &&
498 "TODO: We don't support movement across BBs!");
499 assert(
500 (To == std::next(DAGInterval.bottom()->getIterator()) ||
501 (To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) ||
502 (To != BB->end() && DAGInterval.contains(&*To))) &&
503 "TODO: To should be either within the DAGInterval or right "
504 "before/after it.");
505
506 // Make a copy of the DAGInterval before we update it.
507 auto OrigDAGInterval = DAGInterval;
508
509 // Maintain the DAGInterval.
510 DAGInterval.notifyMoveInstr(I, BeforeIt: To);
511
512 // TODO: Perhaps check if this is legal by checking the dependencies?
513
514 // Update the MemDGNode chain to reflect the instr movement if necessary.
515 DGNode *N = getNodeOrNull(I);
516 if (N == nullptr)
517 return;
518 MemDGNode *MemN = dyn_cast<MemDGNode>(Val: N);
519 if (MemN == nullptr)
520 return;
521
522 // First safely detach it from the existing chain.
523 MemN->detachFromChain();
524
525 // Now insert it back into the chain at the new location.
526 //
527 // We won't always have a DGNode to insert before it. If `To` is BB->end() or
528 // if it points to an instr after DAGInterval.bottom() then we will have to
529 // find a node to insert *after*.
530 //
531 // BB: BB:
532 // I1 I1 ^
533 // I2 I2 | DAGInteval [I1 to I3]
534 // I3 I3 V
535 // I4 I4 <- `To` == right after DAGInterval
536 // <- `To` == BB->end()
537 //
538 if (To == BB->end() ||
539 To == std::next(x: OrigDAGInterval.bottom()->getIterator())) {
540 // If we don't have a node to insert before, find a node to insert after and
541 // update the chain.
542 DGNode *InsertAfterN = getNode(I: &*std::prev(x: To));
543 MemN->setPrevNode(
544 getMemDGNodeBefore(N: InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN));
545 } else {
546 // We have a node to insert before, so update the chain.
547 DGNode *BeforeToN = getNode(I: &*To);
548 MemN->setPrevNode(
549 getMemDGNodeBefore(N: BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN));
550 MemN->setNextNode(
551 getMemDGNodeAfter(N: BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN));
552 }
553}
554
555void DependencyGraph::notifyEraseInstr(Instruction *I) {
556 if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting)
557 // We don't maintain the DAG while reverting.
558 return;
559 auto *N = getNode(I);
560 if (N == nullptr)
561 // Early return if there is no DAG node for `I`.
562 return;
563 if (auto *MemN = dyn_cast<MemDGNode>(Val: getNode(I))) {
564 // Update the MemDGNode chain if this is a memory node.
565 auto *PrevMemN = getMemDGNodeBefore(N: MemN, /*IncludingN=*/false);
566 auto *NextMemN = getMemDGNodeAfter(N: MemN, /*IncludingN=*/false);
567 if (PrevMemN != nullptr)
568 PrevMemN->NextMemN = NextMemN;
569 if (NextMemN != nullptr)
570 NextMemN->PrevMemN = PrevMemN;
571
572 // Drop the memory dependencies from both predecessors and successors.
573 while (!MemN->memPreds().empty()) {
574 auto *PredN = *MemN->memPreds().begin();
575 MemN->removeMemPred(PredN);
576 }
577 while (!MemN->memSuccs().empty()) {
578 auto *SuccN = *MemN->memSuccs().begin();
579 SuccN->removeMemPred(PredN: MemN);
580 }
581 // NOTE: The unscheduled succs for MemNodes get updated be setMemPred().
582 } else {
583 // If this is a non-mem node we only need to update UnscheduledSuccs.
584 if (!N->scheduled()) {
585 for (auto *PredN : N->preds(DAG&: *this))
586 PredN->decrUnscheduledSuccs();
587 for (auto *SuccN : N->succs(DAG&: *this))
588 SuccN->decrUnscheduledPreds();
589 }
590 }
591 // Finally erase the Node.
592 InstrToNodeMap.erase(Val: I);
593}
594
595void DependencyGraph::notifySetUse(const Use &U, Value *NewSrc) {
596 // If U.User is not in the DAG, then we should not attempt to decrement
597 // CurrSrcN's unscheduled successors.
598 // ------- ------- -
599 // CurrSrc | DAG interval
600 // | NewSrc |
601 // ---|--- ---|--- -
602 // U.User U.User
603 auto *UserI = dyn_cast_or_null<Instruction>(Val: U.getUser());
604 if (UserI == nullptr)
605 return;
606 auto *UserN = getNode(I: UserI);
607 if (UserN == nullptr)
608 return;
609 // If UserN is marked as scheduled then we should not update CrrSrcN' or
610 // NewSrcN's unscheduled successors.
611 if (UserN->scheduled())
612 return;
613 // Update the UnscheduledSuccs counter for both the current source and
614 // NewSrc if needed.
615 if (auto *CurrSrcI = dyn_cast<Instruction>(Val: U.get())) {
616 if (auto *CurrSrcN = getNode(I: CurrSrcI)) {
617 // If CurrSrcN is scheduled there is no point in updating UnscheduleSuccs.
618 if (!CurrSrcN->scheduled()) {
619 CurrSrcN->decrUnscheduledSuccs();
620 UserN->decrUnscheduledPreds();
621 }
622 }
623 }
624 if (auto *NewSrcI = dyn_cast<Instruction>(Val: NewSrc)) {
625 if (auto *NewSrcN = getNode(I: NewSrcI)) {
626 // If CurrSrcN is scheduled there is no point in updating UnscheduleSuccs.
627 if (!NewSrcN->scheduled()) {
628 NewSrcN->incrUnscheduledSuccs();
629 UserN->incrUnscheduledPreds();
630 }
631 }
632 }
633}
634
635Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
636 if (Instrs.empty())
637 return {};
638
639 Interval<Instruction> InstrsInterval(Instrs);
640 Interval<Instruction> Union = DAGInterval.getUnionInterval(Other: InstrsInterval);
641 auto NewInterval = Union.getSingleDiff(Other: DAGInterval);
642 if (NewInterval.empty())
643 return {};
644
645 createNewNodes(NewInterval);
646
647 // Create the dependencies.
648 //
649 // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval.
650 // +---+ - -
651 // | | SrcN | |
652 // | | | | SrcRange |
653 // |New| v | | DstRange
654 // | | DstN - |
655 // | | |
656 // +---+ -
657 // We are scanning for deps with destination in NewInterval and sources in
658 // NewInterval until DstN, for each DstN.
659 auto FullScan = [this](const Interval<Instruction> Intvl) {
660 auto DstRange = MemDGNodeIntervalBuilder::make(Instrs: Intvl, DAG&: *this);
661 if (!DstRange.empty()) {
662 for (MemDGNode &DstN : drop_begin(RangeOrContainer&: DstRange)) {
663 auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
664 scanAndAddDeps(DstN, SrcScanRange: SrcRange);
665 }
666 }
667 };
668 auto MemDAGInterval = MemDGNodeIntervalBuilder::make(Instrs: DAGInterval, DAG&: *this);
669 if (MemDAGInterval.empty()) {
670 FullScan(NewInterval);
671 }
672 // 2. The new section is below the old section.
673 // +---+ -
674 // | | |
675 // |Old| SrcN |
676 // | | | |
677 // +---+ | | SrcRange
678 // +---+ | | -
679 // | | | | |
680 // |New| v | | DstRange
681 // | | DstN - |
682 // | | |
683 // +---+ -
684 // We are scanning for deps with destination in NewInterval because the deps
685 // in DAGInterval have already been computed. We consider sources in the whole
686 // range including both NewInterval and DAGInterval until DstN, for each DstN.
687 else if (DAGInterval.bottom()->comesBefore(Other: NewInterval.top())) {
688 auto DstRange = MemDGNodeIntervalBuilder::make(Instrs: NewInterval, DAG&: *this);
689 auto SrcRangeFull = MemDAGInterval.getUnionInterval(Other: DstRange);
690 for (MemDGNode &DstN : DstRange) {
691 auto SrcRange =
692 Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
693 scanAndAddDeps(DstN, SrcScanRange: SrcRange);
694 }
695 }
696 // 3. The new section is above the old section.
697 else if (NewInterval.bottom()->comesBefore(Other: DAGInterval.top())) {
698 // +---+ - -
699 // | | SrcN | |
700 // |New| | | SrcRange | DstRange
701 // | | v | |
702 // | | DstN - |
703 // | | |
704 // +---+ -
705 // +---+
706 // |Old|
707 // | |
708 // +---+
709 // When scanning for deps with destination in NewInterval we need to fully
710 // scan the interval. This is the same as the scanning for a new DAG.
711 FullScan(NewInterval);
712
713 // +---+ -
714 // | | |
715 // |New| SrcN | SrcRange
716 // | | | |
717 // | | | |
718 // | | | |
719 // +---+ | -
720 // +---+ | -
721 // |Old| v | DstRange
722 // | | DstN |
723 // +---+ -
724 // When scanning for deps with destination in DAGInterval we need to
725 // consider sources from the NewInterval only, because all intra-DAGInterval
726 // dependencies have already been created.
727 auto DstRangeOld = MemDAGInterval;
728 auto SrcRange = MemDGNodeIntervalBuilder::make(Instrs: NewInterval, DAG&: *this);
729 for (MemDGNode &DstN : DstRangeOld)
730 scanAndAddDeps(DstN, SrcScanRange: SrcRange);
731 } else {
732 llvm_unreachable("We don't expect extending in both directions!");
733 }
734
735 DAGInterval = Union;
736 return NewInterval;
737}
738
739#ifndef NDEBUG
740void DependencyGraph::print(raw_ostream &OS) const {
741 // InstrToNodeMap is unordered so we need to create an ordered vector.
742 SmallVector<DGNode *> Nodes;
743 Nodes.reserve(InstrToNodeMap.size());
744 for (const auto &Pair : InstrToNodeMap)
745 Nodes.push_back(Pair.second.get());
746 // Sort them based on which one comes first in the BB.
747 sort(Nodes, [](DGNode *N1, DGNode *N2) {
748 return N1->getInstruction()->comesBefore(N2->getInstruction());
749 });
750 for (auto *N : Nodes)
751 N->print(OS, /*PrintDeps=*/true);
752}
753
754void DependencyGraph::dump() const {
755 print(dbgs());
756 dbgs() << "\n";
757}
758#endif // NDEBUG
759
760} // namespace llvm::sandboxir
761