1//===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file This file defines a set of schedule DAG mutations that can be used to
10// override default scheduler behavior to enforce specific scheduling patterns.
11// They should be used in cases where runtime performance considerations such as
12// inter-wavefront interactions, mean that compile-time heuristics cannot
13// predict the optimal instruction ordering, or in kernels where optimum
14// instruction scheduling is important enough to warrant manual intervention.
15//
16//===----------------------------------------------------------------------===//
17
18#include "AMDGPUIGroupLP.h"
19#include "MCTargetDesc/AMDGPUMCTargetDesc.h"
20#include "SIInstrInfo.h"
21#include "SIMachineFunctionInfo.h"
22#include "llvm/ADT/BitmaskEnum.h"
23#include "llvm/ADT/DenseMap.h"
24#include "llvm/CodeGen/MachineScheduler.h"
25#include "llvm/CodeGen/TargetOpcodes.h"
26
27using namespace llvm;
28
29#define DEBUG_TYPE "igrouplp"
30
31namespace {
32
33static cl::opt<bool> EnableExactSolver(
34 "amdgpu-igrouplp-exact-solver", cl::Hidden,
35 cl::desc("Whether to use the exponential time solver to fit "
36 "the instructions to the pipeline as closely as "
37 "possible."),
38 cl::init(Val: false));
39
40static cl::opt<unsigned> CutoffForExact(
41 "amdgpu-igrouplp-exact-solver-cutoff", cl::init(Val: 0), cl::Hidden,
42 cl::desc("The maximum number of scheduling group conflicts "
43 "which we attempt to solve with the exponential time "
44 "exact solver. Problem sizes greater than this will"
45 "be solved by the less accurate greedy algorithm. Selecting "
46 "solver by size is superseded by manually selecting "
47 "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
48
49static cl::opt<uint64_t> MaxBranchesExplored(
50 "amdgpu-igrouplp-exact-solver-max-branches", cl::init(Val: 0), cl::Hidden,
51 cl::desc("The amount of branches that we are willing to explore with"
52 "the exact algorithm before giving up."));
53
54static cl::opt<bool> UseCostHeur(
55 "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(Val: true), cl::Hidden,
56 cl::desc("Whether to use the cost heuristic to make choices as we "
57 "traverse the search space using the exact solver. Defaulted "
58 "to on, and if turned off, we will use the node order -- "
59 "attempting to put the later nodes in the later sched groups. "
60 "Experimentally, results are mixed, so this should be set on a "
61 "case-by-case basis."));
62
63// Components of the mask that determines which instruction types may be may be
64// classified into a SchedGroup.
65enum class SchedGroupMask {
66 NONE = 0u,
67 ALU = 1u << 0,
68 VALU = 1u << 1,
69 SALU = 1u << 2,
70 MFMA = 1u << 3,
71 VMEM = 1u << 4,
72 VMEM_READ = 1u << 5,
73 VMEM_WRITE = 1u << 6,
74 DS = 1u << 7,
75 DS_READ = 1u << 8,
76 DS_WRITE = 1u << 9,
77 TRANS = 1u << 10,
78 ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
79 DS_READ | DS_WRITE | TRANS,
80 LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
81};
82
83class SchedGroup;
84
85// InstructionRule class is used to enact a filter which determines whether or
86// not an SU maps to a given SchedGroup. It contains complementary data
87// structures (e.g Cache) to help those filters.
88class InstructionRule {
89protected:
90 const SIInstrInfo *TII;
91 unsigned SGID;
92 // A cache made available to the Filter to store SUnits for subsequent
93 // invocations of the Filter
94 std::optional<SmallVector<SUnit *, 4>> Cache;
95
96public:
97 virtual bool
98 apply(const SUnit *, const ArrayRef<SUnit *>,
99 SmallVectorImpl<SchedGroup> &) {
100 return true;
101 };
102
103 InstructionRule(const SIInstrInfo *TII, unsigned SGID,
104 bool NeedsCache = false)
105 : TII(TII), SGID(SGID) {
106 if (NeedsCache) {
107 Cache = SmallVector<SUnit *, 4>();
108 }
109 }
110
111 virtual ~InstructionRule() = default;
112};
113
114using SUnitsToCandidateSGsMap = DenseMap<SUnit *, SmallVector<int, 4>>;
115
116// Classify instructions into groups to enable fine tuned control over the
117// scheduler. These groups may be more specific than current SchedModel
118// instruction classes.
119class SchedGroup {
120private:
121 // Mask that defines which instruction types can be classified into this
122 // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
123 // and SCHED_GROUP_BARRIER.
124 SchedGroupMask SGMask;
125
126 // Maximum number of SUnits that can be added to this group.
127 std::optional<unsigned> MaxSize;
128
129 // SchedGroups will only synchronize with other SchedGroups that have the same
130 // SyncID.
131 int SyncID = 0;
132
133 // SGID is used to map instructions to candidate SchedGroups
134 unsigned SGID;
135
136 // The different rules each instruction in this SchedGroup must conform to
137 SmallVector<std::shared_ptr<InstructionRule>, 4> Rules;
138
139 // Count of the number of created SchedGroups, used to initialize SGID.
140 static unsigned NumSchedGroups;
141
142 // Try to add and edge from SU A to SU B.
143 bool tryAddEdge(SUnit *A, SUnit *B);
144
145 // Use SGMask to determine whether we can classify MI as a member of this
146 // SchedGroup object.
147 bool canAddMI(const MachineInstr &MI) const;
148
149public:
150 // Collection of SUnits that are classified as members of this group.
151 SmallVector<SUnit *, 32> Collection;
152
153 ScheduleDAGInstrs *DAG;
154 const SIInstrInfo *TII;
155
156 // Returns true if SU can be added to this SchedGroup.
157 bool canAddSU(SUnit &SU) const;
158
159 // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
160 // MakePred is true, SU will be a predecessor of the SUnits in this
161 // SchedGroup, otherwise SU will be a successor.
162 void link(SUnit &SU, bool MakePred = false);
163
164 // Add DAG dependencies and track which edges are added, and the count of
165 // missed edges
166 int link(SUnit &SU, bool MakePred,
167 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
168
169 // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
170 // Use the predicate to determine whether SU should be a predecessor (P =
171 // true) or a successor (P = false) of this SchedGroup.
172 void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
173
174 // Add DAG dependencies such that SUnits in this group shall be ordered
175 // before SUnits in OtherGroup.
176 void link(SchedGroup &OtherGroup);
177
178 // Returns true if no more instructions may be added to this group.
179 bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
180
181 // Append a constraint that SUs must meet in order to fit into this
182 // SchedGroup. Since many rules involve the relationship between a SchedGroup
183 // and the SUnits in other SchedGroups, rules are checked at Pipeline Solve
184 // time (rather than SchedGroup init time.)
185 void addRule(std::shared_ptr<InstructionRule> NewRule) {
186 Rules.push_back(Elt: NewRule);
187 }
188
189 // Returns true if the SU matches all rules
190 bool allowedByRules(const SUnit *SU,
191 SmallVectorImpl<SchedGroup> &SyncPipe) const {
192 for (auto &Rule : Rules) {
193 if (!Rule->apply(SU, Collection, SyncPipe))
194 return false;
195 }
196 return true;
197 }
198
199 // Add SU to the SchedGroup.
200 void add(SUnit &SU) {
201 LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
202 << format_hex((int)SGMask, 10, true) << " adding "
203 << *SU.getInstr());
204 Collection.push_back(Elt: &SU);
205 }
206
207 // Remove last element in the SchedGroup
208 void pop() { Collection.pop_back(); }
209
210 // Identify and add all relevant SUs from the DAG to this SchedGroup.
211 void initSchedGroup();
212
213 // Add instructions to the SchedGroup bottom up starting from RIter.
214 // PipelineInstrs is a set of instructions that should not be added to the
215 // SchedGroup even when the other conditions for adding it are satisfied.
216 // RIter will be added to the SchedGroup as well, and dependencies will be
217 // added so that RIter will always be scheduled at the end of the group.
218 void initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
219 SUnitsToCandidateSGsMap &SyncedInstrs);
220
221 void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);
222
223 int getSyncID() { return SyncID; }
224
225 int getSGID() { return SGID; }
226
227 SchedGroupMask getMask() { return SGMask; }
228
229 SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
230 ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
231 : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) {
232 SGID = NumSchedGroups++;
233 }
234
235 SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
236 ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
237 : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) {
238 SGID = NumSchedGroups++;
239 }
240};
241
242using SUToCandSGsPair = std::pair<SUnit *, SmallVector<int, 4>>;
243using SUsToCandSGsVec = SmallVector<SUToCandSGsPair, 4>;
244
245// The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
246// in non-trivial cases. For example, if the requested pipeline is
247// {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
248// in the DAG, then we will have an instruction that can not be trivially
249// assigned to a SchedGroup. The PipelineSolver class implements two algorithms
250// to find a good solution to the pipeline -- a greedy algorithm and an exact
251// algorithm. The exact algorithm has an exponential time complexity and should
252// only be used for small sized problems or medium sized problems where an exact
253// solution is highly desired.
254class PipelineSolver {
255 [[maybe_unused]] ScheduleDAGMI *DAG;
256
257 // Instructions that can be assigned to multiple SchedGroups
258 DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
259 SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
260 DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
261 // The current working pipeline
262 SmallVector<SmallVector<SchedGroup, 4>, 4> CurrPipeline;
263 // The pipeline that has the best solution found so far
264 SmallVector<SmallVector<SchedGroup, 4>, 4> BestPipeline;
265
266 // Whether or not we actually have any SyncedInstrs to try to solve.
267 bool NeedsSolver = false;
268
269 // Compute an estimate of the size of search tree -- the true size is
270 // the product of each conflictedInst.Matches.size() across all SyncPipelines
271 unsigned computeProblemSize();
272
273 // The cost penalty of not assigning a SU to a SchedGroup
274 int MissPenalty = 0;
275
276 // Costs in terms of the number of edges we are unable to add
277 int BestCost = -1;
278 int CurrCost = 0;
279
280 // Index pointing to the conflicting instruction that is currently being
281 // fitted
282 int CurrConflInstNo = 0;
283 // Index to the pipeline that is currently being fitted
284 int CurrSyncGroupIdx = 0;
285 // The first non trivial pipeline
286 int BeginSyncGroupIdx = 0;
287
288 // How many branches we have explored
289 uint64_t BranchesExplored = 0;
290
291 // The direction in which we process the candidate SchedGroups per SU
292 bool IsBottomUp = true;
293
294 // Update indices to fit next conflicting instruction
295 void advancePosition();
296 // Recede indices to attempt to find better fit for previous conflicting
297 // instruction
298 void retreatPosition();
299
300 // The exponential time algorithm which finds the provably best fit
301 bool solveExact();
302 // The polynomial time algorithm which attempts to find a good fit
303 bool solveGreedy();
304 // Find the best SchedGroup for the current SU using the heuristic given all
305 // current information. One step in the greedy algorithm. Templated against
306 // the SchedGroup iterator (either reverse or forward).
307 template <typename T>
308 void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,
309 T E);
310 // Whether or not the current solution is optimal
311 bool checkOptimal();
312 // Populate the ready list, prioiritizing fewest missed edges first
313 // Templated against the SchedGroup iterator (either reverse or forward).
314 template <typename T>
315 void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,
316 T E);
317 // Add edges corresponding to the SchedGroups as assigned by solver
318 void makePipeline();
319 // Link the SchedGroups in the best found pipeline.
320 // Tmplated against the SchedGroup iterator (either reverse or forward).
321 template <typename T> void linkSchedGroups(T I, T E);
322 // Add the edges from the SU to the other SchedGroups in pipeline, and
323 // return the number of edges missed.
324 int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
325 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
326 /// Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
327 /// returns the cost (in terms of missed pipeline edges), and tracks the edges
328 /// added in \p AddedEdges
329 template <typename T>
330 int linkSUnit(SUnit *SU, int SGID,
331 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
332 /// Remove the edges passed via \p AddedEdges
333 void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
334 // Convert the passed in maps to arrays for bidirectional iterators
335 void convertSyncMapsToArrays();
336
337 void reset();
338
339public:
340 // Invoke the solver to map instructions to instruction groups. Heuristic &&
341 // command-line-option determines to use exact or greedy algorithm.
342 void solve();
343
344 PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
345 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
346 ScheduleDAGMI *DAG, bool IsBottomUp = true)
347 : DAG(DAG), SyncedInstrs(SyncedInstrs),
348 SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {
349
350 for (auto &PipelineInstrs : SyncedInstrs) {
351 if (PipelineInstrs.second.size() > 0) {
352 NeedsSolver = true;
353 break;
354 }
355 }
356
357 if (!NeedsSolver)
358 return;
359
360 convertSyncMapsToArrays();
361
362 CurrPipeline = BestPipeline;
363
364 while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
365 PipelineInstrs[BeginSyncGroupIdx].size() == 0)
366 ++BeginSyncGroupIdx;
367
368 if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
369 return;
370 }
371};
372
373void PipelineSolver::reset() {
374
375 for (auto &SyncPipeline : CurrPipeline) {
376 for (auto &SG : SyncPipeline) {
377 SmallVector<SUnit *, 32> TempCollection = SG.Collection;
378 SG.Collection.clear();
379 auto *SchedBarr = llvm::find_if(Range&: TempCollection, P: [](SUnit *SU) {
380 return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
381 });
382 if (SchedBarr != TempCollection.end())
383 SG.Collection.push_back(Elt: *SchedBarr);
384 }
385 }
386
387 CurrSyncGroupIdx = BeginSyncGroupIdx;
388 CurrConflInstNo = 0;
389 CurrCost = 0;
390}
391
392void PipelineSolver::convertSyncMapsToArrays() {
393 for (auto &SyncPipe : SyncedSchedGroups) {
394 BestPipeline.insert(I: BestPipeline.begin(), Elt: SyncPipe.second);
395 }
396
397 int PipelineIDx = SyncedInstrs.size() - 1;
398 PipelineInstrs.resize(N: SyncedInstrs.size());
399 for (auto &SyncInstrMap : SyncedInstrs) {
400 for (auto &SUsToCandSGs : SyncInstrMap.second) {
401 if (PipelineInstrs[PipelineIDx].size() == 0) {
402 PipelineInstrs[PipelineIDx].push_back(
403 Elt: std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
404 continue;
405 }
406 auto *SortPosition = PipelineInstrs[PipelineIDx].begin();
407 // Insert them in sorted order -- this allows for good parsing order in
408 // the greedy algorithm
409 while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
410 SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
411 ++SortPosition;
412 PipelineInstrs[PipelineIDx].insert(
413 I: SortPosition, Elt: std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
414 }
415 --PipelineIDx;
416 }
417}
418
419template <typename T> void PipelineSolver::linkSchedGroups(T I, T E) {
420 for (; I != E; ++I) {
421 auto &GroupA = *I;
422 for (auto J = std::next(I); J != E; ++J) {
423 auto &GroupB = *J;
424 GroupA.link(GroupB);
425 }
426 }
427}
428
429void PipelineSolver::makePipeline() {
430 // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
431 for (auto &SyncPipeline : BestPipeline) {
432 LLVM_DEBUG(dbgs() << "Printing SchedGroups\n");
433 for (auto &SG : SyncPipeline) {
434 LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID()
435 << " has: \n");
436 SUnit *SGBarr = nullptr;
437 for (auto &SU : SG.Collection) {
438 if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
439 SGBarr = SU;
440 LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n");
441 }
442 // Command line requested IGroupLP doesn't have SGBarr
443 if (!SGBarr)
444 continue;
445 SG.link(SU&: *SGBarr, MakePred: false);
446 }
447 }
448
449 for (auto &SyncPipeline : BestPipeline) {
450 IsBottomUp ? linkSchedGroups(I: SyncPipeline.rbegin(), E: SyncPipeline.rend())
451 : linkSchedGroups(I: SyncPipeline.begin(), E: SyncPipeline.end());
452 }
453}
454
455template <typename T>
456int PipelineSolver::linkSUnit(
457 SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,
458 T I, T E) {
459 bool MakePred = false;
460 int AddedCost = 0;
461 for (; I < E; ++I) {
462 if (I->getSGID() == SGID) {
463 MakePred = true;
464 continue;
465 }
466 auto Group = *I;
467 AddedCost += Group.link(*SU, MakePred, AddedEdges);
468 assert(AddedCost >= 0);
469 }
470 return AddedCost;
471}
472
473int PipelineSolver::addEdges(
474 SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
475 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
476
477 // For IsBottomUp, the first SchedGroup in SyncPipeline contains the
478 // instructions that are the ultimate successors in the resultant mutation.
479 // Therefore, in such a configuration, the SchedGroups occurring before the
480 // candidate SGID are successors of the candidate SchedGroup, thus the current
481 // SU should be linked as a predecessor to SUs in those SchedGroups. The
482 // opposite is true if !IsBottomUp. IsBottomUp occurs in the case of multiple
483 // SCHED_GROUP_BARRIERS, or if a user specifies IGLP_OPT SchedGroups using
484 // IsBottomUp (in reverse).
485 return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, I: SyncPipeline.rbegin(),
486 E: SyncPipeline.rend())
487 : linkSUnit(SU, SGID, AddedEdges, I: SyncPipeline.begin(),
488 E: SyncPipeline.end());
489}
490
491void PipelineSolver::removeEdges(
492 const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
493 // Only remove the edges that we have added when testing
494 // the fit.
495 for (auto &PredSuccPair : EdgesToRemove) {
496 SUnit *Pred = PredSuccPair.first;
497 SUnit *Succ = PredSuccPair.second;
498
499 auto *Match = llvm::find_if(
500 Range&: Succ->Preds, P: [&Pred](SDep &P) { return P.getSUnit() == Pred; });
501 if (Match != Succ->Preds.end()) {
502 assert(Match->isArtificial());
503 Succ->removePred(D: *Match);
504 }
505 }
506}
507
508void PipelineSolver::advancePosition() {
509 ++CurrConflInstNo;
510
511 if (static_cast<size_t>(CurrConflInstNo) >=
512 PipelineInstrs[CurrSyncGroupIdx].size()) {
513 CurrConflInstNo = 0;
514 ++CurrSyncGroupIdx;
515 // Advance to next non-trivial pipeline
516 while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
517 PipelineInstrs[CurrSyncGroupIdx].size() == 0)
518 ++CurrSyncGroupIdx;
519 }
520}
521
522void PipelineSolver::retreatPosition() {
523 assert(CurrConflInstNo >= 0);
524 assert(CurrSyncGroupIdx >= 0);
525
526 if (CurrConflInstNo > 0) {
527 --CurrConflInstNo;
528 return;
529 }
530
531 if (CurrConflInstNo == 0) {
532 // If we return to the starting position, we have explored
533 // the entire tree
534 if (CurrSyncGroupIdx == BeginSyncGroupIdx)
535 return;
536
537 --CurrSyncGroupIdx;
538 // Go to previous non-trivial pipeline
539 while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)
540 --CurrSyncGroupIdx;
541
542 CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
543 }
544}
545
546bool PipelineSolver::checkOptimal() {
547 if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
548 if (BestCost == -1 || CurrCost < BestCost) {
549 BestPipeline = CurrPipeline;
550 BestCost = CurrCost;
551 LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
552 }
553 assert(BestCost >= 0);
554 }
555
556 bool DoneExploring = false;
557 if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
558 DoneExploring = true;
559
560 return (DoneExploring || BestCost == 0);
561}
562
563template <typename T>
564void PipelineSolver::populateReadyList(
565 SmallVectorImpl<std::pair<int, int>> &ReadyList, T I, T E) {
566 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
567 auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
568 assert(CurrSU.second.size() >= 1);
569
570 for (; I != E; ++I) {
571 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
572 int CandSGID = *I;
573 SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
574 return SG.getSGID() == CandSGID;
575 });
576 assert(Match);
577
578 if (UseCostHeur) {
579 if (Match->isFull()) {
580 ReadyList.push_back(Elt: std::pair(*I, MissPenalty));
581 continue;
582 }
583
584 int TempCost = addEdges(SyncPipeline, SU: CurrSU.first, SGID: CandSGID, AddedEdges);
585 ReadyList.push_back(Elt: std::pair(*I, TempCost));
586 removeEdges(EdgesToRemove: AddedEdges);
587 } else
588 ReadyList.push_back(Elt: std::pair(*I, -1));
589 }
590
591 if (UseCostHeur)
592 std::sort(first: ReadyList.begin(), last: ReadyList.end(), comp: llvm::less_second());
593
594 assert(ReadyList.size() == CurrSU.second.size());
595}
596
597bool PipelineSolver::solveExact() {
598 if (checkOptimal())
599 return true;
600
601 if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
602 return false;
603
604 assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
605 assert(static_cast<size_t>(CurrConflInstNo) <
606 PipelineInstrs[CurrSyncGroupIdx].size());
607 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
608 LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
609 << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
610
611 // SchedGroup -> Cost pairs
612 SmallVector<std::pair<int, int>, 4> ReadyList;
613 // Prioritize the candidate sched groups in terms of lowest cost first
614 IsBottomUp ? populateReadyList(ReadyList, I: CurrSU.second.rbegin(),
615 E: CurrSU.second.rend())
616 : populateReadyList(ReadyList, I: CurrSU.second.begin(),
617 E: CurrSU.second.end());
618
619 auto *I = ReadyList.begin();
620 auto *E = ReadyList.end();
621 for (; I != E; ++I) {
622 // If we are trying SGs in least cost order, and the current SG is cost
623 // infeasible, then all subsequent SGs will also be cost infeasible, so we
624 // can prune.
625 if (BestCost != -1 && (CurrCost + I->second > BestCost))
626 return false;
627
628 int CandSGID = I->first;
629 int AddedCost = 0;
630 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
631 auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
632 SchedGroup *Match;
633 for (auto &SG : SyncPipeline) {
634 if (SG.getSGID() == CandSGID)
635 Match = &SG;
636 }
637
638 if (Match->isFull())
639 continue;
640
641 if (!Match->allowedByRules(SU: CurrSU.first, SyncPipe&: SyncPipeline))
642 continue;
643
644 LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
645 << (int)Match->getMask() << "and ID " << CandSGID
646 << "\n");
647 Match->add(SU&: *CurrSU.first);
648 AddedCost = addEdges(SyncPipeline, SU: CurrSU.first, SGID: CandSGID, AddedEdges);
649 LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
650 CurrCost += AddedCost;
651 advancePosition();
652 ++BranchesExplored;
653 bool FinishedExploring = false;
654 // If the Cost after adding edges is greater than a known solution,
655 // backtrack
656 if (CurrCost < BestCost || BestCost == -1) {
657 if (solveExact()) {
658 FinishedExploring = BestCost != 0;
659 if (!FinishedExploring)
660 return true;
661 }
662 }
663
664 retreatPosition();
665 CurrCost -= AddedCost;
666 removeEdges(EdgesToRemove: AddedEdges);
667 Match->pop();
668 CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
669 if (FinishedExploring)
670 return true;
671 }
672
673 // Try the pipeline where the current instruction is omitted
674 // Potentially if we omit a problematic instruction from the pipeline,
675 // all the other instructions can nicely fit.
676 CurrCost += MissPenalty;
677 advancePosition();
678
679 LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
680
681 bool FinishedExploring = false;
682 if (CurrCost < BestCost || BestCost == -1) {
683 if (solveExact()) {
684 bool FinishedExploring = BestCost != 0;
685 if (!FinishedExploring)
686 return true;
687 }
688 }
689
690 retreatPosition();
691 CurrCost -= MissPenalty;
692 return FinishedExploring;
693}
694
695template <typename T>
696void PipelineSolver::greedyFind(
697 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
698 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
699 int BestNodeCost = -1;
700 int TempCost;
701 SchedGroup *BestGroup = nullptr;
702 int BestGroupID = -1;
703 auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
704 LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
705 << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
706
707 // Since we have added the potential SchedGroups from bottom up, but
708 // traversed the DAG from top down, parse over the groups from last to
709 // first. If we fail to do this for the greedy algorithm, the solution will
710 // likely not be good in more complex cases.
711 for (; I != E; ++I) {
712 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
713 int CandSGID = *I;
714 SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
715 return SG.getSGID() == CandSGID;
716 });
717 assert(Match);
718
719 LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
720 << (int)Match->getMask() << "\n");
721
722 if (Match->isFull()) {
723 LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
724 continue;
725 }
726 if (!Match->allowedByRules(SU: CurrSU.first, SyncPipe&: SyncPipeline)) {
727 LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");
728 continue;
729 }
730 TempCost = addEdges(SyncPipeline, SU: CurrSU.first, SGID: CandSGID, AddedEdges);
731 LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
732 if (TempCost < BestNodeCost || BestNodeCost == -1) {
733 BestGroup = Match;
734 BestNodeCost = TempCost;
735 BestGroupID = CandSGID;
736 }
737 removeEdges(EdgesToRemove: AddedEdges);
738 if (BestNodeCost == 0)
739 break;
740 }
741
742 if (BestGroupID != -1) {
743 BestGroup->add(SU&: *CurrSU.first);
744 addEdges(SyncPipeline, SU: CurrSU.first, SGID: BestGroupID, AddedEdges);
745 LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
746 << (int)BestGroup->getMask() << "\n");
747 BestCost += TempCost;
748 } else
749 BestCost += MissPenalty;
750
751 CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
752}
753
754bool PipelineSolver::solveGreedy() {
755 BestCost = 0;
756 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
757
758 while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
759 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
760 IsBottomUp
761 ? greedyFind(AddedEdges, I: CurrSU.second.rbegin(), E: CurrSU.second.rend())
762 : greedyFind(AddedEdges, I: CurrSU.second.begin(), E: CurrSU.second.end());
763 advancePosition();
764 }
765 BestPipeline = CurrPipeline;
766 removeEdges(EdgesToRemove: AddedEdges);
767 return false;
768}
769
770unsigned PipelineSolver::computeProblemSize() {
771 unsigned ProblemSize = 0;
772 for (auto &PipeConflicts : PipelineInstrs) {
773 ProblemSize += PipeConflicts.size();
774 }
775
776 return ProblemSize;
777}
778
779void PipelineSolver::solve() {
780 if (!NeedsSolver)
781 return;
782
783 unsigned ProblemSize = computeProblemSize();
784 assert(ProblemSize > 0);
785
786 bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
787 MissPenalty = (ProblemSize / 2) + 1;
788
789 LLVM_DEBUG(DAG->dump());
790 if (EnableExactSolver || BelowCutoff) {
791 LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
792 solveGreedy();
793 reset();
794 LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
795 if (BestCost > 0) {
796 LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
797 solveExact();
798 LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
799 }
800 } else { // Use the Greedy Algorithm by default
801 LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
802 solveGreedy();
803 }
804
805 makePipeline();
806 LLVM_DEBUG(dbgs() << "After applying mutation\n");
807 LLVM_DEBUG(DAG->dump());
808}
809
810enum IGLPStrategyID : int {
811 MFMASmallGemmOptID = 0,
812 MFMASmallGemmSingleWaveOptID = 1,
813 MFMAExpInterleaveID = 2,
814 MFMAExpSimpleInterleaveID = 3
815};
816
817// Implement a IGLP scheduling strategy.
818class IGLPStrategy {
819protected:
820 ScheduleDAGInstrs *DAG;
821
822 const SIInstrInfo *TII;
823
824public:
825 /// Add SchedGroups to \p SyncedSchedGroups to implement this Strategy.
826 virtual bool applyIGLPStrategy(
827 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
828 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
829 AMDGPU::SchedulingPhase Phase) = 0;
830
831 // Returns true if this strategy should be applied to a ScheduleDAG.
832 virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
833 AMDGPU::SchedulingPhase Phase) = 0;
834
835 bool IsBottomUp = true;
836
837 IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
838 : DAG(DAG), TII(TII) {}
839
840 virtual ~IGLPStrategy() = default;
841};
842
843class MFMASmallGemmOpt final : public IGLPStrategy {
844private:
845public:
846 bool applyIGLPStrategy(
847 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
848 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
849 AMDGPU::SchedulingPhase Phase) override;
850
851 bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
852 AMDGPU::SchedulingPhase Phase) override {
853 return true;
854 }
855
856 MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
857 : IGLPStrategy(DAG, TII) {
858 IsBottomUp = true;
859 }
860};
861
862bool MFMASmallGemmOpt::applyIGLPStrategy(
863 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
864 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
865 AMDGPU::SchedulingPhase Phase) {
866 // Count the number of MFMA instructions.
867 unsigned MFMACount = 0;
868 for (const MachineInstr &I : *DAG)
869 if (TII->isMFMAorWMMA(MI: I))
870 ++MFMACount;
871
872 const unsigned PipelineSyncID = 0;
873 SchedGroup *SG = nullptr;
874 for (unsigned I = 0; I < MFMACount * 3; ++I) {
875 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
876 Args: SchedGroupMask::DS, Args: 2, Args: PipelineSyncID, Args&: DAG, Args&: TII);
877 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
878
879 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
880 Args: SchedGroupMask::MFMA, Args: 1, Args: PipelineSyncID, Args&: DAG, Args&: TII);
881 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
882 }
883
884 return true;
885}
886
887class MFMAExpInterleaveOpt final : public IGLPStrategy {
888private:
889 // The count of TRANS SUs involved in the interleaved pipeline
890 static unsigned TransPipeCount;
891 // The count of MFMA SUs involved in the interleaved pipeline
892 static unsigned MFMAPipeCount;
893 // The count of Add SUs involved in the interleaved pipeline
894 static unsigned AddPipeCount;
895 // The number of transitive MFMA successors for each TRANS SU
896 static unsigned MFMAEnablement;
897 // The number of transitive TRANS predecessors for each MFMA SU
898 static unsigned ExpRequirement;
899 // The count of independent "chains" of MFMA instructions in the pipeline
900 static unsigned MFMAChains;
901 // The length of each independent "chain" of MFMA instructions
902 static unsigned MFMAChainLength;
903 // Whether or not the pipeline has V_CVT instructions
904 static bool HasCvt;
905 // Whether or not there are instructions between the TRANS instruction and
906 // V_CVT
907 static bool HasChainBetweenCvt;
908 // The first occuring DS_READ which feeds an MFMA chain
909 static std::optional<unsigned> FirstPipeDSR;
910 // The MFMAPipe SUs with no MFMA predecessors
911 SmallVector<SUnit *, 4> MFMAChainSeeds;
912 // Compute the heuristics for the pipeline, returning whether or not the DAG
913 // is well formatted for the mutation
914 bool analyzeDAG(const SIInstrInfo *TII);
915
916 /// Whether or not the instruction is a transitive predecessor of an MFMA
917 /// instruction
918 class IsPipeExp final : public InstructionRule {
919 public:
920 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
921 SmallVectorImpl<SchedGroup> &SyncPipe) override {
922
923 auto *DAG = SyncPipe[0].DAG;
924
925 if (Cache->empty()) {
926 auto I = DAG->SUnits.rbegin();
927 auto E = DAG->SUnits.rend();
928 for (; I != E; I++) {
929 if (TII->isMFMAorWMMA(MI: *I->getInstr()))
930 Cache->push_back(Elt: &*I);
931 }
932 if (Cache->empty())
933 return false;
934 }
935
936 auto Reaches = any_of(Range&: *Cache, P: [&SU, &DAG](SUnit *TargetSU) {
937 return DAG->IsReachable(SU: TargetSU, TargetSU: const_cast<SUnit *>(SU));
938 });
939
940 return Reaches;
941 }
942 IsPipeExp(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
943 : InstructionRule(TII, SGID, NeedsCache) {}
944 };
945
946 /// Whether or not the instruction is a transitive predecessor of the
947 /// \p Number th MFMA of the MFMAs occuring after a TRANS instruction
948 class EnablesNthMFMA final : public InstructionRule {
949 private:
950 unsigned Number = 1;
951
952 public:
953 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
954 SmallVectorImpl<SchedGroup> &SyncPipe) override {
955 bool FoundTrans = false;
956 unsigned Counter = 1;
957 auto *DAG = SyncPipe[0].DAG;
958
959 if (Cache->empty()) {
960 auto I = DAG->SUnits.begin();
961 auto E = DAG->SUnits.end();
962 for (; I != E; I++) {
963 if (FoundTrans && TII->isMFMAorWMMA(MI: *I->getInstr())) {
964 if (Counter == Number) {
965 Cache->push_back(Elt: &*I);
966 break;
967 }
968 ++Counter;
969 }
970 if (!FoundTrans && TII->isTRANS(Opcode: I->getInstr()->getOpcode()))
971 FoundTrans = true;
972 }
973 if (Cache->empty())
974 return false;
975 }
976
977 return DAG->IsReachable(SU: (*Cache)[0], TargetSU: const_cast<SUnit *>(SU));
978 }
979
980 EnablesNthMFMA(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
981 bool NeedsCache = false)
982 : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
983 };
984
985 /// Whether or not the instruction enables the exact MFMA that is the \p
986 /// Number th MFMA in the chain starting with \p ChainSeed
987 class EnablesNthMFMAInChain final : public InstructionRule {
988 private:
989 unsigned Number = 1;
990 SUnit *ChainSeed;
991
992 public:
993 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
994 SmallVectorImpl<SchedGroup> &SyncPipe) override {
995 auto *DAG = SyncPipe[0].DAG;
996
997 if (!SU || !TII->isMFMAorWMMA(MI: *ChainSeed->getInstr()))
998 return false;
999
1000 if (Cache->empty()) {
1001 auto *TempSU = ChainSeed;
1002 auto Depth = Number;
1003 while (Depth > 0) {
1004 --Depth;
1005 bool Found = false;
1006 for (auto &Succ : TempSU->Succs) {
1007 if (TII->isMFMAorWMMA(MI: *Succ.getSUnit()->getInstr())) {
1008 TempSU = Succ.getSUnit();
1009 Found = true;
1010 break;
1011 }
1012 }
1013 if (!Found)
1014 return false;
1015 }
1016
1017 Cache->push_back(Elt: TempSU);
1018 }
1019 // If we failed to find the instruction to be placed into the cache, we
1020 // would have already exited.
1021 assert(!Cache->empty());
1022
1023 return DAG->IsReachable(SU: (*Cache)[0], TargetSU: const_cast<SUnit *>(SU));
1024 }
1025
1026 EnablesNthMFMAInChain(unsigned Number, SUnit *ChainSeed,
1027 const SIInstrInfo *TII, unsigned SGID,
1028 bool NeedsCache = false)
1029 : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1030 ChainSeed(ChainSeed) {}
1031 };
1032
1033 /// Whether or not the instruction has less than \p Size immediate successors.
1034 /// If \p HasIntermediary is true, this tests also whether all successors of
1035 /// the SUnit have less than \p Size successors.
1036 class LessThanNSuccs final : public InstructionRule {
1037 private:
1038 unsigned Size = 1;
1039 bool HasIntermediary = false;
1040
1041 public:
1042 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1043 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1044 if (!SyncPipe.size())
1045 return false;
1046
1047 auto SuccSize = llvm::count_if(Range: SU->Succs, P: [](const SDep &Succ) {
1048 return Succ.getKind() == SDep::Data;
1049 });
1050 if (SuccSize >= Size)
1051 return false;
1052
1053 if (HasIntermediary) {
1054 for (auto Succ : SU->Succs) {
1055 auto SuccSize =
1056 llvm::count_if(Range&: Succ.getSUnit()->Succs, P: [](const SDep &SuccSucc) {
1057 return SuccSucc.getKind() == SDep::Data;
1058 });
1059 if (SuccSize >= Size)
1060 return false;
1061 }
1062 }
1063
1064 return true;
1065 }
1066 LessThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
1067 bool HasIntermediary = false, bool NeedsCache = false)
1068 : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1069 HasIntermediary(HasIntermediary) {}
1070 };
1071
1072 /// Whether or not the instruction has greater than or equal to \p Size
1073 /// immediate successors. If \p HasIntermediary is true, this tests also
1074 /// whether all successors of the SUnit have greater than or equal to \p Size
1075 /// successors.
1076 class GreaterThanOrEqualToNSuccs final : public InstructionRule {
1077 private:
1078 unsigned Size = 1;
1079 bool HasIntermediary = false;
1080
1081 public:
1082 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1083 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1084 if (!SyncPipe.size())
1085 return false;
1086
1087 auto SuccSize = llvm::count_if(Range: SU->Succs, P: [](const SDep &Succ) {
1088 return Succ.getKind() == SDep::Data;
1089 });
1090 if (SuccSize >= Size)
1091 return true;
1092
1093 if (HasIntermediary) {
1094 for (auto Succ : SU->Succs) {
1095 auto SuccSize =
1096 llvm::count_if(Range&: Succ.getSUnit()->Succs, P: [](const SDep &SuccSucc) {
1097 return SuccSucc.getKind() == SDep::Data;
1098 });
1099 if (SuccSize >= Size)
1100 return true;
1101 }
1102 }
1103
1104 return false;
1105 }
1106 GreaterThanOrEqualToNSuccs(unsigned Size, const SIInstrInfo *TII,
1107 unsigned SGID, bool HasIntermediary = false,
1108 bool NeedsCache = false)
1109 : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1110 HasIntermediary(HasIntermediary) {}
1111 };
1112
1113 // Whether or not the instruction is a relevant V_CVT instruction.
1114 class IsCvt final : public InstructionRule {
1115 public:
1116 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1117 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1118 auto Opc = SU->getInstr()->getOpcode();
1119 return Opc == AMDGPU::V_CVT_F16_F32_e32 ||
1120 Opc == AMDGPU::V_CVT_I32_F32_e32;
1121 }
1122 IsCvt(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1123 : InstructionRule(TII, SGID, NeedsCache) {}
1124 };
1125
1126 // Whether or not the instruction is FMA_F32.
1127 class IsFMA final : public InstructionRule {
1128 public:
1129 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1130 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1131 return SU->getInstr()->getOpcode() == AMDGPU::V_FMA_F32_e64 ||
1132 SU->getInstr()->getOpcode() == AMDGPU::V_PK_FMA_F32;
1133 }
1134 IsFMA(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1135 : InstructionRule(TII, SGID, NeedsCache) {}
1136 };
1137
1138 // Whether or not the instruction is a V_ADD_F32 instruction.
1139 class IsPipeAdd final : public InstructionRule {
1140 public:
1141 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1142 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1143 return SU->getInstr()->getOpcode() == AMDGPU::V_ADD_F32_e32;
1144 }
1145 IsPipeAdd(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1146 : InstructionRule(TII, SGID, NeedsCache) {}
1147 };
1148
1149 /// Whether or not the instruction is an immediate RAW successor
1150 /// of the SchedGroup \p Distance steps before.
1151 class IsSuccOfPrevNthGroup final : public InstructionRule {
1152 private:
1153 unsigned Distance = 1;
1154
1155 public:
1156 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1157 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1158 SchedGroup *OtherGroup = nullptr;
1159 if (!SyncPipe.size())
1160 return false;
1161
1162 for (auto &PipeSG : SyncPipe) {
1163 if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1164 OtherGroup = &PipeSG;
1165 }
1166
1167 if (!OtherGroup)
1168 return false;
1169 if (!OtherGroup->Collection.size())
1170 return true;
1171
1172 for (auto &OtherEle : OtherGroup->Collection) {
1173 for (auto &Succ : OtherEle->Succs) {
1174 if (Succ.getSUnit() == SU && Succ.getKind() == SDep::Data)
1175 return true;
1176 }
1177 }
1178
1179 return false;
1180 }
1181 IsSuccOfPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1182 unsigned SGID, bool NeedsCache = false)
1183 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1184 };
1185
1186 /// Whether or not the instruction is a transitive successor of any
1187 /// instruction the the SchedGroup \p Distance steps before.
1188 class IsReachableFromPrevNthGroup final : public InstructionRule {
1189 private:
1190 unsigned Distance = 1;
1191
1192 public:
1193 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1194 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1195 SchedGroup *OtherGroup = nullptr;
1196 if (!SyncPipe.size())
1197 return false;
1198
1199 for (auto &PipeSG : SyncPipe) {
1200 if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1201 OtherGroup = &PipeSG;
1202 }
1203
1204 if (!OtherGroup)
1205 return false;
1206 if (!OtherGroup->Collection.size())
1207 return true;
1208
1209 auto *DAG = SyncPipe[0].DAG;
1210
1211 for (auto &OtherEle : OtherGroup->Collection)
1212 if (DAG->IsReachable(SU: const_cast<SUnit *>(SU), TargetSU: OtherEle))
1213 return true;
1214
1215 return false;
1216 }
1217 IsReachableFromPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1218 unsigned SGID, bool NeedsCache = false)
1219 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1220 };
1221
1222 /// Whether or not the instruction occurs after the SU with NodeNUm \p Number
1223 class OccursAtOrAfterNode final : public InstructionRule {
1224 private:
1225 unsigned Number = 1;
1226
1227 public:
1228 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1229 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1230
1231 return SU->NodeNum >= Number;
1232 }
1233 OccursAtOrAfterNode(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
1234 bool NeedsCache = false)
1235 : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
1236 };
1237
1238 /// Whether or not the SU is exactly the \p Number th MFMA in the chain
1239 /// starting with \p ChainSeed
1240 class IsExactMFMA final : public InstructionRule {
1241 private:
1242 unsigned Number = 1;
1243 SUnit *ChainSeed;
1244
1245 public:
1246 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1247 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1248 if (!SU || !TII->isMFMAorWMMA(MI: *ChainSeed->getInstr()))
1249 return false;
1250
1251 if (Cache->empty()) {
1252 auto *TempSU = ChainSeed;
1253 auto Depth = Number;
1254 while (Depth > 0) {
1255 --Depth;
1256 bool Found = false;
1257 for (auto &Succ : TempSU->Succs) {
1258 if (TII->isMFMAorWMMA(MI: *Succ.getSUnit()->getInstr())) {
1259 TempSU = Succ.getSUnit();
1260 Found = true;
1261 break;
1262 }
1263 }
1264 if (!Found) {
1265 return false;
1266 }
1267 }
1268 Cache->push_back(Elt: TempSU);
1269 }
1270 // If we failed to find the instruction to be placed into the cache, we
1271 // would have already exited.
1272 assert(!Cache->empty());
1273
1274 return (*Cache)[0] == SU;
1275 }
1276
1277 IsExactMFMA(unsigned Number, SUnit *ChainSeed, const SIInstrInfo *TII,
1278 unsigned SGID, bool NeedsCache = false)
1279 : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1280 ChainSeed(ChainSeed) {}
1281 };
1282
1283 // Whether the instruction occurs after the first TRANS instruction. This
1284 // implies the instruction can not be a predecessor of the first TRANS
1285 // insruction
1286 class OccursAfterExp final : public InstructionRule {
1287 public:
1288 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1289 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1290
1291 auto *DAG = SyncPipe[0].DAG;
1292 if (Cache->empty()) {
1293 for (auto &SU : DAG->SUnits)
1294 if (TII->isTRANS(Opcode: SU.getInstr()->getOpcode())) {
1295 Cache->push_back(Elt: &SU);
1296 break;
1297 }
1298 if (Cache->empty())
1299 return false;
1300 }
1301
1302 return SU->NodeNum > (*Cache)[0]->NodeNum;
1303 }
1304
1305 OccursAfterExp(const SIInstrInfo *TII, unsigned SGID,
1306 bool NeedsCache = false)
1307 : InstructionRule(TII, SGID, NeedsCache) {}
1308 };
1309
1310public:
1311 bool applyIGLPStrategy(
1312 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1313 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1314 AMDGPU::SchedulingPhase Phase) override;
1315
1316 bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1317 AMDGPU::SchedulingPhase Phase) override;
1318
1319 MFMAExpInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1320 : IGLPStrategy(DAG, TII) {
1321 IsBottomUp = false;
1322 }
1323};
1324
1325unsigned MFMAExpInterleaveOpt::TransPipeCount = 0;
1326unsigned MFMAExpInterleaveOpt::MFMAPipeCount = 0;
1327unsigned MFMAExpInterleaveOpt::AddPipeCount = 0;
1328unsigned MFMAExpInterleaveOpt::MFMAEnablement = 0;
1329unsigned MFMAExpInterleaveOpt::ExpRequirement = 0;
1330unsigned MFMAExpInterleaveOpt::MFMAChains = 0;
1331unsigned MFMAExpInterleaveOpt::MFMAChainLength = 0;
1332bool MFMAExpInterleaveOpt::HasCvt = false;
1333bool MFMAExpInterleaveOpt::HasChainBetweenCvt = false;
1334std::optional<unsigned> MFMAExpInterleaveOpt::FirstPipeDSR = std::nullopt;
1335
1336bool MFMAExpInterleaveOpt::analyzeDAG(const SIInstrInfo *TII) {
1337 SmallVector<SUnit *, 10> ExpPipeCands;
1338 SmallVector<SUnit *, 10> MFMAPipeCands;
1339 SmallVector<SUnit *, 10> MFMAPipeSUs;
1340 SmallVector<SUnit *, 10> PackSUs;
1341 SmallVector<SUnit *, 10> CvtSUs;
1342
1343 auto isBitPack = [](unsigned Opc) {
1344 return Opc == AMDGPU::V_PACK_B32_F16_e64 || Opc == AMDGPU::V_PERM_B32_e64;
1345 };
1346
1347 auto isCvt = [](unsigned Opc) {
1348 return Opc == AMDGPU::V_CVT_F16_F32_e32 || Opc == AMDGPU::V_CVT_I32_F32_e32;
1349 };
1350
1351 auto isAdd = [](unsigned Opc) { return Opc == AMDGPU::V_ADD_F32_e32; };
1352
1353 AddPipeCount = 0;
1354 for (SUnit &SU : DAG->SUnits) {
1355 auto Opc = SU.getInstr()->getOpcode();
1356 if (TII->isTRANS(Opcode: Opc)) {
1357 // Avoid counting a potential bonus V_EXP which all the MFMA depend on
1358 if (SU.Succs.size() >= 7)
1359 continue;
1360 for (auto &Succ : SU.Succs) {
1361 if (Succ.getSUnit()->Succs.size() >= 7)
1362 continue;
1363 }
1364 ExpPipeCands.push_back(Elt: &SU);
1365 }
1366
1367 if (TII->isMFMAorWMMA(MI: *SU.getInstr()))
1368 MFMAPipeCands.push_back(Elt: &SU);
1369
1370 if (isBitPack(Opc))
1371 PackSUs.push_back(Elt: &SU);
1372
1373 if (isCvt(Opc))
1374 CvtSUs.push_back(Elt: &SU);
1375
1376 if (isAdd(Opc))
1377 ++AddPipeCount;
1378 }
1379
1380 if (!(PackSUs.size() && MFMAPipeCands.size() && ExpPipeCands.size()))
1381 return false;
1382
1383 TransPipeCount = 0;
1384
1385 std::optional<SUnit *> TempMFMA;
1386 std::optional<SUnit *> TempExp;
1387 // Count the number of EXPs that reach an MFMA
1388 for (auto &PredSU : ExpPipeCands) {
1389 for (auto &SuccSU : MFMAPipeCands) {
1390 if (DAG->IsReachable(SU: SuccSU, TargetSU: PredSU)) {
1391 if (!TempExp) {
1392 TempExp = PredSU;
1393 TempMFMA = SuccSU;
1394 }
1395 MFMAPipeSUs.push_back(Elt: SuccSU);
1396 ++TransPipeCount;
1397 break;
1398 }
1399 }
1400 }
1401
1402 if (!(TempExp && TempMFMA))
1403 return false;
1404
1405 HasChainBetweenCvt = none_of(Range&: (*TempExp)->Succs, P: [&isCvt](SDep &Succ) {
1406 return isCvt(Succ.getSUnit()->getInstr()->getOpcode());
1407 });
1408
1409 // Count the number of MFMAs that are reached by an EXP
1410 for (auto &SuccSU : MFMAPipeCands) {
1411 if (MFMAPipeSUs.size() &&
1412 any_of(Range&: MFMAPipeSUs, P: [&SuccSU](SUnit *PotentialMatch) {
1413 return PotentialMatch->NodeNum == SuccSU->NodeNum;
1414 }))
1415 continue;
1416
1417 for (auto &PredSU : ExpPipeCands) {
1418 if (DAG->IsReachable(SU: SuccSU, TargetSU: PredSU)) {
1419 MFMAPipeSUs.push_back(Elt: SuccSU);
1420 break;
1421 }
1422 }
1423 }
1424
1425 MFMAPipeCount = MFMAPipeSUs.size();
1426
1427 assert(TempExp && TempMFMA);
1428 assert(MFMAPipeCount > 0);
1429
1430 std::optional<SUnit *> TempCvt;
1431 for (auto &SuccSU : CvtSUs) {
1432 if (DAG->IsReachable(SU: SuccSU, TargetSU: *TempExp)) {
1433 TempCvt = SuccSU;
1434 break;
1435 }
1436 }
1437
1438 HasCvt = false;
1439 if (TempCvt.has_value()) {
1440 for (auto &SuccSU : MFMAPipeSUs) {
1441 if (DAG->IsReachable(SU: SuccSU, TargetSU: *TempCvt)) {
1442 HasCvt = true;
1443 break;
1444 }
1445 }
1446 }
1447
1448 MFMAChains = 0;
1449 for (auto &MFMAPipeSU : MFMAPipeSUs) {
1450 if (is_contained(Range&: MFMAChainSeeds, Element: MFMAPipeSU))
1451 continue;
1452 if (none_of(Range&: MFMAPipeSU->Preds, P: [&TII](SDep &Succ) {
1453 return TII->isMFMAorWMMA(MI: *Succ.getSUnit()->getInstr());
1454 })) {
1455 MFMAChainSeeds.push_back(Elt: MFMAPipeSU);
1456 ++MFMAChains;
1457 }
1458 }
1459
1460 if (!MFMAChains)
1461 return false;
1462
1463 for (auto Pred : MFMAChainSeeds[0]->Preds) {
1464 if (TII->isDS(Opcode: Pred.getSUnit()->getInstr()->getOpcode()) &&
1465 Pred.getSUnit()->getInstr()->mayLoad())
1466 FirstPipeDSR = Pred.getSUnit()->NodeNum;
1467 }
1468
1469 MFMAChainLength = MFMAPipeCount / MFMAChains;
1470
1471 // The number of bit pack operations that depend on a single V_EXP
1472 unsigned PackSuccCount =
1473 llvm::count_if(Range&: PackSUs, P: [this, &TempExp](SUnit *VPack) {
1474 return DAG->IsReachable(SU: VPack, TargetSU: *TempExp);
1475 });
1476
1477 // The number of bit pack operations an MFMA depends on
1478 unsigned PackPredCount =
1479 llvm::count_if(Range&: (*TempMFMA)->Preds, P: [&isBitPack](SDep &Pred) {
1480 auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1481 return isBitPack(Opc);
1482 });
1483
1484 auto *PackPred = llvm::find_if(Range&: (*TempMFMA)->Preds, P: [&isBitPack](SDep &Pred) {
1485 auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1486 return isBitPack(Opc);
1487 });
1488
1489 if (PackPred == (*TempMFMA)->Preds.end())
1490 return false;
1491
1492 MFMAEnablement = 0;
1493 ExpRequirement = 0;
1494 // How many MFMAs depend on a single bit pack operation
1495 MFMAEnablement =
1496 llvm::count_if(Range&: PackPred->getSUnit()->Succs, P: [&TII](SDep &Succ) {
1497 return TII->isMFMAorWMMA(MI: *Succ.getSUnit()->getInstr());
1498 });
1499
1500 // The number of MFMAs that depend on a single V_EXP
1501 MFMAEnablement *= PackSuccCount;
1502
1503 // The number of V_EXPs required to resolve all dependencies for an MFMA
1504 ExpRequirement =
1505 llvm::count_if(Range&: ExpPipeCands, P: [this, &PackPred](SUnit *ExpBase) {
1506 return DAG->IsReachable(SU: PackPred->getSUnit(), TargetSU: ExpBase);
1507 });
1508
1509 ExpRequirement *= PackPredCount;
1510 return true;
1511}
1512
1513bool MFMAExpInterleaveOpt::shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1514 AMDGPU::SchedulingPhase Phase) {
1515 const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1516 const SIInstrInfo *TII = ST.getInstrInfo();
1517
1518 if (Phase != AMDGPU::SchedulingPhase::PostRA)
1519 MFMAChainSeeds.clear();
1520 if (Phase != AMDGPU::SchedulingPhase::PostRA && !analyzeDAG(TII))
1521 return false;
1522
1523 return true;
1524}
1525
1526bool MFMAExpInterleaveOpt::applyIGLPStrategy(
1527 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1528 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1529 AMDGPU::SchedulingPhase Phase) {
1530
1531 bool IsSmallKernelType =
1532 MFMAEnablement == 2 && ExpRequirement == 4 && TransPipeCount == 32;
1533 bool IsLargeKernelType =
1534 MFMAEnablement == 4 && ExpRequirement == 4 && TransPipeCount == 64;
1535
1536 if (!(IsSmallKernelType || IsLargeKernelType))
1537 return false;
1538
1539 const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1540 const SIInstrInfo *TII = ST.getInstrInfo();
1541
1542 unsigned PipelineSyncID = 0;
1543 SchedGroup *SG = nullptr;
1544
1545 unsigned MFMAChain = 0;
1546 unsigned PositionInChain = 0;
1547 unsigned CurrMFMAForTransPosition = 0;
1548
1549 auto incrementTransPosition = [&MFMAChain, &PositionInChain,
1550 &CurrMFMAForTransPosition]() {
1551 CurrMFMAForTransPosition += MFMAEnablement;
1552 PositionInChain = (CurrMFMAForTransPosition / MFMAChains);
1553 MFMAChain = CurrMFMAForTransPosition % MFMAChains;
1554 };
1555
1556 auto getNextTransPositionInChain = [&CurrMFMAForTransPosition]() {
1557 auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1558 return (TempMFMAForTrans / MFMAChains);
1559 };
1560
1561 auto getNextTransMFMAChain = [&CurrMFMAForTransPosition]() {
1562 auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1563 return TempMFMAForTrans % MFMAChains;
1564 };
1565
1566 unsigned CurrMFMAPosition = 0;
1567 unsigned MFMAChainForMFMA = 0;
1568 unsigned PositionInChainForMFMA = 0;
1569
1570 auto incrementMFMAPosition = [&CurrMFMAPosition, &MFMAChainForMFMA,
1571 &PositionInChainForMFMA]() {
1572 ++CurrMFMAPosition;
1573 MFMAChainForMFMA = CurrMFMAPosition % MFMAChains;
1574 PositionInChainForMFMA = CurrMFMAPosition / MFMAChains;
1575 };
1576
1577 bool IsPostRA = Phase == AMDGPU::SchedulingPhase::PostRA;
1578 assert(IsPostRA || MFMAChainSeeds.size() == MFMAChains);
1579
1580 bool UsesFMA = IsSmallKernelType || !IsPostRA;
1581 bool UsesDSRead = IsLargeKernelType && !IsPostRA && FirstPipeDSR;
1582 bool UsesCvt = HasCvt && (IsSmallKernelType || !IsPostRA);
1583 bool UsesVALU = IsSmallKernelType;
1584
1585 // PHASE 1: "Prefetch"
1586 if (UsesFMA) {
1587 // First Round FMA
1588 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1589 Args: SchedGroupMask::VALU, Args&: ExpRequirement, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1590 if (!IsPostRA && MFMAChains) {
1591 SG->addRule(NewRule: std::make_shared<EnablesNthMFMAInChain>(
1592 args&: PositionInChain, args&: MFMAChainSeeds[MFMAChain], args&: TII, args: SG->getSGID(),
1593 args: true));
1594 } else
1595 SG->addRule(
1596 NewRule: std::make_shared<EnablesNthMFMA>(args: 1, args&: TII, args: SG->getSGID(), args: true));
1597 SG->addRule(NewRule: std::make_shared<IsFMA>(args&: TII, args: SG->getSGID()));
1598 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1599
1600 // Second Round FMA
1601 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1602 Args: SchedGroupMask::VALU, Args&: ExpRequirement, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1603 if (!IsPostRA && MFMAChains) {
1604 SG->addRule(NewRule: std::make_shared<EnablesNthMFMAInChain>(
1605 args: getNextTransPositionInChain(),
1606 args&: MFMAChainSeeds[getNextTransMFMAChain()], args&: TII, args: SG->getSGID(), args: true));
1607 } else
1608 SG->addRule(NewRule: std::make_shared<EnablesNthMFMA>(args: MFMAEnablement + 1, args&: TII,
1609 args: SG->getSGID(), args: true));
1610 SG->addRule(NewRule: std::make_shared<IsFMA>(args&: TII, args: SG->getSGID()));
1611 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1612 }
1613
1614 if (UsesDSRead) {
1615 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1616 Args: SchedGroupMask::DS_READ, Args: 2, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1617 SG->addRule(NewRule: std::make_shared<OccursAtOrAfterNode>(args&: *FirstPipeDSR, args&: TII,
1618 args: SG->getSGID()));
1619 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1620 }
1621
1622 // First Round EXP
1623 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1624 Args: SchedGroupMask::TRANS, Args&: ExpRequirement, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1625 if (!IsPostRA && MFMAChains)
1626 SG->addRule(NewRule: std::make_shared<EnablesNthMFMAInChain>(
1627 args&: PositionInChain, args&: MFMAChainSeeds[MFMAChain], args&: TII, args: SG->getSGID(), args: true));
1628 else
1629 SG->addRule(NewRule: std::make_shared<EnablesNthMFMA>(args: 1, args&: TII, args: SG->getSGID(), args: true));
1630 SG->addRule(NewRule: std::make_shared<IsPipeExp>(args&: TII, args: SG->getSGID(), args: true));
1631 SG->addRule(NewRule: std::make_shared<LessThanNSuccs>(args: 8, args&: TII, args: SG->getSGID(),
1632 args&: HasChainBetweenCvt));
1633 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1634
1635 incrementTransPosition();
1636
1637 // First Round CVT, Third Round FMA, Second Round EXP; interleaved
1638 for (unsigned I = 0; I < ExpRequirement; I++) {
1639 // First Round CVT
1640 if (UsesCvt) {
1641 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1642 Args: SchedGroupMask::VALU, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1643 SG->addRule(NewRule: std::make_shared<IsCvt>(args&: TII, args: SG->getSGID()));
1644 if (HasChainBetweenCvt)
1645 SG->addRule(NewRule: std::make_shared<IsReachableFromPrevNthGroup>(
1646 args: 1 + (2 + UsesFMA) * I, args&: TII, args: SG->getSGID()));
1647 else
1648 SG->addRule(NewRule: std::make_shared<IsSuccOfPrevNthGroup>(
1649 args: 1 + (2 + UsesFMA) * I, args&: TII, args: SG->getSGID()));
1650 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1651 }
1652
1653 // Third Round FMA
1654 if (UsesFMA) {
1655 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1656 Args: SchedGroupMask::VALU, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1657 if (!IsPostRA && MFMAChains) {
1658 SG->addRule(NewRule: std::make_shared<EnablesNthMFMAInChain>(
1659 args: getNextTransPositionInChain(),
1660 args&: MFMAChainSeeds[getNextTransMFMAChain()], args&: TII, args: SG->getSGID(), args: true));
1661 } else
1662 SG->addRule(NewRule: std::make_shared<EnablesNthMFMA>(args: 2 * MFMAEnablement + 1,
1663 args&: TII, args: SG->getSGID(), args: true));
1664 SG->addRule(NewRule: std::make_shared<IsFMA>(args&: TII, args: SG->getSGID()));
1665 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1666 }
1667
1668 // Second Round EXP
1669 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1670 Args: SchedGroupMask::TRANS, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1671 if (!IsPostRA && MFMAChains)
1672 SG->addRule(NewRule: std::make_shared<EnablesNthMFMAInChain>(
1673 args&: PositionInChain, args&: MFMAChainSeeds[MFMAChain], args&: TII, args: SG->getSGID(),
1674 args: true));
1675 else
1676 SG->addRule(NewRule: std::make_shared<EnablesNthMFMA>(args: MFMAEnablement + 1, args&: TII,
1677 args: SG->getSGID(), args: true));
1678 SG->addRule(NewRule: std::make_shared<IsPipeExp>(args&: TII, args: SG->getSGID(), args: true));
1679 SG->addRule(NewRule: std::make_shared<LessThanNSuccs>(args: 8, args&: TII, args: SG->getSGID(),
1680 args&: HasChainBetweenCvt));
1681 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1682 }
1683
1684 // The "extra" EXP which enables all MFMA
1685 // TODO: UsesExtraExp
1686 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1687 Args: SchedGroupMask::TRANS, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1688 SG->addRule(NewRule: std::make_shared<IsPipeExp>(args&: TII, args: SG->getSGID(), args: true));
1689 SG->addRule(NewRule: std::make_shared<GreaterThanOrEqualToNSuccs>(
1690 args: 8, args&: TII, args: SG->getSGID(), args&: HasChainBetweenCvt));
1691 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1692
1693 // PHASE 2: Main Interleave Loop
1694
1695 // The number of MFMAs per iteration
1696 unsigned MFMARatio =
1697 MFMAEnablement > ExpRequirement ? MFMAEnablement / ExpRequirement : 1;
1698 // The number of Exps per iteration
1699 unsigned ExpRatio =
1700 MFMAEnablement > ExpRequirement ? 1 : ExpRequirement / MFMAEnablement;
1701 // The reamaining Exps
1702 unsigned RemainingExp = TransPipeCount > (2 * ExpRequirement)
1703 ? TransPipeCount - (2 * ExpRequirement)
1704 : 0;
1705 unsigned ExpLoopCount = RemainingExp / ExpRatio;
1706 // In loop MFMAs
1707 unsigned MFMAInLoop = MFMAPipeCount > (MFMAEnablement * 2)
1708 ? MFMAPipeCount - (MFMAEnablement * 2)
1709 : 0;
1710 unsigned MFMALoopCount = MFMAInLoop / MFMARatio;
1711 unsigned VALUOps =
1712 AddPipeCount < MFMAPipeCount ? 1 : AddPipeCount / MFMAPipeCount;
1713 unsigned LoopSize = std::min(a: ExpLoopCount, b: MFMALoopCount);
1714
1715 for (unsigned I = 0; I < LoopSize; I++) {
1716 if (!(I * ExpRatio % ExpRequirement))
1717 incrementTransPosition();
1718
1719 // Round N MFMA
1720 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1721 Args: SchedGroupMask::MFMA, Args&: MFMARatio, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1722 if (!IsPostRA && MFMAChains)
1723 SG->addRule(NewRule: std::make_shared<IsExactMFMA>(
1724 args&: PositionInChainForMFMA, args&: MFMAChainSeeds[MFMAChainForMFMA], args&: TII,
1725 args: SG->getSGID(), args: true));
1726 else
1727 SG->addRule(NewRule: std::make_shared<OccursAfterExp>(args&: TII, args: SG->getSGID(), args: true));
1728 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1729 incrementMFMAPosition();
1730
1731 if (UsesVALU) {
1732 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1733 Args: SchedGroupMask::VALU, Args&: VALUOps, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1734 SG->addRule(NewRule: std::make_shared<IsPipeAdd>(args&: TII, args: SG->getSGID()));
1735 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1736 }
1737
1738 if (UsesDSRead && !(I % 4)) {
1739 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1740 Args: SchedGroupMask::DS_READ, Args: 2, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1741 SG->addRule(NewRule: std::make_shared<OccursAtOrAfterNode>(args&: *FirstPipeDSR, args&: TII,
1742 args: SG->getSGID()));
1743 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1744 }
1745
1746 // CVT, EXP, FMA Interleaving
1747 for (unsigned J = 0; J < ExpRatio; J++) {
1748 auto MFMAOffset = (1 + UsesVALU) * MFMARatio * (I + 1);
1749 auto MaxMFMAOffset =
1750 (1 + UsesVALU) * ExpRequirement * MFMARatio / ExpRatio;
1751
1752 // Round N + 1 CVT
1753 if (UsesCvt) {
1754 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1755 Args: SchedGroupMask::VALU, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1756 SG->addRule(NewRule: std::make_shared<IsCvt>(args&: TII, args: SG->getSGID()));
1757 auto BaseDiff = (2 + UsesFMA) * (ExpRequirement - 1) + 1;
1758 auto DSROffset = I / 4 + 1;
1759 auto MaxDSROffset = MaxMFMAOffset / 4;
1760 // TODO: UsesExtraExp
1761 auto ExpOffset = I * ExpRatio + J >= ExpRequirement ? 0 : 1;
1762 auto CurrentOffset = UsesDSRead * std::min(a: MaxDSROffset, b: DSROffset) +
1763 std::min(a: MaxMFMAOffset, b: MFMAOffset) + BaseDiff +
1764 ExpOffset;
1765 if (HasChainBetweenCvt)
1766 SG->addRule(NewRule: std::make_shared<IsReachableFromPrevNthGroup>(
1767 args&: CurrentOffset, args&: TII, args: SG->getSGID()));
1768 else
1769 SG->addRule(NewRule: std::make_shared<IsSuccOfPrevNthGroup>(args&: CurrentOffset, args&: TII,
1770 args: SG->getSGID()));
1771 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1772 }
1773
1774 // Round N + 3 FMA
1775 if (UsesFMA) {
1776 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1777 Args: SchedGroupMask::VALU, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1778 if (!IsPostRA && MFMAChains)
1779 SG->addRule(NewRule: std::make_shared<EnablesNthMFMAInChain>(
1780 args: getNextTransPositionInChain(),
1781 args&: MFMAChainSeeds[getNextTransMFMAChain()], args&: TII, args: SG->getSGID(),
1782 args: true));
1783 else
1784 SG->addRule(NewRule: std::make_shared<EnablesNthMFMA>(
1785 args: (((I * ExpRatio + J) / ExpRequirement) + 3) * MFMAEnablement + 1,
1786 args&: TII, args: SG->getSGID(), args: true));
1787 SG->addRule(NewRule: std::make_shared<IsFMA>(args&: TII, args: SG->getSGID()));
1788 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1789 }
1790
1791 // Round N + 2 Exp
1792 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1793 Args: SchedGroupMask::TRANS, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1794 if (!IsPostRA && MFMAChains)
1795 SG->addRule(NewRule: std::make_shared<EnablesNthMFMAInChain>(
1796 args&: PositionInChain, args&: MFMAChainSeeds[MFMAChain], args&: TII, args: SG->getSGID(),
1797 args: true));
1798 else
1799 SG->addRule(NewRule: std::make_shared<EnablesNthMFMA>(
1800 args: (((I * ExpRatio + J) / ExpRequirement) + 2) * MFMAEnablement + 1,
1801 args&: TII, args: SG->getSGID(), args: true));
1802 SG->addRule(NewRule: std::make_shared<IsPipeExp>(args&: TII, args: SG->getSGID(), args: true));
1803 SG->addRule(NewRule: std::make_shared<LessThanNSuccs>(args: 8, args&: TII, args: SG->getSGID(),
1804 args&: HasChainBetweenCvt));
1805 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1806 }
1807 }
1808
1809 // PHASE 3: Remaining MFMAs
1810 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1811 Args: SchedGroupMask::MFMA, Args: MFMAEnablement * 2, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
1812 SG->addRule(NewRule: std::make_shared<OccursAfterExp>(args&: TII, args: SG->getSGID(), args: true));
1813 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1814 return true;
1815}
1816
1817class MFMAExpSimpleInterleaveOpt final : public IGLPStrategy {
1818public:
1819 bool applyIGLPStrategy(
1820 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1821 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1822 AMDGPU::SchedulingPhase Phase) override;
1823
1824 bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1825 AMDGPU::SchedulingPhase Phase) override {
1826 return true;
1827 }
1828
1829 MFMAExpSimpleInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1830 : IGLPStrategy(DAG, TII) {
1831 IsBottomUp = true;
1832 }
1833};
1834
1835bool MFMAExpSimpleInterleaveOpt::applyIGLPStrategy(
1836 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1837 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1838 AMDGPU::SchedulingPhase Phase) {
1839 // Count the number of MFMA instructions.
1840 unsigned MFMACount = 0;
1841 for (const MachineInstr &I : *DAG)
1842 if (TII->isMFMAorWMMA(MI: I))
1843 ++MFMACount;
1844
1845 const unsigned PipelineSyncID = 0;
1846 for (unsigned I = 0; I < MFMACount * 3; ++I) {
1847 SchedGroup *SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1848 Args: SchedGroupMask::TRANS, Args: 1, Args: PipelineSyncID, Args&: DAG, Args&: TII);
1849 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1850
1851 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1852 Args: SchedGroupMask::MFMA, Args: 1, Args: PipelineSyncID, Args&: DAG, Args&: TII);
1853 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
1854 }
1855
1856 return true;
1857}
1858
1859class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1860private:
1861 // Whether the DS_READ is a predecessor of first four MFMA in region
1862 class EnablesInitialMFMA final : public InstructionRule {
1863 public:
1864 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1865 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1866 if (!SyncPipe.size())
1867 return false;
1868 int MFMAsFound = 0;
1869 if (!Cache->size()) {
1870 for (auto &Elt : SyncPipe[0].DAG->SUnits) {
1871 if (TII->isMFMAorWMMA(MI: *Elt.getInstr())) {
1872 ++MFMAsFound;
1873 if (MFMAsFound > 4)
1874 break;
1875 Cache->push_back(Elt: &Elt);
1876 }
1877 }
1878 }
1879
1880 auto *DAG = SyncPipe[0].DAG;
1881 for (auto &Elt : *Cache) {
1882 if (DAG->IsReachable(SU: Elt, TargetSU: const_cast<SUnit *>(SU)))
1883 return true;
1884 }
1885 return false;
1886 }
1887
1888 EnablesInitialMFMA(const SIInstrInfo *TII, unsigned SGID,
1889 bool NeedsCache = false)
1890 : InstructionRule(TII, SGID, NeedsCache) {}
1891 };
1892
1893 // Whether the MI is a V_PERM and is a predecessor of a common DS_WRITE
1894 class IsPermForDSW final : public InstructionRule {
1895 public:
1896 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1897 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1898 auto *MI = SU->getInstr();
1899 if (MI->getOpcode() != AMDGPU::V_PERM_B32_e64)
1900 return false;
1901
1902 bool FitsInGroup = false;
1903 // Does the VALU have a DS_WRITE successor
1904 if (!Collection.size()) {
1905 for (auto &Succ : SU->Succs) {
1906 SUnit *SuccUnit = Succ.getSUnit();
1907 if (TII->isDS(MI: *SuccUnit->getInstr()) &&
1908 SuccUnit->getInstr()->mayStore()) {
1909 Cache->push_back(Elt: SuccUnit);
1910 FitsInGroup = true;
1911 }
1912 }
1913 return FitsInGroup;
1914 }
1915
1916 // Does the VALU have a DS_WRITE successor that is the same as other
1917 // VALU already in the group. The V_PERMs will all share 1 DS_W succ
1918 return llvm::any_of(Range&: *Cache, P: [&SU](SUnit *Elt) {
1919 return llvm::any_of(Range: SU->Succs, P: [&Elt](const SDep &ThisSucc) {
1920 return ThisSucc.getSUnit() == Elt;
1921 });
1922 });
1923 }
1924
1925 IsPermForDSW(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1926 : InstructionRule(TII, SGID, NeedsCache) {}
1927 };
1928
1929 // Whether the SU is a successor of any element in previous SchedGroup
1930 class IsSuccOfPrevGroup final : public InstructionRule {
1931 public:
1932 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1933 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1934 SchedGroup *OtherGroup = nullptr;
1935 for (auto &PipeSG : SyncPipe) {
1936 if ((unsigned)PipeSG.getSGID() == SGID - 1) {
1937 OtherGroup = &PipeSG;
1938 }
1939 }
1940
1941 if (!OtherGroup)
1942 return false;
1943 if (!OtherGroup->Collection.size())
1944 return true;
1945
1946 // Does the previous VALU have this DS_Write as a successor
1947 return any_of(Range&: OtherGroup->Collection, P: [&SU](SUnit *Elt) {
1948 return any_of(Range&: Elt->Succs,
1949 P: [&SU](SDep &Succ) { return Succ.getSUnit() == SU; });
1950 });
1951 }
1952 IsSuccOfPrevGroup(const SIInstrInfo *TII, unsigned SGID,
1953 bool NeedsCache = false)
1954 : InstructionRule(TII, SGID, NeedsCache) {}
1955 };
1956
1957 // Whether the combined load width of group is 128 bits
1958 class VMEMSize final : public InstructionRule {
1959 public:
1960 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1961 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1962 auto *MI = SU->getInstr();
1963 if (MI->getOpcode() == TargetOpcode::BUNDLE)
1964 return false;
1965 if (!Collection.size())
1966 return true;
1967
1968 int NumBits = 0;
1969
1970 auto TRI = TII->getRegisterInfo();
1971 auto &MRI = MI->getParent()->getParent()->getRegInfo();
1972 for (auto &Elt : Collection) {
1973 auto Op = Elt->getInstr()->getOperand(i: 0);
1974 auto Size =
1975 TRI.getRegSizeInBits(RC: *TRI.getRegClassForOperandReg(MRI, MO: Op));
1976 NumBits += Size;
1977 }
1978
1979 if (NumBits < 128) {
1980 assert(TII->isVMEM(*MI) && MI->mayLoad());
1981 if (NumBits + TRI.getRegSizeInBits(RC: *TRI.getRegClassForOperandReg(
1982 MRI, MO: MI->getOperand(i: 0))) <=
1983 128)
1984 return true;
1985 }
1986
1987 return false;
1988 }
1989
1990 VMEMSize(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1991 : InstructionRule(TII, SGID, NeedsCache) {}
1992 };
1993
1994 /// Whether the SU shares a V_PERM predecessor with any SU in the SchedGroup
1995 /// that is \p Distance steps away
1996 class SharesPredWithPrevNthGroup final : public InstructionRule {
1997 private:
1998 unsigned Distance = 1;
1999
2000 public:
2001 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
2002 SmallVectorImpl<SchedGroup> &SyncPipe) override {
2003 SchedGroup *OtherGroup = nullptr;
2004 if (!SyncPipe.size())
2005 return false;
2006
2007 if (!Cache->size()) {
2008
2009 for (auto &PipeSG : SyncPipe) {
2010 if ((unsigned)PipeSG.getSGID() == SGID - Distance) {
2011 OtherGroup = &PipeSG;
2012 }
2013 }
2014
2015 if (!OtherGroup)
2016 return false;
2017 if (!OtherGroup->Collection.size())
2018 return true;
2019
2020 for (auto &OtherEle : OtherGroup->Collection) {
2021 for (auto &Pred : OtherEle->Preds) {
2022 if (Pred.getSUnit()->getInstr()->getOpcode() ==
2023 AMDGPU::V_PERM_B32_e64)
2024 Cache->push_back(Elt: Pred.getSUnit());
2025 }
2026 }
2027
2028 // If the other group has no PERM preds, then this group won't share any
2029 if (!Cache->size())
2030 return false;
2031 }
2032
2033 auto *DAG = SyncPipe[0].DAG;
2034 // Does the previous DS_WRITE share a V_PERM predecessor with this
2035 // VMEM_READ
2036 return llvm::any_of(Range&: *Cache, P: [&SU, &DAG](SUnit *Elt) {
2037 return DAG->IsReachable(SU: const_cast<SUnit *>(SU), TargetSU: Elt);
2038 });
2039 }
2040 SharesPredWithPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
2041 unsigned SGID, bool NeedsCache = false)
2042 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
2043 };
2044
2045public:
2046 bool applyIGLPStrategy(
2047 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2048 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2049 AMDGPU::SchedulingPhase Phase) override;
2050
2051 bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
2052 AMDGPU::SchedulingPhase Phase) override {
2053 return true;
2054 }
2055
2056 MFMASmallGemmSingleWaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
2057 : IGLPStrategy(DAG, TII) {
2058 IsBottomUp = false;
2059 }
2060};
2061
2062static unsigned DSWCount = 0;
2063static unsigned DSWWithPermCount = 0;
2064static unsigned DSWWithSharedVMEMCount = 0;
2065
2066bool MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
2067 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2068 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2069 AMDGPU::SchedulingPhase Phase) {
2070 unsigned MFMACount = 0;
2071 unsigned DSRCount = 0;
2072
2073 bool IsInitial = Phase == AMDGPU::SchedulingPhase::Initial;
2074
2075 assert((!IsInitial || (DSWCount == 0 && DSWWithPermCount == 0 &&
2076 DSWWithSharedVMEMCount == 0)) &&
2077 "DSWCounters should be zero in pre-RA scheduling!");
2078 SmallVector<SUnit *, 6> DSWithPerms;
2079 for (auto &SU : DAG->SUnits) {
2080 auto *I = SU.getInstr();
2081 if (TII->isMFMAorWMMA(MI: *I))
2082 ++MFMACount;
2083 else if (TII->isDS(MI: *I)) {
2084 if (I->mayLoad())
2085 ++DSRCount;
2086 else if (I->mayStore() && IsInitial) {
2087 ++DSWCount;
2088 for (auto Pred : SU.Preds) {
2089 if (Pred.getSUnit()->getInstr()->getOpcode() ==
2090 AMDGPU::V_PERM_B32_e64) {
2091 DSWithPerms.push_back(Elt: &SU);
2092 break;
2093 }
2094 }
2095 }
2096 }
2097 }
2098
2099 if (IsInitial) {
2100 DSWWithPermCount = DSWithPerms.size();
2101 auto *I = DSWithPerms.begin();
2102 auto *E = DSWithPerms.end();
2103
2104 // Get the count of DS_WRITES with V_PERM predecessors which
2105 // have loop carried dependencies (WAR) on the same VMEM_READs.
2106 // We consider partial overlap as a miss -- in other words,
2107 // for a given DS_W, we only consider another DS_W as matching
2108 // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
2109 // for every V_PERM pred of this DS_W.
2110 DenseMap<MachineInstr *, SUnit *> VMEMLookup;
2111 SmallVector<SUnit *, 6> Counted;
2112 for (; I != E; I++) {
2113 SUnit *Cand = nullptr;
2114 bool MissedAny = false;
2115 for (auto &Pred : (*I)->Preds) {
2116 if (Pred.getSUnit()->getInstr()->getOpcode() != AMDGPU::V_PERM_B32_e64)
2117 continue;
2118
2119 if (Cand && llvm::is_contained(Range&: Counted, Element: Cand))
2120 break;
2121
2122 for (auto &Succ : Pred.getSUnit()->Succs) {
2123 auto *MI = Succ.getSUnit()->getInstr();
2124 if (!TII->isVMEM(MI: *MI) || !MI->mayLoad())
2125 continue;
2126
2127 if (MissedAny || !VMEMLookup.size()) {
2128 MissedAny = true;
2129 VMEMLookup[MI] = *I;
2130 continue;
2131 }
2132
2133 auto [It, Inserted] = VMEMLookup.try_emplace(Key: MI, Args&: *I);
2134 if (Inserted) {
2135 MissedAny = true;
2136 continue;
2137 }
2138
2139 Cand = It->second;
2140 if (llvm::is_contained(Range&: Counted, Element: Cand)) {
2141 MissedAny = true;
2142 break;
2143 }
2144 }
2145 }
2146 if (!MissedAny && Cand) {
2147 DSWWithSharedVMEMCount += 2;
2148 Counted.push_back(Elt: Cand);
2149 Counted.push_back(Elt: *I);
2150 }
2151 }
2152 }
2153
2154 assert(DSWWithSharedVMEMCount <= DSWWithPermCount);
2155 SchedGroup *SG;
2156 unsigned PipelineSyncID = 0;
2157 // For kernels with V_PERM, there are enough VALU to mix in between MFMAs
2158 if (DSWWithPermCount) {
2159 for (unsigned I = 0; I < MFMACount; I++) {
2160 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2161 Args: SchedGroupMask::MFMA, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2162 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2163
2164 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2165 Args: SchedGroupMask::VALU, Args: 2, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2166 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2167 }
2168 }
2169
2170 PipelineSyncID = 1;
2171 // Phase 1: Break up DS_READ and MFMA clusters.
2172 // First DS_READ to make ready initial MFMA, then interleave MFMA with DS_READ
2173 // prefetch
2174
2175 // Make ready initial MFMA
2176 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2177 Args: SchedGroupMask::DS_READ, Args: 4, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2178 SG->addRule(NewRule: std::make_shared<EnablesInitialMFMA>(args&: TII, args: SG->getSGID(), args: true));
2179 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2180
2181 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2182 Args: SchedGroupMask::MFMA, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2183 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2184
2185 // Interleave MFMA with DS_READ prefetch
2186 for (unsigned I = 0; I < DSRCount - 4; ++I) {
2187 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2188 Args: SchedGroupMask::DS_READ, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2189 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2190
2191 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2192 Args: SchedGroupMask::MFMA, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2193 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2194 }
2195
2196 // Phase 2a: Loop carried dependency with V_PERM
2197 // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2198 // depend on. Interleave MFMA to keep XDL unit busy throughout.
2199 for (unsigned I = 0; I < DSWWithPermCount - DSWWithSharedVMEMCount; ++I) {
2200 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2201 Args: SchedGroupMask::VALU, Args: 4, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2202 SG->addRule(NewRule: std::make_shared<IsPermForDSW>(args&: TII, args: SG->getSGID(), args: true));
2203 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2204
2205 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2206 Args: SchedGroupMask::DS_WRITE, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2207 SG->addRule(NewRule: std::make_shared<IsSuccOfPrevGroup>(args&: TII, args: SG->getSGID()));
2208 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2209
2210 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2211 Args: SchedGroupMask::VMEM_READ, Args: 4, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2212 SG->addRule(NewRule: std::make_shared<SharesPredWithPrevNthGroup>(
2213 args: 1, args&: TII, args: SG->getSGID(), args: true));
2214 SG->addRule(NewRule: std::make_shared<VMEMSize>(args&: TII, args: SG->getSGID()));
2215 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2216
2217 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2218 Args: SchedGroupMask::MFMA, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2219 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2220
2221 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2222 Args: SchedGroupMask::VMEM_READ, Args: 4, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2223 SG->addRule(NewRule: std::make_shared<SharesPredWithPrevNthGroup>(
2224 args: 3, args&: TII, args: SG->getSGID(), args: true));
2225 SG->addRule(NewRule: std::make_shared<VMEMSize>(args&: TII, args: SG->getSGID()));
2226 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2227
2228 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2229 Args: SchedGroupMask::MFMA, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2230 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2231 }
2232
2233 // Phase 2b: Loop carried dependency without V_PERM
2234 // Schedule DS_WRITE as closely as possible to the VMEM_READ they depend on.
2235 // Interleave MFMA to keep XDL unit busy throughout.
2236 for (unsigned I = 0; I < DSWCount - DSWWithPermCount; I++) {
2237 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2238 Args: SchedGroupMask::DS_WRITE, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2239 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2240
2241 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2242 Args: SchedGroupMask::VMEM_READ, Args: 4, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2243 SG->addRule(NewRule: std::make_shared<VMEMSize>(args&: TII, args: SG->getSGID()));
2244 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2245
2246 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2247 Args: SchedGroupMask::MFMA, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2248 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2249 }
2250
2251 // Phase 2c: Loop carried dependency with V_PERM, VMEM_READs are
2252 // ultimately used by two DS_WRITE
2253 // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2254 // depend on. Interleave MFMA to keep XDL unit busy throughout.
2255
2256 for (unsigned I = 0; I < DSWWithSharedVMEMCount; ++I) {
2257 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2258 Args: SchedGroupMask::VALU, Args: 4, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2259 SG->addRule(NewRule: std::make_shared<IsPermForDSW>(args&: TII, args: SG->getSGID(), args: true));
2260 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2261
2262 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2263 Args: SchedGroupMask::DS_WRITE, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2264 SG->addRule(NewRule: std::make_shared<IsSuccOfPrevGroup>(args&: TII, args: SG->getSGID()));
2265 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2266
2267 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2268 Args: SchedGroupMask::MFMA, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2269 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2270
2271 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2272 Args: SchedGroupMask::VALU, Args: 4, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2273 SG->addRule(NewRule: std::make_shared<IsPermForDSW>(args&: TII, args: SG->getSGID(), args: true));
2274 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2275
2276 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2277 Args: SchedGroupMask::DS_WRITE, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2278 SG->addRule(NewRule: std::make_shared<IsSuccOfPrevGroup>(args&: TII, args: SG->getSGID()));
2279 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2280
2281 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2282 Args: SchedGroupMask::MFMA, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2283 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2284
2285 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2286 Args: SchedGroupMask::VMEM_READ, Args: 4, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2287 SG->addRule(NewRule: std::make_shared<SharesPredWithPrevNthGroup>(
2288 args: 2, args&: TII, args: SG->getSGID(), args: true));
2289 SG->addRule(NewRule: std::make_shared<VMEMSize>(args&: TII, args: SG->getSGID()));
2290 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2291
2292 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2293 Args: SchedGroupMask::MFMA, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2294 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2295
2296 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2297 Args: SchedGroupMask::VMEM_READ, Args: 4, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2298 SG->addRule(NewRule: std::make_shared<SharesPredWithPrevNthGroup>(
2299 args: 4, args&: TII, args: SG->getSGID(), args: true));
2300 SG->addRule(NewRule: std::make_shared<VMEMSize>(args&: TII, args: SG->getSGID()));
2301 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2302
2303 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2304 Args: SchedGroupMask::MFMA, Args: 1, Args&: PipelineSyncID, Args&: DAG, Args&: TII);
2305 SG->initSchedGroup(SyncedInstrs&: SyncedInstrs[SG->getSyncID()]);
2306 }
2307
2308 return true;
2309}
2310
2311static std::unique_ptr<IGLPStrategy>
2312createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
2313 const SIInstrInfo *TII) {
2314 switch (ID) {
2315 case MFMASmallGemmOptID:
2316 return std::make_unique<MFMASmallGemmOpt>(args&: DAG, args&: TII);
2317 case MFMASmallGemmSingleWaveOptID:
2318 return std::make_unique<MFMASmallGemmSingleWaveOpt>(args&: DAG, args&: TII);
2319 case MFMAExpInterleaveID:
2320 return std::make_unique<MFMAExpInterleaveOpt>(args&: DAG, args&: TII);
2321 case MFMAExpSimpleInterleaveID:
2322 return std::make_unique<MFMAExpSimpleInterleaveOpt>(args&: DAG, args&: TII);
2323 }
2324
2325 llvm_unreachable("Unknown IGLPStrategyID");
2326}
2327
2328class IGroupLPDAGMutation : public ScheduleDAGMutation {
2329private:
2330 const SIInstrInfo *TII;
2331
2332 ScheduleDAGMI *DAG;
2333
2334 // Organize lists of SchedGroups by their SyncID. SchedGroups /
2335 // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
2336 // between then.
2337 DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
2338
2339 // Used to track instructions that can be mapped to multiple sched groups
2340 DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
2341
2342 // Add DAG edges that enforce SCHED_BARRIER ordering.
2343 void addSchedBarrierEdges(SUnit &SU);
2344
2345 // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
2346 // not be reordered accross the SCHED_BARRIER. This is used for the base
2347 // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
2348 // SCHED_BARRIER will always block all instructions that can be classified
2349 // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
2350 // and may only synchronize with some SchedGroups. Returns the inverse of
2351 // Mask. SCHED_BARRIER's mask describes which instruction types should be
2352 // allowed to be scheduled across it. Invert the mask to get the
2353 // SchedGroupMask of instructions that should be barred.
2354 SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
2355
2356 // Create SchedGroups for a SCHED_GROUP_BARRIER.
2357 void initSchedGroupBarrierPipelineStage(
2358 std::vector<SUnit>::reverse_iterator RIter);
2359
2360 bool initIGLPOpt(SUnit &SU);
2361
2362public:
2363 void apply(ScheduleDAGInstrs *DAGInstrs) override;
2364
2365 // The order in which the PipelineSolver should process the candidate
2366 // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last
2367 // created SchedGroup first, and will consider that as the ultimate
2368 // predecessor group when linking. TOP_DOWN instead links and processes the
2369 // first created SchedGroup first.
2370 bool IsBottomUp = true;
2371
2372 // The scheduling phase this application of IGLP corresponds with.
2373 AMDGPU::SchedulingPhase Phase = AMDGPU::SchedulingPhase::Initial;
2374
2375 IGroupLPDAGMutation() = default;
2376 IGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) : Phase(Phase) {}
2377};
2378
2379unsigned SchedGroup::NumSchedGroups = 0;
2380
2381bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
2382 if (A != B && DAG->canAddEdge(SuccSU: B, PredSU: A)) {
2383 DAG->addEdge(SuccSU: B, PredDep: SDep(A, SDep::Artificial));
2384 return true;
2385 }
2386 return false;
2387}
2388
2389bool SchedGroup::canAddMI(const MachineInstr &MI) const {
2390 bool Result = false;
2391 if (MI.isMetaInstruction())
2392 Result = false;
2393
2394 else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
2395 (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI) ||
2396 TII->isTRANS(MI)))
2397 Result = true;
2398
2399 else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
2400 TII->isVALU(MI) && !TII->isMFMAorWMMA(MI) && !TII->isTRANS(MI))
2401 Result = true;
2402
2403 else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
2404 TII->isSALU(MI))
2405 Result = true;
2406
2407 else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
2408 TII->isMFMAorWMMA(MI))
2409 Result = true;
2410
2411 else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
2412 TII->isVMEM(MI))
2413 Result = true;
2414
2415 else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
2416 MI.mayLoad() && TII->isVMEM(MI))
2417 Result = true;
2418
2419 else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
2420 MI.mayStore() && TII->isVMEM(MI))
2421 Result = true;
2422
2423 else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
2424 TII->isDS(MI))
2425 Result = true;
2426
2427 else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
2428 MI.mayLoad() && TII->isDS(MI))
2429 Result = true;
2430
2431 else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
2432 MI.mayStore() && TII->isDS(MI))
2433 Result = true;
2434
2435 else if (((SGMask & SchedGroupMask::TRANS) != SchedGroupMask::NONE) &&
2436 TII->isTRANS(MI))
2437 Result = true;
2438
2439 LLVM_DEBUG(
2440 dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
2441 << (Result ? " could classify " : " unable to classify ") << MI);
2442
2443 return Result;
2444}
2445
2446int SchedGroup::link(SUnit &SU, bool MakePred,
2447 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
2448 int MissedEdges = 0;
2449 for (auto *A : Collection) {
2450 SUnit *B = &SU;
2451 if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2452 continue;
2453 if (MakePred)
2454 std::swap(a&: A, b&: B);
2455
2456 if (DAG->IsReachable(SU: B, TargetSU: A))
2457 continue;
2458
2459 // tryAddEdge returns false if there is a dependency that makes adding
2460 // the A->B edge impossible, otherwise it returns true;
2461 bool Added = tryAddEdge(A, B);
2462 if (Added)
2463 AddedEdges.emplace_back(args&: A, args&: B);
2464 else
2465 ++MissedEdges;
2466 }
2467
2468 return MissedEdges;
2469}
2470
2471void SchedGroup::link(SUnit &SU, bool MakePred) {
2472 for (auto *A : Collection) {
2473 SUnit *B = &SU;
2474 if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2475 continue;
2476 if (MakePred)
2477 std::swap(a&: A, b&: B);
2478
2479 tryAddEdge(A, B);
2480 }
2481}
2482
2483void SchedGroup::link(SUnit &SU,
2484 function_ref<bool(const SUnit *A, const SUnit *B)> P) {
2485 for (auto *A : Collection) {
2486 SUnit *B = &SU;
2487 if (P(A, B))
2488 std::swap(a&: A, b&: B);
2489
2490 tryAddEdge(A, B);
2491 }
2492}
2493
2494void SchedGroup::link(SchedGroup &OtherGroup) {
2495 for (auto *B : OtherGroup.Collection)
2496 link(SU&: *B);
2497}
2498
2499bool SchedGroup::canAddSU(SUnit &SU) const {
2500 MachineInstr &MI = *SU.getInstr();
2501 if (MI.getOpcode() != TargetOpcode::BUNDLE)
2502 return canAddMI(MI);
2503
2504 // Special case for bundled MIs.
2505 const MachineBasicBlock *MBB = MI.getParent();
2506 MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
2507 while (E != MBB->end() && E->isBundledWithPred())
2508 ++E;
2509
2510 // Return true if all of the bundled MIs can be added to this group.
2511 return std::all_of(first: B, last: E, pred: [this](MachineInstr &MI) { return canAddMI(MI); });
2512}
2513
2514void SchedGroup::initSchedGroup() {
2515 for (auto &SU : DAG->SUnits) {
2516 if (isFull())
2517 break;
2518
2519 if (canAddSU(SU))
2520 add(SU);
2521 }
2522}
2523
2524void SchedGroup::initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
2525 SUnitsToCandidateSGsMap &SyncedInstrs) {
2526 SUnit &InitSU = *RIter;
2527 for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {
2528 auto &SU = *RIter;
2529 if (isFull())
2530 break;
2531
2532 if (canAddSU(SU))
2533 SyncedInstrs[&SU].push_back(Elt: SGID);
2534 }
2535
2536 add(SU&: InitSU);
2537 assert(MaxSize);
2538 (*MaxSize)++;
2539}
2540
2541void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {
2542 auto I = DAG->SUnits.rbegin();
2543 auto E = DAG->SUnits.rend();
2544 for (; I != E; ++I) {
2545 auto &SU = *I;
2546 if (isFull())
2547 break;
2548 if (canAddSU(SU))
2549 SyncedInstrs[&SU].push_back(Elt: SGID);
2550 }
2551}
2552
2553void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
2554 const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
2555 if (!TSchedModel || DAGInstrs->SUnits.empty())
2556 return;
2557
2558 LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
2559 const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
2560 TII = ST.getInstrInfo();
2561 DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
2562 SyncedSchedGroups.clear();
2563 SyncedInstrs.clear();
2564 bool FoundSB = false;
2565 bool FoundIGLP = false;
2566 bool ShouldApplyIGLP = false;
2567 for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
2568 unsigned Opc = R->getInstr()->getOpcode();
2569 // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
2570 if (Opc == AMDGPU::SCHED_BARRIER) {
2571 addSchedBarrierEdges(SU&: *R);
2572 FoundSB = true;
2573 } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
2574 initSchedGroupBarrierPipelineStage(RIter: R);
2575 FoundSB = true;
2576 } else if (Opc == AMDGPU::IGLP_OPT) {
2577 if (!FoundSB && !FoundIGLP) {
2578 FoundIGLP = true;
2579 ShouldApplyIGLP = initIGLPOpt(SU&: *R);
2580 }
2581 }
2582 }
2583
2584 if (FoundSB || (FoundIGLP && ShouldApplyIGLP)) {
2585 PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
2586 // PipelineSolver performs the mutation by adding the edges it
2587 // determined as the best
2588 PS.solve();
2589 return;
2590 }
2591}
2592
2593void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
2594 MachineInstr &MI = *SchedBarrier.getInstr();
2595 assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
2596 // Remove all existing edges from the SCHED_BARRIER that were added due to the
2597 // instruction having side effects.
2598 LLVM_DEBUG(dbgs() << "Building SchedGroup for SchedBarrier with Mask: "
2599 << MI.getOperand(0).getImm() << "\n");
2600 auto InvertedMask =
2601 invertSchedBarrierMask(Mask: (SchedGroupMask)MI.getOperand(i: 0).getImm());
2602 SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
2603 SG.initSchedGroup();
2604
2605 // Preserve original instruction ordering relative to the SCHED_BARRIER.
2606 SG.link(
2607 SU&: SchedBarrier,
2608 P: (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
2609 const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
2610}
2611
2612SchedGroupMask
2613IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
2614 // Invert mask and erase bits for types of instructions that are implied to be
2615 // allowed past the SCHED_BARRIER.
2616 SchedGroupMask InvertedMask = ~Mask;
2617
2618 // ALU implies VALU, SALU, MFMA, TRANS.
2619 if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
2620 InvertedMask &= ~SchedGroupMask::VALU & ~SchedGroupMask::SALU &
2621 ~SchedGroupMask::MFMA & ~SchedGroupMask::TRANS;
2622 // VALU, SALU, MFMA, TRANS implies ALU.
2623 else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
2624 (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
2625 (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE ||
2626 (InvertedMask & SchedGroupMask::TRANS) == SchedGroupMask::NONE)
2627 InvertedMask &= ~SchedGroupMask::ALU;
2628
2629 // VMEM implies VMEM_READ, VMEM_WRITE.
2630 if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
2631 InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;
2632 // VMEM_READ, VMEM_WRITE implies VMEM.
2633 else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
2634 (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)
2635 InvertedMask &= ~SchedGroupMask::VMEM;
2636
2637 // DS implies DS_READ, DS_WRITE.
2638 if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
2639 InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;
2640 // DS_READ, DS_WRITE implies DS.
2641 else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
2642 (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
2643 InvertedMask &= ~SchedGroupMask::DS;
2644
2645 LLVM_DEBUG(dbgs() << "After Inverting, SchedGroup Mask: " << (int)InvertedMask
2646 << "\n");
2647
2648 return InvertedMask;
2649}
2650
2651void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
2652 std::vector<SUnit>::reverse_iterator RIter) {
2653 // Remove all existing edges from the SCHED_GROUP_BARRIER that were added due
2654 // to the instruction having side effects.
2655 MachineInstr &SGB = *RIter->getInstr();
2656 assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
2657 int32_t SGMask = SGB.getOperand(i: 0).getImm();
2658 int32_t Size = SGB.getOperand(i: 1).getImm();
2659 int32_t SyncID = SGB.getOperand(i: 2).getImm();
2660
2661 auto &SG = SyncedSchedGroups[SyncID].emplace_back(Args: (SchedGroupMask)SGMask,
2662 Args&: Size, Args&: SyncID, Args&: DAG, Args&: TII);
2663
2664 SG.initSchedGroup(RIter, SyncedInstrs&: SyncedInstrs[SG.getSyncID()]);
2665}
2666
2667bool IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
2668 IGLPStrategyID StrategyID =
2669 (IGLPStrategyID)SU.getInstr()->getOperand(i: 0).getImm();
2670 auto S = createIGLPStrategy(ID: StrategyID, DAG, TII);
2671 if (!S->shouldApplyStrategy(DAG, Phase))
2672 return false;
2673
2674 IsBottomUp = S->IsBottomUp;
2675 return S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups, Phase);
2676}
2677
2678} // namespace
2679
2680/// \p Phase specifes whether or not this is a reentry into the
2681/// IGroupLPDAGMutation. Since there may be multiple scheduling passes on the
2682/// same scheduling region (e.g. pre and post-RA scheduling / multiple
2683/// scheduling "phases"), we can reenter this mutation framework more than once
2684/// for a given region.
2685std::unique_ptr<ScheduleDAGMutation>
2686llvm::createIGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) {
2687 return std::make_unique<IGroupLPDAGMutation>(args&: Phase);
2688}
2689