1//===- RISCVVectorMaskDAGMutation.cpp - RISC-V Vector Mask DAGMutation ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A schedule mutation that adds an artificial dependency between masks producer
10// instructions and masked instructions, so that we can reduce the live range
11// overlaps of mask registers.
12//
13// The reason why we need to do this:
14// 1. When tracking register pressure, we don't track physical registers.
15// 2. We have a RegisterClass for mask register (which is `VMV0`), but we don't
16// use it by the time we reach scheduling. Instead, we use physical
17// register V0 directly and insert a `$v0 = COPY ...` before the use.
18// 3. For mask producers, we are using VR RegisterClass (we can allocate V0-V31
19// to it). So if V0 is not available, there are still 31 available registers
20// out there.
21//
22// This means that the RegPressureTracker can't track the pressure of mask
23// registers correctly.
24//
25// This schedule mutation is a workaround to fix this issue.
26//
27//===----------------------------------------------------------------------===//
28
29#include "MCTargetDesc/RISCVBaseInfo.h"
30#include "MCTargetDesc/RISCVMCTargetDesc.h"
31#include "RISCVTargetMachine.h"
32#include "llvm/CodeGen/LiveIntervals.h"
33#include "llvm/CodeGen/MachineInstr.h"
34#include "llvm/CodeGen/ScheduleDAGInstrs.h"
35#include "llvm/CodeGen/ScheduleDAGMutation.h"
36#include "llvm/TargetParser/RISCVTargetParser.h"
37
38#define DEBUG_TYPE "machine-scheduler"
39
40namespace llvm {
41
42static bool isCopyToV0(const MachineInstr &MI) {
43 return MI.isCopy() && MI.getOperand(i: 0).getReg() == RISCV::V0 &&
44 MI.getOperand(i: 1).getReg().isVirtual() &&
45 MI.getOperand(i: 1).getSubReg() == RISCV::NoSubRegister;
46}
47
48static bool isSoleUseCopyToV0(SUnit &SU) {
49 if (SU.Succs.size() != 1)
50 return false;
51 SDep &Dep = SU.Succs[0];
52 // Ignore dependencies other than data or strong ordering.
53 if (Dep.isWeak())
54 return false;
55
56 SUnit &DepSU = *Dep.getSUnit();
57 if (DepSU.isBoundaryNode())
58 return false;
59 return isCopyToV0(MI: *DepSU.getInstr());
60}
61
62class RISCVVectorMaskDAGMutation : public ScheduleDAGMutation {
63private:
64 const TargetRegisterInfo *TRI;
65
66public:
67 RISCVVectorMaskDAGMutation(const TargetRegisterInfo *TRI) : TRI(TRI) {}
68
69 void apply(ScheduleDAGInstrs *DAG) override {
70 SUnit *NearestUseV0SU = nullptr;
71 for (SUnit &SU : DAG->SUnits) {
72 const MachineInstr *MI = SU.getInstr();
73 if (MI->findRegisterUseOperand(Reg: RISCV::V0, TRI))
74 NearestUseV0SU = &SU;
75
76 if (NearestUseV0SU && NearestUseV0SU != &SU && isSoleUseCopyToV0(SU) &&
77 // For LMUL=8 cases, there will be more possibilities to spill.
78 // FIXME: We should use RegPressureTracker to do fine-grained
79 // controls.
80 RISCVII::getLMul(TSFlags: MI->getDesc().TSFlags) != RISCVVType::LMUL_8)
81 DAG->addEdge(SuccSU: &SU, PredDep: SDep(NearestUseV0SU, SDep::Artificial));
82 }
83 }
84};
85
86std::unique_ptr<ScheduleDAGMutation>
87createRISCVVectorMaskDAGMutation(const TargetRegisterInfo *TRI) {
88 return std::make_unique<RISCVVectorMaskDAGMutation>(args&: TRI);
89}
90
91} // namespace llvm
92