1 | //===- RISCVVectorPeephole.cpp - MI Vector Pseudo Peepholes ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass performs various vector pseudo peephole optimisations after |
10 | // instruction selection. |
11 | // |
12 | // Currently it converts vmerge.vvm to vmv.v.v |
13 | // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew |
14 | // -> |
15 | // PseudoVMV_V_V %false, %true, %vl, %sew |
16 | // |
17 | // And masked pseudos to unmasked pseudos |
18 | // PseudoVADD_V_V_MASK %passthru, %a, %b, %allonesmask, %vl, sew, policy |
19 | // -> |
20 | // PseudoVADD_V_V %passthru %a, %b, %vl, sew, policy |
21 | // |
22 | // It also converts AVLs to VLMAX where possible |
23 | // %vl = VLENB * something |
24 | // PseudoVADD_V_V %passthru, %a, %b, %vl, sew, policy |
25 | // -> |
26 | // PseudoVADD_V_V %passthru, %a, %b, -1, sew, policy |
27 | // |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | #include "RISCV.h" |
31 | #include "RISCVISelDAGToDAG.h" |
32 | #include "RISCVSubtarget.h" |
33 | #include "llvm/CodeGen/MachineFunctionPass.h" |
34 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
35 | #include "llvm/CodeGen/TargetInstrInfo.h" |
36 | #include "llvm/CodeGen/TargetRegisterInfo.h" |
37 | |
38 | using namespace llvm; |
39 | |
40 | #define DEBUG_TYPE "riscv-vector-peephole" |
41 | |
42 | namespace { |
43 | |
44 | class RISCVVectorPeephole : public MachineFunctionPass { |
45 | public: |
46 | static char ID; |
47 | const TargetInstrInfo *TII; |
48 | MachineRegisterInfo *MRI; |
49 | const TargetRegisterInfo *TRI; |
50 | RISCVVectorPeephole() : MachineFunctionPass(ID) {} |
51 | |
52 | bool runOnMachineFunction(MachineFunction &MF) override; |
53 | MachineFunctionProperties getRequiredProperties() const override { |
54 | return MachineFunctionProperties().set( |
55 | MachineFunctionProperties::Property::IsSSA); |
56 | } |
57 | |
58 | StringRef getPassName() const override { return "RISC-V Fold Masks" ; } |
59 | |
60 | private: |
61 | bool convertToVLMAX(MachineInstr &MI) const; |
62 | bool convertToUnmasked(MachineInstr &MI) const; |
63 | bool convertVMergeToVMv(MachineInstr &MI) const; |
64 | |
65 | bool isAllOnesMask(const MachineInstr *MaskDef) const; |
66 | |
67 | /// Maps uses of V0 to the corresponding def of V0. |
68 | DenseMap<const MachineInstr *, const MachineInstr *> V0Defs; |
69 | }; |
70 | |
71 | } // namespace |
72 | |
73 | char RISCVVectorPeephole::ID = 0; |
74 | |
75 | INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks" , false, |
76 | false) |
77 | |
78 | // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it |
79 | // to the VLMAX sentinel value. |
80 | bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const { |
81 | if (!RISCVII::hasVLOp(TSFlags: MI.getDesc().TSFlags) || |
82 | !RISCVII::hasSEWOp(TSFlags: MI.getDesc().TSFlags)) |
83 | return false; |
84 | MachineOperand &VL = MI.getOperand(i: RISCVII::getVLOpNum(Desc: MI.getDesc())); |
85 | if (!VL.isReg()) |
86 | return false; |
87 | MachineInstr *Def = MRI->getVRegDef(Reg: VL.getReg()); |
88 | if (!Def) |
89 | return false; |
90 | |
91 | // Fixed-point value, denominator=8 |
92 | uint64_t ScaleFixed = 8; |
93 | // Check if the VLENB was potentially scaled with slli/srli |
94 | if (Def->getOpcode() == RISCV::SLLI) { |
95 | assert(Def->getOperand(2).getImm() < 64); |
96 | ScaleFixed <<= Def->getOperand(i: 2).getImm(); |
97 | Def = MRI->getVRegDef(Reg: Def->getOperand(i: 1).getReg()); |
98 | } else if (Def->getOpcode() == RISCV::SRLI) { |
99 | assert(Def->getOperand(2).getImm() < 64); |
100 | ScaleFixed >>= Def->getOperand(i: 2).getImm(); |
101 | Def = MRI->getVRegDef(Reg: Def->getOperand(i: 1).getReg()); |
102 | } |
103 | |
104 | if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB) |
105 | return false; |
106 | |
107 | auto LMUL = RISCVVType::decodeVLMUL(VLMUL: RISCVII::getLMul(TSFlags: MI.getDesc().TSFlags)); |
108 | // Fixed-point value, denominator=8 |
109 | unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first; |
110 | unsigned Log2SEW = MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MI.getDesc())).getImm(); |
111 | // A Log2SEW of 0 is an operation on mask registers only |
112 | unsigned SEW = Log2SEW ? 1 << Log2SEW : 8; |
113 | assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW" ); |
114 | assert(8 * LMULFixed / SEW > 0); |
115 | |
116 | // AVL = (VLENB * Scale) |
117 | // |
118 | // VLMAX = (VLENB * 8 * LMUL) / SEW |
119 | // |
120 | // AVL == VLMAX |
121 | // -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW |
122 | // -> Scale == (8 * LMUL) / SEW |
123 | if (ScaleFixed != 8 * LMULFixed / SEW) |
124 | return false; |
125 | |
126 | VL.ChangeToImmediate(ImmVal: RISCV::VLMaxSentinel); |
127 | |
128 | return true; |
129 | } |
130 | |
131 | bool RISCVVectorPeephole::isAllOnesMask(const MachineInstr *MaskDef) const { |
132 | assert(MaskDef && MaskDef->isCopy() && |
133 | MaskDef->getOperand(0).getReg() == RISCV::V0); |
134 | Register SrcReg = TRI->lookThruCopyLike(SrcReg: MaskDef->getOperand(i: 1).getReg(), MRI); |
135 | if (!SrcReg.isVirtual()) |
136 | return false; |
137 | MaskDef = MRI->getVRegDef(Reg: SrcReg); |
138 | if (!MaskDef) |
139 | return false; |
140 | |
141 | // TODO: Check that the VMSET is the expected bitwidth? The pseudo has |
142 | // undefined behaviour if it's the wrong bitwidth, so we could choose to |
143 | // assume that it's all-ones? Same applies to its VL. |
144 | switch (MaskDef->getOpcode()) { |
145 | case RISCV::PseudoVMSET_M_B1: |
146 | case RISCV::PseudoVMSET_M_B2: |
147 | case RISCV::PseudoVMSET_M_B4: |
148 | case RISCV::PseudoVMSET_M_B8: |
149 | case RISCV::PseudoVMSET_M_B16: |
150 | case RISCV::PseudoVMSET_M_B32: |
151 | case RISCV::PseudoVMSET_M_B64: |
152 | return true; |
153 | default: |
154 | return false; |
155 | } |
156 | } |
157 | |
158 | // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to |
159 | // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET. |
160 | bool RISCVVectorPeephole::convertVMergeToVMv(MachineInstr &MI) const { |
161 | #define CASE_VMERGE_TO_VMV(lmul) \ |
162 | case RISCV::PseudoVMERGE_VVM_##lmul: \ |
163 | NewOpc = RISCV::PseudoVMV_V_V_##lmul; \ |
164 | break; |
165 | unsigned NewOpc; |
166 | switch (MI.getOpcode()) { |
167 | default: |
168 | return false; |
169 | CASE_VMERGE_TO_VMV(MF8) |
170 | CASE_VMERGE_TO_VMV(MF4) |
171 | CASE_VMERGE_TO_VMV(MF2) |
172 | CASE_VMERGE_TO_VMV(M1) |
173 | CASE_VMERGE_TO_VMV(M2) |
174 | CASE_VMERGE_TO_VMV(M4) |
175 | CASE_VMERGE_TO_VMV(M8) |
176 | } |
177 | |
178 | Register MergeReg = MI.getOperand(i: 1).getReg(); |
179 | Register FalseReg = MI.getOperand(i: 2).getReg(); |
180 | // Check merge == false (or merge == undef) |
181 | if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(SrcReg: MergeReg, MRI) != |
182 | TRI->lookThruCopyLike(SrcReg: FalseReg, MRI)) |
183 | return false; |
184 | |
185 | assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0); |
186 | if (!isAllOnesMask(MaskDef: V0Defs.lookup(Val: &MI))) |
187 | return false; |
188 | |
189 | MI.setDesc(TII->get(Opcode: NewOpc)); |
190 | MI.removeOperand(OpNo: 1); // Merge operand |
191 | MI.tieOperands(DefIdx: 0, UseIdx: 1); // Tie false to dest |
192 | MI.removeOperand(OpNo: 3); // Mask operand |
193 | MI.addOperand( |
194 | Op: MachineOperand::CreateImm(Val: RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED)); |
195 | |
196 | // vmv.v.v doesn't have a mask operand, so we may be able to inflate the |
197 | // register class for the destination and merge operands e.g. VRNoV0 -> VR |
198 | MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg()); |
199 | MRI->recomputeRegClass(Reg: MI.getOperand(i: 1).getReg()); |
200 | return true; |
201 | } |
202 | |
203 | bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const { |
204 | const RISCV::RISCVMaskedPseudoInfo *I = |
205 | RISCV::getMaskedPseudoInfo(MaskedPseudo: MI.getOpcode()); |
206 | if (!I) |
207 | return false; |
208 | |
209 | if (!isAllOnesMask(MaskDef: V0Defs.lookup(Val: &MI))) |
210 | return false; |
211 | |
212 | // There are two classes of pseudos in the table - compares and |
213 | // everything else. See the comment on RISCVMaskedPseudo for details. |
214 | const unsigned Opc = I->UnmaskedPseudo; |
215 | const MCInstrDesc &MCID = TII->get(Opcode: Opc); |
216 | [[maybe_unused]] const bool HasPolicyOp = |
217 | RISCVII::hasVecPolicyOp(TSFlags: MCID.TSFlags); |
218 | const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc: MCID); |
219 | #ifndef NDEBUG |
220 | const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode()); |
221 | assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) == |
222 | RISCVII::hasVecPolicyOp(MCID.TSFlags) && |
223 | "Masked and unmasked pseudos are inconsistent" ); |
224 | assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure" ); |
225 | #endif |
226 | (void)HasPolicyOp; |
227 | |
228 | MI.setDesc(MCID); |
229 | |
230 | // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs? |
231 | unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs(); |
232 | MI.removeOperand(OpNo: MaskOpIdx); |
233 | |
234 | // The unmasked pseudo will no longer be constrained to the vrnov0 reg class, |
235 | // so try and relax it to vr. |
236 | MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg()); |
237 | unsigned PassthruOpIdx = MI.getNumExplicitDefs(); |
238 | if (HasPassthru) { |
239 | if (MI.getOperand(i: PassthruOpIdx).getReg() != RISCV::NoRegister) |
240 | MRI->recomputeRegClass(Reg: MI.getOperand(i: PassthruOpIdx).getReg()); |
241 | } else |
242 | MI.removeOperand(OpNo: PassthruOpIdx); |
243 | |
244 | return true; |
245 | } |
246 | |
247 | bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) { |
248 | if (skipFunction(F: MF.getFunction())) |
249 | return false; |
250 | |
251 | // Skip if the vector extension is not enabled. |
252 | const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); |
253 | if (!ST.hasVInstructions()) |
254 | return false; |
255 | |
256 | TII = ST.getInstrInfo(); |
257 | MRI = &MF.getRegInfo(); |
258 | TRI = MRI->getTargetRegisterInfo(); |
259 | |
260 | bool Changed = false; |
261 | |
262 | // Masked pseudos coming out of isel will have their mask operand in the form: |
263 | // |
264 | // $v0:vr = COPY %mask:vr |
265 | // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr |
266 | // |
267 | // Because $v0 isn't in SSA, keep track of its definition at each use so we |
268 | // can check mask operands. |
269 | for (const MachineBasicBlock &MBB : MF) { |
270 | const MachineInstr *CurrentV0Def = nullptr; |
271 | for (const MachineInstr &MI : MBB) { |
272 | if (MI.readsRegister(Reg: RISCV::V0, TRI)) |
273 | V0Defs[&MI] = CurrentV0Def; |
274 | |
275 | if (MI.definesRegister(Reg: RISCV::V0, TRI)) |
276 | CurrentV0Def = &MI; |
277 | } |
278 | } |
279 | |
280 | for (MachineBasicBlock &MBB : MF) { |
281 | for (MachineInstr &MI : MBB) { |
282 | Changed |= convertToVLMAX(MI); |
283 | Changed |= convertToUnmasked(MI); |
284 | Changed |= convertVMergeToVMv(MI); |
285 | } |
286 | } |
287 | |
288 | return Changed; |
289 | } |
290 | |
291 | FunctionPass *llvm::createRISCVVectorPeepholePass() { |
292 | return new RISCVVectorPeephole(); |
293 | } |
294 | |