1//===- RISCVVectorPeephole.cpp - MI Vector Pseudo Peepholes ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass performs various vector pseudo peephole optimisations after
10// instruction selection.
11//
12// Currently it converts vmerge.vvm to vmv.v.v
13// PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14// ->
15// PseudoVMV_V_V %false, %true, %vl, %sew
16//
17// And masked pseudos to unmasked pseudos
18// PseudoVADD_V_V_MASK %passthru, %a, %b, %allonesmask, %vl, sew, policy
19// ->
20// PseudoVADD_V_V %passthru %a, %b, %vl, sew, policy
21//
22// It also converts AVLs to VLMAX where possible
23// %vl = VLENB * something
24// PseudoVADD_V_V %passthru, %a, %b, %vl, sew, policy
25// ->
26// PseudoVADD_V_V %passthru, %a, %b, -1, sew, policy
27//
28//===----------------------------------------------------------------------===//
29
30#include "RISCV.h"
31#include "RISCVSubtarget.h"
32#include "llvm/CodeGen/MachineFunctionPass.h"
33#include "llvm/CodeGen/MachineRegisterInfo.h"
34#include "llvm/CodeGen/TargetInstrInfo.h"
35#include "llvm/CodeGen/TargetRegisterInfo.h"
36
37using namespace llvm;
38
39#define DEBUG_TYPE "riscv-vector-peephole"
40
41namespace {
42
43class RISCVVectorPeephole : public MachineFunctionPass {
44public:
45 static char ID;
46 const TargetInstrInfo *TII;
47 MachineRegisterInfo *MRI;
48 const TargetRegisterInfo *TRI;
49 const RISCVSubtarget *ST;
50 RISCVVectorPeephole() : MachineFunctionPass(ID) {}
51
52 bool runOnMachineFunction(MachineFunction &MF) override;
53 MachineFunctionProperties getRequiredProperties() const override {
54 return MachineFunctionProperties().setIsSSA();
55 }
56
57 StringRef getPassName() const override {
58 return "RISC-V Vector Peephole Optimization";
59 }
60
61private:
62 bool convertToVLMAX(MachineInstr &MI) const;
63 bool convertToWholeRegister(MachineInstr &MI) const;
64 bool convertToUnmasked(MachineInstr &MI) const;
65 bool convertAllOnesVMergeToVMv(MachineInstr &MI) const;
66 bool convertSameMaskVMergeToVMv(MachineInstr &MI);
67 bool foldUndefPassthruVMV_V_V(MachineInstr &MI);
68 bool foldVMV_V_V(MachineInstr &MI);
69 bool foldVMergeToMask(MachineInstr &MI) const;
70
71 bool hasSameEEW(const MachineInstr &User, const MachineInstr &Src) const;
72 bool isAllOnesMask(const MachineInstr *MaskDef) const;
73 std::optional<unsigned> getConstant(const MachineOperand &VL) const;
74 bool ensureDominates(const MachineOperand &Use, MachineInstr &Src) const;
75 Register
76 lookThruCopies(Register Reg, bool OneUseOnly = false,
77 SmallVectorImpl<MachineInstr *> *Copies = nullptr) const;
78};
79
80} // namespace
81
82char RISCVVectorPeephole::ID = 0;
83
84INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
85 false)
86
87/// Given \p User that has an input operand with EEW=SEW, which uses the dest
88/// operand of \p Src with an unknown EEW, return true if their EEWs match.
89bool RISCVVectorPeephole::hasSameEEW(const MachineInstr &User,
90 const MachineInstr &Src) const {
91 unsigned UserLog2SEW =
92 User.getOperand(i: RISCVII::getSEWOpNum(Desc: User.getDesc())).getImm();
93 unsigned SrcLog2SEW =
94 Src.getOperand(i: RISCVII::getSEWOpNum(Desc: Src.getDesc())).getImm();
95 unsigned SrcLog2EEW = RISCV::getDestLog2EEW(
96 Desc: TII->get(Opcode: RISCV::getRVVMCOpcode(RVVPseudoOpcode: Src.getOpcode())), Log2SEW: SrcLog2SEW);
97 return SrcLog2EEW == UserLog2SEW;
98}
99
100/// Check if an operand is an immediate or a materialized ADDI $x0, imm.
101std::optional<unsigned>
102RISCVVectorPeephole::getConstant(const MachineOperand &VL) const {
103 if (VL.isImm())
104 return VL.getImm();
105
106 MachineInstr *Def = MRI->getVRegDef(Reg: VL.getReg());
107 if (!Def || Def->getOpcode() != RISCV::ADDI ||
108 Def->getOperand(i: 1).getReg() != RISCV::X0)
109 return std::nullopt;
110 return Def->getOperand(i: 2).getImm();
111}
112
113/// Convert AVLs that are known to be VLMAX to the VLMAX sentinel.
114bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
115 if (!RISCVII::hasVLOp(TSFlags: MI.getDesc().TSFlags) ||
116 !RISCVII::hasSEWOp(TSFlags: MI.getDesc().TSFlags))
117 return false;
118
119 auto LMUL = RISCVVType::decodeVLMUL(VLMul: RISCVII::getLMul(TSFlags: MI.getDesc().TSFlags));
120 // Fixed-point value, denominator=8
121 unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
122 unsigned Log2SEW = MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MI.getDesc())).getImm();
123 // A Log2SEW of 0 is an operation on mask registers only
124 unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
125 assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
126 assert(8 * LMULFixed / SEW > 0);
127
128 // If the exact VLEN is known then we know VLMAX, check if the AVL == VLMAX.
129 MachineOperand &VL = MI.getOperand(i: RISCVII::getVLOpNum(Desc: MI.getDesc()));
130 if (auto VLen = ST->getRealVLen(), AVL = getConstant(VL);
131 VLen && AVL && (*VLen * LMULFixed) / SEW == *AVL * 8) {
132 VL.ChangeToImmediate(ImmVal: RISCV::VLMaxSentinel);
133 return true;
134 }
135
136 // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert
137 // it to the VLMAX sentinel value.
138 if (!VL.isReg())
139 return false;
140 MachineInstr *Def = MRI->getVRegDef(Reg: VL.getReg());
141 if (!Def)
142 return false;
143
144 // Fixed-point value, denominator=8
145 uint64_t ScaleFixed = 8;
146 // Check if the VLENB was potentially scaled with slli/srli
147 if (Def->getOpcode() == RISCV::SLLI) {
148 assert(Def->getOperand(2).getImm() < 64);
149 ScaleFixed <<= Def->getOperand(i: 2).getImm();
150 Def = MRI->getVRegDef(Reg: Def->getOperand(i: 1).getReg());
151 } else if (Def->getOpcode() == RISCV::SRLI) {
152 assert(Def->getOperand(2).getImm() < 64);
153 ScaleFixed >>= Def->getOperand(i: 2).getImm();
154 Def = MRI->getVRegDef(Reg: Def->getOperand(i: 1).getReg());
155 }
156
157 if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
158 return false;
159
160 // AVL = (VLENB * Scale)
161 //
162 // VLMAX = (VLENB * 8 * LMUL) / SEW
163 //
164 // AVL == VLMAX
165 // -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW
166 // -> Scale == (8 * LMUL) / SEW
167 if (ScaleFixed != 8 * LMULFixed / SEW)
168 return false;
169
170 VL.ChangeToImmediate(ImmVal: RISCV::VLMaxSentinel);
171
172 return true;
173}
174
175bool RISCVVectorPeephole::isAllOnesMask(const MachineInstr *MaskDef) const {
176 while (MaskDef->isCopy() && MaskDef->getOperand(i: 1).getReg().isVirtual())
177 MaskDef = MRI->getVRegDef(Reg: MaskDef->getOperand(i: 1).getReg());
178
179 // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
180 // undefined behaviour if it's the wrong bitwidth, so we could choose to
181 // assume that it's all-ones? Same applies to its VL.
182 switch (MaskDef->getOpcode()) {
183 case RISCV::PseudoVMSET_M_B1:
184 case RISCV::PseudoVMSET_M_B2:
185 case RISCV::PseudoVMSET_M_B4:
186 case RISCV::PseudoVMSET_M_B8:
187 case RISCV::PseudoVMSET_M_B16:
188 case RISCV::PseudoVMSET_M_B32:
189 case RISCV::PseudoVMSET_M_B64:
190 return true;
191 default:
192 return false;
193 }
194}
195
196/// Convert unit strided unmasked loads and stores to whole-register equivalents
197/// to avoid the dependency on $vl and $vtype.
198///
199/// %x = PseudoVLE8_V_M1 %passthru, %ptr, %vlmax, policy
200/// PseudoVSE8_V_M1 %v, %ptr, %vlmax
201///
202/// ->
203///
204/// %x = VL1RE8_V %ptr
205/// VS1R_V %v, %ptr
206bool RISCVVectorPeephole::convertToWholeRegister(MachineInstr &MI) const {
207#define CASE_WHOLE_REGISTER_LMUL_SEW(lmul, sew) \
208 case RISCV::PseudoVLE##sew##_V_M##lmul: \
209 NewOpc = RISCV::VL##lmul##RE##sew##_V; \
210 break; \
211 case RISCV::PseudoVSE##sew##_V_M##lmul: \
212 NewOpc = RISCV::VS##lmul##R_V; \
213 break;
214#define CASE_WHOLE_REGISTER_LMUL(lmul) \
215 CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 8) \
216 CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 16) \
217 CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 32) \
218 CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 64)
219
220 unsigned NewOpc;
221 switch (MI.getOpcode()) {
222 CASE_WHOLE_REGISTER_LMUL(1)
223 CASE_WHOLE_REGISTER_LMUL(2)
224 CASE_WHOLE_REGISTER_LMUL(4)
225 CASE_WHOLE_REGISTER_LMUL(8)
226 default:
227 return false;
228 }
229
230 MachineOperand &VLOp = MI.getOperand(i: RISCVII::getVLOpNum(Desc: MI.getDesc()));
231 if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
232 return false;
233
234 // Whole register instructions aren't pseudos so they don't have
235 // policy/SEW/AVL ops, and they don't have passthrus.
236 if (RISCVII::hasVecPolicyOp(TSFlags: MI.getDesc().TSFlags))
237 MI.removeOperand(OpNo: RISCVII::getVecPolicyOpNum(Desc: MI.getDesc()));
238 MI.removeOperand(OpNo: RISCVII::getSEWOpNum(Desc: MI.getDesc()));
239 MI.removeOperand(OpNo: RISCVII::getVLOpNum(Desc: MI.getDesc()));
240 if (RISCVII::isFirstDefTiedToFirstUse(Desc: MI.getDesc()))
241 MI.removeOperand(OpNo: 1);
242
243 MI.setDesc(TII->get(Opcode: NewOpc));
244
245 return true;
246}
247
248static unsigned getVMV_V_VOpcodeForVMERGE_VVM(const MachineInstr &MI) {
249#define CASE_VMERGE_TO_VMV(lmul) \
250 case RISCV::PseudoVMERGE_VVM_##lmul: \
251 return RISCV::PseudoVMV_V_V_##lmul;
252 switch (MI.getOpcode()) {
253 default:
254 return 0;
255 CASE_VMERGE_TO_VMV(MF8)
256 CASE_VMERGE_TO_VMV(MF4)
257 CASE_VMERGE_TO_VMV(MF2)
258 CASE_VMERGE_TO_VMV(M1)
259 CASE_VMERGE_TO_VMV(M2)
260 CASE_VMERGE_TO_VMV(M4)
261 CASE_VMERGE_TO_VMV(M8)
262 }
263}
264
265/// Convert a PseudoVMERGE_VVM with an all ones mask to a PseudoVMV_V_V.
266///
267/// %x = PseudoVMERGE_VVM %passthru, %false, %true, %allones, sew, vl
268/// ->
269/// %x = PseudoVMV_V_V %passthru, %true, vl, sew, tu_mu
270bool RISCVVectorPeephole::convertAllOnesVMergeToVMv(MachineInstr &MI) const {
271 unsigned NewOpc = getVMV_V_VOpcodeForVMERGE_VVM(MI);
272 if (!NewOpc)
273 return false;
274 if (!isAllOnesMask(MaskDef: MRI->getVRegDef(Reg: MI.getOperand(i: 4).getReg())))
275 return false;
276
277 MI.setDesc(TII->get(Opcode: NewOpc));
278 MI.removeOperand(OpNo: 2); // False operand
279 MI.removeOperand(OpNo: 3); // Mask operand
280 MI.addOperand(
281 Op: MachineOperand::CreateImm(Val: RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED));
282
283 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
284 // register class for the destination and passthru operands e.g. VRNoV0 -> VR
285 MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg());
286 if (MI.getOperand(i: 1).getReg().isValid())
287 MRI->recomputeRegClass(Reg: MI.getOperand(i: 1).getReg());
288 return true;
289}
290
291// If \p Reg is defined by one or more COPYs of virtual registers, traverses
292// the chain and returns the root non-COPY source.
293Register RISCVVectorPeephole::lookThruCopies(
294 Register Reg, bool OneUseOnly,
295 SmallVectorImpl<MachineInstr *> *Copies) const {
296 while (MachineInstr *Def = MRI->getUniqueVRegDef(Reg)) {
297 if (!Def->isFullCopy())
298 break;
299 Register Src = Def->getOperand(i: 1).getReg();
300 if (!Src.isVirtual())
301 break;
302 if (OneUseOnly && !MRI->hasOneNonDBGUse(RegNo: Reg))
303 break;
304 if (Copies)
305 Copies->push_back(Elt: Def);
306 Reg = Src;
307 }
308 return Reg;
309}
310
311/// If a PseudoVMERGE_VVM's true operand is a masked pseudo and both have the
312/// same mask, and the masked pseudo's passthru is the same as the false
313/// operand, we can convert the PseudoVMERGE_VVM to a PseudoVMV_V_V.
314///
315/// %true = PseudoVADD_VV_M1_MASK %false, %x, %y, %mask, vl1, sew, policy
316/// %x = PseudoVMERGE_VVM %passthru, %false, %true, %mask, vl2, sew
317/// ->
318/// %true = PseudoVADD_VV_M1_MASK %false, %x, %y, %mask, vl1, sew, policy
319/// %x = PseudoVMV_V_V %passthru, %true, vl2, sew, tu_mu
320bool RISCVVectorPeephole::convertSameMaskVMergeToVMv(MachineInstr &MI) {
321 unsigned NewOpc = getVMV_V_VOpcodeForVMERGE_VVM(MI);
322 if (!NewOpc)
323 return false;
324 MachineInstr *True = MRI->getVRegDef(Reg: MI.getOperand(i: 3).getReg());
325
326 if (!True || True->getParent() != MI.getParent())
327 return false;
328
329 auto *TrueMaskedInfo = RISCV::getMaskedPseudoInfo(MaskedPseudo: True->getOpcode());
330 if (!TrueMaskedInfo || !hasSameEEW(User: MI, Src: *True))
331 return false;
332
333 Register TrueMaskReg = lookThruCopies(
334 Reg: True->getOperand(i: TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs())
335 .getReg());
336 Register MIMaskReg = lookThruCopies(Reg: MI.getOperand(i: 4).getReg());
337 if (!TrueMaskReg.isVirtual() || TrueMaskReg != MIMaskReg)
338 return false;
339
340 // Masked off lanes past TrueVL will come from False, and converting to vmv
341 // will lose these lanes unless MIVL <= TrueVL.
342 // TODO: We could relax this for False == Passthru and True policy == TU
343 const MachineOperand &MIVL = MI.getOperand(i: RISCVII::getVLOpNum(Desc: MI.getDesc()));
344 const MachineOperand &TrueVL =
345 True->getOperand(i: RISCVII::getVLOpNum(Desc: True->getDesc()));
346 if (!RISCV::isVLKnownLE(LHS: MIVL, RHS: TrueVL))
347 return false;
348
349 // True's passthru needs to be equivalent to False
350 Register TruePassthruReg = True->getOperand(i: 1).getReg();
351 Register FalseReg = MI.getOperand(i: 2).getReg();
352 if (TruePassthruReg != FalseReg) {
353 // If True's passthru is undef see if we can change it to False
354 if (TruePassthruReg.isValid() ||
355 !MRI->hasOneUse(RegNo: MI.getOperand(i: 3).getReg()) ||
356 !ensureDominates(Use: MI.getOperand(i: 2), Src&: *True))
357 return false;
358 True->getOperand(i: 1).setReg(MI.getOperand(i: 2).getReg());
359 // If True is masked then its passthru needs to be in VRNoV0.
360 MRI->constrainRegClass(Reg: True->getOperand(i: 1).getReg(),
361 RC: TII->getRegClass(MCID: True->getDesc(), OpNum: 1));
362 }
363
364 MI.setDesc(TII->get(Opcode: NewOpc));
365 MI.removeOperand(OpNo: 2); // False operand
366 MI.removeOperand(OpNo: 3); // Mask operand
367 MI.addOperand(
368 Op: MachineOperand::CreateImm(Val: RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED));
369
370 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
371 // register class for the destination and passthru operands e.g. VRNoV0 -> VR
372 MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg());
373 if (MI.getOperand(i: 1).getReg().isValid())
374 MRI->recomputeRegClass(Reg: MI.getOperand(i: 1).getReg());
375 return true;
376}
377
378bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
379 const RISCV::RISCVMaskedPseudoInfo *I =
380 RISCV::getMaskedPseudoInfo(MaskedPseudo: MI.getOpcode());
381 if (!I)
382 return false;
383
384 if (!isAllOnesMask(MaskDef: MRI->getVRegDef(
385 Reg: MI.getOperand(i: I->MaskOpIdx + MI.getNumExplicitDefs()).getReg())))
386 return false;
387
388 // There are two classes of pseudos in the table - compares and
389 // everything else. See the comment on RISCVMaskedPseudo for details.
390 const unsigned Opc = I->UnmaskedPseudo;
391 const MCInstrDesc &MCID = TII->get(Opcode: Opc);
392 [[maybe_unused]] const bool HasPolicyOp =
393 RISCVII::hasVecPolicyOp(TSFlags: MCID.TSFlags);
394 const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc: MCID);
395 const MCInstrDesc &MaskedMCID = TII->get(Opcode: MI.getOpcode());
396 assert((RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) ||
397 !RISCVII::hasVecPolicyOp(MCID.TSFlags)) &&
398 "Unmasked pseudo has policy but masked pseudo doesn't?");
399 assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure");
400 assert(!(HasPassthru && !RISCVII::isFirstDefTiedToFirstUse(MaskedMCID)) &&
401 "Unmasked with passthru but masked with no passthru?");
402 (void)HasPolicyOp;
403
404 MI.setDesc(MCID);
405
406 // Drop the policy operand if unmasked doesn't need it.
407 if (RISCVII::hasVecPolicyOp(TSFlags: MaskedMCID.TSFlags) &&
408 !RISCVII::hasVecPolicyOp(TSFlags: MCID.TSFlags))
409 MI.removeOperand(OpNo: RISCVII::getVecPolicyOpNum(Desc: MaskedMCID));
410
411 // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
412 unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
413 MI.removeOperand(OpNo: MaskOpIdx);
414
415 // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
416 // so try and relax it to vr.
417 MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg());
418
419 // If the original masked pseudo had a passthru, relax it or remove it.
420 if (RISCVII::isFirstDefTiedToFirstUse(Desc: MaskedMCID)) {
421 unsigned PassthruOpIdx = MI.getNumExplicitDefs();
422 if (HasPassthru) {
423 if (MI.getOperand(i: PassthruOpIdx).getReg())
424 MRI->recomputeRegClass(Reg: MI.getOperand(i: PassthruOpIdx).getReg());
425 } else
426 MI.removeOperand(OpNo: PassthruOpIdx);
427 }
428
429 return true;
430}
431
432/// Given A and B are in the same MBB, returns true if A comes before B.
433static bool dominates(MachineBasicBlock::const_iterator A,
434 MachineBasicBlock::const_iterator B) {
435 assert(A->getParent() == B->getParent());
436 const MachineBasicBlock *MBB = A->getParent();
437 auto MBBEnd = MBB->end();
438 if (B == MBBEnd)
439 return true;
440
441 MachineBasicBlock::const_iterator I = MBB->begin();
442 for (; &*I != A && &*I != B; ++I)
443 ;
444
445 return &*I == A;
446}
447
448/// If the register in \p MO doesn't dominate \p Src, try to move \p Src so it
449/// does. Returns false if doesn't dominate and we can't move. \p MO must be in
450/// the same basic block as \Src.
451bool RISCVVectorPeephole::ensureDominates(const MachineOperand &MO,
452 MachineInstr &Src) const {
453 assert(MO.getParent()->getParent() == Src.getParent());
454 if (!MO.isReg() || !MO.getReg().isValid())
455 return true;
456
457 MachineInstr *Def = MRI->getVRegDef(Reg: MO.getReg());
458 if (Def->getParent() == Src.getParent() && !dominates(A: Def, B: Src)) {
459 if (!RISCVInstrInfo::isSafeToMove(From: Src, To: *Def->getNextNode()))
460 return false;
461 Src.moveBefore(MovePos: Def->getNextNode());
462 }
463
464 return true;
465}
466
467/// If a PseudoVMV_V_V's passthru is undef then we can replace it with its input
468bool RISCVVectorPeephole::foldUndefPassthruVMV_V_V(MachineInstr &MI) {
469 if (RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode()) != RISCV::VMV_V_V)
470 return false;
471 if (MI.getOperand(i: 1).getReg().isValid())
472 return false;
473
474 // If the input was a pseudo with a policy operand, we can give it a tail
475 // agnostic policy if MI's undef tail subsumes the input's.
476 MachineInstr *Src = MRI->getVRegDef(Reg: MI.getOperand(i: 2).getReg());
477 if (Src && !Src->hasUnmodeledSideEffects() &&
478 MRI->hasOneUse(RegNo: MI.getOperand(i: 2).getReg()) &&
479 RISCVII::hasVLOp(TSFlags: Src->getDesc().TSFlags) &&
480 RISCVII::hasVecPolicyOp(TSFlags: Src->getDesc().TSFlags) && hasSameEEW(User: MI, Src: *Src)) {
481 const MachineOperand &MIVL = MI.getOperand(i: 3);
482 const MachineOperand &SrcVL =
483 Src->getOperand(i: RISCVII::getVLOpNum(Desc: Src->getDesc()));
484
485 MachineOperand &SrcPolicy =
486 Src->getOperand(i: RISCVII::getVecPolicyOpNum(Desc: Src->getDesc()));
487
488 if (RISCV::isVLKnownLE(LHS: MIVL, RHS: SrcVL))
489 SrcPolicy.setImm(SrcPolicy.getImm() | RISCVVType::TAIL_AGNOSTIC);
490 }
491
492 MRI->constrainRegClass(Reg: MI.getOperand(i: 2).getReg(),
493 RC: MRI->getRegClass(Reg: MI.getOperand(i: 0).getReg()));
494 MRI->replaceRegWith(FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 2).getReg());
495 MRI->clearKillFlags(Reg: MI.getOperand(i: 2).getReg());
496 MI.eraseFromParent();
497 return true;
498}
499
500/// If a PseudoVMV_V_V is the only user of its input, fold its passthru and VL
501/// into it.
502///
503/// %x = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl1, sew, policy
504/// %y = PseudoVMV_V_V_M1 %passthru, %x, %vl2, sew, policy
505/// (where %vl1 <= %vl2)
506///
507/// ->
508///
509/// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, vl1, sew, policy
510bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
511 if (RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode()) != RISCV::VMV_V_V)
512 return false;
513
514 MachineOperand &Passthru = MI.getOperand(i: 1);
515
516 if (!MRI->hasOneUse(RegNo: MI.getOperand(i: 2).getReg()))
517 return false;
518
519 MachineInstr *Src = MRI->getVRegDef(Reg: MI.getOperand(i: 2).getReg());
520 if (!Src || Src->hasUnmodeledSideEffects() ||
521 Src->getParent() != MI.getParent() ||
522 !RISCVII::isFirstDefTiedToFirstUse(Desc: Src->getDesc()) ||
523 !RISCVII::hasVLOp(TSFlags: Src->getDesc().TSFlags))
524 return false;
525
526 // Src's dest needs to have the same EEW as MI's input.
527 if (!hasSameEEW(User: MI, Src: *Src))
528 return false;
529
530 std::optional<std::pair<unsigned, unsigned>> NeedsCommute;
531
532 // Src needs to have the same passthru as VMV_V_V
533 MachineOperand &SrcPassthru = Src->getOperand(i: Src->getNumExplicitDefs());
534 if (SrcPassthru.getReg().isValid() &&
535 SrcPassthru.getReg() != Passthru.getReg()) {
536 // If Src's passthru != Passthru, check if it uses Passthru in another
537 // operand and try to commute it.
538 int OtherIdx = Src->findRegisterUseOperandIdx(Reg: Passthru.getReg(), TRI);
539 if (OtherIdx == -1)
540 return false;
541 unsigned OpIdx1 = OtherIdx;
542 unsigned OpIdx2 = Src->getNumExplicitDefs();
543 if (!TII->findCommutedOpIndices(MI: *Src, SrcOpIdx1&: OpIdx1, SrcOpIdx2&: OpIdx2))
544 return false;
545 NeedsCommute = {OpIdx1, OpIdx2};
546 }
547
548 // Src VL will have already been reduced if legal by RISCVVLOptimizer,
549 // so we don't need to handle a smaller source VL here. However, the
550 // user's VL may be larger
551 MachineOperand &SrcVL = Src->getOperand(i: RISCVII::getVLOpNum(Desc: Src->getDesc()));
552 if (!RISCV::isVLKnownLE(LHS: SrcVL, RHS: MI.getOperand(i: 3)))
553 return false;
554
555 // If the new passthru doesn't dominate Src, try to move Src so it does.
556 if (!ensureDominates(MO: Passthru, Src&: *Src))
557 return false;
558
559 if (NeedsCommute) {
560 auto [OpIdx1, OpIdx2] = *NeedsCommute;
561 [[maybe_unused]] bool Commuted =
562 TII->commuteInstruction(MI&: *Src, /*NewMI=*/false, OpIdx1, OpIdx2);
563 assert(Commuted && "Failed to commute Src?");
564 }
565
566 if (SrcPassthru.getReg() != Passthru.getReg()) {
567 SrcPassthru.setReg(Passthru.getReg());
568 // If Src is masked then its passthru needs to be in VRNoV0.
569 if (Passthru.getReg().isValid())
570 MRI->constrainRegClass(
571 Reg: Passthru.getReg(),
572 RC: TII->getRegClass(MCID: Src->getDesc(), OpNum: SrcPassthru.getOperandNo()));
573 }
574
575 if (RISCVII::hasVecPolicyOp(TSFlags: Src->getDesc().TSFlags)) {
576 // If MI was tail agnostic and the VL didn't increase, preserve it.
577 int64_t Policy = RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED;
578 if ((MI.getOperand(i: 5).getImm() & RISCVVType::TAIL_AGNOSTIC) &&
579 RISCV::isVLKnownLE(LHS: MI.getOperand(i: 3), RHS: SrcVL))
580 Policy |= RISCVVType::TAIL_AGNOSTIC;
581 Src->getOperand(i: RISCVII::getVecPolicyOpNum(Desc: Src->getDesc())).setImm(Policy);
582 }
583
584 MRI->constrainRegClass(Reg: Src->getOperand(i: 0).getReg(),
585 RC: MRI->getRegClass(Reg: MI.getOperand(i: 0).getReg()));
586 MRI->replaceRegWith(FromReg: MI.getOperand(i: 0).getReg(), ToReg: Src->getOperand(i: 0).getReg());
587 MI.eraseFromParent();
588
589 return true;
590}
591
592/// Try to fold away VMERGE_VVM instructions into their operands:
593///
594/// %true = PseudoVADD_VV ...
595/// %x = PseudoVMERGE_VVM_M1 %false, %false, %true, %mask
596/// ->
597/// %x = PseudoVADD_VV_M1_MASK %false, ..., %mask
598///
599/// We can only fold if vmerge's passthru operand, vmerge's false operand and
600/// %true's passthru operand (if it has one) are the same. This is because we
601/// have to consolidate them into one passthru operand in the result.
602///
603/// If %true is masked, then we can use its mask instead of vmerge's if vmerge's
604/// mask is all ones.
605///
606/// The resulting VL is the minimum of the two VLs.
607///
608/// The resulting policy is the effective policy the vmerge would have had,
609/// i.e. whether or not it's passthru operand was implicit-def.
610bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const {
611 if (RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode()) != RISCV::VMERGE_VVM)
612 return false;
613
614 // Collect chain of COPYs on True's result for later cleanup.
615 SmallVector<MachineInstr *, 4> TrueCopies;
616 Register PassthruReg = lookThruCopies(Reg: MI.getOperand(i: 1).getReg());
617 const MachineOperand &FalseOp = MI.getOperand(i: 2);
618 Register FalseReg = lookThruCopies(Reg: FalseOp.getReg());
619 Register TrueReg = lookThruCopies(Reg: MI.getOperand(i: 3).getReg(),
620 /*OneUseOnly=*/true, Copies: &TrueCopies);
621 if (!TrueReg.isVirtual() || !MRI->hasOneUse(RegNo: TrueReg))
622 return false;
623 MachineInstr &True = *MRI->getUniqueVRegDef(Reg: TrueReg);
624 if (True.getParent() != MI.getParent())
625 return false;
626 const MachineOperand &MaskOp = MI.getOperand(i: 4);
627 MachineInstr *Mask = MRI->getUniqueVRegDef(Reg: MaskOp.getReg());
628 assert(Mask);
629
630 const RISCV::RISCVMaskedPseudoInfo *Info =
631 RISCV::lookupMaskedIntrinsicByUnmasked(UnmaskedPseudo: True.getOpcode());
632 if (!Info)
633 return false;
634
635 // If the EEW of True is different from vmerge's SEW, then we can't fold.
636 if (!hasSameEEW(User: MI, Src: True))
637 return false;
638
639 // We require that either passthru and false are the same, or that passthru
640 // is undefined.
641 if (PassthruReg && !(PassthruReg.isVirtual() && PassthruReg == FalseReg))
642 return false;
643
644 std::optional<std::pair<unsigned, unsigned>> NeedsCommute;
645
646 // If True has a passthru operand then it needs to be the same as vmerge's
647 // False, since False will be used for the result's passthru operand.
648 Register TruePassthru;
649 if (RISCVII::isFirstDefTiedToFirstUse(Desc: True.getDesc()))
650 TruePassthru =
651 lookThruCopies(Reg: True.getOperand(i: True.getNumExplicitDefs()).getReg());
652 if (TruePassthru && !(TruePassthru.isVirtual() && TruePassthru == FalseReg)) {
653 // If True's passthru != False, check if it uses False in another operand
654 // and try to commute it.
655 int OtherIdx = True.findRegisterUseOperandIdx(Reg: FalseReg, TRI);
656 if (OtherIdx == -1)
657 return false;
658 unsigned OpIdx1 = OtherIdx;
659 unsigned OpIdx2 = True.getNumExplicitDefs();
660 if (!TII->findCommutedOpIndices(MI: True, SrcOpIdx1&: OpIdx1, SrcOpIdx2&: OpIdx2))
661 return false;
662 NeedsCommute = {OpIdx1, OpIdx2};
663 }
664
665 // Make sure it doesn't raise any observable fp exceptions, since changing the
666 // active elements will affect how fflags is set.
667 if (True.hasUnmodeledSideEffects() || True.mayRaiseFPException())
668 return false;
669
670 const MachineOperand &VMergeVL =
671 MI.getOperand(i: RISCVII::getVLOpNum(Desc: MI.getDesc()));
672 const MachineOperand &TrueVL =
673 True.getOperand(i: RISCVII::getVLOpNum(Desc: True.getDesc()));
674
675 MachineOperand MinVL = MachineOperand::CreateImm(Val: 0);
676 if (RISCV::isVLKnownLE(LHS: TrueVL, RHS: VMergeVL))
677 MinVL = TrueVL;
678 else if (RISCV::isVLKnownLE(LHS: VMergeVL, RHS: TrueVL))
679 MinVL = VMergeVL;
680 else if (!TruePassthru && !True.mayLoadOrStore())
681 // If True's passthru is undef, we can use vmerge's vl.
682 MinVL = VMergeVL;
683 else
684 return false;
685
686 unsigned RVVTSFlags =
687 TII->get(Opcode: RISCV::getRVVMCOpcode(RVVPseudoOpcode: True.getOpcode())).TSFlags;
688 if (RISCVII::elementsDependOnVL(TSFlags: RVVTSFlags) && !TrueVL.isIdenticalTo(Other: MinVL))
689 return false;
690 if (RISCVII::elementsDependOnMask(TSFlags: RVVTSFlags) && !isAllOnesMask(MaskDef: Mask))
691 return false;
692
693 // Use a tumu policy, relaxing it to tail agnostic provided that the passthru
694 // operand is undefined.
695 //
696 // However, if the VL became smaller than what the vmerge had originally, then
697 // elements past VL that were previously in the vmerge's body will have moved
698 // to the tail. In that case we always need to use tail undisturbed to
699 // preserve them.
700 uint64_t Policy = RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED;
701 if (!PassthruReg && RISCV::isVLKnownLE(LHS: VMergeVL, RHS: MinVL))
702 Policy |= RISCVVType::TAIL_AGNOSTIC;
703
704 assert(RISCVII::hasVecPolicyOp(True.getDesc().TSFlags) &&
705 "Foldable unmasked pseudo should have a policy op already");
706
707 // Make sure both mask and false dominate True and its copies, otherwise move
708 // down True so it does. VL will always dominate since if it's a register they
709 // need to be the same.
710 const MachineOperand *DomOp = &MaskOp;
711 MachineInstr *False = MRI->getUniqueVRegDef(Reg: FalseReg);
712 if (False && False->getParent() == Mask->getParent() &&
713 dominates(A: Mask, B: False))
714 DomOp = &FalseOp;
715 if (!ensureDominates(MO: *DomOp, Src&: True))
716 return false;
717
718 if (NeedsCommute) {
719 auto [OpIdx1, OpIdx2] = *NeedsCommute;
720 [[maybe_unused]] bool Commuted =
721 TII->commuteInstruction(MI&: True, /*NewMI=*/false, OpIdx1, OpIdx2);
722 assert(Commuted && "Failed to commute True?");
723 Info = RISCV::lookupMaskedIntrinsicByUnmasked(UnmaskedPseudo: True.getOpcode());
724 }
725
726 True.setDesc(TII->get(Opcode: Info->MaskedPseudo));
727
728 // Insert the mask operand.
729 // TODO: Increment MaskOpIdx by number of explicit defs?
730 True.insert(InsertBefore: True.operands_begin() + Info->MaskOpIdx +
731 True.getNumExplicitDefs(),
732 Ops: MachineOperand::CreateReg(Reg: MaskOp.getReg(), isDef: false));
733
734 // Update the passthru, AVL and policy.
735 True.getOperand(i: True.getNumExplicitDefs()).setReg(FalseReg);
736 True.removeOperand(OpNo: RISCVII::getVLOpNum(Desc: True.getDesc()));
737 True.insert(InsertBefore: True.operands_begin() + RISCVII::getVLOpNum(Desc: True.getDesc()),
738 Ops: MinVL);
739 True.getOperand(i: RISCVII::getVecPolicyOpNum(Desc: True.getDesc())).setImm(Policy);
740
741 MRI->replaceRegWith(FromReg: True.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 0).getReg());
742 // Now that True is masked, constrain its operands from vr -> vrnov0.
743 for (MachineOperand &MO : True.explicit_operands()) {
744 if (!MO.isReg() || !MO.getReg().isVirtual())
745 continue;
746 MRI->constrainRegClass(
747 Reg: MO.getReg(), RC: True.getRegClassConstraint(OpIdx: MO.getOperandNo(), TII, TRI));
748 }
749 // We should clear the IsKill flag since we have a new use now.
750 MRI->clearKillFlags(Reg: FalseReg);
751 MI.eraseFromParent();
752
753 // Cleanup all the COPYs on True's value. We have to manually do this because
754 // sometimes sinking True causes these COPY to be invalid (use before define).
755 for (MachineInstr *TrueCopy : TrueCopies)
756 TrueCopy->eraseFromParent();
757
758 return true;
759}
760
761bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
762 if (skipFunction(F: MF.getFunction()))
763 return false;
764
765 // Skip if the vector extension is not enabled.
766 ST = &MF.getSubtarget<RISCVSubtarget>();
767 if (!ST->hasVInstructions())
768 return false;
769
770 TII = ST->getInstrInfo();
771 MRI = &MF.getRegInfo();
772 TRI = MRI->getTargetRegisterInfo();
773
774 bool Changed = false;
775
776 for (MachineBasicBlock &MBB : MF) {
777 for (MachineInstr &MI : make_early_inc_range(Range&: MBB))
778 Changed |= foldVMergeToMask(MI);
779
780 for (MachineInstr &MI : make_early_inc_range(Range&: MBB)) {
781 Changed |= convertToVLMAX(MI);
782 Changed |= convertToUnmasked(MI);
783 Changed |= convertToWholeRegister(MI);
784 Changed |= convertAllOnesVMergeToVMv(MI);
785 Changed |= convertSameMaskVMergeToVMv(MI);
786 if (foldUndefPassthruVMV_V_V(MI)) {
787 Changed |= true;
788 continue; // MI is erased
789 }
790 Changed |= foldVMV_V_V(MI);
791 }
792 }
793
794 return Changed;
795}
796
797FunctionPass *llvm::createRISCVVectorPeepholePass() {
798 return new RISCVVectorPeephole();
799}
800