1//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// This pass reduces the VL where possible at the MI level, before VSETVLI
10// instructions are inserted.
11//
12// The purpose of this optimization is to make the VL argument, for instructions
13// that have a VL argument, as small as possible.
14//
15// This is split into a sparse dataflow analysis where we determine what VL is
16// demanded by each instruction first, and then afterwards try to reduce the VL
17// of each instruction if it demands less than its VL operand.
18//
19// The analysis is explained in more detail in the 2025 EuroLLVM Developers'
20// Meeting talk "Accidental Dataflow Analysis: Extending the RISC-V VL
21// Optimizer", which is available on YouTube at
22// https://www.youtube.com/watch?v=Mfb5fRSdJAc
23//
24// The slides for the talk are available at
25// https://llvm.org/devmtg/2025-04/slides/technical_talk/lau_accidental_dataflow.pdf
26//
27//===---------------------------------------------------------------------===//
28
29#include "RISCV.h"
30#include "RISCVSubtarget.h"
31#include "llvm/ADT/PostOrderIterator.h"
32#include "llvm/CodeGen/MachineDominators.h"
33#include "llvm/CodeGen/MachineFunctionPass.h"
34#include "llvm/InitializePasses.h"
35
36using namespace llvm;
37
38#define DEBUG_TYPE "riscv-vl-optimizer"
39#define PASS_NAME "RISC-V VL Optimizer"
40
41namespace {
42
43/// Wrapper around MachineOperand that defaults to immediate 0.
44struct DemandedVL {
45 MachineOperand VL;
46 DemandedVL() : VL(MachineOperand::CreateImm(Val: 0)) {}
47 DemandedVL(MachineOperand VL) : VL(VL) {}
48 static DemandedVL vlmax() {
49 return DemandedVL(MachineOperand::CreateImm(Val: RISCV::VLMaxSentinel));
50 }
51 bool operator!=(const DemandedVL &Other) const {
52 return !VL.isIdenticalTo(Other: Other.VL);
53 }
54
55 DemandedVL max(const DemandedVL &X) const {
56 if (RISCV::isVLKnownLE(LHS: VL, RHS: X.VL))
57 return X;
58 if (RISCV::isVLKnownLE(LHS: X.VL, RHS: VL))
59 return *this;
60 return DemandedVL::vlmax();
61 }
62};
63
64class RISCVVLOptimizer : public MachineFunctionPass {
65 MachineRegisterInfo *MRI;
66 const MachineDominatorTree *MDT;
67 const TargetInstrInfo *TII;
68
69public:
70 static char ID;
71
72 RISCVVLOptimizer() : MachineFunctionPass(ID) {}
73
74 bool runOnMachineFunction(MachineFunction &MF) override;
75
76 void getAnalysisUsage(AnalysisUsage &AU) const override {
77 AU.setPreservesCFG();
78 AU.addRequired<MachineDominatorTreeWrapperPass>();
79 MachineFunctionPass::getAnalysisUsage(AU);
80 }
81
82 StringRef getPassName() const override { return PASS_NAME; }
83
84private:
85 DemandedVL getMinimumVLForUser(const MachineOperand &UserOp) const;
86 /// Returns true if the users of \p MI have compatible EEWs and SEWs.
87 bool checkUsers(const MachineInstr &MI) const;
88 bool tryReduceVL(MachineInstr &MI, MachineOperand VL) const;
89 bool isSupportedInstr(const MachineInstr &MI) const;
90 bool isCandidate(const MachineInstr &MI) const;
91 void transfer(const MachineInstr &MI);
92
93 /// For a given instruction, records what elements of it are demanded by
94 /// downstream users.
95 MapVector<const MachineInstr *, DemandedVL> DemandedVLs;
96 SetVector<const MachineInstr *> Worklist;
97
98 /// \returns all vector virtual registers that \p MI uses.
99 auto virtual_vec_uses(const MachineInstr &MI) const {
100 return make_filter_range(Range: MI.uses(), Pred: [this](const MachineOperand &MO) {
101 return MO.isReg() && MO.getReg().isVirtual() &&
102 RISCVRegisterInfo::isRVVRegClass(RC: MRI->getRegClass(Reg: MO.getReg()));
103 });
104 }
105};
106
107/// Represents the EMUL and EEW of a MachineOperand.
108struct OperandInfo {
109 // Represent as 1,2,4,8, ... and fractional indicator. This is because
110 // EMUL can take on values that don't map to RISCVVType::VLMUL values exactly.
111 // For example, a mask operand can have an EMUL less than MF8.
112 // If nullopt, then EMUL isn't used (i.e. only a single scalar is read).
113 std::optional<std::pair<unsigned, bool>> EMUL;
114
115 unsigned Log2EEW;
116
117 OperandInfo(RISCVVType::VLMUL EMUL, unsigned Log2EEW)
118 : EMUL(RISCVVType::decodeVLMUL(VLMul: EMUL)), Log2EEW(Log2EEW) {}
119
120 OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
121 : EMUL(EMUL), Log2EEW(Log2EEW) {}
122
123 OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {}
124
125 OperandInfo() = delete;
126
127 /// Return true if the EMUL and EEW produced by \p Def is compatible with the
128 /// EMUL and EEW used by \p User.
129 static bool areCompatible(const OperandInfo &Def, const OperandInfo &User) {
130 if (Def.Log2EEW != User.Log2EEW)
131 return false;
132 if (User.EMUL && Def.EMUL != User.EMUL)
133 return false;
134 return true;
135 }
136
137 void print(raw_ostream &OS) const {
138 if (EMUL) {
139 OS << "EMUL: m";
140 if (EMUL->second)
141 OS << "f";
142 OS << EMUL->first;
143 } else
144 OS << "EMUL: none\n";
145 OS << ", EEW: " << (1 << Log2EEW);
146 }
147};
148
149} // end anonymous namespace
150
151char RISCVVLOptimizer::ID = 0;
152INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
153INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
154INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
155
156FunctionPass *llvm::createRISCVVLOptimizerPass() {
157 return new RISCVVLOptimizer();
158}
159
160[[maybe_unused]]
161static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
162 OI.print(OS);
163 return OS;
164}
165
166[[maybe_unused]]
167static raw_ostream &operator<<(raw_ostream &OS,
168 const std::optional<OperandInfo> &OI) {
169 if (OI)
170 OI->print(OS);
171 else
172 OS << "nullopt";
173 return OS;
174}
175
176/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
177/// SEW are from the TSFlags of MI.
178static std::pair<unsigned, bool>
179getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
180 RISCVVType::VLMUL MIVLMUL = RISCVII::getLMul(TSFlags: MI.getDesc().TSFlags);
181 auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(VLMul: MIVLMUL);
182 unsigned MILog2SEW =
183 MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MI.getDesc())).getImm();
184
185 // Mask instructions will have 0 as the SEW operand. But the LMUL of these
186 // instructions is calculated is as if the SEW operand was 3 (e8).
187 if (MILog2SEW == 0)
188 MILog2SEW = 3;
189
190 unsigned MISEW = 1 << MILog2SEW;
191
192 unsigned EEW = 1 << Log2EEW;
193 // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
194 // to put fraction in simplest form.
195 unsigned Num = EEW, Denom = MISEW;
196 int GCD = MILMULIsFractional ? std::gcd(m: Num, n: Denom * MILMUL)
197 : std::gcd(m: Num * MILMUL, n: Denom);
198 Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
199 Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
200 return std::make_pair(x&: Num > Denom ? Num : Denom, y: Denom > Num);
201}
202
203/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
204/// SEW comes from TSFlags of MI.
205static unsigned getIntegerExtensionOperandEEW(unsigned Factor,
206 const MachineInstr &MI,
207 const MachineOperand &MO) {
208 unsigned MILog2SEW =
209 MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MI.getDesc())).getImm();
210
211 if (MO.getOperandNo() == 0)
212 return MILog2SEW;
213
214 unsigned MISEW = 1 << MILog2SEW;
215 unsigned EEW = MISEW / Factor;
216 unsigned Log2EEW = Log2_32(Value: EEW);
217
218 return Log2EEW;
219}
220
221#define VSEG_CASES(Prefix, EEW) \
222 RISCV::Prefix##SEG2E##EEW##_V: \
223 case RISCV::Prefix##SEG3E##EEW##_V: \
224 case RISCV::Prefix##SEG4E##EEW##_V: \
225 case RISCV::Prefix##SEG5E##EEW##_V: \
226 case RISCV::Prefix##SEG6E##EEW##_V: \
227 case RISCV::Prefix##SEG7E##EEW##_V: \
228 case RISCV::Prefix##SEG8E##EEW##_V
229#define VSSEG_CASES(EEW) VSEG_CASES(VS, EEW)
230#define VSSSEG_CASES(EEW) VSEG_CASES(VSS, EEW)
231#define VSUXSEG_CASES(EEW) VSEG_CASES(VSUX, I##EEW)
232#define VSOXSEG_CASES(EEW) VSEG_CASES(VSOX, I##EEW)
233
234static std::optional<unsigned> getOperandLog2EEW(const MachineOperand &MO) {
235 const MachineInstr &MI = *MO.getParent();
236 const MCInstrDesc &Desc = MI.getDesc();
237 const RISCVVPseudosTable::PseudoInfo *RVV =
238 RISCVVPseudosTable::getPseudoInfo(Pseudo: MI.getOpcode());
239 assert(RVV && "Could not find MI in PseudoTable");
240
241 // MI has a SEW associated with it. The RVV specification defines
242 // the EEW of each operand and definition in relation to MI.SEW.
243 unsigned MILog2SEW = MI.getOperand(i: RISCVII::getSEWOpNum(Desc)).getImm();
244
245 const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
246 const bool IsTied = RISCVII::isTiedPseudo(TSFlags: Desc.TSFlags);
247
248 bool IsMODef = MO.getOperandNo() == 0 ||
249 (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs());
250
251 // All mask operands have EEW=1
252 const MCOperandInfo &Info = Desc.operands()[MO.getOperandNo()];
253 if (Info.OperandType == MCOI::OPERAND_REGISTER &&
254 Info.RegClass == RISCV::VMV0RegClassID)
255 return 0;
256
257 // switch against BaseInstr to reduce number of cases that need to be
258 // considered.
259 switch (RVV->BaseInstr) {
260
261 // 6. Configuration-Setting Instructions
262 // Configuration setting instructions do not read or write vector registers
263 case RISCV::VSETIVLI:
264 case RISCV::VSETVL:
265 case RISCV::VSETVLI:
266 llvm_unreachable("Configuration setting instructions do not read or write "
267 "vector registers");
268
269 // Vector Loads and Stores
270 // Vector Unit-Stride Instructions
271 // Vector Strided Instructions
272 /// Dest EEW encoded in the instruction
273 case RISCV::VLM_V:
274 case RISCV::VSM_V:
275 return 0;
276 case RISCV::VLE8_V:
277 case RISCV::VSE8_V:
278 case RISCV::VLSE8_V:
279 case RISCV::VSSE8_V:
280 case VSSEG_CASES(8):
281 case VSSSEG_CASES(8):
282 return 3;
283 case RISCV::VLE16_V:
284 case RISCV::VSE16_V:
285 case RISCV::VLSE16_V:
286 case RISCV::VSSE16_V:
287 case VSSEG_CASES(16):
288 case VSSSEG_CASES(16):
289 return 4;
290 case RISCV::VLE32_V:
291 case RISCV::VSE32_V:
292 case RISCV::VLSE32_V:
293 case RISCV::VSSE32_V:
294 case VSSEG_CASES(32):
295 case VSSSEG_CASES(32):
296 return 5;
297 case RISCV::VLE64_V:
298 case RISCV::VSE64_V:
299 case RISCV::VLSE64_V:
300 case RISCV::VSSE64_V:
301 case VSSEG_CASES(64):
302 case VSSSEG_CASES(64):
303 return 6;
304
305 // Vector Indexed Instructions
306 // vs(o|u)xei<eew>.v
307 // Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>.
308 case RISCV::VLUXEI8_V:
309 case RISCV::VLOXEI8_V:
310 case RISCV::VSUXEI8_V:
311 case RISCV::VSOXEI8_V:
312 case VSUXSEG_CASES(8):
313 case VSOXSEG_CASES(8): {
314 if (MO.getOperandNo() == 0)
315 return MILog2SEW;
316 return 3;
317 }
318 case RISCV::VLUXEI16_V:
319 case RISCV::VLOXEI16_V:
320 case RISCV::VSUXEI16_V:
321 case RISCV::VSOXEI16_V:
322 case VSUXSEG_CASES(16):
323 case VSOXSEG_CASES(16): {
324 if (MO.getOperandNo() == 0)
325 return MILog2SEW;
326 return 4;
327 }
328 case RISCV::VLUXEI32_V:
329 case RISCV::VLOXEI32_V:
330 case RISCV::VSUXEI32_V:
331 case RISCV::VSOXEI32_V:
332 case VSUXSEG_CASES(32):
333 case VSOXSEG_CASES(32): {
334 if (MO.getOperandNo() == 0)
335 return MILog2SEW;
336 return 5;
337 }
338 case RISCV::VLUXEI64_V:
339 case RISCV::VLOXEI64_V:
340 case RISCV::VSUXEI64_V:
341 case RISCV::VSOXEI64_V:
342 case VSUXSEG_CASES(64):
343 case VSOXSEG_CASES(64): {
344 if (MO.getOperandNo() == 0)
345 return MILog2SEW;
346 return 6;
347 }
348
349 // Vector Integer Arithmetic Instructions
350 // Vector Single-Width Integer Add and Subtract
351 case RISCV::VADD_VI:
352 case RISCV::VADD_VV:
353 case RISCV::VADD_VX:
354 case RISCV::VSUB_VV:
355 case RISCV::VSUB_VX:
356 case RISCV::VRSUB_VI:
357 case RISCV::VRSUB_VX:
358 // Vector Bitwise Logical Instructions
359 // Vector Single-Width Shift Instructions
360 // EEW=SEW.
361 case RISCV::VAND_VI:
362 case RISCV::VAND_VV:
363 case RISCV::VAND_VX:
364 case RISCV::VOR_VI:
365 case RISCV::VOR_VV:
366 case RISCV::VOR_VX:
367 case RISCV::VXOR_VI:
368 case RISCV::VXOR_VV:
369 case RISCV::VXOR_VX:
370 case RISCV::VSLL_VI:
371 case RISCV::VSLL_VV:
372 case RISCV::VSLL_VX:
373 case RISCV::VSRL_VI:
374 case RISCV::VSRL_VV:
375 case RISCV::VSRL_VX:
376 case RISCV::VSRA_VI:
377 case RISCV::VSRA_VV:
378 case RISCV::VSRA_VX:
379 // Vector Integer Min/Max Instructions
380 // EEW=SEW.
381 case RISCV::VMINU_VV:
382 case RISCV::VMINU_VX:
383 case RISCV::VMIN_VV:
384 case RISCV::VMIN_VX:
385 case RISCV::VMAXU_VV:
386 case RISCV::VMAXU_VX:
387 case RISCV::VMAX_VV:
388 case RISCV::VMAX_VX:
389 // Vector Single-Width Integer Multiply Instructions
390 // Source and Dest EEW=SEW.
391 case RISCV::VMUL_VV:
392 case RISCV::VMUL_VX:
393 case RISCV::VMULH_VV:
394 case RISCV::VMULH_VX:
395 case RISCV::VMULHU_VV:
396 case RISCV::VMULHU_VX:
397 case RISCV::VMULHSU_VV:
398 case RISCV::VMULHSU_VX:
399 // Vector Integer Divide Instructions
400 // EEW=SEW.
401 case RISCV::VDIVU_VV:
402 case RISCV::VDIVU_VX:
403 case RISCV::VDIV_VV:
404 case RISCV::VDIV_VX:
405 case RISCV::VREMU_VV:
406 case RISCV::VREMU_VX:
407 case RISCV::VREM_VV:
408 case RISCV::VREM_VX:
409 // Vector Single-Width Integer Multiply-Add Instructions
410 // EEW=SEW.
411 case RISCV::VMACC_VV:
412 case RISCV::VMACC_VX:
413 case RISCV::VNMSAC_VV:
414 case RISCV::VNMSAC_VX:
415 case RISCV::VMADD_VV:
416 case RISCV::VMADD_VX:
417 case RISCV::VNMSUB_VV:
418 case RISCV::VNMSUB_VX:
419 // Vector Integer Merge Instructions
420 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
421 // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled
422 // before this switch.
423 case RISCV::VMERGE_VIM:
424 case RISCV::VMERGE_VVM:
425 case RISCV::VMERGE_VXM:
426 case RISCV::VADC_VIM:
427 case RISCV::VADC_VVM:
428 case RISCV::VADC_VXM:
429 case RISCV::VSBC_VVM:
430 case RISCV::VSBC_VXM:
431 // Vector Integer Move Instructions
432 // Vector Fixed-Point Arithmetic Instructions
433 // Vector Single-Width Saturating Add and Subtract
434 // Vector Single-Width Averaging Add and Subtract
435 // EEW=SEW.
436 case RISCV::VMV_V_I:
437 case RISCV::VMV_V_V:
438 case RISCV::VMV_V_X:
439 case RISCV::VSADDU_VI:
440 case RISCV::VSADDU_VV:
441 case RISCV::VSADDU_VX:
442 case RISCV::VSADD_VI:
443 case RISCV::VSADD_VV:
444 case RISCV::VSADD_VX:
445 case RISCV::VSSUBU_VV:
446 case RISCV::VSSUBU_VX:
447 case RISCV::VSSUB_VV:
448 case RISCV::VSSUB_VX:
449 case RISCV::VAADDU_VV:
450 case RISCV::VAADDU_VX:
451 case RISCV::VAADD_VV:
452 case RISCV::VAADD_VX:
453 case RISCV::VASUBU_VV:
454 case RISCV::VASUBU_VX:
455 case RISCV::VASUB_VV:
456 case RISCV::VASUB_VX:
457 // Vector Single-Width Fractional Multiply with Rounding and Saturation
458 // EEW=SEW. The instruction produces 2*SEW product internally but
459 // saturates to fit into SEW bits.
460 case RISCV::VSMUL_VV:
461 case RISCV::VSMUL_VX:
462 // Vector Single-Width Scaling Shift Instructions
463 // EEW=SEW.
464 case RISCV::VSSRL_VI:
465 case RISCV::VSSRL_VV:
466 case RISCV::VSSRL_VX:
467 case RISCV::VSSRA_VI:
468 case RISCV::VSSRA_VV:
469 case RISCV::VSSRA_VX:
470 // Vector Permutation Instructions
471 // Integer Scalar Move Instructions
472 // Floating-Point Scalar Move Instructions
473 // EEW=SEW.
474 case RISCV::VMV_X_S:
475 case RISCV::VMV_S_X:
476 case RISCV::VFMV_F_S:
477 case RISCV::VFMV_S_F:
478 // Vector Slide Instructions
479 // EEW=SEW.
480 case RISCV::VSLIDEUP_VI:
481 case RISCV::VSLIDEUP_VX:
482 case RISCV::VSLIDEDOWN_VI:
483 case RISCV::VSLIDEDOWN_VX:
484 case RISCV::VSLIDE1UP_VX:
485 case RISCV::VFSLIDE1UP_VF:
486 case RISCV::VSLIDE1DOWN_VX:
487 case RISCV::VFSLIDE1DOWN_VF:
488 // Vector Register Gather Instructions
489 // EEW=SEW. For mask operand, EEW=1.
490 case RISCV::VRGATHER_VI:
491 case RISCV::VRGATHER_VV:
492 case RISCV::VRGATHER_VX:
493 // Vector Element Index Instruction
494 case RISCV::VID_V:
495 // Vector Single-Width Floating-Point Add/Subtract Instructions
496 case RISCV::VFADD_VF:
497 case RISCV::VFADD_VV:
498 case RISCV::VFSUB_VF:
499 case RISCV::VFSUB_VV:
500 case RISCV::VFRSUB_VF:
501 // Vector Single-Width Floating-Point Multiply/Divide Instructions
502 case RISCV::VFMUL_VF:
503 case RISCV::VFMUL_VV:
504 case RISCV::VFDIV_VF:
505 case RISCV::VFDIV_VV:
506 case RISCV::VFRDIV_VF:
507 // Vector Single-Width Floating-Point Fused Multiply-Add Instructions
508 case RISCV::VFMACC_VV:
509 case RISCV::VFMACC_VF:
510 case RISCV::VFNMACC_VV:
511 case RISCV::VFNMACC_VF:
512 case RISCV::VFMSAC_VV:
513 case RISCV::VFMSAC_VF:
514 case RISCV::VFNMSAC_VV:
515 case RISCV::VFNMSAC_VF:
516 case RISCV::VFMADD_VV:
517 case RISCV::VFMADD_VF:
518 case RISCV::VFNMADD_VV:
519 case RISCV::VFNMADD_VF:
520 case RISCV::VFMSUB_VV:
521 case RISCV::VFMSUB_VF:
522 case RISCV::VFNMSUB_VV:
523 case RISCV::VFNMSUB_VF:
524 // Vector Floating-Point Square-Root Instruction
525 case RISCV::VFSQRT_V:
526 // Vector Floating-Point Reciprocal Square-Root Estimate Instruction
527 case RISCV::VFRSQRT7_V:
528 // Vector Floating-Point Reciprocal Estimate Instruction
529 case RISCV::VFREC7_V:
530 // Vector Floating-Point MIN/MAX Instructions
531 case RISCV::VFMIN_VF:
532 case RISCV::VFMIN_VV:
533 case RISCV::VFMAX_VF:
534 case RISCV::VFMAX_VV:
535 // Vector Floating-Point Sign-Injection Instructions
536 case RISCV::VFSGNJ_VF:
537 case RISCV::VFSGNJ_VV:
538 case RISCV::VFSGNJN_VV:
539 case RISCV::VFSGNJN_VF:
540 case RISCV::VFSGNJX_VF:
541 case RISCV::VFSGNJX_VV:
542 // Vector Floating-Point Classify Instruction
543 case RISCV::VFCLASS_V:
544 // Vector Floating-Point Move Instruction
545 case RISCV::VFMV_V_F:
546 // Single-Width Floating-Point/Integer Type-Convert Instructions
547 case RISCV::VFCVT_XU_F_V:
548 case RISCV::VFCVT_X_F_V:
549 case RISCV::VFCVT_RTZ_XU_F_V:
550 case RISCV::VFCVT_RTZ_X_F_V:
551 case RISCV::VFCVT_F_XU_V:
552 case RISCV::VFCVT_F_X_V:
553 // Vector Floating-Point Merge Instruction
554 case RISCV::VFMERGE_VFM:
555 // Vector count population in mask vcpop.m
556 // vfirst find-first-set mask bit
557 case RISCV::VCPOP_M:
558 case RISCV::VFIRST_M:
559 // Vector Bit-manipulation Instructions (Zvbb)
560 // Vector And-Not
561 case RISCV::VANDN_VV:
562 case RISCV::VANDN_VX:
563 // Vector Reverse Bits in Elements
564 case RISCV::VBREV_V:
565 // Vector Reverse Bits in Bytes
566 case RISCV::VBREV8_V:
567 // Vector Reverse Bytes
568 case RISCV::VREV8_V:
569 // Vector Count Leading Zeros
570 case RISCV::VCLZ_V:
571 // Vector Count Trailing Zeros
572 case RISCV::VCTZ_V:
573 // Vector Population Count
574 case RISCV::VCPOP_V:
575 // Vector Rotate Left
576 case RISCV::VROL_VV:
577 case RISCV::VROL_VX:
578 // Vector Rotate Right
579 case RISCV::VROR_VI:
580 case RISCV::VROR_VV:
581 case RISCV::VROR_VX:
582 // Vector Carry-less Multiplication Instructions (Zvbc)
583 // Vector Carry-less Multiply
584 case RISCV::VCLMUL_VV:
585 case RISCV::VCLMUL_VX:
586 // Vector Carry-less Multiply Return High Half
587 case RISCV::VCLMULH_VV:
588 case RISCV::VCLMULH_VX:
589
590 // Zvabd
591 case RISCV::VABS_V:
592 case RISCV::VABD_VV:
593 case RISCV::VABDU_VV:
594
595 // XRivosVizip
596 case RISCV::RI_VZIPEVEN_VV:
597 case RISCV::RI_VZIPODD_VV:
598 case RISCV::RI_VZIP2A_VV:
599 case RISCV::RI_VZIP2B_VV:
600 case RISCV::RI_VUNZIP2A_VV:
601 case RISCV::RI_VUNZIP2B_VV:
602 return MILog2SEW;
603
604 // Vector Widening Shift Left Logical (Zvbb)
605 case RISCV::VWSLL_VI:
606 case RISCV::VWSLL_VX:
607 case RISCV::VWSLL_VV:
608 // Vector Widening Integer Add/Subtract
609 // Def uses EEW=2*SEW . Operands use EEW=SEW.
610 case RISCV::VWADDU_VV:
611 case RISCV::VWADDU_VX:
612 case RISCV::VWSUBU_VV:
613 case RISCV::VWSUBU_VX:
614 case RISCV::VWADD_VV:
615 case RISCV::VWADD_VX:
616 case RISCV::VWSUB_VV:
617 case RISCV::VWSUB_VX:
618 // Vector Widening Integer Multiply Instructions
619 // Destination EEW=2*SEW. Source EEW=SEW.
620 case RISCV::VWMUL_VV:
621 case RISCV::VWMUL_VX:
622 case RISCV::VWMULSU_VV:
623 case RISCV::VWMULSU_VX:
624 case RISCV::VWMULU_VV:
625 case RISCV::VWMULU_VX:
626 // Vector Widening Integer Multiply-Add Instructions
627 // Destination EEW=2*SEW. Source EEW=SEW.
628 // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which
629 // is then added to the 2*SEW-bit Dest. These instructions never have a
630 // passthru operand.
631 case RISCV::VWMACCU_VV:
632 case RISCV::VWMACCU_VX:
633 case RISCV::VWMACC_VV:
634 case RISCV::VWMACC_VX:
635 case RISCV::VWMACCSU_VV:
636 case RISCV::VWMACCSU_VX:
637 case RISCV::VWMACCUS_VX:
638 // Vector Widening Floating-Point Fused Multiply-Add Instructions
639 case RISCV::VFWMACC_VF:
640 case RISCV::VFWMACC_VV:
641 case RISCV::VFWNMACC_VF:
642 case RISCV::VFWNMACC_VV:
643 case RISCV::VFWMSAC_VF:
644 case RISCV::VFWMSAC_VV:
645 case RISCV::VFWNMSAC_VF:
646 case RISCV::VFWNMSAC_VV:
647 case RISCV::VFWMACCBF16_VV:
648 case RISCV::VFWMACCBF16_VF:
649 // Vector Widening Floating-Point Add/Subtract Instructions
650 // Dest EEW=2*SEW. Source EEW=SEW.
651 case RISCV::VFWADD_VV:
652 case RISCV::VFWADD_VF:
653 case RISCV::VFWSUB_VV:
654 case RISCV::VFWSUB_VF:
655 // Vector Widening Floating-Point Multiply
656 case RISCV::VFWMUL_VF:
657 case RISCV::VFWMUL_VV:
658 // Widening Floating-Point/Integer Type-Convert Instructions
659 case RISCV::VFWCVT_XU_F_V:
660 case RISCV::VFWCVT_X_F_V:
661 case RISCV::VFWCVT_RTZ_XU_F_V:
662 case RISCV::VFWCVT_RTZ_X_F_V:
663 case RISCV::VFWCVT_F_XU_V:
664 case RISCV::VFWCVT_F_X_V:
665 case RISCV::VFWCVT_F_F_V:
666 case RISCV::VFWCVTBF16_F_F_V:
667 // Zvabd
668 case RISCV::VWABDA_VV:
669 case RISCV::VWABDAU_VV:
670 return IsMODef ? MILog2SEW + 1 : MILog2SEW;
671
672 // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW.
673 case RISCV::VWADDU_WV:
674 case RISCV::VWADDU_WX:
675 case RISCV::VWSUBU_WV:
676 case RISCV::VWSUBU_WX:
677 case RISCV::VWADD_WV:
678 case RISCV::VWADD_WX:
679 case RISCV::VWSUB_WV:
680 case RISCV::VWSUB_WX:
681 // Vector Widening Floating-Point Add/Subtract Instructions
682 case RISCV::VFWADD_WF:
683 case RISCV::VFWADD_WV:
684 case RISCV::VFWSUB_WF:
685 case RISCV::VFWSUB_WV: {
686 bool IsOp1 = (HasPassthru && !IsTied) ? MO.getOperandNo() == 2
687 : MO.getOperandNo() == 1;
688 bool TwoTimes = IsMODef || IsOp1;
689 return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
690 }
691
692 // Vector Integer Extension
693 case RISCV::VZEXT_VF2:
694 case RISCV::VSEXT_VF2:
695 return getIntegerExtensionOperandEEW(Factor: 2, MI, MO);
696 case RISCV::VZEXT_VF4:
697 case RISCV::VSEXT_VF4:
698 return getIntegerExtensionOperandEEW(Factor: 4, MI, MO);
699 case RISCV::VZEXT_VF8:
700 case RISCV::VSEXT_VF8:
701 return getIntegerExtensionOperandEEW(Factor: 8, MI, MO);
702
703 // Vector Narrowing Integer Right Shift Instructions
704 // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW
705 case RISCV::VNSRL_WX:
706 case RISCV::VNSRL_WI:
707 case RISCV::VNSRL_WV:
708 case RISCV::VNSRA_WI:
709 case RISCV::VNSRA_WV:
710 case RISCV::VNSRA_WX:
711 // Vector Narrowing Fixed-Point Clip Instructions
712 // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW.
713 case RISCV::VNCLIPU_WI:
714 case RISCV::VNCLIPU_WV:
715 case RISCV::VNCLIPU_WX:
716 case RISCV::VNCLIP_WI:
717 case RISCV::VNCLIP_WV:
718 case RISCV::VNCLIP_WX:
719 // Narrowing Floating-Point/Integer Type-Convert Instructions
720 case RISCV::VFNCVT_XU_F_W:
721 case RISCV::VFNCVT_X_F_W:
722 case RISCV::VFNCVT_RTZ_XU_F_W:
723 case RISCV::VFNCVT_RTZ_X_F_W:
724 case RISCV::VFNCVT_F_XU_W:
725 case RISCV::VFNCVT_F_X_W:
726 case RISCV::VFNCVT_F_F_W:
727 case RISCV::VFNCVT_ROD_F_F_W:
728 case RISCV::VFNCVTBF16_F_F_W: {
729 assert(!IsTied);
730 bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
731 bool TwoTimes = IsOp1;
732 return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
733 }
734
735 // Vector Mask Instructions
736 // Vector Mask-Register Logical Instructions
737 // vmsbf.m set-before-first mask bit
738 // vmsif.m set-including-first mask bit
739 // vmsof.m set-only-first mask bit
740 // EEW=1
741 // We handle the cases when operand is a v0 mask operand above the switch,
742 // but these instructions may use non-v0 mask operands and need to be handled
743 // specifically.
744 case RISCV::VMAND_MM:
745 case RISCV::VMNAND_MM:
746 case RISCV::VMANDN_MM:
747 case RISCV::VMXOR_MM:
748 case RISCV::VMOR_MM:
749 case RISCV::VMNOR_MM:
750 case RISCV::VMORN_MM:
751 case RISCV::VMXNOR_MM:
752 case RISCV::VMSBF_M:
753 case RISCV::VMSIF_M:
754 case RISCV::VMSOF_M: {
755 return MILog2SEW;
756 }
757
758 // Vector Compress Instruction
759 // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
760 // before this switch.
761 case RISCV::VCOMPRESS_VM:
762 return MO.getOperandNo() == 3 ? 0 : MILog2SEW;
763
764 // Vector Iota Instruction
765 // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
766 // before this switch.
767 case RISCV::VIOTA_M: {
768 if (IsMODef || MO.getOperandNo() == 1)
769 return MILog2SEW;
770 return 0;
771 }
772
773 // Vector Integer Compare Instructions
774 // Dest EEW=1. Source EEW=SEW.
775 case RISCV::VMSEQ_VI:
776 case RISCV::VMSEQ_VV:
777 case RISCV::VMSEQ_VX:
778 case RISCV::VMSNE_VI:
779 case RISCV::VMSNE_VV:
780 case RISCV::VMSNE_VX:
781 case RISCV::VMSLTU_VV:
782 case RISCV::VMSLTU_VX:
783 case RISCV::VMSLT_VV:
784 case RISCV::VMSLT_VX:
785 case RISCV::VMSLEU_VV:
786 case RISCV::VMSLEU_VI:
787 case RISCV::VMSLEU_VX:
788 case RISCV::VMSLE_VV:
789 case RISCV::VMSLE_VI:
790 case RISCV::VMSLE_VX:
791 case RISCV::VMSGTU_VI:
792 case RISCV::VMSGTU_VX:
793 case RISCV::VMSGT_VI:
794 case RISCV::VMSGT_VX:
795 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
796 // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch.
797 case RISCV::VMADC_VIM:
798 case RISCV::VMADC_VVM:
799 case RISCV::VMADC_VXM:
800 case RISCV::VMSBC_VVM:
801 case RISCV::VMSBC_VXM:
802 // Dest EEW=1. Source EEW=SEW.
803 case RISCV::VMADC_VV:
804 case RISCV::VMADC_VI:
805 case RISCV::VMADC_VX:
806 case RISCV::VMSBC_VV:
807 case RISCV::VMSBC_VX:
808 // 13.13. Vector Floating-Point Compare Instructions
809 // Dest EEW=1. Source EEW=SEW
810 case RISCV::VMFEQ_VF:
811 case RISCV::VMFEQ_VV:
812 case RISCV::VMFNE_VF:
813 case RISCV::VMFNE_VV:
814 case RISCV::VMFLT_VF:
815 case RISCV::VMFLT_VV:
816 case RISCV::VMFLE_VF:
817 case RISCV::VMFLE_VV:
818 case RISCV::VMFGT_VF:
819 case RISCV::VMFGE_VF: {
820 if (IsMODef)
821 return 0;
822 return MILog2SEW;
823 }
824
825 // Vector Reduction Operations
826 // Vector Single-Width Integer Reduction Instructions
827 case RISCV::VREDAND_VS:
828 case RISCV::VREDMAX_VS:
829 case RISCV::VREDMAXU_VS:
830 case RISCV::VREDMIN_VS:
831 case RISCV::VREDMINU_VS:
832 case RISCV::VREDOR_VS:
833 case RISCV::VREDSUM_VS:
834 case RISCV::VREDXOR_VS:
835 // Vector Single-Width Floating-Point Reduction Instructions
836 case RISCV::VFREDMAX_VS:
837 case RISCV::VFREDMIN_VS:
838 case RISCV::VFREDOSUM_VS:
839 case RISCV::VFREDUSUM_VS: {
840 return MILog2SEW;
841 }
842
843 // Vector Widening Integer Reduction Instructions
844 // The Dest and VS1 read only element 0 for the vector register. Return
845 // 2*EEW for these. VS2 has EEW=SEW and EMUL=LMUL.
846 case RISCV::VWREDSUM_VS:
847 case RISCV::VWREDSUMU_VS:
848 // Vector Widening Floating-Point Reduction Instructions
849 case RISCV::VFWREDOSUM_VS:
850 case RISCV::VFWREDUSUM_VS: {
851 bool TwoTimes = IsMODef || MO.getOperandNo() == 3;
852 return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
853 }
854
855 // Vector Register Gather with 16-bit Index Elements Instruction
856 // Dest and source data EEW=SEW. Index vector EEW=16.
857 case RISCV::VRGATHEREI16_VV: {
858 if (MO.getOperandNo() == 2)
859 return 4;
860 return MILog2SEW;
861 }
862
863 default:
864 return std::nullopt;
865 }
866}
867
868static std::optional<OperandInfo> getOperandInfo(const MachineOperand &MO) {
869 const MachineInstr &MI = *MO.getParent();
870 const RISCVVPseudosTable::PseudoInfo *RVV =
871 RISCVVPseudosTable::getPseudoInfo(Pseudo: MI.getOpcode());
872 assert(RVV && "Could not find MI in PseudoTable");
873
874 std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO);
875 if (!Log2EEW)
876 return std::nullopt;
877
878 switch (RVV->BaseInstr) {
879 // Vector Reduction Operations
880 // Vector Single-Width Integer Reduction Instructions
881 // Vector Widening Integer Reduction Instructions
882 // Vector Widening Floating-Point Reduction Instructions
883 // The Dest and VS1 only read element 0 of the vector register. Return just
884 // the EEW for these.
885 case RISCV::VREDAND_VS:
886 case RISCV::VREDMAX_VS:
887 case RISCV::VREDMAXU_VS:
888 case RISCV::VREDMIN_VS:
889 case RISCV::VREDMINU_VS:
890 case RISCV::VREDOR_VS:
891 case RISCV::VREDSUM_VS:
892 case RISCV::VREDXOR_VS:
893 case RISCV::VWREDSUM_VS:
894 case RISCV::VWREDSUMU_VS:
895 case RISCV::VFWREDOSUM_VS:
896 case RISCV::VFWREDUSUM_VS:
897 if (MO.getOperandNo() != 2)
898 return OperandInfo(*Log2EEW);
899 break;
900 };
901
902 // All others have EMUL=EEW/SEW*LMUL
903 return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW: *Log2EEW, MI), *Log2EEW);
904}
905
906static bool isTupleInsertInstr(const MachineInstr &MI);
907
908/// Return true if we can reason about demanded VLs elementwise for \p MI.
909bool RISCVVLOptimizer::isSupportedInstr(const MachineInstr &MI) const {
910 if (MI.isPHI() || MI.isFullCopy() || isTupleInsertInstr(MI))
911 return true;
912
913 unsigned RVVOpc = RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode());
914 if (!RVVOpc)
915 return false;
916
917 assert(!(MI.getNumExplicitDefs() == 0 && !MI.mayStore() &&
918 !RISCVII::elementsDependOnVL(TII->get(RVVOpc).TSFlags)) &&
919 "No defs but elements don't depend on VL?");
920
921 // TODO: Reduce vl for vmv.s.x and vfmv.s.f. Currently this introduces more vl
922 // toggles, we need to extend PRE in RISCVInsertVSETVLI first.
923 if (RVVOpc == RISCV::VMV_S_X || RVVOpc == RISCV::VFMV_S_F)
924 return false;
925
926 if (RISCVII::elementsDependOnVL(TSFlags: TII->get(Opcode: RVVOpc).TSFlags))
927 return false;
928
929 if (MI.mayStore())
930 return false;
931
932 return true;
933}
934
935/// Return true if MO is a vector operand but is used as a scalar operand.
936static bool isVectorOpUsedAsScalarOp(const MachineOperand &MO) {
937 const MachineInstr *MI = MO.getParent();
938 const RISCVVPseudosTable::PseudoInfo *RVV =
939 RISCVVPseudosTable::getPseudoInfo(Pseudo: MI->getOpcode());
940
941 if (!RVV)
942 return false;
943
944 switch (RVV->BaseInstr) {
945 // Reductions only use vs1[0] of vs1
946 case RISCV::VREDAND_VS:
947 case RISCV::VREDMAX_VS:
948 case RISCV::VREDMAXU_VS:
949 case RISCV::VREDMIN_VS:
950 case RISCV::VREDMINU_VS:
951 case RISCV::VREDOR_VS:
952 case RISCV::VREDSUM_VS:
953 case RISCV::VREDXOR_VS:
954 case RISCV::VWREDSUM_VS:
955 case RISCV::VWREDSUMU_VS:
956 case RISCV::VFREDMAX_VS:
957 case RISCV::VFREDMIN_VS:
958 case RISCV::VFREDOSUM_VS:
959 case RISCV::VFREDUSUM_VS:
960 case RISCV::VFWREDOSUM_VS:
961 case RISCV::VFWREDUSUM_VS:
962 return MO.getOperandNo() == 3;
963 case RISCV::VMV_X_S:
964 case RISCV::VFMV_F_S:
965 return MO.getOperandNo() == 1;
966 default:
967 return false;
968 }
969}
970
971bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
972 const MCInstrDesc &Desc = MI.getDesc();
973 if (!RISCVII::hasVLOp(TSFlags: Desc.TSFlags) || !RISCVII::hasSEWOp(TSFlags: Desc.TSFlags))
974 return false;
975
976 if (MI.getNumExplicitDefs() != 1)
977 return false;
978
979 // Some instructions have implicit defs e.g. $vxsat. If they might be read
980 // later then we can't reduce VL.
981 if (!MI.allImplicitDefsAreDead()) {
982 LLVM_DEBUG(dbgs() << "Not a candidate because has non-dead implicit def\n");
983 return false;
984 }
985
986 if (MI.mayRaiseFPException()) {
987 LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n");
988 return false;
989 }
990
991 for (const MachineMemOperand *MMO : MI.memoperands()) {
992 if (MMO->isVolatile()) {
993 LLVM_DEBUG(dbgs() << "Not a candidate because contains volatile MMO\n");
994 return false;
995 }
996 }
997
998 if (!isSupportedInstr(MI)) {
999 LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction: "
1000 << MI);
1001 return false;
1002 }
1003
1004 assert(!RISCVII::elementsDependOnVL(
1005 TII->get(RISCV::getRVVMCOpcode(MI.getOpcode())).TSFlags) &&
1006 "Instruction shouldn't be supported if elements depend on VL");
1007
1008 assert(RISCVRI::isVRegClass(
1009 MRI->getRegClass(MI.getOperand(0).getReg())->TSFlags) &&
1010 "All supported instructions produce a vector register result");
1011
1012 LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n");
1013 return true;
1014}
1015
1016/// Given a vslidedown.vx like:
1017///
1018/// %slideamt = ADDI %x, -1
1019/// %v = PseudoVSLIDEDOWN_VX %passthru, %src, %slideamt, avl=1
1020///
1021/// %v will only read the first %slideamt + 1 lanes of %src, which = %x.
1022/// This is a common case when lowering extractelement.
1023///
1024/// Note that if %x is 0, %slideamt will be all ones. In this case %src will be
1025/// completely slid down and none of its lanes will be read (since %slideamt is
1026/// greater than the largest VLMAX of 65536) so we can demand any minimum VL.
1027static std::optional<DemandedVL>
1028getMinimumVLForVSLIDEDOWN_VX(const MachineOperand &UserOp,
1029 const MachineRegisterInfo *MRI) {
1030 const MachineInstr &MI = *UserOp.getParent();
1031 if (RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode()) != RISCV::VSLIDEDOWN_VX)
1032 return std::nullopt;
1033 // We're looking at what lanes are used from the src operand.
1034 if (UserOp.getOperandNo() != 2)
1035 return std::nullopt;
1036 // For now, the AVL must be 1.
1037 const MachineOperand &AVL = MI.getOperand(i: 4);
1038 if (!AVL.isImm() || AVL.getImm() != 1)
1039 return std::nullopt;
1040 // The slide amount must be %x - 1.
1041 const MachineOperand &SlideAmt = MI.getOperand(i: 3);
1042 if (!SlideAmt.getReg().isVirtual())
1043 return std::nullopt;
1044 MachineInstr *SlideAmtDef = MRI->getUniqueVRegDef(Reg: SlideAmt.getReg());
1045 if (SlideAmtDef->getOpcode() != RISCV::ADDI ||
1046 SlideAmtDef->getOperand(i: 2).getImm() != -AVL.getImm() ||
1047 !SlideAmtDef->getOperand(i: 1).getReg().isVirtual())
1048 return std::nullopt;
1049 return SlideAmtDef->getOperand(i: 1);
1050}
1051
1052DemandedVL
1053RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
1054 const MachineInstr &UserMI = *UserOp.getParent();
1055 const MCInstrDesc &Desc = UserMI.getDesc();
1056
1057 if (UserMI.isPHI() || UserMI.isFullCopy() || isTupleInsertInstr(MI: UserMI))
1058 return DemandedVLs.lookup(Key: &UserMI);
1059
1060 if (!RISCVII::hasVLOp(TSFlags: Desc.TSFlags) || !RISCVII::hasSEWOp(TSFlags: Desc.TSFlags)) {
1061 LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
1062 " use VLMAX\n");
1063 return DemandedVL::vlmax();
1064 }
1065
1066 if (auto VL = getMinimumVLForVSLIDEDOWN_VX(UserOp, MRI))
1067 return *VL;
1068
1069 if (RISCVII::readsPastVL(
1070 TSFlags: TII->get(Opcode: RISCV::getRVVMCOpcode(RVVPseudoOpcode: UserMI.getOpcode())).TSFlags)) {
1071 LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
1072 return DemandedVL::vlmax();
1073 }
1074
1075 unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
1076 const MachineOperand &VLOp = UserMI.getOperand(i: VLOpNum);
1077 // Looking for an immediate or a register VL that isn't X0.
1078 assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
1079 "Did not expect X0 VL");
1080
1081 // If the user is a passthru it will read the elements past VL, so
1082 // abort if any of the elements past VL are demanded.
1083 if (UserOp.isTied()) {
1084 assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() &&
1085 RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc()));
1086 if (!RISCV::isVLKnownLE(LHS: DemandedVLs.lookup(Key: &UserMI).VL, RHS: VLOp)) {
1087 LLVM_DEBUG(dbgs() << " Abort because user is passthru in "
1088 "instruction with demanded tail\n");
1089 return DemandedVL::vlmax();
1090 }
1091 }
1092
1093 // Instructions like reductions may use a vector register as a scalar
1094 // register. In this case, we should treat it as only reading the first lane.
1095 if (isVectorOpUsedAsScalarOp(MO: UserOp)) {
1096 LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n");
1097 return MachineOperand::CreateImm(Val: 1);
1098 }
1099
1100 // If we know the demanded VL of UserMI, then we can reduce the VL it
1101 // requires.
1102 if (RISCV::isVLKnownLE(LHS: DemandedVLs.lookup(Key: &UserMI).VL, RHS: VLOp))
1103 return DemandedVLs.lookup(Key: &UserMI);
1104
1105 return VLOp;
1106}
1107
1108/// Return true if MI is an instruction used for assembling registers
1109/// for segmented store instructions, namely, RISCVISD::TUPLE_INSERT.
1110/// Currently it's lowered to INSERT_SUBREG.
1111static bool isTupleInsertInstr(const MachineInstr &MI) {
1112 if (!MI.isInsertSubreg())
1113 return false;
1114
1115 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1116 const TargetRegisterClass *DstRC = MRI.getRegClass(Reg: MI.getOperand(i: 0).getReg());
1117 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
1118 if (!RISCVRI::isVRegClass(TSFlags: DstRC->TSFlags))
1119 return false;
1120 unsigned NF = RISCVRI::getNF(TSFlags: DstRC->TSFlags);
1121 if (NF < 2)
1122 return false;
1123
1124 // Check whether INSERT_SUBREG has the correct subreg index for tuple inserts.
1125 auto VLMul = RISCVRI::getLMul(TSFlags: DstRC->TSFlags);
1126 unsigned SubRegIdx = MI.getOperand(i: 3).getImm();
1127 [[maybe_unused]] auto [LMul, IsFractional] = RISCVVType::decodeVLMUL(VLMul);
1128 assert(!IsFractional && "unexpected LMUL for tuple register classes");
1129 return TRI->getSubRegIdxSize(Idx: SubRegIdx) == RISCV::RVVBitsPerBlock * LMul;
1130}
1131
1132static bool isSegmentedStoreInstr(const MachineInstr &MI) {
1133 switch (RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode())) {
1134 case VSSEG_CASES(8):
1135 case VSSSEG_CASES(8):
1136 case VSUXSEG_CASES(8):
1137 case VSOXSEG_CASES(8):
1138 case VSSEG_CASES(16):
1139 case VSSSEG_CASES(16):
1140 case VSUXSEG_CASES(16):
1141 case VSOXSEG_CASES(16):
1142 case VSSEG_CASES(32):
1143 case VSSSEG_CASES(32):
1144 case VSUXSEG_CASES(32):
1145 case VSOXSEG_CASES(32):
1146 case VSSEG_CASES(64):
1147 case VSSSEG_CASES(64):
1148 case VSUXSEG_CASES(64):
1149 case VSOXSEG_CASES(64):
1150 return true;
1151 default:
1152 return false;
1153 }
1154}
1155
1156bool RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
1157 if (MI.isPHI() || MI.isFullCopy() || isTupleInsertInstr(MI))
1158 return true;
1159
1160 SmallSetVector<MachineOperand *, 8> OpWorklist;
1161 SmallPtrSet<const MachineInstr *, 4> PHISeen;
1162 for (auto &UserOp : MRI->use_operands(Reg: MI.getOperand(i: 0).getReg()))
1163 OpWorklist.insert(X: &UserOp);
1164
1165 while (!OpWorklist.empty()) {
1166 MachineOperand &UserOp = *OpWorklist.pop_back_val();
1167 const MachineInstr &UserMI = *UserOp.getParent();
1168 LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
1169
1170 if (UserMI.isFullCopy() && UserMI.getOperand(i: 0).getReg().isVirtual()) {
1171 LLVM_DEBUG(dbgs() << " Peeking through uses of COPY\n");
1172 OpWorklist.insert_range(R: llvm::make_pointer_range(
1173 Range: MRI->use_operands(Reg: UserMI.getOperand(i: 0).getReg())));
1174 continue;
1175 }
1176
1177 if (isTupleInsertInstr(MI: UserMI)) {
1178 LLVM_DEBUG(dbgs().indent(4) << "Peeking through uses of INSERT_SUBREG\n");
1179 for (MachineOperand &UseOp :
1180 MRI->use_operands(Reg: UserMI.getOperand(i: 0).getReg())) {
1181 const MachineInstr &CandidateMI = *UseOp.getParent();
1182 // We should not propagate the VL if the user is not a segmented store
1183 // or another INSERT_SUBREG, since VL just works differently
1184 // between segmented operations (per-field) v.s. other RVV ops (on the
1185 // whole register group).
1186 if (!isTupleInsertInstr(MI: CandidateMI) &&
1187 !isSegmentedStoreInstr(MI: CandidateMI))
1188 return false;
1189 OpWorklist.insert(X: &UseOp);
1190 }
1191 continue;
1192 }
1193
1194 if (UserMI.isPHI()) {
1195 // Don't follow PHI cycles
1196 if (!PHISeen.insert(Ptr: &UserMI).second)
1197 continue;
1198 LLVM_DEBUG(dbgs() << " Peeking through uses of PHI\n");
1199 OpWorklist.insert_range(R: llvm::make_pointer_range(
1200 Range: MRI->use_operands(Reg: UserMI.getOperand(i: 0).getReg())));
1201 continue;
1202 }
1203
1204 if (!RISCVII::hasSEWOp(TSFlags: UserMI.getDesc().TSFlags)) {
1205 LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
1206 return false;
1207 }
1208
1209 std::optional<OperandInfo> ConsumerInfo = getOperandInfo(MO: UserOp);
1210 std::optional<OperandInfo> ProducerInfo = getOperandInfo(MO: MI.getOperand(i: 0));
1211 if (!ConsumerInfo || !ProducerInfo) {
1212 LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
1213 LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
1214 LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1215 return false;
1216 }
1217
1218 if (!OperandInfo::areCompatible(Def: *ProducerInfo, User: *ConsumerInfo)) {
1219 LLVM_DEBUG(
1220 dbgs()
1221 << " Abort due to incompatible information for EMUL or EEW.\n");
1222 LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
1223 LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1224 return false;
1225 }
1226 }
1227
1228 return true;
1229}
1230
1231bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI,
1232 MachineOperand CommonVL) const {
1233 LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI);
1234
1235 unsigned VLOpNum = RISCVII::getVLOpNum(Desc: MI.getDesc());
1236 MachineOperand &VLOp = MI.getOperand(i: VLOpNum);
1237
1238 assert((CommonVL.isImm() || CommonVL.getReg().isVirtual()) &&
1239 "Expected VL to be an Imm or virtual Reg");
1240
1241 // If the VL is defined by a vleff that doesn't dominate MI, try using the
1242 // vleff's AVL. It will be greater than or equal to the output VL.
1243 if (CommonVL.isReg()) {
1244 const MachineInstr *VLMI = MRI->getVRegDef(Reg: CommonVL.getReg());
1245 if (RISCVInstrInfo::isFaultOnlyFirstLoad(MI: *VLMI) &&
1246 !MDT->dominates(A: VLMI, B: &MI))
1247 CommonVL = VLMI->getOperand(i: RISCVII::getVLOpNum(Desc: VLMI->getDesc()));
1248 }
1249
1250 if (!RISCV::isVLKnownLE(LHS: CommonVL, RHS: VLOp)) {
1251 LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
1252 return false;
1253 }
1254
1255 if (CommonVL.isIdenticalTo(Other: VLOp)) {
1256 LLVM_DEBUG(
1257 dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
1258 return false;
1259 }
1260
1261 if (CommonVL.isImm()) {
1262 LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1263 << CommonVL.getImm() << " for " << MI << "\n");
1264 VLOp.ChangeToImmediate(ImmVal: CommonVL.getImm());
1265 return true;
1266 }
1267 MachineInstr *VLMI = MRI->getVRegDef(Reg: CommonVL.getReg());
1268 auto VLDominates = [this, &VLMI](const MachineInstr &MI) {
1269 return MDT->dominates(A: VLMI, B: &MI);
1270 };
1271 if (!VLDominates(MI)) {
1272 assert(MI.getNumExplicitDefs() == 1);
1273 auto Uses = MRI->use_instructions(Reg: MI.getOperand(i: 0).getReg());
1274 auto UsesSameBB = make_filter_range(Range&: Uses, Pred: [&MI](const MachineInstr &Use) {
1275 return Use.getParent() == MI.getParent();
1276 });
1277 if (VLMI->getParent() == MI.getParent() &&
1278 all_of(Range&: UsesSameBB, P: VLDominates) &&
1279 RISCVInstrInfo::isSafeToMove(From: MI, To: *VLMI->getNextNode())) {
1280 MI.moveBefore(MovePos: VLMI->getNextNode());
1281 } else {
1282 LLVM_DEBUG(dbgs() << " Abort due to VL not dominating.\n");
1283 return false;
1284 }
1285 }
1286 LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1287 << printReg(CommonVL.getReg(), MRI->getTargetRegisterInfo())
1288 << " for " << MI << "\n");
1289
1290 // All our checks passed. We can reduce VL.
1291 VLOp.ChangeToRegister(Reg: CommonVL.getReg(), isDef: false);
1292 MRI->constrainRegClass(Reg: CommonVL.getReg(), RC: &RISCV::GPRNoX0RegClass);
1293 return true;
1294}
1295
1296static bool isPhysical(const MachineOperand &MO) {
1297 return MO.isReg() && MO.getReg().isPhysical();
1298}
1299
1300/// Look through \p MI's operands and propagate what it demands to its uses.
1301void RISCVVLOptimizer::transfer(const MachineInstr &MI) {
1302 if (!isSupportedInstr(MI) || !checkUsers(MI) || any_of(Range: MI.defs(), P: isPhysical))
1303 DemandedVLs[&MI] = DemandedVL::vlmax();
1304
1305 for (const MachineOperand &MO : virtual_vec_uses(MI)) {
1306 const MachineInstr *Def = MRI->getVRegDef(Reg: MO.getReg());
1307 DemandedVL Prev = DemandedVLs[Def];
1308 DemandedVLs[Def] = DemandedVLs[Def].max(X: getMinimumVLForUser(UserOp: MO));
1309 if (DemandedVLs[Def] != Prev)
1310 Worklist.insert(X: Def);
1311 }
1312}
1313
1314bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
1315 if (skipFunction(F: MF.getFunction()))
1316 return false;
1317
1318 MRI = &MF.getRegInfo();
1319 MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
1320
1321 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
1322 if (!ST.hasVInstructions())
1323 return false;
1324
1325 TII = ST.getInstrInfo();
1326
1327 assert(DemandedVLs.empty());
1328
1329 // For each instruction that defines a vector, propagate the VL it
1330 // uses to its inputs.
1331 for (MachineBasicBlock *MBB : post_order(G: &MF)) {
1332 assert(MDT->isReachableFromEntry(MBB));
1333 for (MachineInstr &MI : reverse(C&: *MBB))
1334 if (!MI.isDebugInstr())
1335 Worklist.insert(X: &MI);
1336 }
1337
1338 while (!Worklist.empty()) {
1339 const MachineInstr *MI = Worklist.front();
1340 Worklist.remove(X: MI);
1341 transfer(MI: *MI);
1342 }
1343
1344 // Then go through and see if we can reduce the VL of any instructions to
1345 // only what's demanded.
1346 bool MadeChange = false;
1347 for (auto &[MI, VL] : DemandedVLs) {
1348 assert(MDT->isReachableFromEntry(MI->getParent()));
1349 if (!isCandidate(MI: *MI))
1350 continue;
1351 if (!tryReduceVL(MI&: *const_cast<MachineInstr *>(MI), CommonVL: VL.VL))
1352 continue;
1353 MadeChange = true;
1354 }
1355
1356 DemandedVLs.clear();
1357 return MadeChange;
1358}
1359