1//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// This pass reduces the VL where possible at the MI level, before VSETVLI
10// instructions are inserted.
11//
12// The purpose of this optimization is to make the VL argument, for instructions
13// that have a VL argument, as small as possible.
14//
15// This is split into a sparse dataflow analysis where we determine what VL is
16// demanded by each instruction first, and then afterwards try to reduce the VL
17// of each instruction if it demands less than its VL operand.
18//
19// The analysis is explained in more detail in the 2025 EuroLLVM Developers'
20// Meeting talk "Accidental Dataflow Analysis: Extending the RISC-V VL
21// Optimizer", which is available on YouTube at
22// https://www.youtube.com/watch?v=Mfb5fRSdJAc
23//
24// The slides for the talk are available at
25// https://llvm.org/devmtg/2025-04/slides/technical_talk/lau_accidental_dataflow.pdf
26//
27//===---------------------------------------------------------------------===//
28
29#include "RISCV.h"
30#include "RISCVSubtarget.h"
31#include "llvm/ADT/PostOrderIterator.h"
32#include "llvm/ADT/SetVector.h"
33#include "llvm/CodeGen/MachineDominators.h"
34#include "llvm/CodeGen/MachineFunctionPass.h"
35#include "llvm/InitializePasses.h"
36
37using namespace llvm;
38
39#define DEBUG_TYPE "riscv-vl-optimizer"
40#define PASS_NAME "RISC-V VL Optimizer"
41
42namespace {
43
44/// Wrapper around MachineOperand that defaults to immediate 0.
45struct DemandedVL {
46 MachineOperand VL;
47 DemandedVL() : VL(MachineOperand::CreateImm(Val: 0)) {}
48 DemandedVL(MachineOperand VL) : VL(VL) {}
49 static DemandedVL vlmax() {
50 return DemandedVL(MachineOperand::CreateImm(Val: RISCV::VLMaxSentinel));
51 }
52 bool operator!=(const DemandedVL &Other) const {
53 return !VL.isIdenticalTo(Other: Other.VL);
54 }
55
56 DemandedVL max(const DemandedVL &X) const {
57 if (RISCV::isVLKnownLE(LHS: VL, RHS: X.VL))
58 return X;
59 if (RISCV::isVLKnownLE(LHS: X.VL, RHS: VL))
60 return *this;
61 return DemandedVL::vlmax();
62 }
63};
64
65class RISCVVLOptimizer : public MachineFunctionPass {
66 MachineRegisterInfo *MRI;
67 const MachineDominatorTree *MDT;
68 const TargetInstrInfo *TII;
69
70public:
71 static char ID;
72
73 RISCVVLOptimizer() : MachineFunctionPass(ID) {}
74
75 bool runOnMachineFunction(MachineFunction &MF) override;
76
77 void getAnalysisUsage(AnalysisUsage &AU) const override {
78 AU.setPreservesCFG();
79 AU.addRequired<MachineDominatorTreeWrapperPass>();
80 MachineFunctionPass::getAnalysisUsage(AU);
81 }
82
83 StringRef getPassName() const override { return PASS_NAME; }
84
85private:
86 DemandedVL getMinimumVLForUser(const MachineOperand &UserOp) const;
87 /// Returns true if the users of \p MI have compatible EEWs and SEWs.
88 bool checkUsers(const MachineInstr &MI) const;
89 bool tryReduceVL(MachineInstr &MI, MachineOperand VL) const;
90 bool isSupportedInstr(const MachineInstr &MI) const;
91 bool isCandidate(const MachineInstr &MI) const;
92 void transfer(const MachineInstr &MI);
93
94 /// For a given instruction, records what elements of it are demanded by
95 /// downstream users.
96 MapVector<const MachineInstr *, DemandedVL> DemandedVLs;
97 SetVector<const MachineInstr *> Worklist;
98
99 /// \returns all vector virtual registers that \p MI uses.
100 auto virtual_vec_uses(const MachineInstr &MI) const {
101 return make_filter_range(Range: MI.uses(), Pred: [this](const MachineOperand &MO) {
102 return MO.isReg() && MO.getReg().isVirtual() &&
103 RISCVRegisterInfo::isRVVRegClass(RC: MRI->getRegClass(Reg: MO.getReg()));
104 });
105 }
106};
107
108/// Represents the EMUL and EEW of a MachineOperand.
109struct OperandInfo {
110 // Represent as 1,2,4,8, ... and fractional indicator. This is because
111 // EMUL can take on values that don't map to RISCVVType::VLMUL values exactly.
112 // For example, a mask operand can have an EMUL less than MF8.
113 // If nullopt, then EMUL isn't used (i.e. only a single scalar is read).
114 std::optional<std::pair<unsigned, bool>> EMUL;
115
116 unsigned Log2EEW;
117
118 OperandInfo(RISCVVType::VLMUL EMUL, unsigned Log2EEW)
119 : EMUL(RISCVVType::decodeVLMUL(VLMul: EMUL)), Log2EEW(Log2EEW) {}
120
121 OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
122 : EMUL(EMUL), Log2EEW(Log2EEW) {}
123
124 OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {}
125
126 OperandInfo() = delete;
127
128 /// Return true if the EMUL and EEW produced by \p Def is compatible with the
129 /// EMUL and EEW used by \p User.
130 static bool areCompatible(const OperandInfo &Def, const OperandInfo &User) {
131 if (Def.Log2EEW != User.Log2EEW)
132 return false;
133 if (User.EMUL && Def.EMUL != User.EMUL)
134 return false;
135 return true;
136 }
137
138 void print(raw_ostream &OS) const {
139 if (EMUL) {
140 OS << "EMUL: m";
141 if (EMUL->second)
142 OS << "f";
143 OS << EMUL->first;
144 } else
145 OS << "EMUL: none\n";
146 OS << ", EEW: " << (1 << Log2EEW);
147 }
148};
149
150} // end anonymous namespace
151
152char RISCVVLOptimizer::ID = 0;
153INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
154INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
155INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
156
157FunctionPass *llvm::createRISCVVLOptimizerPass() {
158 return new RISCVVLOptimizer();
159}
160
161[[maybe_unused]]
162static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
163 OI.print(OS);
164 return OS;
165}
166
167[[maybe_unused]]
168static raw_ostream &operator<<(raw_ostream &OS,
169 const std::optional<OperandInfo> &OI) {
170 if (OI)
171 OI->print(OS);
172 else
173 OS << "nullopt";
174 return OS;
175}
176
177/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
178/// SEW are from the TSFlags of MI.
179static std::pair<unsigned, bool>
180getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
181 RISCVVType::VLMUL MIVLMUL = RISCVII::getLMul(TSFlags: MI.getDesc().TSFlags);
182 auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(VLMul: MIVLMUL);
183 unsigned MILog2SEW =
184 MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MI.getDesc())).getImm();
185
186 // Mask instructions will have 0 as the SEW operand. But the LMUL of these
187 // instructions is calculated is as if the SEW operand was 3 (e8).
188 if (MILog2SEW == 0)
189 MILog2SEW = 3;
190
191 unsigned MISEW = 1 << MILog2SEW;
192
193 unsigned EEW = 1 << Log2EEW;
194 // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
195 // to put fraction in simplest form.
196 unsigned Num = EEW, Denom = MISEW;
197 int GCD = MILMULIsFractional ? std::gcd(m: Num, n: Denom * MILMUL)
198 : std::gcd(m: Num * MILMUL, n: Denom);
199 Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
200 Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
201 return std::make_pair(x&: Num > Denom ? Num : Denom, y: Denom > Num);
202}
203
204/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
205/// SEW comes from TSFlags of MI.
206static unsigned getIntegerExtensionOperandEEW(unsigned Factor,
207 const MachineInstr &MI,
208 const MachineOperand &MO) {
209 unsigned MILog2SEW =
210 MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MI.getDesc())).getImm();
211
212 if (MO.getOperandNo() == 0)
213 return MILog2SEW;
214
215 unsigned MISEW = 1 << MILog2SEW;
216 unsigned EEW = MISEW / Factor;
217 unsigned Log2EEW = Log2_32(Value: EEW);
218
219 return Log2EEW;
220}
221
222#define VSEG_CASES(Prefix, EEW) \
223 RISCV::Prefix##SEG2E##EEW##_V: \
224 case RISCV::Prefix##SEG3E##EEW##_V: \
225 case RISCV::Prefix##SEG4E##EEW##_V: \
226 case RISCV::Prefix##SEG5E##EEW##_V: \
227 case RISCV::Prefix##SEG6E##EEW##_V: \
228 case RISCV::Prefix##SEG7E##EEW##_V: \
229 case RISCV::Prefix##SEG8E##EEW##_V
230#define VSSEG_CASES(EEW) VSEG_CASES(VS, EEW)
231#define VSSSEG_CASES(EEW) VSEG_CASES(VSS, EEW)
232#define VSUXSEG_CASES(EEW) VSEG_CASES(VSUX, I##EEW)
233#define VSOXSEG_CASES(EEW) VSEG_CASES(VSOX, I##EEW)
234
235static std::optional<unsigned> getOperandLog2EEW(const MachineOperand &MO) {
236 const MachineInstr &MI = *MO.getParent();
237 const MCInstrDesc &Desc = MI.getDesc();
238 const RISCVVPseudosTable::PseudoInfo *RVV =
239 RISCVVPseudosTable::getPseudoInfo(Pseudo: MI.getOpcode());
240 assert(RVV && "Could not find MI in PseudoTable");
241
242 // MI has a SEW associated with it. The RVV specification defines
243 // the EEW of each operand and definition in relation to MI.SEW.
244 unsigned MILog2SEW = MI.getOperand(i: RISCVII::getSEWOpNum(Desc)).getImm();
245
246 const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
247 const bool IsTied = RISCVII::isTiedPseudo(TSFlags: Desc.TSFlags);
248
249 bool IsMODef = MO.getOperandNo() == 0 ||
250 (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs());
251
252 // All mask operands have EEW=1
253 const MCOperandInfo &Info = Desc.operands()[MO.getOperandNo()];
254 if (Info.OperandType == MCOI::OPERAND_REGISTER &&
255 Info.RegClass == RISCV::VMV0RegClassID)
256 return 0;
257
258 // switch against BaseInstr to reduce number of cases that need to be
259 // considered.
260 switch (RVV->BaseInstr) {
261
262 // 6. Configuration-Setting Instructions
263 // Configuration setting instructions do not read or write vector registers
264 case RISCV::VSETIVLI:
265 case RISCV::VSETVL:
266 case RISCV::VSETVLI:
267 llvm_unreachable("Configuration setting instructions do not read or write "
268 "vector registers");
269
270 // Vector Loads and Stores
271 // Vector Unit-Stride Instructions
272 // Vector Strided Instructions
273 /// Dest EEW encoded in the instruction
274 case RISCV::VLM_V:
275 case RISCV::VSM_V:
276 return 0;
277 case RISCV::VLE8_V:
278 case RISCV::VSE8_V:
279 case RISCV::VLSE8_V:
280 case RISCV::VSSE8_V:
281 case VSSEG_CASES(8):
282 case VSSSEG_CASES(8):
283 return 3;
284 case RISCV::VLE16_V:
285 case RISCV::VSE16_V:
286 case RISCV::VLSE16_V:
287 case RISCV::VSSE16_V:
288 case VSSEG_CASES(16):
289 case VSSSEG_CASES(16):
290 return 4;
291 case RISCV::VLE32_V:
292 case RISCV::VSE32_V:
293 case RISCV::VLSE32_V:
294 case RISCV::VSSE32_V:
295 case VSSEG_CASES(32):
296 case VSSSEG_CASES(32):
297 return 5;
298 case RISCV::VLE64_V:
299 case RISCV::VSE64_V:
300 case RISCV::VLSE64_V:
301 case RISCV::VSSE64_V:
302 case VSSEG_CASES(64):
303 case VSSSEG_CASES(64):
304 return 6;
305
306 // Vector Indexed Instructions
307 // vs(o|u)xei<eew>.v
308 // Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>.
309 case RISCV::VLUXEI8_V:
310 case RISCV::VLOXEI8_V:
311 case RISCV::VSUXEI8_V:
312 case RISCV::VSOXEI8_V:
313 case VSUXSEG_CASES(8):
314 case VSOXSEG_CASES(8): {
315 if (MO.getOperandNo() == 0)
316 return MILog2SEW;
317 return 3;
318 }
319 case RISCV::VLUXEI16_V:
320 case RISCV::VLOXEI16_V:
321 case RISCV::VSUXEI16_V:
322 case RISCV::VSOXEI16_V:
323 case VSUXSEG_CASES(16):
324 case VSOXSEG_CASES(16): {
325 if (MO.getOperandNo() == 0)
326 return MILog2SEW;
327 return 4;
328 }
329 case RISCV::VLUXEI32_V:
330 case RISCV::VLOXEI32_V:
331 case RISCV::VSUXEI32_V:
332 case RISCV::VSOXEI32_V:
333 case VSUXSEG_CASES(32):
334 case VSOXSEG_CASES(32): {
335 if (MO.getOperandNo() == 0)
336 return MILog2SEW;
337 return 5;
338 }
339 case RISCV::VLUXEI64_V:
340 case RISCV::VLOXEI64_V:
341 case RISCV::VSUXEI64_V:
342 case RISCV::VSOXEI64_V:
343 case VSUXSEG_CASES(64):
344 case VSOXSEG_CASES(64): {
345 if (MO.getOperandNo() == 0)
346 return MILog2SEW;
347 return 6;
348 }
349
350 // Vector Integer Arithmetic Instructions
351 // Vector Single-Width Integer Add and Subtract
352 case RISCV::VADD_VI:
353 case RISCV::VADD_VV:
354 case RISCV::VADD_VX:
355 case RISCV::VSUB_VV:
356 case RISCV::VSUB_VX:
357 case RISCV::VRSUB_VI:
358 case RISCV::VRSUB_VX:
359 // Vector Bitwise Logical Instructions
360 // Vector Single-Width Shift Instructions
361 // EEW=SEW.
362 case RISCV::VAND_VI:
363 case RISCV::VAND_VV:
364 case RISCV::VAND_VX:
365 case RISCV::VOR_VI:
366 case RISCV::VOR_VV:
367 case RISCV::VOR_VX:
368 case RISCV::VXOR_VI:
369 case RISCV::VXOR_VV:
370 case RISCV::VXOR_VX:
371 case RISCV::VSLL_VI:
372 case RISCV::VSLL_VV:
373 case RISCV::VSLL_VX:
374 case RISCV::VSRL_VI:
375 case RISCV::VSRL_VV:
376 case RISCV::VSRL_VX:
377 case RISCV::VSRA_VI:
378 case RISCV::VSRA_VV:
379 case RISCV::VSRA_VX:
380 // Vector Integer Min/Max Instructions
381 // EEW=SEW.
382 case RISCV::VMINU_VV:
383 case RISCV::VMINU_VX:
384 case RISCV::VMIN_VV:
385 case RISCV::VMIN_VX:
386 case RISCV::VMAXU_VV:
387 case RISCV::VMAXU_VX:
388 case RISCV::VMAX_VV:
389 case RISCV::VMAX_VX:
390 // Vector Single-Width Integer Multiply Instructions
391 // Source and Dest EEW=SEW.
392 case RISCV::VMUL_VV:
393 case RISCV::VMUL_VX:
394 case RISCV::VMULH_VV:
395 case RISCV::VMULH_VX:
396 case RISCV::VMULHU_VV:
397 case RISCV::VMULHU_VX:
398 case RISCV::VMULHSU_VV:
399 case RISCV::VMULHSU_VX:
400 // Vector Integer Divide Instructions
401 // EEW=SEW.
402 case RISCV::VDIVU_VV:
403 case RISCV::VDIVU_VX:
404 case RISCV::VDIV_VV:
405 case RISCV::VDIV_VX:
406 case RISCV::VREMU_VV:
407 case RISCV::VREMU_VX:
408 case RISCV::VREM_VV:
409 case RISCV::VREM_VX:
410 // Vector Single-Width Integer Multiply-Add Instructions
411 // EEW=SEW.
412 case RISCV::VMACC_VV:
413 case RISCV::VMACC_VX:
414 case RISCV::VNMSAC_VV:
415 case RISCV::VNMSAC_VX:
416 case RISCV::VMADD_VV:
417 case RISCV::VMADD_VX:
418 case RISCV::VNMSUB_VV:
419 case RISCV::VNMSUB_VX:
420 // Vector Integer Merge Instructions
421 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
422 // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled
423 // before this switch.
424 case RISCV::VMERGE_VIM:
425 case RISCV::VMERGE_VVM:
426 case RISCV::VMERGE_VXM:
427 case RISCV::VADC_VIM:
428 case RISCV::VADC_VVM:
429 case RISCV::VADC_VXM:
430 case RISCV::VSBC_VVM:
431 case RISCV::VSBC_VXM:
432 // Vector Integer Move Instructions
433 // Vector Fixed-Point Arithmetic Instructions
434 // Vector Single-Width Saturating Add and Subtract
435 // Vector Single-Width Averaging Add and Subtract
436 // EEW=SEW.
437 case RISCV::VMV_V_I:
438 case RISCV::VMV_V_V:
439 case RISCV::VMV_V_X:
440 case RISCV::VSADDU_VI:
441 case RISCV::VSADDU_VV:
442 case RISCV::VSADDU_VX:
443 case RISCV::VSADD_VI:
444 case RISCV::VSADD_VV:
445 case RISCV::VSADD_VX:
446 case RISCV::VSSUBU_VV:
447 case RISCV::VSSUBU_VX:
448 case RISCV::VSSUB_VV:
449 case RISCV::VSSUB_VX:
450 case RISCV::VAADDU_VV:
451 case RISCV::VAADDU_VX:
452 case RISCV::VAADD_VV:
453 case RISCV::VAADD_VX:
454 case RISCV::VASUBU_VV:
455 case RISCV::VASUBU_VX:
456 case RISCV::VASUB_VV:
457 case RISCV::VASUB_VX:
458 // Vector Single-Width Fractional Multiply with Rounding and Saturation
459 // EEW=SEW. The instruction produces 2*SEW product internally but
460 // saturates to fit into SEW bits.
461 case RISCV::VSMUL_VV:
462 case RISCV::VSMUL_VX:
463 // Vector Single-Width Scaling Shift Instructions
464 // EEW=SEW.
465 case RISCV::VSSRL_VI:
466 case RISCV::VSSRL_VV:
467 case RISCV::VSSRL_VX:
468 case RISCV::VSSRA_VI:
469 case RISCV::VSSRA_VV:
470 case RISCV::VSSRA_VX:
471 // Vector Permutation Instructions
472 // Integer Scalar Move Instructions
473 // Floating-Point Scalar Move Instructions
474 // EEW=SEW.
475 case RISCV::VMV_X_S:
476 case RISCV::VMV_S_X:
477 case RISCV::VFMV_F_S:
478 case RISCV::VFMV_S_F:
479 // Vector Slide Instructions
480 // EEW=SEW.
481 case RISCV::VSLIDEUP_VI:
482 case RISCV::VSLIDEUP_VX:
483 case RISCV::VSLIDEDOWN_VI:
484 case RISCV::VSLIDEDOWN_VX:
485 case RISCV::VSLIDE1UP_VX:
486 case RISCV::VFSLIDE1UP_VF:
487 case RISCV::VSLIDE1DOWN_VX:
488 case RISCV::VFSLIDE1DOWN_VF:
489 // Vector Register Gather Instructions
490 // EEW=SEW. For mask operand, EEW=1.
491 case RISCV::VRGATHER_VI:
492 case RISCV::VRGATHER_VV:
493 case RISCV::VRGATHER_VX:
494 // Vector Element Index Instruction
495 case RISCV::VID_V:
496 // Vector Single-Width Floating-Point Add/Subtract Instructions
497 case RISCV::VFADD_VF:
498 case RISCV::VFADD_VV:
499 case RISCV::VFSUB_VF:
500 case RISCV::VFSUB_VV:
501 case RISCV::VFRSUB_VF:
502 // Vector Single-Width Floating-Point Multiply/Divide Instructions
503 case RISCV::VFMUL_VF:
504 case RISCV::VFMUL_VV:
505 case RISCV::VFDIV_VF:
506 case RISCV::VFDIV_VV:
507 case RISCV::VFRDIV_VF:
508 // Vector Single-Width Floating-Point Fused Multiply-Add Instructions
509 case RISCV::VFMACC_VV:
510 case RISCV::VFMACC_VF:
511 case RISCV::VFNMACC_VV:
512 case RISCV::VFNMACC_VF:
513 case RISCV::VFMSAC_VV:
514 case RISCV::VFMSAC_VF:
515 case RISCV::VFNMSAC_VV:
516 case RISCV::VFNMSAC_VF:
517 case RISCV::VFMADD_VV:
518 case RISCV::VFMADD_VF:
519 case RISCV::VFNMADD_VV:
520 case RISCV::VFNMADD_VF:
521 case RISCV::VFMSUB_VV:
522 case RISCV::VFMSUB_VF:
523 case RISCV::VFNMSUB_VV:
524 case RISCV::VFNMSUB_VF:
525 // Vector Floating-Point Square-Root Instruction
526 case RISCV::VFSQRT_V:
527 // Vector Floating-Point Reciprocal Square-Root Estimate Instruction
528 case RISCV::VFRSQRT7_V:
529 // Vector Floating-Point Reciprocal Estimate Instruction
530 case RISCV::VFREC7_V:
531 // Vector Floating-Point MIN/MAX Instructions
532 case RISCV::VFMIN_VF:
533 case RISCV::VFMIN_VV:
534 case RISCV::VFMAX_VF:
535 case RISCV::VFMAX_VV:
536 // Vector Floating-Point Sign-Injection Instructions
537 case RISCV::VFSGNJ_VF:
538 case RISCV::VFSGNJ_VV:
539 case RISCV::VFSGNJN_VV:
540 case RISCV::VFSGNJN_VF:
541 case RISCV::VFSGNJX_VF:
542 case RISCV::VFSGNJX_VV:
543 // Vector Floating-Point Classify Instruction
544 case RISCV::VFCLASS_V:
545 // Vector Floating-Point Move Instruction
546 case RISCV::VFMV_V_F:
547 // Single-Width Floating-Point/Integer Type-Convert Instructions
548 case RISCV::VFCVT_XU_F_V:
549 case RISCV::VFCVT_X_F_V:
550 case RISCV::VFCVT_RTZ_XU_F_V:
551 case RISCV::VFCVT_RTZ_X_F_V:
552 case RISCV::VFCVT_F_XU_V:
553 case RISCV::VFCVT_F_X_V:
554 // Vector Floating-Point Merge Instruction
555 case RISCV::VFMERGE_VFM:
556 // Vector count population in mask vcpop.m
557 // vfirst find-first-set mask bit
558 case RISCV::VCPOP_M:
559 case RISCV::VFIRST_M:
560 // Vector Bit-manipulation Instructions (Zvbb)
561 // Vector And-Not
562 case RISCV::VANDN_VV:
563 case RISCV::VANDN_VX:
564 // Vector Reverse Bits in Elements
565 case RISCV::VBREV_V:
566 // Vector Reverse Bits in Bytes
567 case RISCV::VBREV8_V:
568 // Vector Reverse Bytes
569 case RISCV::VREV8_V:
570 // Vector Count Leading Zeros
571 case RISCV::VCLZ_V:
572 // Vector Count Trailing Zeros
573 case RISCV::VCTZ_V:
574 // Vector Population Count
575 case RISCV::VCPOP_V:
576 // Vector Rotate Left
577 case RISCV::VROL_VV:
578 case RISCV::VROL_VX:
579 // Vector Rotate Right
580 case RISCV::VROR_VI:
581 case RISCV::VROR_VV:
582 case RISCV::VROR_VX:
583 // Vector Carry-less Multiplication Instructions (Zvbc)
584 // Vector Carry-less Multiply
585 case RISCV::VCLMUL_VV:
586 case RISCV::VCLMUL_VX:
587 // Vector Carry-less Multiply Return High Half
588 case RISCV::VCLMULH_VV:
589 case RISCV::VCLMULH_VX:
590
591 // Zvabd
592 case RISCV::VABS_V:
593 case RISCV::VABD_VV:
594 case RISCV::VABDU_VV:
595
596 // XRivosVizip
597 case RISCV::RI_VZIPEVEN_VV:
598 case RISCV::RI_VZIPODD_VV:
599 case RISCV::RI_VZIP2A_VV:
600 case RISCV::RI_VZIP2B_VV:
601 case RISCV::RI_VUNZIP2A_VV:
602 case RISCV::RI_VUNZIP2B_VV:
603 return MILog2SEW;
604
605 // Vector Widening Shift Left Logical (Zvbb)
606 case RISCV::VWSLL_VI:
607 case RISCV::VWSLL_VX:
608 case RISCV::VWSLL_VV:
609 // Vector Widening Integer Add/Subtract
610 // Def uses EEW=2*SEW . Operands use EEW=SEW.
611 case RISCV::VWADDU_VV:
612 case RISCV::VWADDU_VX:
613 case RISCV::VWSUBU_VV:
614 case RISCV::VWSUBU_VX:
615 case RISCV::VWADD_VV:
616 case RISCV::VWADD_VX:
617 case RISCV::VWSUB_VV:
618 case RISCV::VWSUB_VX:
619 // Vector Widening Integer Multiply Instructions
620 // Destination EEW=2*SEW. Source EEW=SEW.
621 case RISCV::VWMUL_VV:
622 case RISCV::VWMUL_VX:
623 case RISCV::VWMULSU_VV:
624 case RISCV::VWMULSU_VX:
625 case RISCV::VWMULU_VV:
626 case RISCV::VWMULU_VX:
627 // Vector Widening Integer Multiply-Add Instructions
628 // Destination EEW=2*SEW. Source EEW=SEW.
629 // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which
630 // is then added to the 2*SEW-bit Dest. These instructions never have a
631 // passthru operand.
632 case RISCV::VWMACCU_VV:
633 case RISCV::VWMACCU_VX:
634 case RISCV::VWMACC_VV:
635 case RISCV::VWMACC_VX:
636 case RISCV::VWMACCSU_VV:
637 case RISCV::VWMACCSU_VX:
638 case RISCV::VWMACCUS_VX:
639 // Vector Widening Floating-Point Fused Multiply-Add Instructions
640 case RISCV::VFWMACC_VF:
641 case RISCV::VFWMACC_VV:
642 case RISCV::VFWNMACC_VF:
643 case RISCV::VFWNMACC_VV:
644 case RISCV::VFWMSAC_VF:
645 case RISCV::VFWMSAC_VV:
646 case RISCV::VFWNMSAC_VF:
647 case RISCV::VFWNMSAC_VV:
648 case RISCV::VFWMACCBF16_VV:
649 case RISCV::VFWMACCBF16_VF:
650 // Vector Widening Floating-Point Add/Subtract Instructions
651 // Dest EEW=2*SEW. Source EEW=SEW.
652 case RISCV::VFWADD_VV:
653 case RISCV::VFWADD_VF:
654 case RISCV::VFWSUB_VV:
655 case RISCV::VFWSUB_VF:
656 // Vector Widening Floating-Point Multiply
657 case RISCV::VFWMUL_VF:
658 case RISCV::VFWMUL_VV:
659 // Widening Floating-Point/Integer Type-Convert Instructions
660 case RISCV::VFWCVT_XU_F_V:
661 case RISCV::VFWCVT_X_F_V:
662 case RISCV::VFWCVT_RTZ_XU_F_V:
663 case RISCV::VFWCVT_RTZ_X_F_V:
664 case RISCV::VFWCVT_F_XU_V:
665 case RISCV::VFWCVT_F_X_V:
666 case RISCV::VFWCVT_F_F_V:
667 case RISCV::VFWCVTBF16_F_F_V:
668 // Zvabd
669 case RISCV::VWABDA_VV:
670 case RISCV::VWABDAU_VV:
671 return IsMODef ? MILog2SEW + 1 : MILog2SEW;
672
673 // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW.
674 case RISCV::VWADDU_WV:
675 case RISCV::VWADDU_WX:
676 case RISCV::VWSUBU_WV:
677 case RISCV::VWSUBU_WX:
678 case RISCV::VWADD_WV:
679 case RISCV::VWADD_WX:
680 case RISCV::VWSUB_WV:
681 case RISCV::VWSUB_WX:
682 // Vector Widening Floating-Point Add/Subtract Instructions
683 case RISCV::VFWADD_WF:
684 case RISCV::VFWADD_WV:
685 case RISCV::VFWSUB_WF:
686 case RISCV::VFWSUB_WV: {
687 bool IsOp1 = (HasPassthru && !IsTied) ? MO.getOperandNo() == 2
688 : MO.getOperandNo() == 1;
689 bool TwoTimes = IsMODef || IsOp1;
690 return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
691 }
692
693 // Vector Integer Extension
694 case RISCV::VZEXT_VF2:
695 case RISCV::VSEXT_VF2:
696 return getIntegerExtensionOperandEEW(Factor: 2, MI, MO);
697 case RISCV::VZEXT_VF4:
698 case RISCV::VSEXT_VF4:
699 return getIntegerExtensionOperandEEW(Factor: 4, MI, MO);
700 case RISCV::VZEXT_VF8:
701 case RISCV::VSEXT_VF8:
702 return getIntegerExtensionOperandEEW(Factor: 8, MI, MO);
703
704 // Vector Narrowing Integer Right Shift Instructions
705 // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW
706 case RISCV::VNSRL_WX:
707 case RISCV::VNSRL_WI:
708 case RISCV::VNSRL_WV:
709 case RISCV::VNSRA_WI:
710 case RISCV::VNSRA_WV:
711 case RISCV::VNSRA_WX:
712 // Vector Narrowing Fixed-Point Clip Instructions
713 // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW.
714 case RISCV::VNCLIPU_WI:
715 case RISCV::VNCLIPU_WV:
716 case RISCV::VNCLIPU_WX:
717 case RISCV::VNCLIP_WI:
718 case RISCV::VNCLIP_WV:
719 case RISCV::VNCLIP_WX:
720 // Narrowing Floating-Point/Integer Type-Convert Instructions
721 case RISCV::VFNCVT_XU_F_W:
722 case RISCV::VFNCVT_X_F_W:
723 case RISCV::VFNCVT_RTZ_XU_F_W:
724 case RISCV::VFNCVT_RTZ_X_F_W:
725 case RISCV::VFNCVT_F_XU_W:
726 case RISCV::VFNCVT_F_X_W:
727 case RISCV::VFNCVT_F_F_W:
728 case RISCV::VFNCVT_ROD_F_F_W:
729 case RISCV::VFNCVTBF16_F_F_W: {
730 assert(!IsTied);
731 bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
732 bool TwoTimes = IsOp1;
733 return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
734 }
735
736 // Vector Mask Instructions
737 // Vector Mask-Register Logical Instructions
738 // vmsbf.m set-before-first mask bit
739 // vmsif.m set-including-first mask bit
740 // vmsof.m set-only-first mask bit
741 // EEW=1
742 // We handle the cases when operand is a v0 mask operand above the switch,
743 // but these instructions may use non-v0 mask operands and need to be handled
744 // specifically.
745 case RISCV::VMAND_MM:
746 case RISCV::VMNAND_MM:
747 case RISCV::VMANDN_MM:
748 case RISCV::VMXOR_MM:
749 case RISCV::VMOR_MM:
750 case RISCV::VMNOR_MM:
751 case RISCV::VMORN_MM:
752 case RISCV::VMXNOR_MM:
753 case RISCV::VMSBF_M:
754 case RISCV::VMSIF_M:
755 case RISCV::VMSOF_M: {
756 return MILog2SEW;
757 }
758
759 // Vector Compress Instruction
760 // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
761 // before this switch.
762 case RISCV::VCOMPRESS_VM:
763 return MO.getOperandNo() == 3 ? 0 : MILog2SEW;
764
765 // Vector Iota Instruction
766 // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
767 // before this switch.
768 case RISCV::VIOTA_M: {
769 if (IsMODef || MO.getOperandNo() == 1)
770 return MILog2SEW;
771 return 0;
772 }
773
774 // Vector Integer Compare Instructions
775 // Dest EEW=1. Source EEW=SEW.
776 case RISCV::VMSEQ_VI:
777 case RISCV::VMSEQ_VV:
778 case RISCV::VMSEQ_VX:
779 case RISCV::VMSNE_VI:
780 case RISCV::VMSNE_VV:
781 case RISCV::VMSNE_VX:
782 case RISCV::VMSLTU_VV:
783 case RISCV::VMSLTU_VX:
784 case RISCV::VMSLT_VV:
785 case RISCV::VMSLT_VX:
786 case RISCV::VMSLEU_VV:
787 case RISCV::VMSLEU_VI:
788 case RISCV::VMSLEU_VX:
789 case RISCV::VMSLE_VV:
790 case RISCV::VMSLE_VI:
791 case RISCV::VMSLE_VX:
792 case RISCV::VMSGTU_VI:
793 case RISCV::VMSGTU_VX:
794 case RISCV::VMSGT_VI:
795 case RISCV::VMSGT_VX:
796 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
797 // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch.
798 case RISCV::VMADC_VIM:
799 case RISCV::VMADC_VVM:
800 case RISCV::VMADC_VXM:
801 case RISCV::VMSBC_VVM:
802 case RISCV::VMSBC_VXM:
803 // Dest EEW=1. Source EEW=SEW.
804 case RISCV::VMADC_VV:
805 case RISCV::VMADC_VI:
806 case RISCV::VMADC_VX:
807 case RISCV::VMSBC_VV:
808 case RISCV::VMSBC_VX:
809 // 13.13. Vector Floating-Point Compare Instructions
810 // Dest EEW=1. Source EEW=SEW
811 case RISCV::VMFEQ_VF:
812 case RISCV::VMFEQ_VV:
813 case RISCV::VMFNE_VF:
814 case RISCV::VMFNE_VV:
815 case RISCV::VMFLT_VF:
816 case RISCV::VMFLT_VV:
817 case RISCV::VMFLE_VF:
818 case RISCV::VMFLE_VV:
819 case RISCV::VMFGT_VF:
820 case RISCV::VMFGE_VF: {
821 if (IsMODef)
822 return 0;
823 return MILog2SEW;
824 }
825
826 // Vector Reduction Operations
827 // Vector Single-Width Integer Reduction Instructions
828 case RISCV::VREDAND_VS:
829 case RISCV::VREDMAX_VS:
830 case RISCV::VREDMAXU_VS:
831 case RISCV::VREDMIN_VS:
832 case RISCV::VREDMINU_VS:
833 case RISCV::VREDOR_VS:
834 case RISCV::VREDSUM_VS:
835 case RISCV::VREDXOR_VS:
836 // Vector Single-Width Floating-Point Reduction Instructions
837 case RISCV::VFREDMAX_VS:
838 case RISCV::VFREDMIN_VS:
839 case RISCV::VFREDOSUM_VS:
840 case RISCV::VFREDUSUM_VS: {
841 return MILog2SEW;
842 }
843
844 // Vector Widening Integer Reduction Instructions
845 // The Dest and VS1 read only element 0 for the vector register. Return
846 // 2*EEW for these. VS2 has EEW=SEW and EMUL=LMUL.
847 case RISCV::VWREDSUM_VS:
848 case RISCV::VWREDSUMU_VS:
849 // Vector Widening Floating-Point Reduction Instructions
850 case RISCV::VFWREDOSUM_VS:
851 case RISCV::VFWREDUSUM_VS: {
852 bool TwoTimes = IsMODef || MO.getOperandNo() == 3;
853 return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
854 }
855
856 // Vector Register Gather with 16-bit Index Elements Instruction
857 // Dest and source data EEW=SEW. Index vector EEW=16.
858 case RISCV::VRGATHEREI16_VV: {
859 if (MO.getOperandNo() == 2)
860 return 4;
861 return MILog2SEW;
862 }
863
864 default:
865 return std::nullopt;
866 }
867}
868
869static std::optional<OperandInfo> getOperandInfo(const MachineOperand &MO) {
870 const MachineInstr &MI = *MO.getParent();
871 const RISCVVPseudosTable::PseudoInfo *RVV =
872 RISCVVPseudosTable::getPseudoInfo(Pseudo: MI.getOpcode());
873 assert(RVV && "Could not find MI in PseudoTable");
874
875 std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO);
876 if (!Log2EEW)
877 return std::nullopt;
878
879 switch (RVV->BaseInstr) {
880 // Vector Reduction Operations
881 // Vector Single-Width Integer Reduction Instructions
882 // Vector Widening Integer Reduction Instructions
883 // Vector Widening Floating-Point Reduction Instructions
884 // The Dest and VS1 only read element 0 of the vector register. Return just
885 // the EEW for these.
886 case RISCV::VREDAND_VS:
887 case RISCV::VREDMAX_VS:
888 case RISCV::VREDMAXU_VS:
889 case RISCV::VREDMIN_VS:
890 case RISCV::VREDMINU_VS:
891 case RISCV::VREDOR_VS:
892 case RISCV::VREDSUM_VS:
893 case RISCV::VREDXOR_VS:
894 case RISCV::VWREDSUM_VS:
895 case RISCV::VWREDSUMU_VS:
896 case RISCV::VFWREDOSUM_VS:
897 case RISCV::VFWREDUSUM_VS:
898 if (MO.getOperandNo() != 2)
899 return OperandInfo(*Log2EEW);
900 break;
901 };
902
903 // All others have EMUL=EEW/SEW*LMUL
904 return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW: *Log2EEW, MI), *Log2EEW);
905}
906
907static bool isTupleInsertInstr(const MachineInstr &MI);
908
909/// Return true if we can reason about demanded VLs elementwise for \p MI.
910bool RISCVVLOptimizer::isSupportedInstr(const MachineInstr &MI) const {
911 if (MI.isPHI() || MI.isFullCopy() || isTupleInsertInstr(MI))
912 return true;
913
914 unsigned RVVOpc = RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode());
915 if (!RVVOpc)
916 return false;
917
918 assert(!(MI.getNumExplicitDefs() == 0 && !MI.mayStore() &&
919 !RISCVII::elementsDependOnVL(TII->get(RVVOpc).TSFlags)) &&
920 "No defs but elements don't depend on VL?");
921
922 // TODO: Reduce vl for vmv.s.x and vfmv.s.f. Currently this introduces more vl
923 // toggles, we need to extend PRE in RISCVInsertVSETVLI first.
924 if (RVVOpc == RISCV::VMV_S_X || RVVOpc == RISCV::VFMV_S_F)
925 return false;
926
927 if (RISCVII::elementsDependOnVL(TSFlags: TII->get(Opcode: RVVOpc).TSFlags))
928 return false;
929
930 if (MI.mayStore())
931 return false;
932
933 return true;
934}
935
936/// Return true if MO is a vector operand but is used as a scalar operand.
937static bool isVectorOpUsedAsScalarOp(const MachineOperand &MO) {
938 const MachineInstr *MI = MO.getParent();
939 const RISCVVPseudosTable::PseudoInfo *RVV =
940 RISCVVPseudosTable::getPseudoInfo(Pseudo: MI->getOpcode());
941
942 if (!RVV)
943 return false;
944
945 switch (RVV->BaseInstr) {
946 // Reductions only use vs1[0] of vs1
947 case RISCV::VREDAND_VS:
948 case RISCV::VREDMAX_VS:
949 case RISCV::VREDMAXU_VS:
950 case RISCV::VREDMIN_VS:
951 case RISCV::VREDMINU_VS:
952 case RISCV::VREDOR_VS:
953 case RISCV::VREDSUM_VS:
954 case RISCV::VREDXOR_VS:
955 case RISCV::VWREDSUM_VS:
956 case RISCV::VWREDSUMU_VS:
957 case RISCV::VFREDMAX_VS:
958 case RISCV::VFREDMIN_VS:
959 case RISCV::VFREDOSUM_VS:
960 case RISCV::VFREDUSUM_VS:
961 case RISCV::VFWREDOSUM_VS:
962 case RISCV::VFWREDUSUM_VS:
963 return MO.getOperandNo() == 3;
964 case RISCV::VMV_X_S:
965 case RISCV::VFMV_F_S:
966 return MO.getOperandNo() == 1;
967 default:
968 return false;
969 }
970}
971
972bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
973 const MCInstrDesc &Desc = MI.getDesc();
974 if (!RISCVII::hasVLOp(TSFlags: Desc.TSFlags) || !RISCVII::hasSEWOp(TSFlags: Desc.TSFlags))
975 return false;
976
977 if (MI.getNumExplicitDefs() != 1)
978 return false;
979
980 // Some instructions have implicit defs e.g. $vxsat. If they might be read
981 // later then we can't reduce VL.
982 if (!MI.allImplicitDefsAreDead()) {
983 LLVM_DEBUG(dbgs() << "Not a candidate because has non-dead implicit def\n");
984 return false;
985 }
986
987 if (MI.mayRaiseFPException()) {
988 LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n");
989 return false;
990 }
991
992 for (const MachineMemOperand *MMO : MI.memoperands()) {
993 if (MMO->isVolatile()) {
994 LLVM_DEBUG(dbgs() << "Not a candidate because contains volatile MMO\n");
995 return false;
996 }
997 }
998
999 if (!isSupportedInstr(MI)) {
1000 LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction: "
1001 << MI);
1002 return false;
1003 }
1004
1005 assert(!RISCVII::elementsDependOnVL(
1006 TII->get(RISCV::getRVVMCOpcode(MI.getOpcode())).TSFlags) &&
1007 "Instruction shouldn't be supported if elements depend on VL");
1008
1009 assert(RISCVRI::isVRegClass(
1010 MRI->getRegClass(MI.getOperand(0).getReg())->TSFlags) &&
1011 "All supported instructions produce a vector register result");
1012
1013 LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n");
1014 return true;
1015}
1016
1017/// Given a vslidedown.vx like:
1018///
1019/// %slideamt = ADDI %x, -1
1020/// %v = PseudoVSLIDEDOWN_VX %passthru, %src, %slideamt, avl=1
1021///
1022/// %v will only read the first %slideamt + 1 lanes of %src, which = %x.
1023/// This is a common case when lowering extractelement.
1024///
1025/// Note that if %x is 0, %slideamt will be all ones. In this case %src will be
1026/// completely slid down and none of its lanes will be read (since %slideamt is
1027/// greater than the largest VLMAX of 65536) so we can demand any minimum VL.
1028static std::optional<DemandedVL>
1029getMinimumVLForVSLIDEDOWN_VX(const MachineOperand &UserOp,
1030 const MachineRegisterInfo *MRI) {
1031 const MachineInstr &MI = *UserOp.getParent();
1032 if (RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode()) != RISCV::VSLIDEDOWN_VX)
1033 return std::nullopt;
1034 // We're looking at what lanes are used from the src operand.
1035 if (UserOp.getOperandNo() != 2)
1036 return std::nullopt;
1037 // For now, the AVL must be 1.
1038 const MachineOperand &AVL = MI.getOperand(i: 4);
1039 if (!AVL.isImm() || AVL.getImm() != 1)
1040 return std::nullopt;
1041 // The slide amount must be %x - 1.
1042 const MachineOperand &SlideAmt = MI.getOperand(i: 3);
1043 if (!SlideAmt.getReg().isVirtual())
1044 return std::nullopt;
1045 MachineInstr *SlideAmtDef = MRI->getUniqueVRegDef(Reg: SlideAmt.getReg());
1046 if (SlideAmtDef->getOpcode() != RISCV::ADDI ||
1047 SlideAmtDef->getOperand(i: 2).getImm() != -AVL.getImm() ||
1048 !SlideAmtDef->getOperand(i: 1).getReg().isVirtual())
1049 return std::nullopt;
1050 return SlideAmtDef->getOperand(i: 1);
1051}
1052
1053DemandedVL
1054RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
1055 const MachineInstr &UserMI = *UserOp.getParent();
1056 const MCInstrDesc &Desc = UserMI.getDesc();
1057
1058 if (UserMI.isPHI() || UserMI.isFullCopy() || isTupleInsertInstr(MI: UserMI))
1059 return DemandedVLs.lookup(Key: &UserMI);
1060
1061 if (!RISCVII::hasVLOp(TSFlags: Desc.TSFlags) || !RISCVII::hasSEWOp(TSFlags: Desc.TSFlags)) {
1062 LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
1063 " use VLMAX\n");
1064 return DemandedVL::vlmax();
1065 }
1066
1067 if (auto VL = getMinimumVLForVSLIDEDOWN_VX(UserOp, MRI))
1068 return *VL;
1069
1070 if (RISCVII::readsPastVL(
1071 TSFlags: TII->get(Opcode: RISCV::getRVVMCOpcode(RVVPseudoOpcode: UserMI.getOpcode())).TSFlags)) {
1072 LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
1073 return DemandedVL::vlmax();
1074 }
1075
1076 unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
1077 const MachineOperand &VLOp = UserMI.getOperand(i: VLOpNum);
1078 // Looking for an immediate or a register VL that isn't X0.
1079 assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
1080 "Did not expect X0 VL");
1081
1082 // If the user is a passthru it will read the elements past VL, so
1083 // abort if any of the elements past VL are demanded.
1084 if (UserOp.isTied()) {
1085 assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() &&
1086 RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc()));
1087 if (!RISCV::isVLKnownLE(LHS: DemandedVLs.lookup(Key: &UserMI).VL, RHS: VLOp)) {
1088 LLVM_DEBUG(dbgs() << " Abort because user is passthru in "
1089 "instruction with demanded tail\n");
1090 return DemandedVL::vlmax();
1091 }
1092 }
1093
1094 // Instructions like reductions may use a vector register as a scalar
1095 // register. In this case, we should treat it as only reading the first lane.
1096 if (isVectorOpUsedAsScalarOp(MO: UserOp)) {
1097 LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n");
1098 return MachineOperand::CreateImm(Val: 1);
1099 }
1100
1101 // If we know the demanded VL of UserMI, then we can reduce the VL it
1102 // requires.
1103 if (RISCV::isVLKnownLE(LHS: DemandedVLs.lookup(Key: &UserMI).VL, RHS: VLOp))
1104 return DemandedVLs.lookup(Key: &UserMI);
1105
1106 return VLOp;
1107}
1108
1109/// Return true if MI is an instruction used for assembling registers
1110/// for segmented store instructions, namely, RISCVISD::TUPLE_INSERT.
1111/// Currently it's lowered to INSERT_SUBREG.
1112static bool isTupleInsertInstr(const MachineInstr &MI) {
1113 if (!MI.isInsertSubreg())
1114 return false;
1115
1116 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1117 const TargetRegisterClass *DstRC = MRI.getRegClass(Reg: MI.getOperand(i: 0).getReg());
1118 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
1119 if (!RISCVRI::isVRegClass(TSFlags: DstRC->TSFlags))
1120 return false;
1121 unsigned NF = RISCVRI::getNF(TSFlags: DstRC->TSFlags);
1122 if (NF < 2)
1123 return false;
1124
1125 // Check whether INSERT_SUBREG has the correct subreg index for tuple inserts.
1126 auto VLMul = RISCVRI::getLMul(TSFlags: DstRC->TSFlags);
1127 unsigned SubRegIdx = MI.getOperand(i: 3).getImm();
1128 [[maybe_unused]] auto [LMul, IsFractional] = RISCVVType::decodeVLMUL(VLMul);
1129 assert(!IsFractional && "unexpected LMUL for tuple register classes");
1130 return TRI->getSubRegIdxSize(Idx: SubRegIdx) == RISCV::RVVBitsPerBlock * LMul;
1131}
1132
1133static bool isSegmentedStoreInstr(const MachineInstr &MI) {
1134 switch (RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode())) {
1135 case VSSEG_CASES(8):
1136 case VSSSEG_CASES(8):
1137 case VSUXSEG_CASES(8):
1138 case VSOXSEG_CASES(8):
1139 case VSSEG_CASES(16):
1140 case VSSSEG_CASES(16):
1141 case VSUXSEG_CASES(16):
1142 case VSOXSEG_CASES(16):
1143 case VSSEG_CASES(32):
1144 case VSSSEG_CASES(32):
1145 case VSUXSEG_CASES(32):
1146 case VSOXSEG_CASES(32):
1147 case VSSEG_CASES(64):
1148 case VSSSEG_CASES(64):
1149 case VSUXSEG_CASES(64):
1150 case VSOXSEG_CASES(64):
1151 return true;
1152 default:
1153 return false;
1154 }
1155}
1156
1157bool RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
1158 if (MI.isPHI() || MI.isFullCopy() || isTupleInsertInstr(MI))
1159 return true;
1160
1161 SmallSetVector<MachineOperand *, 8> OpWorklist;
1162 SmallPtrSet<const MachineInstr *, 4> PHISeen;
1163 for (auto &UserOp : MRI->use_operands(Reg: MI.getOperand(i: 0).getReg()))
1164 OpWorklist.insert(X: &UserOp);
1165
1166 while (!OpWorklist.empty()) {
1167 MachineOperand &UserOp = *OpWorklist.pop_back_val();
1168 const MachineInstr &UserMI = *UserOp.getParent();
1169 LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
1170
1171 if (UserMI.isFullCopy() && UserMI.getOperand(i: 0).getReg().isVirtual()) {
1172 LLVM_DEBUG(dbgs() << " Peeking through uses of COPY\n");
1173 OpWorklist.insert_range(R: llvm::make_pointer_range(
1174 Range: MRI->use_operands(Reg: UserMI.getOperand(i: 0).getReg())));
1175 continue;
1176 }
1177
1178 if (isTupleInsertInstr(MI: UserMI)) {
1179 LLVM_DEBUG(dbgs().indent(4) << "Peeking through uses of INSERT_SUBREG\n");
1180 for (MachineOperand &UseOp :
1181 MRI->use_operands(Reg: UserMI.getOperand(i: 0).getReg())) {
1182 const MachineInstr &CandidateMI = *UseOp.getParent();
1183 // We should not propagate the VL if the user is not a segmented store
1184 // or another INSERT_SUBREG, since VL just works differently
1185 // between segmented operations (per-field) v.s. other RVV ops (on the
1186 // whole register group).
1187 if (!isTupleInsertInstr(MI: CandidateMI) &&
1188 !isSegmentedStoreInstr(MI: CandidateMI))
1189 return false;
1190 OpWorklist.insert(X: &UseOp);
1191 }
1192 continue;
1193 }
1194
1195 if (UserMI.isPHI()) {
1196 // Don't follow PHI cycles
1197 if (!PHISeen.insert(Ptr: &UserMI).second)
1198 continue;
1199 LLVM_DEBUG(dbgs() << " Peeking through uses of PHI\n");
1200 OpWorklist.insert_range(R: llvm::make_pointer_range(
1201 Range: MRI->use_operands(Reg: UserMI.getOperand(i: 0).getReg())));
1202 continue;
1203 }
1204
1205 if (!RISCVII::hasSEWOp(TSFlags: UserMI.getDesc().TSFlags)) {
1206 LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
1207 return false;
1208 }
1209
1210 std::optional<OperandInfo> ConsumerInfo = getOperandInfo(MO: UserOp);
1211 std::optional<OperandInfo> ProducerInfo = getOperandInfo(MO: MI.getOperand(i: 0));
1212 if (!ConsumerInfo || !ProducerInfo) {
1213 LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
1214 LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
1215 LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1216 return false;
1217 }
1218
1219 if (!OperandInfo::areCompatible(Def: *ProducerInfo, User: *ConsumerInfo)) {
1220 LLVM_DEBUG(
1221 dbgs()
1222 << " Abort due to incompatible information for EMUL or EEW.\n");
1223 LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
1224 LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1225 return false;
1226 }
1227 }
1228
1229 return true;
1230}
1231
1232bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI,
1233 MachineOperand CommonVL) const {
1234 LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI);
1235
1236 unsigned VLOpNum = RISCVII::getVLOpNum(Desc: MI.getDesc());
1237 MachineOperand &VLOp = MI.getOperand(i: VLOpNum);
1238
1239 assert((CommonVL.isImm() || CommonVL.getReg().isVirtual()) &&
1240 "Expected VL to be an Imm or virtual Reg");
1241
1242 // If the VL is defined by a vleff that doesn't dominate MI, try using the
1243 // vleff's AVL. It will be greater than or equal to the output VL.
1244 if (CommonVL.isReg()) {
1245 const MachineInstr *VLMI = MRI->getVRegDef(Reg: CommonVL.getReg());
1246 if (RISCVInstrInfo::isFaultOnlyFirstLoad(MI: *VLMI) &&
1247 !MDT->dominates(A: VLMI, B: &MI))
1248 CommonVL = VLMI->getOperand(i: RISCVII::getVLOpNum(Desc: VLMI->getDesc()));
1249 }
1250
1251 if (!RISCV::isVLKnownLE(LHS: CommonVL, RHS: VLOp)) {
1252 LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
1253 return false;
1254 }
1255
1256 if (CommonVL.isIdenticalTo(Other: VLOp)) {
1257 LLVM_DEBUG(
1258 dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
1259 return false;
1260 }
1261
1262 if (CommonVL.isImm()) {
1263 LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1264 << CommonVL.getImm() << " for " << MI << "\n");
1265 VLOp.ChangeToImmediate(ImmVal: CommonVL.getImm());
1266 return true;
1267 }
1268 MachineInstr *VLMI = MRI->getVRegDef(Reg: CommonVL.getReg());
1269 auto VLDominates = [this, &VLMI](const MachineInstr &MI) {
1270 return MDT->dominates(A: VLMI, B: &MI);
1271 };
1272 if (!VLDominates(MI)) {
1273 assert(MI.getNumExplicitDefs() == 1);
1274 auto Uses = MRI->use_instructions(Reg: MI.getOperand(i: 0).getReg());
1275 auto UsesSameBB = make_filter_range(Range&: Uses, Pred: [&MI](const MachineInstr &Use) {
1276 return Use.getParent() == MI.getParent();
1277 });
1278 if (VLMI->getParent() == MI.getParent() &&
1279 all_of(Range&: UsesSameBB, P: VLDominates) &&
1280 RISCVInstrInfo::isSafeToMove(From: MI, To: *VLMI->getNextNode())) {
1281 MI.moveBefore(MovePos: VLMI->getNextNode());
1282 } else {
1283 LLVM_DEBUG(dbgs() << " Abort due to VL not dominating.\n");
1284 return false;
1285 }
1286 }
1287 LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1288 << printReg(CommonVL.getReg(), MRI->getTargetRegisterInfo())
1289 << " for " << MI << "\n");
1290
1291 // All our checks passed. We can reduce VL.
1292 VLOp.ChangeToRegister(Reg: CommonVL.getReg(), isDef: false);
1293 MRI->constrainRegClass(Reg: CommonVL.getReg(), RC: &RISCV::GPRNoX0RegClass);
1294 return true;
1295}
1296
1297static bool isPhysical(const MachineOperand &MO) {
1298 return MO.isReg() && MO.getReg().isPhysical();
1299}
1300
1301/// Look through \p MI's operands and propagate what it demands to its uses.
1302void RISCVVLOptimizer::transfer(const MachineInstr &MI) {
1303 if (!isSupportedInstr(MI) || !checkUsers(MI) || any_of(Range: MI.defs(), P: isPhysical))
1304 DemandedVLs[&MI] = DemandedVL::vlmax();
1305
1306 for (const MachineOperand &MO : virtual_vec_uses(MI)) {
1307 const MachineInstr *Def = MRI->getVRegDef(Reg: MO.getReg());
1308 DemandedVL Prev = DemandedVLs[Def];
1309 DemandedVLs[Def] = DemandedVLs[Def].max(X: getMinimumVLForUser(UserOp: MO));
1310 if (DemandedVLs[Def] != Prev)
1311 Worklist.insert(X: Def);
1312 }
1313}
1314
1315bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
1316 if (skipFunction(F: MF.getFunction()))
1317 return false;
1318
1319 MRI = &MF.getRegInfo();
1320 MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
1321
1322 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
1323 if (!ST.hasVInstructions())
1324 return false;
1325
1326 TII = ST.getInstrInfo();
1327
1328 assert(DemandedVLs.empty());
1329
1330 // For each instruction that defines a vector, propagate the VL it
1331 // uses to its inputs.
1332 for (MachineBasicBlock *MBB : post_order(G: &MF)) {
1333 assert(MDT->isReachableFromEntry(MBB));
1334 for (MachineInstr &MI : reverse(C&: *MBB))
1335 if (!MI.isDebugInstr())
1336 Worklist.insert(X: &MI);
1337 }
1338
1339 while (!Worklist.empty()) {
1340 const MachineInstr *MI = Worklist.front();
1341 Worklist.remove(X: MI);
1342 transfer(MI: *MI);
1343 }
1344
1345 // Then go through and see if we can reduce the VL of any instructions to
1346 // only what's demanded.
1347 bool MadeChange = false;
1348 for (auto &[MI, VL] : DemandedVLs) {
1349 assert(MDT->isReachableFromEntry(MI->getParent()));
1350 if (!isCandidate(MI: *MI))
1351 continue;
1352 if (!tryReduceVL(MI&: *const_cast<MachineInstr *>(MI), CommonVL: VL.VL))
1353 continue;
1354 MadeChange = true;
1355 }
1356
1357 DemandedVLs.clear();
1358 return MadeChange;
1359}
1360