| 1 | //===- DependencyGraph.cpp ------------------------------------------===// | 
|---|
| 2 | // | 
|---|
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|---|
| 4 | // See https://llvm.org/LICENSE.txt for license information. | 
|---|
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|---|
| 6 | // | 
|---|
| 7 | //===----------------------------------------------------------------------===// | 
|---|
| 8 |  | 
|---|
| 9 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" | 
|---|
| 10 | #include "llvm/ADT/ArrayRef.h" | 
|---|
| 11 | #include "llvm/SandboxIR/Instruction.h" | 
|---|
| 12 | #include "llvm/SandboxIR/Utils.h" | 
|---|
| 13 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h" | 
|---|
| 14 |  | 
|---|
| 15 | namespace llvm::sandboxir { | 
|---|
| 16 |  | 
|---|
| 17 | User::op_iterator PredIterator::skipBadIt(User::op_iterator OpIt, | 
|---|
| 18 | User::op_iterator OpItE, | 
|---|
| 19 | const DependencyGraph &DAG) { | 
|---|
| 20 | auto Skip = [&DAG](auto OpIt) { | 
|---|
| 21 | auto *I = dyn_cast<Instruction>((*OpIt).get()); | 
|---|
| 22 | return I == nullptr || DAG.getNode(I) == nullptr; | 
|---|
| 23 | }; | 
|---|
| 24 | while (OpIt != OpItE && Skip(OpIt)) | 
|---|
| 25 | ++OpIt; | 
|---|
| 26 | return OpIt; | 
|---|
| 27 | } | 
|---|
| 28 |  | 
|---|
| 29 | PredIterator::value_type PredIterator::operator*() { | 
|---|
| 30 | // If it's a DGNode then we dereference the operand iterator. | 
|---|
| 31 | if (!isa<MemDGNode>(Val: N)) { | 
|---|
| 32 | assert(OpIt != OpItE && "Can't dereference end iterator!"); | 
|---|
| 33 | return DAG->getNode(I: cast<Instruction>(Val: (Value *)*OpIt)); | 
|---|
| 34 | } | 
|---|
| 35 | // It's a MemDGNode, so we check if we return either the use-def operand, | 
|---|
| 36 | // or a mem predecessor. | 
|---|
| 37 | if (OpIt != OpItE) | 
|---|
| 38 | return DAG->getNode(I: cast<Instruction>(Val: (Value *)*OpIt)); | 
|---|
| 39 | // It's a MemDGNode with OpIt == end, so we need to use MemIt. | 
|---|
| 40 | assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && | 
|---|
| 41 | "Cant' dereference end iterator!"); | 
|---|
| 42 | return *MemIt; | 
|---|
| 43 | } | 
|---|
| 44 |  | 
|---|
| 45 | PredIterator &PredIterator::operator++() { | 
|---|
| 46 | // If it's a DGNode then we increment the use-def iterator. | 
|---|
| 47 | if (!isa<MemDGNode>(Val: N)) { | 
|---|
| 48 | assert(OpIt != OpItE && "Already at end!"); | 
|---|
| 49 | ++OpIt; | 
|---|
| 50 | // Skip operands that are not instructions or are outside the DAG. | 
|---|
| 51 | OpIt = PredIterator::skipBadIt(OpIt, OpItE, DAG: *DAG); | 
|---|
| 52 | return *this; | 
|---|
| 53 | } | 
|---|
| 54 | // It's a MemDGNode, so if we are not at the end of the use-def iterator we | 
|---|
| 55 | // need to first increment that. | 
|---|
| 56 | if (OpIt != OpItE) { | 
|---|
| 57 | ++OpIt; | 
|---|
| 58 | // Skip operands that are not instructions or are outside the DAG. | 
|---|
| 59 | OpIt = PredIterator::skipBadIt(OpIt, OpItE, DAG: *DAG); | 
|---|
| 60 | return *this; | 
|---|
| 61 | } | 
|---|
| 62 | // It's a MemDGNode with OpIt == end, so we need to increment MemIt. | 
|---|
| 63 | assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!"); | 
|---|
| 64 | ++MemIt; | 
|---|
| 65 | return *this; | 
|---|
| 66 | } | 
|---|
| 67 |  | 
|---|
| 68 | bool PredIterator::operator==(const PredIterator &Other) const { | 
|---|
| 69 | assert(DAG == Other.DAG && "Iterators of different DAGs!"); | 
|---|
| 70 | assert(N == Other.N && "Iterators of different nodes!"); | 
|---|
| 71 | return OpIt == Other.OpIt && MemIt == Other.MemIt; | 
|---|
| 72 | } | 
|---|
| 73 |  | 
|---|
| 74 | void DGNode::setSchedBundle(SchedBundle &SB) { | 
|---|
| 75 | if (this->SB != nullptr) | 
|---|
| 76 | this->SB->eraseFromBundle(N: this); | 
|---|
| 77 | this->SB = &SB; | 
|---|
| 78 | } | 
|---|
| 79 |  | 
|---|
| 80 | DGNode::~DGNode() { | 
|---|
| 81 | if (SB == nullptr) | 
|---|
| 82 | return; | 
|---|
| 83 | SB->eraseFromBundle(N: this); | 
|---|
| 84 | } | 
|---|
| 85 |  | 
|---|
| 86 | #ifndef NDEBUG | 
|---|
| 87 | void DGNode::print(raw_ostream &OS, bool PrintDeps) const { | 
|---|
| 88 | OS << *I << " USuccs:"<< UnscheduledSuccs << " Sched:"<< Scheduled << "\n"; | 
|---|
| 89 | } | 
|---|
| 90 | void DGNode::dump() const { print(dbgs()); } | 
|---|
| 91 | void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const { | 
|---|
| 92 | DGNode::print(OS, false); | 
|---|
| 93 | if (PrintDeps) { | 
|---|
| 94 | // Print memory preds. | 
|---|
| 95 | static constexpr const unsigned Indent = 4; | 
|---|
| 96 | for (auto *Pred : MemPreds) | 
|---|
| 97 | OS.indent(Indent) << "<-"<< *Pred->getInstruction() << "\n"; | 
|---|
| 98 | } | 
|---|
| 99 | } | 
|---|
| 100 | #endif // NDEBUG | 
|---|
| 101 |  | 
|---|
| 102 | MemDGNode * | 
|---|
| 103 | MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl, | 
|---|
| 104 | const DependencyGraph &DAG) { | 
|---|
| 105 | Instruction *I = Intvl.top(); | 
|---|
| 106 | Instruction *BeforeI = Intvl.bottom(); | 
|---|
| 107 | // Walk down the chain looking for a mem-dep candidate instruction. | 
|---|
| 108 | while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI) | 
|---|
| 109 | I = I->getNextNode(); | 
|---|
| 110 | if (!DGNode::isMemDepNodeCandidate(I)) | 
|---|
| 111 | return nullptr; | 
|---|
| 112 | return cast<MemDGNode>(Val: DAG.getNode(I)); | 
|---|
| 113 | } | 
|---|
| 114 |  | 
|---|
| 115 | MemDGNode * | 
|---|
| 116 | MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl, | 
|---|
| 117 | const DependencyGraph &DAG) { | 
|---|
| 118 | Instruction *I = Intvl.bottom(); | 
|---|
| 119 | Instruction *AfterI = Intvl.top(); | 
|---|
| 120 | // Walk up the chain looking for a mem-dep candidate instruction. | 
|---|
| 121 | while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI) | 
|---|
| 122 | I = I->getPrevNode(); | 
|---|
| 123 | if (!DGNode::isMemDepNodeCandidate(I)) | 
|---|
| 124 | return nullptr; | 
|---|
| 125 | return cast<MemDGNode>(Val: DAG.getNode(I)); | 
|---|
| 126 | } | 
|---|
| 127 |  | 
|---|
| 128 | Interval<MemDGNode> | 
|---|
| 129 | MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs, | 
|---|
| 130 | DependencyGraph &DAG) { | 
|---|
| 131 | if (Instrs.empty()) | 
|---|
| 132 | return {}; | 
|---|
| 133 | auto *TopMemN = getTopMemDGNode(Intvl: Instrs, DAG); | 
|---|
| 134 | // If we couldn't find a mem node in range TopN - BotN then it's empty. | 
|---|
| 135 | if (TopMemN == nullptr) | 
|---|
| 136 | return {}; | 
|---|
| 137 | auto *BotMemN = getBotMemDGNode(Intvl: Instrs, DAG); | 
|---|
| 138 | assert(BotMemN != nullptr && "TopMemN should be null too!"); | 
|---|
| 139 | // Now that we have the mem-dep nodes, create and return the range. | 
|---|
| 140 | return Interval<MemDGNode>(TopMemN, BotMemN); | 
|---|
| 141 | } | 
|---|
| 142 |  | 
|---|
| 143 | DependencyGraph::DependencyType | 
|---|
| 144 | DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) { | 
|---|
| 145 | // TODO: Perhaps compile-time improvement by skipping if neither is mem? | 
|---|
| 146 | if (FromI->mayWriteToMemory()) { | 
|---|
| 147 | if (ToI->mayReadFromMemory()) | 
|---|
| 148 | return DependencyType::ReadAfterWrite; | 
|---|
| 149 | if (ToI->mayWriteToMemory()) | 
|---|
| 150 | return DependencyType::WriteAfterWrite; | 
|---|
| 151 | } else if (FromI->mayReadFromMemory()) { | 
|---|
| 152 | if (ToI->mayWriteToMemory()) | 
|---|
| 153 | return DependencyType::WriteAfterRead; | 
|---|
| 154 | } | 
|---|
| 155 | if (isa<sandboxir::PHINode>(Val: FromI) || isa<sandboxir::PHINode>(Val: ToI)) | 
|---|
| 156 | return DependencyType::Control; | 
|---|
| 157 | if (ToI->isTerminator()) | 
|---|
| 158 | return DependencyType::Control; | 
|---|
| 159 | if (DGNode::isStackSaveOrRestoreIntrinsic(I: FromI) || | 
|---|
| 160 | DGNode::isStackSaveOrRestoreIntrinsic(I: ToI)) | 
|---|
| 161 | return DependencyType::Other; | 
|---|
| 162 | return DependencyType::None; | 
|---|
| 163 | } | 
|---|
| 164 |  | 
|---|
| 165 | static bool isOrdered(Instruction *I) { | 
|---|
| 166 | auto IsOrdered = [](Instruction *I) { | 
|---|
| 167 | if (auto *LI = dyn_cast<LoadInst>(Val: I)) | 
|---|
| 168 | return !LI->isUnordered(); | 
|---|
| 169 | if (auto *SI = dyn_cast<StoreInst>(Val: I)) | 
|---|
| 170 | return !SI->isUnordered(); | 
|---|
| 171 | if (DGNode::isFenceLike(I)) | 
|---|
| 172 | return true; | 
|---|
| 173 | return false; | 
|---|
| 174 | }; | 
|---|
| 175 | bool Is = IsOrdered(I); | 
|---|
| 176 | assert((!Is || DGNode::isMemDepCandidate(I)) && | 
|---|
| 177 | "An ordered instruction must be a MemDepCandidate!"); | 
|---|
| 178 | return Is; | 
|---|
| 179 | } | 
|---|
| 180 |  | 
|---|
| 181 | bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI, | 
|---|
| 182 | DependencyType DepType) { | 
|---|
| 183 | std::optional<MemoryLocation> DstLocOpt = | 
|---|
| 184 | Utils::memoryLocationGetOrNone(I: DstI); | 
|---|
| 185 | if (!DstLocOpt) | 
|---|
| 186 | return true; | 
|---|
| 187 | // Check aliasing. | 
|---|
| 188 | assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) && | 
|---|
| 189 | "Expected a mem instr"); | 
|---|
| 190 | // TODO: Check AABudget | 
|---|
| 191 | ModRefInfo SrcModRef = | 
|---|
| 192 | isOrdered(I: SrcI) | 
|---|
| 193 | ? ModRefInfo::ModRef | 
|---|
| 194 | : Utils::aliasAnalysisGetModRefInfo(BatchAA&: *BatchAA, I: SrcI, OptLoc: *DstLocOpt); | 
|---|
| 195 | switch (DepType) { | 
|---|
| 196 | case DependencyType::ReadAfterWrite: | 
|---|
| 197 | case DependencyType::WriteAfterWrite: | 
|---|
| 198 | return isModSet(MRI: SrcModRef); | 
|---|
| 199 | case DependencyType::WriteAfterRead: | 
|---|
| 200 | return isRefSet(MRI: SrcModRef); | 
|---|
| 201 | default: | 
|---|
| 202 | llvm_unreachable( "Expected only RAW, WAW and WAR!"); | 
|---|
| 203 | } | 
|---|
| 204 | } | 
|---|
| 205 |  | 
|---|
| 206 | bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) { | 
|---|
| 207 | DependencyType RoughDepType = getRoughDepType(FromI: SrcI, ToI: DstI); | 
|---|
| 208 | switch (RoughDepType) { | 
|---|
| 209 | case DependencyType::ReadAfterWrite: | 
|---|
| 210 | case DependencyType::WriteAfterWrite: | 
|---|
| 211 | case DependencyType::WriteAfterRead: | 
|---|
| 212 | return alias(SrcI, DstI, DepType: RoughDepType); | 
|---|
| 213 | case DependencyType::Control: | 
|---|
| 214 | // Adding actual dep edges from PHIs/to terminator would just create too | 
|---|
| 215 | // many edges, which would be bad for compile-time. | 
|---|
| 216 | // So we ignore them in the DAG formation but handle them in the | 
|---|
| 217 | // scheduler, while sorting the ready list. | 
|---|
| 218 | return false; | 
|---|
| 219 | case DependencyType::Other: | 
|---|
| 220 | return true; | 
|---|
| 221 | case DependencyType::None: | 
|---|
| 222 | return false; | 
|---|
| 223 | } | 
|---|
| 224 | llvm_unreachable( "Unknown DependencyType enum"); | 
|---|
| 225 | } | 
|---|
| 226 |  | 
|---|
| 227 | void DependencyGraph::scanAndAddDeps(MemDGNode &DstN, | 
|---|
| 228 | const Interval<MemDGNode> &SrcScanRange) { | 
|---|
| 229 | assert(isa<MemDGNode>(DstN) && | 
|---|
| 230 | "DstN is the mem dep destination, so it must be mem"); | 
|---|
| 231 | Instruction *DstI = DstN.getInstruction(); | 
|---|
| 232 | // Walk up the instruction chain from ScanRange bottom to top, looking for | 
|---|
| 233 | // memory instrs that may alias. | 
|---|
| 234 | for (MemDGNode &SrcN : reverse(C: SrcScanRange)) { | 
|---|
| 235 | Instruction *SrcI = SrcN.getInstruction(); | 
|---|
| 236 | if (hasDep(SrcI, DstI)) | 
|---|
| 237 | DstN.addMemPred(PredN: &SrcN); | 
|---|
| 238 | } | 
|---|
| 239 | } | 
|---|
| 240 |  | 
|---|
| 241 | void DependencyGraph::setDefUseUnscheduledSuccs( | 
|---|
| 242 | const Interval<Instruction> &NewInterval) { | 
|---|
| 243 | // +---+ | 
|---|
| 244 | // |   |  Def | 
|---|
| 245 | // |   |   | | 
|---|
| 246 | // |   |   v | 
|---|
| 247 | // |   |  Use | 
|---|
| 248 | // +---+ | 
|---|
| 249 | // Set the intra-interval counters in NewInterval. | 
|---|
| 250 | for (Instruction &I : NewInterval) { | 
|---|
| 251 | for (Value *Op : I.operands()) { | 
|---|
| 252 | auto *OpI = dyn_cast<Instruction>(Val: Op); | 
|---|
| 253 | if (OpI == nullptr) | 
|---|
| 254 | continue; | 
|---|
| 255 | // TODO: For now don't cross BBs. | 
|---|
| 256 | if (OpI->getParent() != I.getParent()) | 
|---|
| 257 | continue; | 
|---|
| 258 | if (!NewInterval.contains(I: OpI)) | 
|---|
| 259 | continue; | 
|---|
| 260 | auto *OpN = getNode(I: OpI); | 
|---|
| 261 | if (OpN == nullptr) | 
|---|
| 262 | continue; | 
|---|
| 263 | ++OpN->UnscheduledSuccs; | 
|---|
| 264 | } | 
|---|
| 265 | } | 
|---|
| 266 |  | 
|---|
| 267 | // Now handle the cross-interval edges. | 
|---|
| 268 | bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(Other: DAGInterval); | 
|---|
| 269 | const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; | 
|---|
| 270 | const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; | 
|---|
| 271 | // +---+ | 
|---|
| 272 | // |Top| | 
|---|
| 273 | // |   |  Def | 
|---|
| 274 | // +---+   | | 
|---|
| 275 | // |   |   v | 
|---|
| 276 | // |Bot|  Use | 
|---|
| 277 | // |   | | 
|---|
| 278 | // +---+ | 
|---|
| 279 | // Walk over all instructions in "BotInterval" and update the counter | 
|---|
| 280 | // of operands that are in "TopInterval". | 
|---|
| 281 | for (Instruction &BotI : BotInterval) { | 
|---|
| 282 | auto *BotN = getNode(I: &BotI); | 
|---|
| 283 | // Skip scheduled nodes. | 
|---|
| 284 | if (BotN->scheduled()) | 
|---|
| 285 | continue; | 
|---|
| 286 | for (Value *Op : BotI.operands()) { | 
|---|
| 287 | auto *OpI = dyn_cast<Instruction>(Val: Op); | 
|---|
| 288 | if (OpI == nullptr) | 
|---|
| 289 | continue; | 
|---|
| 290 | auto *OpN = getNode(I: OpI); | 
|---|
| 291 | if (OpN == nullptr) | 
|---|
| 292 | continue; | 
|---|
| 293 | if (!TopInterval.contains(I: OpI)) | 
|---|
| 294 | continue; | 
|---|
| 295 | ++OpN->UnscheduledSuccs; | 
|---|
| 296 | } | 
|---|
| 297 | } | 
|---|
| 298 | } | 
|---|
| 299 |  | 
|---|
| 300 | void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) { | 
|---|
| 301 | // Create Nodes only for the new sections of the DAG. | 
|---|
| 302 | DGNode *LastN = getOrCreateNode(I: NewInterval.top()); | 
|---|
| 303 | MemDGNode *LastMemN = dyn_cast<MemDGNode>(Val: LastN); | 
|---|
| 304 | for (Instruction &I : drop_begin(RangeOrContainer: NewInterval)) { | 
|---|
| 305 | auto *N = getOrCreateNode(I: &I); | 
|---|
| 306 | // Build the Mem node chain. | 
|---|
| 307 | if (auto *MemN = dyn_cast<MemDGNode>(Val: N)) { | 
|---|
| 308 | MemN->setPrevNode(LastMemN); | 
|---|
| 309 | LastMemN = MemN; | 
|---|
| 310 | } | 
|---|
| 311 | } | 
|---|
| 312 | // Link new MemDGNode chain with the old one, if any. | 
|---|
| 313 | if (!DAGInterval.empty()) { | 
|---|
| 314 | bool NewIsAbove = NewInterval.comesBefore(Other: DAGInterval); | 
|---|
| 315 | const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; | 
|---|
| 316 | const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; | 
|---|
| 317 | MemDGNode *LinkTopN = | 
|---|
| 318 | MemDGNodeIntervalBuilder::getBotMemDGNode(Intvl: TopInterval, DAG: *this); | 
|---|
| 319 | MemDGNode *LinkBotN = | 
|---|
| 320 | MemDGNodeIntervalBuilder::getTopMemDGNode(Intvl: BotInterval, DAG: *this); | 
|---|
| 321 | assert((LinkTopN == nullptr || LinkBotN == nullptr || | 
|---|
| 322 | LinkTopN->comesBefore(LinkBotN)) && | 
|---|
| 323 | "Wrong order!"); | 
|---|
| 324 | if (LinkTopN != nullptr && LinkBotN != nullptr) { | 
|---|
| 325 | LinkTopN->setNextNode(LinkBotN); | 
|---|
| 326 | } | 
|---|
| 327 | #ifndef NDEBUG | 
|---|
| 328 | // TODO: Remove this once we've done enough testing. | 
|---|
| 329 | // Check that the chain is well formed. | 
|---|
| 330 | auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval); | 
|---|
| 331 | MemDGNode *ChainTopN = | 
|---|
| 332 | MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this); | 
|---|
| 333 | MemDGNode *ChainBotN = | 
|---|
| 334 | MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this); | 
|---|
| 335 | if (ChainTopN != nullptr && ChainBotN != nullptr) { | 
|---|
| 336 | for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr; | 
|---|
| 337 | LastN = N, N = N->getNextNode()) { | 
|---|
| 338 | assert(N == LastN->getNextNode() && "Bad chain!"); | 
|---|
| 339 | assert(N->getPrevNode() == LastN && "Bad chain!"); | 
|---|
| 340 | } | 
|---|
| 341 | } | 
|---|
| 342 | #endif // NDEBUG | 
|---|
| 343 | } | 
|---|
| 344 |  | 
|---|
| 345 | setDefUseUnscheduledSuccs(NewInterval); | 
|---|
| 346 | } | 
|---|
| 347 |  | 
|---|
| 348 | MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN, | 
|---|
| 349 | MemDGNode *SkipN) const { | 
|---|
| 350 | auto *I = N->getInstruction(); | 
|---|
| 351 | for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr; | 
|---|
| 352 | PrevI = PrevI->getPrevNode()) { | 
|---|
| 353 | auto *PrevN = getNodeOrNull(I: PrevI); | 
|---|
| 354 | if (PrevN == nullptr) | 
|---|
| 355 | return nullptr; | 
|---|
| 356 | auto *PrevMemN = dyn_cast<MemDGNode>(Val: PrevN); | 
|---|
| 357 | if (PrevMemN != nullptr && PrevMemN != SkipN) | 
|---|
| 358 | return PrevMemN; | 
|---|
| 359 | } | 
|---|
| 360 | return nullptr; | 
|---|
| 361 | } | 
|---|
| 362 |  | 
|---|
| 363 | MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN, | 
|---|
| 364 | MemDGNode *SkipN) const { | 
|---|
| 365 | auto *I = N->getInstruction(); | 
|---|
| 366 | for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr; | 
|---|
| 367 | NextI = NextI->getNextNode()) { | 
|---|
| 368 | auto *NextN = getNodeOrNull(I: NextI); | 
|---|
| 369 | if (NextN == nullptr) | 
|---|
| 370 | return nullptr; | 
|---|
| 371 | auto *NextMemN = dyn_cast<MemDGNode>(Val: NextN); | 
|---|
| 372 | if (NextMemN != nullptr && NextMemN != SkipN) | 
|---|
| 373 | return NextMemN; | 
|---|
| 374 | } | 
|---|
| 375 | return nullptr; | 
|---|
| 376 | } | 
|---|
| 377 |  | 
|---|
| 378 | void DependencyGraph::notifyCreateInstr(Instruction *I) { | 
|---|
| 379 | if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting) | 
|---|
| 380 | // We don't maintain the DAG while reverting. | 
|---|
| 381 | return; | 
|---|
| 382 | // Nothing to do if the node is not in the focus range of the DAG. | 
|---|
| 383 | if (!(DAGInterval.contains(I) || DAGInterval.touches(Elm: I))) | 
|---|
| 384 | return; | 
|---|
| 385 | // Include `I` into the interval. | 
|---|
| 386 | DAGInterval = DAGInterval.getUnionInterval(Other: {I, I}); | 
|---|
| 387 | auto *N = getOrCreateNode(I); | 
|---|
| 388 | auto *MemN = dyn_cast<MemDGNode>(Val: N); | 
|---|
| 389 |  | 
|---|
| 390 | // Update the MemDGNode chain if this is a memory node. | 
|---|
| 391 | if (MemN != nullptr) { | 
|---|
| 392 | if (auto *PrevMemN = getMemDGNodeBefore(N: MemN, /*IncludingN=*/false)) { | 
|---|
| 393 | PrevMemN->NextMemN = MemN; | 
|---|
| 394 | MemN->PrevMemN = PrevMemN; | 
|---|
| 395 | } | 
|---|
| 396 | if (auto *NextMemN = getMemDGNodeAfter(N: MemN, /*IncludingN=*/false)) { | 
|---|
| 397 | NextMemN->PrevMemN = MemN; | 
|---|
| 398 | MemN->NextMemN = NextMemN; | 
|---|
| 399 | } | 
|---|
| 400 |  | 
|---|
| 401 | // Add Mem dependencies. | 
|---|
| 402 | // 1. Scan for deps above `I` for deps to `I`: AboveN->MemN. | 
|---|
| 403 | if (DAGInterval.top()->comesBefore(Other: I)) { | 
|---|
| 404 | Interval<Instruction> AboveIntvl(DAGInterval.top(), I->getPrevNode()); | 
|---|
| 405 | auto SrcInterval = MemDGNodeIntervalBuilder::make(Instrs: AboveIntvl, DAG&: *this); | 
|---|
| 406 | scanAndAddDeps(DstN&: *MemN, SrcScanRange: SrcInterval); | 
|---|
| 407 | } | 
|---|
| 408 | // 2. Scan for deps below `I` for deps from `I`: MemN->BelowN. | 
|---|
| 409 | if (I->comesBefore(Other: DAGInterval.bottom())) { | 
|---|
| 410 | Interval<Instruction> BelowIntvl(I->getNextNode(), DAGInterval.bottom()); | 
|---|
| 411 | for (MemDGNode &BelowN : | 
|---|
| 412 | MemDGNodeIntervalBuilder::make(Instrs: BelowIntvl, DAG&: *this)) | 
|---|
| 413 | scanAndAddDeps(DstN&: BelowN, SrcScanRange: Interval<MemDGNode>(MemN, MemN)); | 
|---|
| 414 | } | 
|---|
| 415 | } | 
|---|
| 416 | } | 
|---|
| 417 |  | 
|---|
| 418 | void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) { | 
|---|
| 419 | if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting) | 
|---|
| 420 | // We don't maintain the DAG while reverting. | 
|---|
| 421 | return; | 
|---|
| 422 | // NOTE: This function runs before `I` moves to its new destination. | 
|---|
| 423 | BasicBlock *BB = To.getNodeParent(); | 
|---|
| 424 | assert(!(To != BB->end() && &*To == I->getNextNode()) && | 
|---|
| 425 | !(To == BB->end() && std::next(I->getIterator()) == BB->end()) && | 
|---|
| 426 | "Should not have been called if destination is same as origin."); | 
|---|
| 427 |  | 
|---|
| 428 | // TODO: We can only handle fully internal movements within DAGInterval or at | 
|---|
| 429 | // the borders, i.e., right before the top or right after the bottom. | 
|---|
| 430 | assert(To.getNodeParent() == I->getParent() && | 
|---|
| 431 | "TODO: We don't support movement across BBs!"); | 
|---|
| 432 | assert( | 
|---|
| 433 | (To == std::next(DAGInterval.bottom()->getIterator()) || | 
|---|
| 434 | (To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) || | 
|---|
| 435 | (To != BB->end() && DAGInterval.contains(&*To))) && | 
|---|
| 436 | "TODO: To should be either within the DAGInterval or right " | 
|---|
| 437 | "before/after it."); | 
|---|
| 438 |  | 
|---|
| 439 | // Make a copy of the DAGInterval before we update it. | 
|---|
| 440 | auto OrigDAGInterval = DAGInterval; | 
|---|
| 441 |  | 
|---|
| 442 | // Maintain the DAGInterval. | 
|---|
| 443 | DAGInterval.notifyMoveInstr(I, BeforeIt: To); | 
|---|
| 444 |  | 
|---|
| 445 | // TODO: Perhaps check if this is legal by checking the dependencies? | 
|---|
| 446 |  | 
|---|
| 447 | // Update the MemDGNode chain to reflect the instr movement if necessary. | 
|---|
| 448 | DGNode *N = getNodeOrNull(I); | 
|---|
| 449 | if (N == nullptr) | 
|---|
| 450 | return; | 
|---|
| 451 | MemDGNode *MemN = dyn_cast<MemDGNode>(Val: N); | 
|---|
| 452 | if (MemN == nullptr) | 
|---|
| 453 | return; | 
|---|
| 454 |  | 
|---|
| 455 | // First safely detach it from the existing chain. | 
|---|
| 456 | MemN->detachFromChain(); | 
|---|
| 457 |  | 
|---|
| 458 | // Now insert it back into the chain at the new location. | 
|---|
| 459 | // | 
|---|
| 460 | // We won't always have a DGNode to insert before it. If `To` is BB->end() or | 
|---|
| 461 | // if it points to an instr after DAGInterval.bottom() then we will have to | 
|---|
| 462 | // find a node to insert *after*. | 
|---|
| 463 | // | 
|---|
| 464 | // BB:                              BB: | 
|---|
| 465 | //  I1                               I1 ^ | 
|---|
| 466 | //  I2                               I2 | DAGInteval [I1 to I3] | 
|---|
| 467 | //  I3                               I3 V | 
|---|
| 468 | //  I4                               I4   <- `To` == right after DAGInterval | 
|---|
| 469 | //    <- `To` == BB->end() | 
|---|
| 470 | // | 
|---|
| 471 | if (To == BB->end() || | 
|---|
| 472 | To == std::next(x: OrigDAGInterval.bottom()->getIterator())) { | 
|---|
| 473 | // If we don't have a node to insert before, find a node to insert after and | 
|---|
| 474 | // update the chain. | 
|---|
| 475 | DGNode *InsertAfterN = getNode(I: &*std::prev(x: To)); | 
|---|
| 476 | MemN->setPrevNode( | 
|---|
| 477 | getMemDGNodeBefore(N: InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN)); | 
|---|
| 478 | } else { | 
|---|
| 479 | // We have a node to insert before, so update the chain. | 
|---|
| 480 | DGNode *BeforeToN = getNode(I: &*To); | 
|---|
| 481 | MemN->setPrevNode( | 
|---|
| 482 | getMemDGNodeBefore(N: BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN)); | 
|---|
| 483 | MemN->setNextNode( | 
|---|
| 484 | getMemDGNodeAfter(N: BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN)); | 
|---|
| 485 | } | 
|---|
| 486 | } | 
|---|
| 487 |  | 
|---|
| 488 | void DependencyGraph::notifyEraseInstr(Instruction *I) { | 
|---|
| 489 | if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting) | 
|---|
| 490 | // We don't maintain the DAG while reverting. | 
|---|
| 491 | return; | 
|---|
| 492 | auto *N = getNode(I); | 
|---|
| 493 | if (N == nullptr) | 
|---|
| 494 | // Early return if there is no DAG node for `I`. | 
|---|
| 495 | return; | 
|---|
| 496 | if (auto *MemN = dyn_cast<MemDGNode>(Val: getNode(I))) { | 
|---|
| 497 | // Update the MemDGNode chain if this is a memory node. | 
|---|
| 498 | auto *PrevMemN = getMemDGNodeBefore(N: MemN, /*IncludingN=*/false); | 
|---|
| 499 | auto *NextMemN = getMemDGNodeAfter(N: MemN, /*IncludingN=*/false); | 
|---|
| 500 | if (PrevMemN != nullptr) | 
|---|
| 501 | PrevMemN->NextMemN = NextMemN; | 
|---|
| 502 | if (NextMemN != nullptr) | 
|---|
| 503 | NextMemN->PrevMemN = PrevMemN; | 
|---|
| 504 |  | 
|---|
| 505 | // Drop the memory dependencies from both predecessors and successors. | 
|---|
| 506 | while (!MemN->memPreds().empty()) { | 
|---|
| 507 | auto *PredN = *MemN->memPreds().begin(); | 
|---|
| 508 | MemN->removeMemPred(PredN); | 
|---|
| 509 | } | 
|---|
| 510 | while (!MemN->memSuccs().empty()) { | 
|---|
| 511 | auto *SuccN = *MemN->memSuccs().begin(); | 
|---|
| 512 | SuccN->removeMemPred(PredN: MemN); | 
|---|
| 513 | } | 
|---|
| 514 | // NOTE: The unscheduled succs for MemNodes get updated be setMemPred(). | 
|---|
| 515 | } else { | 
|---|
| 516 | // If this is a non-mem node we only need to update UnscheduledSuccs. | 
|---|
| 517 | if (!N->scheduled()) | 
|---|
| 518 | for (auto *PredN : N->preds(DAG&: *this)) | 
|---|
| 519 | PredN->decrUnscheduledSuccs(); | 
|---|
| 520 | } | 
|---|
| 521 | // Finally erase the Node. | 
|---|
| 522 | InstrToNodeMap.erase(Val: I); | 
|---|
| 523 | } | 
|---|
| 524 |  | 
|---|
| 525 | void DependencyGraph::notifySetUse(const Use &U, Value *NewSrc) { | 
|---|
| 526 | // Update the UnscheduledSuccs counter for both the current source and NewSrc | 
|---|
| 527 | // if needed. | 
|---|
| 528 | if (auto *CurrSrcI = dyn_cast<Instruction>(Val: U.get())) { | 
|---|
| 529 | if (auto *CurrSrcN = getNode(I: CurrSrcI)) { | 
|---|
| 530 | CurrSrcN->decrUnscheduledSuccs(); | 
|---|
| 531 | } | 
|---|
| 532 | } | 
|---|
| 533 | if (auto *NewSrcI = dyn_cast<Instruction>(Val: NewSrc)) { | 
|---|
| 534 | if (auto *NewSrcN = getNode(I: NewSrcI)) { | 
|---|
| 535 | ++NewSrcN->UnscheduledSuccs; | 
|---|
| 536 | } | 
|---|
| 537 | } | 
|---|
| 538 | } | 
|---|
| 539 |  | 
|---|
| 540 | Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) { | 
|---|
| 541 | if (Instrs.empty()) | 
|---|
| 542 | return {}; | 
|---|
| 543 |  | 
|---|
| 544 | Interval<Instruction> InstrsInterval(Instrs); | 
|---|
| 545 | Interval<Instruction> Union = DAGInterval.getUnionInterval(Other: InstrsInterval); | 
|---|
| 546 | auto NewInterval = Union.getSingleDiff(Other: DAGInterval); | 
|---|
| 547 | if (NewInterval.empty()) | 
|---|
| 548 | return {}; | 
|---|
| 549 |  | 
|---|
| 550 | createNewNodes(NewInterval); | 
|---|
| 551 |  | 
|---|
| 552 | // Create the dependencies. | 
|---|
| 553 | // | 
|---|
| 554 | // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval. | 
|---|
| 555 | // +---+       -             - | 
|---|
| 556 | // |   | SrcN  |             | | 
|---|
| 557 | // |   |  |    | SrcRange    | | 
|---|
| 558 | // |New|  v    |             | DstRange | 
|---|
| 559 | // |   | DstN  -             | | 
|---|
| 560 | // |   |                     | | 
|---|
| 561 | // +---+                     - | 
|---|
| 562 | // We are scanning for deps with destination in NewInterval and sources in | 
|---|
| 563 | // NewInterval until DstN, for each DstN. | 
|---|
| 564 | auto FullScan = [this](const Interval<Instruction> Intvl) { | 
|---|
| 565 | auto DstRange = MemDGNodeIntervalBuilder::make(Instrs: Intvl, DAG&: *this); | 
|---|
| 566 | if (!DstRange.empty()) { | 
|---|
| 567 | for (MemDGNode &DstN : drop_begin(RangeOrContainer&: DstRange)) { | 
|---|
| 568 | auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode()); | 
|---|
| 569 | scanAndAddDeps(DstN, SrcScanRange: SrcRange); | 
|---|
| 570 | } | 
|---|
| 571 | } | 
|---|
| 572 | }; | 
|---|
| 573 | auto MemDAGInterval = MemDGNodeIntervalBuilder::make(Instrs: DAGInterval, DAG&: *this); | 
|---|
| 574 | if (MemDAGInterval.empty()) { | 
|---|
| 575 | FullScan(NewInterval); | 
|---|
| 576 | } | 
|---|
| 577 | // 2. The new section is below the old section. | 
|---|
| 578 | // +---+       - | 
|---|
| 579 | // |   |       | | 
|---|
| 580 | // |Old| SrcN  | | 
|---|
| 581 | // |   |  |    | | 
|---|
| 582 | // +---+  |    | SrcRange | 
|---|
| 583 | // +---+  |    |             - | 
|---|
| 584 | // |   |  |    |             | | 
|---|
| 585 | // |New|  v    |             | DstRange | 
|---|
| 586 | // |   | DstN  -             | | 
|---|
| 587 | // |   |                     | | 
|---|
| 588 | // +---+                     - | 
|---|
| 589 | // We are scanning for deps with destination in NewInterval because the deps | 
|---|
| 590 | // in DAGInterval have already been computed. We consider sources in the whole | 
|---|
| 591 | // range including both NewInterval and DAGInterval until DstN, for each DstN. | 
|---|
| 592 | else if (DAGInterval.bottom()->comesBefore(Other: NewInterval.top())) { | 
|---|
| 593 | auto DstRange = MemDGNodeIntervalBuilder::make(Instrs: NewInterval, DAG&: *this); | 
|---|
| 594 | auto SrcRangeFull = MemDAGInterval.getUnionInterval(Other: DstRange); | 
|---|
| 595 | for (MemDGNode &DstN : DstRange) { | 
|---|
| 596 | auto SrcRange = | 
|---|
| 597 | Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode()); | 
|---|
| 598 | scanAndAddDeps(DstN, SrcScanRange: SrcRange); | 
|---|
| 599 | } | 
|---|
| 600 | } | 
|---|
| 601 | // 3. The new section is above the old section. | 
|---|
| 602 | else if (NewInterval.bottom()->comesBefore(Other: DAGInterval.top())) { | 
|---|
| 603 | // +---+       -             - | 
|---|
| 604 | // |   | SrcN  |             | | 
|---|
| 605 | // |New|  |    | SrcRange    | DstRange | 
|---|
| 606 | // |   |  v    |             | | 
|---|
| 607 | // |   | DstN  -             | | 
|---|
| 608 | // |   |                     | | 
|---|
| 609 | // +---+                     - | 
|---|
| 610 | // +---+ | 
|---|
| 611 | // |Old| | 
|---|
| 612 | // |   | | 
|---|
| 613 | // +---+ | 
|---|
| 614 | // When scanning for deps with destination in NewInterval we need to fully | 
|---|
| 615 | // scan the interval. This is the same as the scanning for a new DAG. | 
|---|
| 616 | FullScan(NewInterval); | 
|---|
| 617 |  | 
|---|
| 618 | // +---+       - | 
|---|
| 619 | // |   |       | | 
|---|
| 620 | // |New| SrcN  | SrcRange | 
|---|
| 621 | // |   |  |    | | 
|---|
| 622 | // |   |  |    | | 
|---|
| 623 | // |   |  |    | | 
|---|
| 624 | // +---+  |    - | 
|---|
| 625 | // +---+  |                  - | 
|---|
| 626 | // |Old|  v                  | DstRange | 
|---|
| 627 | // |   | DstN                | | 
|---|
| 628 | // +---+                     - | 
|---|
| 629 | // When scanning for deps with destination in DAGInterval we need to | 
|---|
| 630 | // consider sources from the NewInterval only, because all intra-DAGInterval | 
|---|
| 631 | // dependencies have already been created. | 
|---|
| 632 | auto DstRangeOld = MemDAGInterval; | 
|---|
| 633 | auto SrcRange = MemDGNodeIntervalBuilder::make(Instrs: NewInterval, DAG&: *this); | 
|---|
| 634 | for (MemDGNode &DstN : DstRangeOld) | 
|---|
| 635 | scanAndAddDeps(DstN, SrcScanRange: SrcRange); | 
|---|
| 636 | } else { | 
|---|
| 637 | llvm_unreachable( "We don't expect extending in both directions!"); | 
|---|
| 638 | } | 
|---|
| 639 |  | 
|---|
| 640 | DAGInterval = Union; | 
|---|
| 641 | return NewInterval; | 
|---|
| 642 | } | 
|---|
| 643 |  | 
|---|
| 644 | #ifndef NDEBUG | 
|---|
| 645 | void DependencyGraph::print(raw_ostream &OS) const { | 
|---|
| 646 | // InstrToNodeMap is unordered so we need to create an ordered vector. | 
|---|
| 647 | SmallVector<DGNode *> Nodes; | 
|---|
| 648 | Nodes.reserve(InstrToNodeMap.size()); | 
|---|
| 649 | for (const auto &Pair : InstrToNodeMap) | 
|---|
| 650 | Nodes.push_back(Pair.second.get()); | 
|---|
| 651 | // Sort them based on which one comes first in the BB. | 
|---|
| 652 | sort(Nodes, [](DGNode *N1, DGNode *N2) { | 
|---|
| 653 | return N1->getInstruction()->comesBefore(N2->getInstruction()); | 
|---|
| 654 | }); | 
|---|
| 655 | for (auto *N : Nodes) | 
|---|
| 656 | N->print(OS, /*PrintDeps=*/true); | 
|---|
| 657 | } | 
|---|
| 658 |  | 
|---|
| 659 | void DependencyGraph::dump() const { | 
|---|
| 660 | print(dbgs()); | 
|---|
| 661 | dbgs() << "\n"; | 
|---|
| 662 | } | 
|---|
| 663 | #endif // NDEBUG | 
|---|
| 664 |  | 
|---|
| 665 | } // namespace llvm::sandboxir | 
|---|
| 666 |  | 
|---|