1//===- RISCVVectorPeephole.cpp - MI Vector Pseudo Peepholes ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass performs various vector pseudo peephole optimisations after
10// instruction selection.
11//
12// Currently it converts vmerge.vvm to vmv.v.v
13// PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14// ->
15// PseudoVMV_V_V %false, %true, %vl, %sew
16//
17// And masked pseudos to unmasked pseudos
18// PseudoVADD_V_V_MASK %passthru, %a, %b, %allonesmask, %vl, sew, policy
19// ->
20// PseudoVADD_V_V %passthru %a, %b, %vl, sew, policy
21//
22// It also converts AVLs to VLMAX where possible
23// %vl = VLENB * something
24// PseudoVADD_V_V %passthru, %a, %b, %vl, sew, policy
25// ->
26// PseudoVADD_V_V %passthru, %a, %b, -1, sew, policy
27//
28//===----------------------------------------------------------------------===//
29
30#include "RISCV.h"
31#include "RISCVSubtarget.h"
32#include "llvm/CodeGen/MachineFunctionPass.h"
33#include "llvm/CodeGen/MachineRegisterInfo.h"
34#include "llvm/CodeGen/TargetInstrInfo.h"
35#include "llvm/CodeGen/TargetRegisterInfo.h"
36
37using namespace llvm;
38
39#define DEBUG_TYPE "riscv-vector-peephole"
40
41namespace {
42
43class RISCVVectorPeephole : public MachineFunctionPass {
44public:
45 static char ID;
46 const TargetInstrInfo *TII;
47 MachineRegisterInfo *MRI;
48 const TargetRegisterInfo *TRI;
49 const RISCVSubtarget *ST;
50 RISCVVectorPeephole() : MachineFunctionPass(ID) {}
51
52 bool runOnMachineFunction(MachineFunction &MF) override;
53 MachineFunctionProperties getRequiredProperties() const override {
54 return MachineFunctionProperties().setIsSSA();
55 }
56
57 StringRef getPassName() const override {
58 return "RISC-V Vector Peephole Optimization";
59 }
60
61private:
62 bool convertToVLMAX(MachineInstr &MI) const;
63 bool convertToWholeRegister(MachineInstr &MI) const;
64 bool convertToUnmasked(MachineInstr &MI) const;
65 bool convertAllOnesVMergeToVMv(MachineInstr &MI) const;
66 bool convertSameMaskVMergeToVMv(MachineInstr &MI);
67 bool foldUndefPassthruVMV_V_V(MachineInstr &MI);
68 bool foldVMV_V_V(MachineInstr &MI);
69 bool foldVMergeToMask(MachineInstr &MI) const;
70
71 bool hasSameEEW(const MachineInstr &User, const MachineInstr &Src) const;
72 bool isAllOnesMask(const MachineInstr *MaskDef) const;
73 std::optional<unsigned> getConstant(const MachineOperand &VL) const;
74 bool ensureDominates(ArrayRef<const MachineOperand *> Defs,
75 MachineInstr &Use) const;
76 Register
77 lookThruCopies(Register Reg, bool OneUseOnly = false,
78 SmallVectorImpl<MachineInstr *> *Copies = nullptr) const;
79};
80
81} // namespace
82
83char RISCVVectorPeephole::ID = 0;
84
85INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
86 false)
87
88/// Given \p User that has an input operand with EEW=SEW, which uses the dest
89/// operand of \p Src with an unknown EEW, return true if their EEWs match.
90bool RISCVVectorPeephole::hasSameEEW(const MachineInstr &User,
91 const MachineInstr &Src) const {
92 unsigned UserLog2SEW =
93 User.getOperand(i: RISCVII::getSEWOpNum(Desc: User.getDesc())).getImm();
94 unsigned SrcLog2SEW =
95 Src.getOperand(i: RISCVII::getSEWOpNum(Desc: Src.getDesc())).getImm();
96 unsigned SrcLog2EEW = RISCV::getDestLog2EEW(
97 Desc: TII->get(Opcode: RISCV::getRVVMCOpcode(RVVPseudoOpcode: Src.getOpcode())), Log2SEW: SrcLog2SEW);
98 return SrcLog2EEW == UserLog2SEW;
99}
100
101/// Check if an operand is an immediate or a materialized ADDI $x0, imm.
102std::optional<unsigned>
103RISCVVectorPeephole::getConstant(const MachineOperand &VL) const {
104 if (VL.isImm())
105 return VL.getImm();
106
107 MachineInstr *Def = MRI->getVRegDef(Reg: VL.getReg());
108 if (!Def || Def->getOpcode() != RISCV::ADDI ||
109 Def->getOperand(i: 1).getReg() != RISCV::X0)
110 return std::nullopt;
111 return Def->getOperand(i: 2).getImm();
112}
113
114/// Convert AVLs that are known to be VLMAX to the VLMAX sentinel.
115bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
116 if (!RISCVII::hasVLOp(TSFlags: MI.getDesc().TSFlags) ||
117 !RISCVII::hasSEWOp(TSFlags: MI.getDesc().TSFlags))
118 return false;
119
120 auto LMUL = RISCVVType::decodeVLMUL(VLMul: RISCVII::getLMul(TSFlags: MI.getDesc().TSFlags));
121 // Fixed-point value, denominator=8
122 unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
123 unsigned Log2SEW = MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MI.getDesc())).getImm();
124 // A Log2SEW of 0 is an operation on mask registers only
125 unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
126 assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
127 assert(8 * LMULFixed / SEW > 0);
128
129 // If the exact VLEN is known then we know VLMAX, check if the AVL == VLMAX.
130 MachineOperand &VL = MI.getOperand(i: RISCVII::getVLOpNum(Desc: MI.getDesc()));
131 if (auto VLen = ST->getRealVLen(), AVL = getConstant(VL);
132 VLen && AVL && (*VLen * LMULFixed) / SEW == *AVL * 8) {
133 VL.ChangeToImmediate(ImmVal: RISCV::VLMaxSentinel);
134 return true;
135 }
136
137 // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert
138 // it to the VLMAX sentinel value.
139 if (!VL.isReg())
140 return false;
141 MachineInstr *Def = MRI->getVRegDef(Reg: VL.getReg());
142 if (!Def)
143 return false;
144
145 // Fixed-point value, denominator=8
146 uint64_t ScaleFixed = 8;
147 // Check if the VLENB was potentially scaled with slli/srli
148 if (Def->getOpcode() == RISCV::SLLI) {
149 assert(Def->getOperand(2).getImm() < 64);
150 ScaleFixed <<= Def->getOperand(i: 2).getImm();
151 Def = MRI->getVRegDef(Reg: Def->getOperand(i: 1).getReg());
152 } else if (Def->getOpcode() == RISCV::SRLI) {
153 assert(Def->getOperand(2).getImm() < 64);
154 ScaleFixed >>= Def->getOperand(i: 2).getImm();
155 Def = MRI->getVRegDef(Reg: Def->getOperand(i: 1).getReg());
156 }
157
158 if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
159 return false;
160
161 // AVL = (VLENB * Scale)
162 //
163 // VLMAX = (VLENB * 8 * LMUL) / SEW
164 //
165 // AVL == VLMAX
166 // -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW
167 // -> Scale == (8 * LMUL) / SEW
168 if (ScaleFixed != 8 * LMULFixed / SEW)
169 return false;
170
171 VL.ChangeToImmediate(ImmVal: RISCV::VLMaxSentinel);
172
173 return true;
174}
175
176bool RISCVVectorPeephole::isAllOnesMask(const MachineInstr *MaskDef) const {
177 while (MaskDef->isCopy() && MaskDef->getOperand(i: 1).getReg().isVirtual())
178 MaskDef = MRI->getVRegDef(Reg: MaskDef->getOperand(i: 1).getReg());
179
180 // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
181 // undefined behaviour if it's the wrong bitwidth, so we could choose to
182 // assume that it's all-ones? Same applies to its VL.
183 switch (MaskDef->getOpcode()) {
184 case RISCV::PseudoVMSET_M_B1:
185 case RISCV::PseudoVMSET_M_B2:
186 case RISCV::PseudoVMSET_M_B4:
187 case RISCV::PseudoVMSET_M_B8:
188 case RISCV::PseudoVMSET_M_B16:
189 case RISCV::PseudoVMSET_M_B32:
190 case RISCV::PseudoVMSET_M_B64:
191 return true;
192 default:
193 return false;
194 }
195}
196
197/// Convert unit strided unmasked loads and stores to whole-register equivalents
198/// to avoid the dependency on $vl and $vtype.
199///
200/// %x = PseudoVLE8_V_M1 %passthru, %ptr, %vlmax, policy
201/// PseudoVSE8_V_M1 %v, %ptr, %vlmax
202///
203/// ->
204///
205/// %x = VL1RE8_V %ptr
206/// VS1R_V %v, %ptr
207bool RISCVVectorPeephole::convertToWholeRegister(MachineInstr &MI) const {
208#define CASE_WHOLE_REGISTER_LMUL_SEW(lmul, sew) \
209 case RISCV::PseudoVLE##sew##_V_M##lmul: \
210 NewOpc = RISCV::VL##lmul##RE##sew##_V; \
211 break; \
212 case RISCV::PseudoVSE##sew##_V_M##lmul: \
213 NewOpc = RISCV::VS##lmul##R_V; \
214 break;
215#define CASE_WHOLE_REGISTER_LMUL(lmul) \
216 CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 8) \
217 CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 16) \
218 CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 32) \
219 CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 64)
220
221 unsigned NewOpc;
222 switch (MI.getOpcode()) {
223 CASE_WHOLE_REGISTER_LMUL(1)
224 CASE_WHOLE_REGISTER_LMUL(2)
225 CASE_WHOLE_REGISTER_LMUL(4)
226 CASE_WHOLE_REGISTER_LMUL(8)
227 default:
228 return false;
229 }
230
231 MachineOperand &VLOp = MI.getOperand(i: RISCVII::getVLOpNum(Desc: MI.getDesc()));
232 if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
233 return false;
234
235 // Whole register instructions aren't pseudos so they don't have
236 // policy/SEW/AVL ops, and they don't have passthrus.
237 if (RISCVII::hasVecPolicyOp(TSFlags: MI.getDesc().TSFlags))
238 MI.removeOperand(OpNo: RISCVII::getVecPolicyOpNum(Desc: MI.getDesc()));
239 MI.removeOperand(OpNo: RISCVII::getSEWOpNum(Desc: MI.getDesc()));
240 MI.removeOperand(OpNo: RISCVII::getVLOpNum(Desc: MI.getDesc()));
241 if (RISCVII::isFirstDefTiedToFirstUse(Desc: MI.getDesc()))
242 MI.removeOperand(OpNo: 1);
243
244 MI.setDesc(TII->get(Opcode: NewOpc));
245
246 return true;
247}
248
249static unsigned getVMV_V_VOpcodeForVMERGE_VVM(const MachineInstr &MI) {
250#define CASE_VMERGE_TO_VMV(lmul) \
251 case RISCV::PseudoVMERGE_VVM_##lmul: \
252 return RISCV::PseudoVMV_V_V_##lmul;
253 switch (MI.getOpcode()) {
254 default:
255 return 0;
256 CASE_VMERGE_TO_VMV(MF8)
257 CASE_VMERGE_TO_VMV(MF4)
258 CASE_VMERGE_TO_VMV(MF2)
259 CASE_VMERGE_TO_VMV(M1)
260 CASE_VMERGE_TO_VMV(M2)
261 CASE_VMERGE_TO_VMV(M4)
262 CASE_VMERGE_TO_VMV(M8)
263 }
264}
265
266/// Convert a PseudoVMERGE_VVM with an all ones mask to a PseudoVMV_V_V.
267///
268/// %x = PseudoVMERGE_VVM %passthru, %false, %true, %allones, sew, vl
269/// ->
270/// %x = PseudoVMV_V_V %passthru, %true, vl, sew, tu_mu
271bool RISCVVectorPeephole::convertAllOnesVMergeToVMv(MachineInstr &MI) const {
272 unsigned NewOpc = getVMV_V_VOpcodeForVMERGE_VVM(MI);
273 if (!NewOpc)
274 return false;
275 if (!isAllOnesMask(MaskDef: MRI->getVRegDef(Reg: MI.getOperand(i: 4).getReg())))
276 return false;
277
278 MI.setDesc(TII->get(Opcode: NewOpc));
279 MI.removeOperand(OpNo: 2); // False operand
280 MI.removeOperand(OpNo: 3); // Mask operand
281 MI.addOperand(
282 Op: MachineOperand::CreateImm(Val: RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED));
283
284 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
285 // register class for the destination and passthru operands e.g. VRNoV0 -> VR
286 MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg());
287 if (MI.getOperand(i: 1).getReg().isValid())
288 MRI->recomputeRegClass(Reg: MI.getOperand(i: 1).getReg());
289 return true;
290}
291
292// If \p Reg is defined by one or more COPYs of virtual registers, traverses
293// the chain and returns the root non-COPY source.
294Register RISCVVectorPeephole::lookThruCopies(
295 Register Reg, bool OneUseOnly,
296 SmallVectorImpl<MachineInstr *> *Copies) const {
297 while (MachineInstr *Def = MRI->getUniqueVRegDef(Reg)) {
298 if (!Def->isFullCopy())
299 break;
300 Register Src = Def->getOperand(i: 1).getReg();
301 if (!Src.isVirtual())
302 break;
303 if (OneUseOnly && !MRI->hasOneNonDBGUse(RegNo: Reg))
304 break;
305 if (Copies)
306 Copies->push_back(Elt: Def);
307 Reg = Src;
308 }
309 return Reg;
310}
311
312/// If a PseudoVMERGE_VVM's true operand is a masked pseudo and both have the
313/// same mask, and the masked pseudo's passthru is the same as the false
314/// operand, we can convert the PseudoVMERGE_VVM to a PseudoVMV_V_V.
315///
316/// %true = PseudoVADD_VV_M1_MASK %false, %x, %y, %mask, vl1, sew, policy
317/// %x = PseudoVMERGE_VVM %passthru, %false, %true, %mask, vl2, sew
318/// ->
319/// %true = PseudoVADD_VV_M1_MASK %false, %x, %y, %mask, vl1, sew, policy
320/// %x = PseudoVMV_V_V %passthru, %true, vl2, sew, tu_mu
321bool RISCVVectorPeephole::convertSameMaskVMergeToVMv(MachineInstr &MI) {
322 unsigned NewOpc = getVMV_V_VOpcodeForVMERGE_VVM(MI);
323 if (!NewOpc)
324 return false;
325 MachineInstr *True = MRI->getVRegDef(Reg: MI.getOperand(i: 3).getReg());
326
327 if (!True || True->getParent() != MI.getParent())
328 return false;
329
330 auto *TrueMaskedInfo = RISCV::getMaskedPseudoInfo(MaskedPseudo: True->getOpcode());
331 if (!TrueMaskedInfo || !hasSameEEW(User: MI, Src: *True))
332 return false;
333
334 Register TrueMaskReg = lookThruCopies(
335 Reg: True->getOperand(i: TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs())
336 .getReg());
337 Register MIMaskReg = lookThruCopies(Reg: MI.getOperand(i: 4).getReg());
338 if (!TrueMaskReg.isVirtual() || TrueMaskReg != MIMaskReg)
339 return false;
340
341 // Masked off lanes past TrueVL will come from False, and converting to vmv
342 // will lose these lanes unless MIVL <= TrueVL.
343 // We can relax this when False == Passthru and True's tail policy is TU,
344 // because True's tail lanes will preserve its passthru (= False = Passthru).
345 const MachineOperand &MIVL = MI.getOperand(i: RISCVII::getVLOpNum(Desc: MI.getDesc()));
346 const MachineOperand &TrueVL =
347 True->getOperand(i: RISCVII::getVLOpNum(Desc: True->getDesc()));
348 Register FalseReg = MI.getOperand(i: 2).getReg();
349 if (!RISCV::isVLKnownLE(LHS: MIVL, RHS: TrueVL)) {
350 Register PassthruReg = MI.getOperand(i: 1).getReg();
351 if (FalseReg.isValid() && FalseReg != PassthruReg)
352 return false;
353 if (!RISCVII::hasVecPolicyOp(TSFlags: True->getDesc().TSFlags))
354 return false;
355 uint64_t TruePolicy =
356 True->getOperand(i: RISCVII::getVecPolicyOpNum(Desc: True->getDesc())).getImm();
357 if (TruePolicy & RISCVVType::TAIL_AGNOSTIC)
358 return false;
359 }
360
361 // True's passthru needs to be equivalent to False
362 Register TruePassthruReg = True->getOperand(i: 1).getReg();
363 if (TruePassthruReg != FalseReg) {
364 // If True's passthru is undef see if we can change it to False
365 if (TruePassthruReg.isValid() ||
366 !MRI->hasOneUse(RegNo: MI.getOperand(i: 3).getReg()) ||
367 !ensureDominates(Defs: &MI.getOperand(i: 2), Use&: *True))
368 return false;
369 True->getOperand(i: 1).setReg(MI.getOperand(i: 2).getReg());
370 // If True is masked then its passthru needs to be in VRNoV0.
371 MRI->constrainRegClass(Reg: True->getOperand(i: 1).getReg(),
372 RC: TII->getRegClass(MCID: True->getDesc(), OpNum: 1));
373 }
374
375 // If True is mask agnostic, we need to make it mask undisturbed.
376 if (RISCVII::hasVecPolicyOp(TSFlags: True->getDesc().TSFlags)) {
377 MachineOperand &PolicyOp =
378 True->getOperand(i: RISCVII::getVecPolicyOpNum(Desc: True->getDesc()));
379 PolicyOp.setImm(PolicyOp.getImm() & ~RISCVVType::MASK_AGNOSTIC);
380 }
381
382 MI.setDesc(TII->get(Opcode: NewOpc));
383 MI.removeOperand(OpNo: 2); // False operand
384 MI.removeOperand(OpNo: 3); // Mask operand
385 MI.addOperand(
386 Op: MachineOperand::CreateImm(Val: RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED));
387
388 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
389 // register class for the destination and passthru operands e.g. VRNoV0 -> VR
390 MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg());
391 if (MI.getOperand(i: 1).getReg().isValid())
392 MRI->recomputeRegClass(Reg: MI.getOperand(i: 1).getReg());
393 return true;
394}
395
396bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
397 const RISCV::RISCVMaskedPseudoInfo *I =
398 RISCV::getMaskedPseudoInfo(MaskedPseudo: MI.getOpcode());
399 if (!I)
400 return false;
401
402 if (!isAllOnesMask(MaskDef: MRI->getVRegDef(
403 Reg: MI.getOperand(i: I->MaskOpIdx + MI.getNumExplicitDefs()).getReg())))
404 return false;
405
406 // There are two classes of pseudos in the table - compares and
407 // everything else. See the comment on RISCVMaskedPseudo for details.
408 const unsigned Opc = I->UnmaskedPseudo;
409 const MCInstrDesc &MCID = TII->get(Opcode: Opc);
410 [[maybe_unused]] const bool HasPolicyOp =
411 RISCVII::hasVecPolicyOp(TSFlags: MCID.TSFlags);
412 const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc: MCID);
413 const MCInstrDesc &MaskedMCID = TII->get(Opcode: MI.getOpcode());
414 assert((RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) ||
415 !RISCVII::hasVecPolicyOp(MCID.TSFlags)) &&
416 "Unmasked pseudo has policy but masked pseudo doesn't?");
417 assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure");
418 assert(!(HasPassthru && !RISCVII::isFirstDefTiedToFirstUse(MaskedMCID)) &&
419 "Unmasked with passthru but masked with no passthru?");
420 (void)HasPolicyOp;
421
422 MI.setDesc(MCID);
423
424 // Drop the policy operand if unmasked doesn't need it.
425 if (RISCVII::hasVecPolicyOp(TSFlags: MaskedMCID.TSFlags) &&
426 !RISCVII::hasVecPolicyOp(TSFlags: MCID.TSFlags))
427 MI.removeOperand(OpNo: RISCVII::getVecPolicyOpNum(Desc: MaskedMCID));
428
429 // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
430 unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
431 MI.removeOperand(OpNo: MaskOpIdx);
432
433 // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
434 // so try and relax it to vr.
435 MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg());
436
437 // If the original masked pseudo had a passthru, relax it or remove it.
438 if (RISCVII::isFirstDefTiedToFirstUse(Desc: MaskedMCID)) {
439 unsigned PassthruOpIdx = MI.getNumExplicitDefs();
440 if (HasPassthru) {
441 if (MI.getOperand(i: PassthruOpIdx).getReg())
442 MRI->recomputeRegClass(Reg: MI.getOperand(i: PassthruOpIdx).getReg());
443 } else
444 MI.removeOperand(OpNo: PassthruOpIdx);
445 }
446
447 return true;
448}
449
450/// Given A and B are in the same MBB, returns true if A comes before B.
451static bool strictlyDominates(MachineBasicBlock::const_iterator A,
452 MachineBasicBlock::const_iterator B) {
453 assert(A->getParent() == B->getParent());
454 if (A == B)
455 return false;
456 const MachineBasicBlock *MBB = A->getParent();
457 auto MBBEnd = MBB->end();
458 if (B == MBBEnd)
459 return true;
460
461 MachineBasicBlock::const_iterator I = MBB->begin();
462 for (; &*I != A && &*I != B; ++I)
463 ;
464
465 return &*I == A;
466}
467
468/// If a register in \p Defs doesn't dominate \p Use, try to move Use so it
469/// does. Returns false if any def doesn't dominate and we can't move Use. Each
470/// def must be in the same block as Use.
471bool RISCVVectorPeephole::ensureDominates(ArrayRef<const MachineOperand *> Defs,
472 MachineInstr &Use) const {
473 MachineInstr *Dest = &Use;
474
475 for (const MachineOperand *MO : Defs) {
476 assert(MO->getParent()->getParent() == Use.getParent());
477 if (!MO->isReg() || !MO->getReg().isValid())
478 continue;
479
480 MachineInstr *Def = MRI->getVRegDef(Reg: MO->getReg());
481 if (Def->getParent() == Dest->getParent() &&
482 !strictlyDominates(A: Def, B: *Dest)) {
483 if (!RISCVInstrInfo::isSafeToMove(From: *Dest, To: *Def->getNextNode()))
484 return false;
485 Dest = Def->getNextNode();
486 }
487 }
488
489 if (Dest != &Use)
490 Use.moveBefore(MovePos: Dest);
491
492 return true;
493}
494
495/// If a PseudoVMV_V_V's passthru is undef then we can replace it with its input
496bool RISCVVectorPeephole::foldUndefPassthruVMV_V_V(MachineInstr &MI) {
497 if (RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode()) != RISCV::VMV_V_V)
498 return false;
499 if (MI.getOperand(i: 1).getReg().isValid())
500 return false;
501
502 // If the input was a pseudo with a policy operand, we can give it a tail
503 // agnostic policy if MI's undef tail subsumes the input's.
504 MachineInstr *Src = MRI->getVRegDef(Reg: MI.getOperand(i: 2).getReg());
505 if (Src && !Src->hasUnmodeledSideEffects() &&
506 MRI->hasOneUse(RegNo: MI.getOperand(i: 2).getReg()) &&
507 RISCVII::hasVLOp(TSFlags: Src->getDesc().TSFlags) &&
508 RISCVII::hasVecPolicyOp(TSFlags: Src->getDesc().TSFlags) && hasSameEEW(User: MI, Src: *Src)) {
509 const MachineOperand &MIVL = MI.getOperand(i: 3);
510 const MachineOperand &SrcVL =
511 Src->getOperand(i: RISCVII::getVLOpNum(Desc: Src->getDesc()));
512
513 MachineOperand &SrcPolicy =
514 Src->getOperand(i: RISCVII::getVecPolicyOpNum(Desc: Src->getDesc()));
515
516 if (RISCV::isVLKnownLE(LHS: MIVL, RHS: SrcVL))
517 SrcPolicy.setImm(SrcPolicy.getImm() | RISCVVType::TAIL_AGNOSTIC);
518 }
519
520 MRI->constrainRegClass(Reg: MI.getOperand(i: 2).getReg(),
521 RC: MRI->getRegClass(Reg: MI.getOperand(i: 0).getReg()));
522 MRI->replaceRegWith(FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 2).getReg());
523 MRI->clearKillFlags(Reg: MI.getOperand(i: 2).getReg());
524 MI.eraseFromParent();
525 return true;
526}
527
528/// If a PseudoVMV_V_V is the only user of its input, fold its passthru and VL
529/// into it.
530///
531/// %x = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl1, sew, policy
532/// %y = PseudoVMV_V_V_M1 %passthru, %x, %vl2, sew, policy
533/// (where %vl1 <= %vl2)
534///
535/// ->
536///
537/// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, vl1, sew, policy
538bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
539 if (RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode()) != RISCV::VMV_V_V)
540 return false;
541
542 MachineOperand &Passthru = MI.getOperand(i: 1);
543
544 if (!MRI->hasOneUse(RegNo: MI.getOperand(i: 2).getReg()))
545 return false;
546
547 MachineInstr *Src = MRI->getVRegDef(Reg: MI.getOperand(i: 2).getReg());
548 if (!Src || Src->hasUnmodeledSideEffects() ||
549 Src->getParent() != MI.getParent() ||
550 !RISCVII::isFirstDefTiedToFirstUse(Desc: Src->getDesc()) ||
551 !RISCVII::hasVLOp(TSFlags: Src->getDesc().TSFlags))
552 return false;
553
554 // Src's dest needs to have the same EEW as MI's input.
555 if (!hasSameEEW(User: MI, Src: *Src))
556 return false;
557
558 std::optional<std::pair<unsigned, unsigned>> NeedsCommute;
559
560 // Src needs to have the same passthru as VMV_V_V
561 MachineOperand &SrcPassthru = Src->getOperand(i: Src->getNumExplicitDefs());
562 if (SrcPassthru.getReg().isValid() &&
563 SrcPassthru.getReg() != Passthru.getReg()) {
564 // If Src's passthru != Passthru, check if it uses Passthru in another
565 // operand and try to commute it.
566 int OtherIdx = Src->findRegisterUseOperandIdx(Reg: Passthru.getReg(), TRI);
567 if (OtherIdx == -1)
568 return false;
569 unsigned OpIdx1 = OtherIdx;
570 unsigned OpIdx2 = Src->getNumExplicitDefs();
571 if (!TII->findCommutedOpIndices(MI: *Src, SrcOpIdx1&: OpIdx1, SrcOpIdx2&: OpIdx2))
572 return false;
573 NeedsCommute = {OpIdx1, OpIdx2};
574 }
575
576 // Src VL will have already been reduced if legal by RISCVVLOptimizer,
577 // so we don't need to handle a smaller source VL here. However, the
578 // user's VL may be larger
579 MachineOperand &SrcVL = Src->getOperand(i: RISCVII::getVLOpNum(Desc: Src->getDesc()));
580 if (!RISCV::isVLKnownLE(LHS: SrcVL, RHS: MI.getOperand(i: 3)))
581 return false;
582
583 // If the new passthru doesn't dominate Src, try to move Src so it does.
584 if (!ensureDominates(Defs: &Passthru, Use&: *Src))
585 return false;
586
587 if (NeedsCommute) {
588 auto [OpIdx1, OpIdx2] = *NeedsCommute;
589 [[maybe_unused]] bool Commuted =
590 TII->commuteInstruction(MI&: *Src, /*NewMI=*/false, OpIdx1, OpIdx2);
591 assert(Commuted && "Failed to commute Src?");
592 }
593
594 if (SrcPassthru.getReg() != Passthru.getReg()) {
595 SrcPassthru.setReg(Passthru.getReg());
596 // If Src is masked then its passthru needs to be in VRNoV0.
597 if (Passthru.getReg().isValid())
598 MRI->constrainRegClass(
599 Reg: Passthru.getReg(),
600 RC: TII->getRegClass(MCID: Src->getDesc(), OpNum: SrcPassthru.getOperandNo()));
601 }
602
603 if (RISCVII::hasVecPolicyOp(TSFlags: Src->getDesc().TSFlags)) {
604 // If MI was tail agnostic and the VL didn't increase, preserve it.
605 int64_t Policy = RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED;
606 if ((MI.getOperand(i: 5).getImm() & RISCVVType::TAIL_AGNOSTIC) &&
607 RISCV::isVLKnownLE(LHS: MI.getOperand(i: 3), RHS: SrcVL))
608 Policy |= RISCVVType::TAIL_AGNOSTIC;
609 Src->getOperand(i: RISCVII::getVecPolicyOpNum(Desc: Src->getDesc())).setImm(Policy);
610 }
611
612 MRI->constrainRegClass(Reg: Src->getOperand(i: 0).getReg(),
613 RC: MRI->getRegClass(Reg: MI.getOperand(i: 0).getReg()));
614 MRI->replaceRegWith(FromReg: MI.getOperand(i: 0).getReg(), ToReg: Src->getOperand(i: 0).getReg());
615 MI.eraseFromParent();
616
617 return true;
618}
619
620/// Try to fold away VMERGE_VVM instructions into their operands:
621///
622/// %true = PseudoVADD_VV ...
623/// %x = PseudoVMERGE_VVM_M1 %false, %false, %true, %mask
624/// ->
625/// %x = PseudoVADD_VV_M1_MASK %false, ..., %mask
626///
627/// We can only fold if vmerge's passthru operand, vmerge's false operand and
628/// %true's passthru operand (if it has one) are the same. This is because we
629/// have to consolidate them into one passthru operand in the result.
630///
631/// If %true is masked, then we can use its mask instead of vmerge's if vmerge's
632/// mask is all ones.
633///
634/// The resulting VL is the minimum of the two VLs.
635///
636/// The resulting policy is the effective policy the vmerge would have had,
637/// i.e. whether or not it's passthru operand was implicit-def.
638bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const {
639 if (RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode()) != RISCV::VMERGE_VVM)
640 return false;
641
642 // Collect chain of COPYs on True's result for later cleanup.
643 SmallVector<MachineInstr *, 4> TrueCopies;
644 Register PassthruReg = lookThruCopies(Reg: MI.getOperand(i: 1).getReg());
645 const MachineOperand &FalseOp = MI.getOperand(i: 2);
646 Register FalseReg = lookThruCopies(Reg: FalseOp.getReg());
647 Register TrueReg = lookThruCopies(Reg: MI.getOperand(i: 3).getReg(),
648 /*OneUseOnly=*/true, Copies: &TrueCopies);
649 if (!TrueReg.isVirtual() || !MRI->hasOneUse(RegNo: TrueReg))
650 return false;
651 MachineInstr &True = *MRI->getUniqueVRegDef(Reg: TrueReg);
652 if (True.getParent() != MI.getParent())
653 return false;
654 const MachineOperand &MaskOp = MI.getOperand(i: 4);
655 MachineInstr *Mask = MRI->getUniqueVRegDef(Reg: MaskOp.getReg());
656 assert(Mask);
657
658 const RISCV::RISCVMaskedPseudoInfo *Info =
659 RISCV::lookupMaskedIntrinsicByUnmasked(UnmaskedPseudo: True.getOpcode());
660 if (!Info)
661 return false;
662
663 // If the EEW of True is different from vmerge's SEW, then we can't fold.
664 if (!hasSameEEW(User: MI, Src: True))
665 return false;
666
667 // We require that either passthru and false are the same, or that passthru
668 // is undefined.
669 if (PassthruReg && !(PassthruReg.isVirtual() && PassthruReg == FalseReg))
670 return false;
671
672 std::optional<std::pair<unsigned, unsigned>> NeedsCommute;
673
674 // If True has a passthru operand then it needs to be the same as vmerge's
675 // False, since False will be used for the result's passthru operand.
676 Register TruePassthru;
677 if (RISCVII::isFirstDefTiedToFirstUse(Desc: True.getDesc()))
678 TruePassthru =
679 lookThruCopies(Reg: True.getOperand(i: True.getNumExplicitDefs()).getReg());
680 if (TruePassthru && !(TruePassthru.isVirtual() && TruePassthru == FalseReg)) {
681 // If True's passthru != False, check if it uses False in another operand
682 // and try to commute it.
683 int OtherIdx = True.findRegisterUseOperandIdx(Reg: FalseReg, TRI);
684 if (OtherIdx == -1)
685 return false;
686 unsigned OpIdx1 = OtherIdx;
687 unsigned OpIdx2 = True.getNumExplicitDefs();
688 if (!TII->findCommutedOpIndices(MI: True, SrcOpIdx1&: OpIdx1, SrcOpIdx2&: OpIdx2))
689 return false;
690 NeedsCommute = {OpIdx1, OpIdx2};
691 }
692
693 // Make sure it doesn't raise any observable fp exceptions, since changing the
694 // active elements will affect how fflags is set.
695 if (True.hasUnmodeledSideEffects() || True.mayRaiseFPException())
696 return false;
697
698 const MachineOperand &VMergeVL =
699 MI.getOperand(i: RISCVII::getVLOpNum(Desc: MI.getDesc()));
700 const MachineOperand &TrueVL =
701 True.getOperand(i: RISCVII::getVLOpNum(Desc: True.getDesc()));
702
703 MachineOperand MinVL = MachineOperand::CreateImm(Val: 0);
704 if (RISCV::isVLKnownLE(LHS: TrueVL, RHS: VMergeVL))
705 MinVL = TrueVL;
706 else if (RISCV::isVLKnownLE(LHS: VMergeVL, RHS: TrueVL))
707 MinVL = VMergeVL;
708 else if (!TruePassthru && !True.mayLoadOrStore())
709 // If True's passthru is undef, we can use vmerge's vl.
710 MinVL = VMergeVL;
711 else
712 return false;
713
714 unsigned RVVTSFlags =
715 TII->get(Opcode: RISCV::getRVVMCOpcode(RVVPseudoOpcode: True.getOpcode())).TSFlags;
716 if (RISCVII::elementsDependOnVL(TSFlags: RVVTSFlags) && !TrueVL.isIdenticalTo(Other: MinVL))
717 return false;
718 if (RISCVII::elementsDependOnMask(TSFlags: RVVTSFlags) && !isAllOnesMask(MaskDef: Mask))
719 return false;
720
721 // Use a tumu policy, relaxing it to tail agnostic provided that the passthru
722 // operand is undefined.
723 //
724 // However, if the VL became smaller than what the vmerge had originally, then
725 // elements past VL that were previously in the vmerge's body will have moved
726 // to the tail. In that case we always need to use tail undisturbed to
727 // preserve them.
728 uint64_t Policy = RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED;
729 if (!PassthruReg && RISCV::isVLKnownLE(LHS: VMergeVL, RHS: MinVL))
730 Policy |= RISCVVType::TAIL_AGNOSTIC;
731
732 assert(RISCVII::hasVecPolicyOp(True.getDesc().TSFlags) &&
733 "Foldable unmasked pseudo should have a policy op already");
734
735 // Make sure Mask, False and MinVL dominate True and its copies, otherwise
736 // move down True so it does.
737 if (!ensureDominates(Defs: {&MaskOp, &FalseOp, &MinVL}, Use&: True))
738 return false;
739
740 if (NeedsCommute) {
741 auto [OpIdx1, OpIdx2] = *NeedsCommute;
742 [[maybe_unused]] bool Commuted =
743 TII->commuteInstruction(MI&: True, /*NewMI=*/false, OpIdx1, OpIdx2);
744 assert(Commuted && "Failed to commute True?");
745 Info = RISCV::lookupMaskedIntrinsicByUnmasked(UnmaskedPseudo: True.getOpcode());
746 }
747
748 True.setDesc(TII->get(Opcode: Info->MaskedPseudo));
749
750 // Insert the mask operand.
751 // TODO: Increment MaskOpIdx by number of explicit defs?
752 True.insert(InsertBefore: True.operands_begin() + Info->MaskOpIdx +
753 True.getNumExplicitDefs(),
754 Ops: MachineOperand::CreateReg(Reg: MaskOp.getReg(), isDef: false));
755
756 // Update the passthru, AVL and policy.
757 True.getOperand(i: True.getNumExplicitDefs()).setReg(FalseReg);
758 True.removeOperand(OpNo: RISCVII::getVLOpNum(Desc: True.getDesc()));
759 True.insert(InsertBefore: True.operands_begin() + RISCVII::getVLOpNum(Desc: True.getDesc()),
760 Ops: MinVL);
761 True.getOperand(i: RISCVII::getVecPolicyOpNum(Desc: True.getDesc())).setImm(Policy);
762
763 MRI->replaceRegWith(FromReg: True.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 0).getReg());
764 // Now that True is masked, constrain its operands from vr -> vrnov0.
765 for (MachineOperand &MO : True.explicit_operands()) {
766 if (!MO.isReg() || !MO.getReg().isVirtual())
767 continue;
768 MRI->constrainRegClass(
769 Reg: MO.getReg(), RC: True.getRegClassConstraint(OpIdx: MO.getOperandNo(), TII, TRI));
770 }
771 // We should clear the IsKill flag since we have a new use now.
772 MRI->clearKillFlags(Reg: FalseReg);
773 MI.eraseFromParent();
774
775 // Cleanup all the COPYs on True's value. We have to manually do this because
776 // sometimes sinking True causes these COPY to be invalid (use before define).
777 for (MachineInstr *TrueCopy : TrueCopies)
778 TrueCopy->eraseFromParent();
779
780 return true;
781}
782
783bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
784 if (skipFunction(F: MF.getFunction()))
785 return false;
786
787 // Skip if the vector extension is not enabled.
788 ST = &MF.getSubtarget<RISCVSubtarget>();
789 if (!ST->hasVInstructions())
790 return false;
791
792 TII = ST->getInstrInfo();
793 MRI = &MF.getRegInfo();
794 TRI = MRI->getTargetRegisterInfo();
795
796 bool Changed = false;
797
798 for (MachineBasicBlock &MBB : MF) {
799 for (MachineInstr &MI : make_early_inc_range(Range&: MBB))
800 Changed |= foldVMergeToMask(MI);
801
802 for (MachineInstr &MI : make_early_inc_range(Range&: MBB)) {
803 Changed |= convertToVLMAX(MI);
804 Changed |= convertToUnmasked(MI);
805 Changed |= convertToWholeRegister(MI);
806 Changed |= convertAllOnesVMergeToVMv(MI);
807 Changed |= convertSameMaskVMergeToVMv(MI);
808 if (foldUndefPassthruVMV_V_V(MI)) {
809 Changed |= true;
810 continue; // MI is erased
811 }
812 Changed |= foldVMV_V_V(MI);
813 }
814 }
815
816 return Changed;
817}
818
819FunctionPass *llvm::createRISCVVectorPeepholePass() {
820 return new RISCVVectorPeephole();
821}
822