| 1 | //===- DependencyGraph.cpp ------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" |
| 10 | #include "llvm/ADT/ArrayRef.h" |
| 11 | #include "llvm/SandboxIR/Instruction.h" |
| 12 | #include "llvm/SandboxIR/Utils.h" |
| 13 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h" |
| 14 | |
| 15 | namespace llvm::sandboxir { |
| 16 | |
| 17 | User::op_iterator PredIterator::skipBadIt(User::op_iterator OpIt, |
| 18 | User::op_iterator OpItE, |
| 19 | const DependencyGraph &DAG) { |
| 20 | auto Skip = [&DAG](auto OpIt) { |
| 21 | auto *I = dyn_cast<Instruction>((*OpIt).get()); |
| 22 | return I == nullptr || DAG.getNode(I) == nullptr; |
| 23 | }; |
| 24 | while (OpIt != OpItE && Skip(OpIt)) |
| 25 | ++OpIt; |
| 26 | return OpIt; |
| 27 | } |
| 28 | |
| 29 | PredIterator::value_type PredIterator::operator*() { |
| 30 | // If it's a DGNode then we dereference the operand iterator. |
| 31 | if (!isa<MemDGNode>(Val: N)) { |
| 32 | assert(OpIt != OpItE && "Can't dereference end iterator!" ); |
| 33 | return DAG->getNode(I: cast<Instruction>(Val: (Value *)*OpIt)); |
| 34 | } |
| 35 | // It's a MemDGNode, so we check if we return either the use-def operand, |
| 36 | // or a mem predecessor. |
| 37 | if (OpIt != OpItE) |
| 38 | return DAG->getNode(I: cast<Instruction>(Val: (Value *)*OpIt)); |
| 39 | // It's a MemDGNode with OpIt == end, so we need to use MemIt. |
| 40 | assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && |
| 41 | "Cant' dereference end iterator!" ); |
| 42 | return *MemIt; |
| 43 | } |
| 44 | |
| 45 | PredIterator &PredIterator::operator++() { |
| 46 | // If it's a DGNode then we increment the use-def iterator. |
| 47 | if (!isa<MemDGNode>(Val: N)) { |
| 48 | assert(OpIt != OpItE && "Already at end!" ); |
| 49 | ++OpIt; |
| 50 | // Skip operands that are not instructions or are outside the DAG. |
| 51 | OpIt = PredIterator::skipBadIt(OpIt, OpItE, DAG: *DAG); |
| 52 | return *this; |
| 53 | } |
| 54 | // It's a MemDGNode, so if we are not at the end of the use-def iterator we |
| 55 | // need to first increment that. |
| 56 | if (OpIt != OpItE) { |
| 57 | ++OpIt; |
| 58 | // Skip operands that are not instructions or are outside the DAG. |
| 59 | OpIt = PredIterator::skipBadIt(OpIt, OpItE, DAG: *DAG); |
| 60 | return *this; |
| 61 | } |
| 62 | // It's a MemDGNode with OpIt == end, so we need to increment MemIt. |
| 63 | assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!" ); |
| 64 | ++MemIt; |
| 65 | return *this; |
| 66 | } |
| 67 | |
| 68 | bool PredIterator::operator==(const PredIterator &Other) const { |
| 69 | assert(DAG == Other.DAG && "Iterators of different DAGs!" ); |
| 70 | assert(N == Other.N && "Iterators of different nodes!" ); |
| 71 | return OpIt == Other.OpIt && MemIt == Other.MemIt; |
| 72 | } |
| 73 | |
| 74 | void DGNode::setSchedBundle(SchedBundle &SB) { |
| 75 | if (this->SB != nullptr) |
| 76 | this->SB->eraseFromBundle(N: this); |
| 77 | this->SB = &SB; |
| 78 | } |
| 79 | |
| 80 | DGNode::~DGNode() { |
| 81 | if (SB == nullptr) |
| 82 | return; |
| 83 | SB->eraseFromBundle(N: this); |
| 84 | } |
| 85 | |
| 86 | #ifndef NDEBUG |
| 87 | void DGNode::print(raw_ostream &OS, bool PrintDeps) const { |
| 88 | OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n" ; |
| 89 | } |
| 90 | void DGNode::dump() const { print(dbgs()); } |
| 91 | void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const { |
| 92 | DGNode::print(OS, false); |
| 93 | if (PrintDeps) { |
| 94 | // Print memory preds. |
| 95 | static constexpr const unsigned Indent = 4; |
| 96 | for (auto *Pred : MemPreds) |
| 97 | OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n" ; |
| 98 | } |
| 99 | } |
| 100 | #endif // NDEBUG |
| 101 | |
| 102 | MemDGNode * |
| 103 | MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl, |
| 104 | const DependencyGraph &DAG) { |
| 105 | Instruction *I = Intvl.top(); |
| 106 | Instruction *BeforeI = Intvl.bottom(); |
| 107 | // Walk down the chain looking for a mem-dep candidate instruction. |
| 108 | while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI) |
| 109 | I = I->getNextNode(); |
| 110 | if (!DGNode::isMemDepNodeCandidate(I)) |
| 111 | return nullptr; |
| 112 | return cast<MemDGNode>(Val: DAG.getNode(I)); |
| 113 | } |
| 114 | |
| 115 | MemDGNode * |
| 116 | MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl, |
| 117 | const DependencyGraph &DAG) { |
| 118 | Instruction *I = Intvl.bottom(); |
| 119 | Instruction *AfterI = Intvl.top(); |
| 120 | // Walk up the chain looking for a mem-dep candidate instruction. |
| 121 | while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI) |
| 122 | I = I->getPrevNode(); |
| 123 | if (!DGNode::isMemDepNodeCandidate(I)) |
| 124 | return nullptr; |
| 125 | return cast<MemDGNode>(Val: DAG.getNode(I)); |
| 126 | } |
| 127 | |
| 128 | Interval<MemDGNode> |
| 129 | MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs, |
| 130 | DependencyGraph &DAG) { |
| 131 | if (Instrs.empty()) |
| 132 | return {}; |
| 133 | auto *TopMemN = getTopMemDGNode(Intvl: Instrs, DAG); |
| 134 | // If we couldn't find a mem node in range TopN - BotN then it's empty. |
| 135 | if (TopMemN == nullptr) |
| 136 | return {}; |
| 137 | auto *BotMemN = getBotMemDGNode(Intvl: Instrs, DAG); |
| 138 | assert(BotMemN != nullptr && "TopMemN should be null too!" ); |
| 139 | // Now that we have the mem-dep nodes, create and return the range. |
| 140 | return Interval<MemDGNode>(TopMemN, BotMemN); |
| 141 | } |
| 142 | |
| 143 | DependencyGraph::DependencyType |
| 144 | DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) { |
| 145 | // TODO: Perhaps compile-time improvement by skipping if neither is mem? |
| 146 | if (FromI->mayWriteToMemory()) { |
| 147 | if (ToI->mayReadFromMemory()) |
| 148 | return DependencyType::ReadAfterWrite; |
| 149 | if (ToI->mayWriteToMemory()) |
| 150 | return DependencyType::WriteAfterWrite; |
| 151 | } else if (FromI->mayReadFromMemory()) { |
| 152 | if (ToI->mayWriteToMemory()) |
| 153 | return DependencyType::WriteAfterRead; |
| 154 | } |
| 155 | if (isa<sandboxir::PHINode>(Val: FromI) || isa<sandboxir::PHINode>(Val: ToI)) |
| 156 | return DependencyType::Control; |
| 157 | if (ToI->isTerminator()) |
| 158 | return DependencyType::Control; |
| 159 | if (DGNode::isStackSaveOrRestoreIntrinsic(I: FromI) || |
| 160 | DGNode::isStackSaveOrRestoreIntrinsic(I: ToI)) |
| 161 | return DependencyType::Other; |
| 162 | return DependencyType::None; |
| 163 | } |
| 164 | |
| 165 | static bool isOrdered(Instruction *I) { |
| 166 | auto IsOrdered = [](Instruction *I) { |
| 167 | if (auto *LI = dyn_cast<LoadInst>(Val: I)) |
| 168 | return !LI->isUnordered(); |
| 169 | if (auto *SI = dyn_cast<StoreInst>(Val: I)) |
| 170 | return !SI->isUnordered(); |
| 171 | if (DGNode::isFenceLike(I)) |
| 172 | return true; |
| 173 | return false; |
| 174 | }; |
| 175 | bool Is = IsOrdered(I); |
| 176 | assert((!Is || DGNode::isMemDepCandidate(I)) && |
| 177 | "An ordered instruction must be a MemDepCandidate!" ); |
| 178 | return Is; |
| 179 | } |
| 180 | |
| 181 | bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI, |
| 182 | DependencyType DepType) { |
| 183 | std::optional<MemoryLocation> DstLocOpt = |
| 184 | Utils::memoryLocationGetOrNone(I: DstI); |
| 185 | if (!DstLocOpt) |
| 186 | return true; |
| 187 | // Check aliasing. |
| 188 | assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) && |
| 189 | "Expected a mem instr" ); |
| 190 | // TODO: Check AABudget |
| 191 | ModRefInfo SrcModRef = |
| 192 | isOrdered(I: SrcI) |
| 193 | ? ModRefInfo::ModRef |
| 194 | : Utils::aliasAnalysisGetModRefInfo(BatchAA&: *BatchAA, I: SrcI, OptLoc: *DstLocOpt); |
| 195 | switch (DepType) { |
| 196 | case DependencyType::ReadAfterWrite: |
| 197 | case DependencyType::WriteAfterWrite: |
| 198 | return isModSet(MRI: SrcModRef); |
| 199 | case DependencyType::WriteAfterRead: |
| 200 | return isRefSet(MRI: SrcModRef); |
| 201 | default: |
| 202 | llvm_unreachable("Expected only RAW, WAW and WAR!" ); |
| 203 | } |
| 204 | } |
| 205 | |
| 206 | bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) { |
| 207 | DependencyType RoughDepType = getRoughDepType(FromI: SrcI, ToI: DstI); |
| 208 | switch (RoughDepType) { |
| 209 | case DependencyType::ReadAfterWrite: |
| 210 | case DependencyType::WriteAfterWrite: |
| 211 | case DependencyType::WriteAfterRead: |
| 212 | return alias(SrcI, DstI, DepType: RoughDepType); |
| 213 | case DependencyType::Control: |
| 214 | // Adding actual dep edges from PHIs/to terminator would just create too |
| 215 | // many edges, which would be bad for compile-time. |
| 216 | // So we ignore them in the DAG formation but handle them in the |
| 217 | // scheduler, while sorting the ready list. |
| 218 | return false; |
| 219 | case DependencyType::Other: |
| 220 | return true; |
| 221 | case DependencyType::None: |
| 222 | return false; |
| 223 | } |
| 224 | llvm_unreachable("Unknown DependencyType enum" ); |
| 225 | } |
| 226 | |
| 227 | void DependencyGraph::scanAndAddDeps(MemDGNode &DstN, |
| 228 | const Interval<MemDGNode> &SrcScanRange) { |
| 229 | assert(isa<MemDGNode>(DstN) && |
| 230 | "DstN is the mem dep destination, so it must be mem" ); |
| 231 | Instruction *DstI = DstN.getInstruction(); |
| 232 | // Walk up the instruction chain from ScanRange bottom to top, looking for |
| 233 | // memory instrs that may alias. |
| 234 | for (MemDGNode &SrcN : reverse(C: SrcScanRange)) { |
| 235 | Instruction *SrcI = SrcN.getInstruction(); |
| 236 | if (hasDep(SrcI, DstI)) |
| 237 | DstN.addMemPred(PredN: &SrcN); |
| 238 | } |
| 239 | } |
| 240 | |
| 241 | void DependencyGraph::setDefUseUnscheduledSuccs( |
| 242 | const Interval<Instruction> &NewInterval) { |
| 243 | // +---+ |
| 244 | // | | Def |
| 245 | // | | | |
| 246 | // | | v |
| 247 | // | | Use |
| 248 | // +---+ |
| 249 | // Set the intra-interval counters in NewInterval. |
| 250 | for (Instruction &I : NewInterval) { |
| 251 | for (Value *Op : I.operands()) { |
| 252 | auto *OpI = dyn_cast<Instruction>(Val: Op); |
| 253 | if (OpI == nullptr) |
| 254 | continue; |
| 255 | // TODO: For now don't cross BBs. |
| 256 | if (OpI->getParent() != I.getParent()) |
| 257 | continue; |
| 258 | if (!NewInterval.contains(I: OpI)) |
| 259 | continue; |
| 260 | auto *OpN = getNode(I: OpI); |
| 261 | if (OpN == nullptr) |
| 262 | continue; |
| 263 | ++OpN->UnscheduledSuccs; |
| 264 | } |
| 265 | } |
| 266 | |
| 267 | // Now handle the cross-interval edges. |
| 268 | bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(Other: DAGInterval); |
| 269 | const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; |
| 270 | const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; |
| 271 | // +---+ |
| 272 | // |Top| |
| 273 | // | | Def |
| 274 | // +---+ | |
| 275 | // | | v |
| 276 | // |Bot| Use |
| 277 | // | | |
| 278 | // +---+ |
| 279 | // Walk over all instructions in "BotInterval" and update the counter |
| 280 | // of operands that are in "TopInterval". |
| 281 | for (Instruction &BotI : BotInterval) { |
| 282 | auto *BotN = getNode(I: &BotI); |
| 283 | // Skip scheduled nodes. |
| 284 | if (BotN->scheduled()) |
| 285 | continue; |
| 286 | for (Value *Op : BotI.operands()) { |
| 287 | auto *OpI = dyn_cast<Instruction>(Val: Op); |
| 288 | if (OpI == nullptr) |
| 289 | continue; |
| 290 | auto *OpN = getNode(I: OpI); |
| 291 | if (OpN == nullptr) |
| 292 | continue; |
| 293 | if (!TopInterval.contains(I: OpI)) |
| 294 | continue; |
| 295 | ++OpN->UnscheduledSuccs; |
| 296 | } |
| 297 | } |
| 298 | } |
| 299 | |
| 300 | void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) { |
| 301 | // Create Nodes only for the new sections of the DAG. |
| 302 | DGNode *LastN = getOrCreateNode(I: NewInterval.top()); |
| 303 | MemDGNode *LastMemN = dyn_cast<MemDGNode>(Val: LastN); |
| 304 | for (Instruction &I : drop_begin(RangeOrContainer: NewInterval)) { |
| 305 | auto *N = getOrCreateNode(I: &I); |
| 306 | // Build the Mem node chain. |
| 307 | if (auto *MemN = dyn_cast<MemDGNode>(Val: N)) { |
| 308 | MemN->setPrevNode(LastMemN); |
| 309 | LastMemN = MemN; |
| 310 | } |
| 311 | } |
| 312 | // Link new MemDGNode chain with the old one, if any. |
| 313 | if (!DAGInterval.empty()) { |
| 314 | bool NewIsAbove = NewInterval.comesBefore(Other: DAGInterval); |
| 315 | const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; |
| 316 | const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; |
| 317 | MemDGNode *LinkTopN = |
| 318 | MemDGNodeIntervalBuilder::getBotMemDGNode(Intvl: TopInterval, DAG: *this); |
| 319 | MemDGNode *LinkBotN = |
| 320 | MemDGNodeIntervalBuilder::getTopMemDGNode(Intvl: BotInterval, DAG: *this); |
| 321 | assert((LinkTopN == nullptr || LinkBotN == nullptr || |
| 322 | LinkTopN->comesBefore(LinkBotN)) && |
| 323 | "Wrong order!" ); |
| 324 | if (LinkTopN != nullptr && LinkBotN != nullptr) { |
| 325 | LinkTopN->setNextNode(LinkBotN); |
| 326 | } |
| 327 | #ifndef NDEBUG |
| 328 | // TODO: Remove this once we've done enough testing. |
| 329 | // Check that the chain is well formed. |
| 330 | auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval); |
| 331 | MemDGNode *ChainTopN = |
| 332 | MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this); |
| 333 | MemDGNode *ChainBotN = |
| 334 | MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this); |
| 335 | if (ChainTopN != nullptr && ChainBotN != nullptr) { |
| 336 | for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr; |
| 337 | LastN = N, N = N->getNextNode()) { |
| 338 | assert(N == LastN->getNextNode() && "Bad chain!" ); |
| 339 | assert(N->getPrevNode() == LastN && "Bad chain!" ); |
| 340 | } |
| 341 | } |
| 342 | #endif // NDEBUG |
| 343 | } |
| 344 | |
| 345 | setDefUseUnscheduledSuccs(NewInterval); |
| 346 | } |
| 347 | |
| 348 | MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN, |
| 349 | MemDGNode *SkipN) const { |
| 350 | auto *I = N->getInstruction(); |
| 351 | for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr; |
| 352 | PrevI = PrevI->getPrevNode()) { |
| 353 | auto *PrevN = getNodeOrNull(I: PrevI); |
| 354 | if (PrevN == nullptr) |
| 355 | return nullptr; |
| 356 | auto *PrevMemN = dyn_cast<MemDGNode>(Val: PrevN); |
| 357 | if (PrevMemN != nullptr && PrevMemN != SkipN) |
| 358 | return PrevMemN; |
| 359 | } |
| 360 | return nullptr; |
| 361 | } |
| 362 | |
| 363 | MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN, |
| 364 | MemDGNode *SkipN) const { |
| 365 | auto *I = N->getInstruction(); |
| 366 | for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr; |
| 367 | NextI = NextI->getNextNode()) { |
| 368 | auto *NextN = getNodeOrNull(I: NextI); |
| 369 | if (NextN == nullptr) |
| 370 | return nullptr; |
| 371 | auto *NextMemN = dyn_cast<MemDGNode>(Val: NextN); |
| 372 | if (NextMemN != nullptr && NextMemN != SkipN) |
| 373 | return NextMemN; |
| 374 | } |
| 375 | return nullptr; |
| 376 | } |
| 377 | |
| 378 | void DependencyGraph::notifyCreateInstr(Instruction *I) { |
| 379 | if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting) |
| 380 | // We don't maintain the DAG while reverting. |
| 381 | return; |
| 382 | // Nothing to do if the node is not in the focus range of the DAG. |
| 383 | if (!(DAGInterval.contains(I) || DAGInterval.touches(Elm: I))) |
| 384 | return; |
| 385 | // Include `I` into the interval. |
| 386 | DAGInterval = DAGInterval.getUnionInterval(Other: {I, I}); |
| 387 | auto *N = getOrCreateNode(I); |
| 388 | auto *MemN = dyn_cast<MemDGNode>(Val: N); |
| 389 | |
| 390 | // Update the MemDGNode chain if this is a memory node. |
| 391 | if (MemN != nullptr) { |
| 392 | if (auto *PrevMemN = getMemDGNodeBefore(N: MemN, /*IncludingN=*/false)) { |
| 393 | PrevMemN->NextMemN = MemN; |
| 394 | MemN->PrevMemN = PrevMemN; |
| 395 | } |
| 396 | if (auto *NextMemN = getMemDGNodeAfter(N: MemN, /*IncludingN=*/false)) { |
| 397 | NextMemN->PrevMemN = MemN; |
| 398 | MemN->NextMemN = NextMemN; |
| 399 | } |
| 400 | |
| 401 | // Add Mem dependencies. |
| 402 | // 1. Scan for deps above `I` for deps to `I`: AboveN->MemN. |
| 403 | if (DAGInterval.top()->comesBefore(Other: I)) { |
| 404 | Interval<Instruction> AboveIntvl(DAGInterval.top(), I->getPrevNode()); |
| 405 | auto SrcInterval = MemDGNodeIntervalBuilder::make(Instrs: AboveIntvl, DAG&: *this); |
| 406 | scanAndAddDeps(DstN&: *MemN, SrcScanRange: SrcInterval); |
| 407 | } |
| 408 | // 2. Scan for deps below `I` for deps from `I`: MemN->BelowN. |
| 409 | if (I->comesBefore(Other: DAGInterval.bottom())) { |
| 410 | Interval<Instruction> BelowIntvl(I->getNextNode(), DAGInterval.bottom()); |
| 411 | for (MemDGNode &BelowN : |
| 412 | MemDGNodeIntervalBuilder::make(Instrs: BelowIntvl, DAG&: *this)) |
| 413 | scanAndAddDeps(DstN&: BelowN, SrcScanRange: Interval<MemDGNode>(MemN, MemN)); |
| 414 | } |
| 415 | } |
| 416 | } |
| 417 | |
| 418 | void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) { |
| 419 | if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting) |
| 420 | // We don't maintain the DAG while reverting. |
| 421 | return; |
| 422 | // NOTE: This function runs before `I` moves to its new destination. |
| 423 | BasicBlock *BB = To.getNodeParent(); |
| 424 | assert(!(To != BB->end() && &*To == I->getNextNode()) && |
| 425 | !(To == BB->end() && std::next(I->getIterator()) == BB->end()) && |
| 426 | "Should not have been called if destination is same as origin." ); |
| 427 | |
| 428 | // TODO: We can only handle fully internal movements within DAGInterval or at |
| 429 | // the borders, i.e., right before the top or right after the bottom. |
| 430 | assert(To.getNodeParent() == I->getParent() && |
| 431 | "TODO: We don't support movement across BBs!" ); |
| 432 | assert( |
| 433 | (To == std::next(DAGInterval.bottom()->getIterator()) || |
| 434 | (To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) || |
| 435 | (To != BB->end() && DAGInterval.contains(&*To))) && |
| 436 | "TODO: To should be either within the DAGInterval or right " |
| 437 | "before/after it." ); |
| 438 | |
| 439 | // Make a copy of the DAGInterval before we update it. |
| 440 | auto OrigDAGInterval = DAGInterval; |
| 441 | |
| 442 | // Maintain the DAGInterval. |
| 443 | DAGInterval.notifyMoveInstr(I, BeforeIt: To); |
| 444 | |
| 445 | // TODO: Perhaps check if this is legal by checking the dependencies? |
| 446 | |
| 447 | // Update the MemDGNode chain to reflect the instr movement if necessary. |
| 448 | DGNode *N = getNodeOrNull(I); |
| 449 | if (N == nullptr) |
| 450 | return; |
| 451 | MemDGNode *MemN = dyn_cast<MemDGNode>(Val: N); |
| 452 | if (MemN == nullptr) |
| 453 | return; |
| 454 | |
| 455 | // First safely detach it from the existing chain. |
| 456 | MemN->detachFromChain(); |
| 457 | |
| 458 | // Now insert it back into the chain at the new location. |
| 459 | // |
| 460 | // We won't always have a DGNode to insert before it. If `To` is BB->end() or |
| 461 | // if it points to an instr after DAGInterval.bottom() then we will have to |
| 462 | // find a node to insert *after*. |
| 463 | // |
| 464 | // BB: BB: |
| 465 | // I1 I1 ^ |
| 466 | // I2 I2 | DAGInteval [I1 to I3] |
| 467 | // I3 I3 V |
| 468 | // I4 I4 <- `To` == right after DAGInterval |
| 469 | // <- `To` == BB->end() |
| 470 | // |
| 471 | if (To == BB->end() || |
| 472 | To == std::next(x: OrigDAGInterval.bottom()->getIterator())) { |
| 473 | // If we don't have a node to insert before, find a node to insert after and |
| 474 | // update the chain. |
| 475 | DGNode *InsertAfterN = getNode(I: &*std::prev(x: To)); |
| 476 | MemN->setPrevNode( |
| 477 | getMemDGNodeBefore(N: InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN)); |
| 478 | } else { |
| 479 | // We have a node to insert before, so update the chain. |
| 480 | DGNode *BeforeToN = getNode(I: &*To); |
| 481 | MemN->setPrevNode( |
| 482 | getMemDGNodeBefore(N: BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN)); |
| 483 | MemN->setNextNode( |
| 484 | getMemDGNodeAfter(N: BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN)); |
| 485 | } |
| 486 | } |
| 487 | |
| 488 | void DependencyGraph::notifyEraseInstr(Instruction *I) { |
| 489 | if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting) |
| 490 | // We don't maintain the DAG while reverting. |
| 491 | return; |
| 492 | auto *N = getNode(I); |
| 493 | if (N == nullptr) |
| 494 | // Early return if there is no DAG node for `I`. |
| 495 | return; |
| 496 | if (auto *MemN = dyn_cast<MemDGNode>(Val: getNode(I))) { |
| 497 | // Update the MemDGNode chain if this is a memory node. |
| 498 | auto *PrevMemN = getMemDGNodeBefore(N: MemN, /*IncludingN=*/false); |
| 499 | auto *NextMemN = getMemDGNodeAfter(N: MemN, /*IncludingN=*/false); |
| 500 | if (PrevMemN != nullptr) |
| 501 | PrevMemN->NextMemN = NextMemN; |
| 502 | if (NextMemN != nullptr) |
| 503 | NextMemN->PrevMemN = PrevMemN; |
| 504 | |
| 505 | // Drop the memory dependencies from both predecessors and successors. |
| 506 | while (!MemN->memPreds().empty()) { |
| 507 | auto *PredN = *MemN->memPreds().begin(); |
| 508 | MemN->removeMemPred(PredN); |
| 509 | } |
| 510 | while (!MemN->memSuccs().empty()) { |
| 511 | auto *SuccN = *MemN->memSuccs().begin(); |
| 512 | SuccN->removeMemPred(PredN: MemN); |
| 513 | } |
| 514 | // NOTE: The unscheduled succs for MemNodes get updated be setMemPred(). |
| 515 | } else { |
| 516 | // If this is a non-mem node we only need to update UnscheduledSuccs. |
| 517 | if (!N->scheduled()) |
| 518 | for (auto *PredN : N->preds(DAG&: *this)) |
| 519 | PredN->decrUnscheduledSuccs(); |
| 520 | } |
| 521 | // Finally erase the Node. |
| 522 | InstrToNodeMap.erase(Val: I); |
| 523 | } |
| 524 | |
| 525 | void DependencyGraph::notifySetUse(const Use &U, Value *NewSrc) { |
| 526 | // Update the UnscheduledSuccs counter for both the current source and NewSrc |
| 527 | // if needed. |
| 528 | if (auto *CurrSrcI = dyn_cast<Instruction>(Val: U.get())) { |
| 529 | if (auto *CurrSrcN = getNode(I: CurrSrcI)) { |
| 530 | CurrSrcN->decrUnscheduledSuccs(); |
| 531 | } |
| 532 | } |
| 533 | if (auto *NewSrcI = dyn_cast<Instruction>(Val: NewSrc)) { |
| 534 | if (auto *NewSrcN = getNode(I: NewSrcI)) { |
| 535 | ++NewSrcN->UnscheduledSuccs; |
| 536 | } |
| 537 | } |
| 538 | } |
| 539 | |
| 540 | Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) { |
| 541 | if (Instrs.empty()) |
| 542 | return {}; |
| 543 | |
| 544 | Interval<Instruction> InstrsInterval(Instrs); |
| 545 | Interval<Instruction> Union = DAGInterval.getUnionInterval(Other: InstrsInterval); |
| 546 | auto NewInterval = Union.getSingleDiff(Other: DAGInterval); |
| 547 | if (NewInterval.empty()) |
| 548 | return {}; |
| 549 | |
| 550 | createNewNodes(NewInterval); |
| 551 | |
| 552 | // Create the dependencies. |
| 553 | // |
| 554 | // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval. |
| 555 | // +---+ - - |
| 556 | // | | SrcN | | |
| 557 | // | | | | SrcRange | |
| 558 | // |New| v | | DstRange |
| 559 | // | | DstN - | |
| 560 | // | | | |
| 561 | // +---+ - |
| 562 | // We are scanning for deps with destination in NewInterval and sources in |
| 563 | // NewInterval until DstN, for each DstN. |
| 564 | auto FullScan = [this](const Interval<Instruction> Intvl) { |
| 565 | auto DstRange = MemDGNodeIntervalBuilder::make(Instrs: Intvl, DAG&: *this); |
| 566 | if (!DstRange.empty()) { |
| 567 | for (MemDGNode &DstN : drop_begin(RangeOrContainer&: DstRange)) { |
| 568 | auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode()); |
| 569 | scanAndAddDeps(DstN, SrcScanRange: SrcRange); |
| 570 | } |
| 571 | } |
| 572 | }; |
| 573 | auto MemDAGInterval = MemDGNodeIntervalBuilder::make(Instrs: DAGInterval, DAG&: *this); |
| 574 | if (MemDAGInterval.empty()) { |
| 575 | FullScan(NewInterval); |
| 576 | } |
| 577 | // 2. The new section is below the old section. |
| 578 | // +---+ - |
| 579 | // | | | |
| 580 | // |Old| SrcN | |
| 581 | // | | | | |
| 582 | // +---+ | | SrcRange |
| 583 | // +---+ | | - |
| 584 | // | | | | | |
| 585 | // |New| v | | DstRange |
| 586 | // | | DstN - | |
| 587 | // | | | |
| 588 | // +---+ - |
| 589 | // We are scanning for deps with destination in NewInterval because the deps |
| 590 | // in DAGInterval have already been computed. We consider sources in the whole |
| 591 | // range including both NewInterval and DAGInterval until DstN, for each DstN. |
| 592 | else if (DAGInterval.bottom()->comesBefore(Other: NewInterval.top())) { |
| 593 | auto DstRange = MemDGNodeIntervalBuilder::make(Instrs: NewInterval, DAG&: *this); |
| 594 | auto SrcRangeFull = MemDAGInterval.getUnionInterval(Other: DstRange); |
| 595 | for (MemDGNode &DstN : DstRange) { |
| 596 | auto SrcRange = |
| 597 | Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode()); |
| 598 | scanAndAddDeps(DstN, SrcScanRange: SrcRange); |
| 599 | } |
| 600 | } |
| 601 | // 3. The new section is above the old section. |
| 602 | else if (NewInterval.bottom()->comesBefore(Other: DAGInterval.top())) { |
| 603 | // +---+ - - |
| 604 | // | | SrcN | | |
| 605 | // |New| | | SrcRange | DstRange |
| 606 | // | | v | | |
| 607 | // | | DstN - | |
| 608 | // | | | |
| 609 | // +---+ - |
| 610 | // +---+ |
| 611 | // |Old| |
| 612 | // | | |
| 613 | // +---+ |
| 614 | // When scanning for deps with destination in NewInterval we need to fully |
| 615 | // scan the interval. This is the same as the scanning for a new DAG. |
| 616 | FullScan(NewInterval); |
| 617 | |
| 618 | // +---+ - |
| 619 | // | | | |
| 620 | // |New| SrcN | SrcRange |
| 621 | // | | | | |
| 622 | // | | | | |
| 623 | // | | | | |
| 624 | // +---+ | - |
| 625 | // +---+ | - |
| 626 | // |Old| v | DstRange |
| 627 | // | | DstN | |
| 628 | // +---+ - |
| 629 | // When scanning for deps with destination in DAGInterval we need to |
| 630 | // consider sources from the NewInterval only, because all intra-DAGInterval |
| 631 | // dependencies have already been created. |
| 632 | auto DstRangeOld = MemDAGInterval; |
| 633 | auto SrcRange = MemDGNodeIntervalBuilder::make(Instrs: NewInterval, DAG&: *this); |
| 634 | for (MemDGNode &DstN : DstRangeOld) |
| 635 | scanAndAddDeps(DstN, SrcScanRange: SrcRange); |
| 636 | } else { |
| 637 | llvm_unreachable("We don't expect extending in both directions!" ); |
| 638 | } |
| 639 | |
| 640 | DAGInterval = Union; |
| 641 | return NewInterval; |
| 642 | } |
| 643 | |
| 644 | #ifndef NDEBUG |
| 645 | void DependencyGraph::print(raw_ostream &OS) const { |
| 646 | // InstrToNodeMap is unordered so we need to create an ordered vector. |
| 647 | SmallVector<DGNode *> Nodes; |
| 648 | Nodes.reserve(InstrToNodeMap.size()); |
| 649 | for (const auto &Pair : InstrToNodeMap) |
| 650 | Nodes.push_back(Pair.second.get()); |
| 651 | // Sort them based on which one comes first in the BB. |
| 652 | sort(Nodes, [](DGNode *N1, DGNode *N2) { |
| 653 | return N1->getInstruction()->comesBefore(N2->getInstruction()); |
| 654 | }); |
| 655 | for (auto *N : Nodes) |
| 656 | N->print(OS, /*PrintDeps=*/true); |
| 657 | } |
| 658 | |
| 659 | void DependencyGraph::dump() const { |
| 660 | print(dbgs()); |
| 661 | dbgs() << "\n" ; |
| 662 | } |
| 663 | #endif // NDEBUG |
| 664 | |
| 665 | } // namespace llvm::sandboxir |
| 666 | |