1//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// This pass reduces the VL where possible at the MI level, before VSETVLI
10// instructions are inserted.
11//
12// The purpose of this optimization is to make the VL argument, for instructions
13// that have a VL argument, as small as possible. This is implemented by
14// visiting each instruction in reverse order and checking that if it has a VL
15// argument, whether the VL can be reduced.
16//
17//===---------------------------------------------------------------------===//
18
19#include "RISCV.h"
20#include "RISCVSubtarget.h"
21#include "llvm/ADT/PostOrderIterator.h"
22#include "llvm/CodeGen/MachineDominators.h"
23#include "llvm/CodeGen/MachineFunctionPass.h"
24#include "llvm/InitializePasses.h"
25
26using namespace llvm;
27
28#define DEBUG_TYPE "riscv-vl-optimizer"
29#define PASS_NAME "RISC-V VL Optimizer"
30
31namespace {
32
33class RISCVVLOptimizer : public MachineFunctionPass {
34 const MachineRegisterInfo *MRI;
35 const MachineDominatorTree *MDT;
36
37public:
38 static char ID;
39
40 RISCVVLOptimizer() : MachineFunctionPass(ID) {}
41
42 bool runOnMachineFunction(MachineFunction &MF) override;
43
44 void getAnalysisUsage(AnalysisUsage &AU) const override {
45 AU.setPreservesCFG();
46 AU.addRequired<MachineDominatorTreeWrapperPass>();
47 MachineFunctionPass::getAnalysisUsage(AU);
48 }
49
50 StringRef getPassName() const override { return PASS_NAME; }
51
52private:
53 std::optional<MachineOperand>
54 getMinimumVLForUser(const MachineOperand &UserOp) const;
55 /// Returns the largest common VL MachineOperand that may be used to optimize
56 /// MI. Returns std::nullopt if it failed to find a suitable VL.
57 std::optional<MachineOperand> checkUsers(const MachineInstr &MI) const;
58 bool tryReduceVL(MachineInstr &MI) const;
59 bool isCandidate(const MachineInstr &MI) const;
60
61 /// For a given instruction, records what elements of it are demanded by
62 /// downstream users.
63 DenseMap<const MachineInstr *, std::optional<MachineOperand>> DemandedVLs;
64};
65
66/// Represents the EMUL and EEW of a MachineOperand.
67struct OperandInfo {
68 // Represent as 1,2,4,8, ... and fractional indicator. This is because
69 // EMUL can take on values that don't map to RISCVVType::VLMUL values exactly.
70 // For example, a mask operand can have an EMUL less than MF8.
71 std::optional<std::pair<unsigned, bool>> EMUL;
72
73 unsigned Log2EEW;
74
75 OperandInfo(RISCVVType::VLMUL EMUL, unsigned Log2EEW)
76 : EMUL(RISCVVType::decodeVLMUL(VLMul: EMUL)), Log2EEW(Log2EEW) {}
77
78 OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
79 : EMUL(EMUL), Log2EEW(Log2EEW) {}
80
81 OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {}
82
83 OperandInfo() = delete;
84
85 static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
86 return A.Log2EEW == B.Log2EEW && A.EMUL == B.EMUL;
87 }
88
89 static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
90 return A.Log2EEW == B.Log2EEW;
91 }
92
93 void print(raw_ostream &OS) const {
94 if (EMUL) {
95 OS << "EMUL: m";
96 if (EMUL->second)
97 OS << "f";
98 OS << EMUL->first;
99 } else
100 OS << "EMUL: unknown\n";
101 OS << ", EEW: " << (1 << Log2EEW);
102 }
103};
104
105} // end anonymous namespace
106
107char RISCVVLOptimizer::ID = 0;
108INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
109INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
110INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
111
112FunctionPass *llvm::createRISCVVLOptimizerPass() {
113 return new RISCVVLOptimizer();
114}
115
116/// Return true if R is a physical or virtual vector register, false otherwise.
117static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) {
118 if (R.isPhysical())
119 return RISCV::VRRegClass.contains(Reg: R);
120 const TargetRegisterClass *RC = MRI->getRegClass(Reg: R);
121 return RISCVRI::isVRegClass(TSFlags: RC->TSFlags);
122}
123
124LLVM_ATTRIBUTE_UNUSED
125static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
126 OI.print(OS);
127 return OS;
128}
129
130LLVM_ATTRIBUTE_UNUSED
131static raw_ostream &operator<<(raw_ostream &OS,
132 const std::optional<OperandInfo> &OI) {
133 if (OI)
134 OI->print(OS);
135 else
136 OS << "nullopt";
137 return OS;
138}
139
140/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
141/// SEW are from the TSFlags of MI.
142static std::pair<unsigned, bool>
143getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
144 RISCVVType::VLMUL MIVLMUL = RISCVII::getLMul(TSFlags: MI.getDesc().TSFlags);
145 auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(VLMul: MIVLMUL);
146 unsigned MILog2SEW =
147 MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MI.getDesc())).getImm();
148
149 // Mask instructions will have 0 as the SEW operand. But the LMUL of these
150 // instructions is calculated is as if the SEW operand was 3 (e8).
151 if (MILog2SEW == 0)
152 MILog2SEW = 3;
153
154 unsigned MISEW = 1 << MILog2SEW;
155
156 unsigned EEW = 1 << Log2EEW;
157 // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
158 // to put fraction in simplest form.
159 unsigned Num = EEW, Denom = MISEW;
160 int GCD = MILMULIsFractional ? std::gcd(m: Num, n: Denom * MILMUL)
161 : std::gcd(m: Num * MILMUL, n: Denom);
162 Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
163 Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
164 return std::make_pair(x&: Num > Denom ? Num : Denom, y: Denom > Num);
165}
166
167/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
168/// SEW comes from TSFlags of MI.
169static unsigned getIntegerExtensionOperandEEW(unsigned Factor,
170 const MachineInstr &MI,
171 const MachineOperand &MO) {
172 unsigned MILog2SEW =
173 MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MI.getDesc())).getImm();
174
175 if (MO.getOperandNo() == 0)
176 return MILog2SEW;
177
178 unsigned MISEW = 1 << MILog2SEW;
179 unsigned EEW = MISEW / Factor;
180 unsigned Log2EEW = Log2_32(Value: EEW);
181
182 return Log2EEW;
183}
184
185/// Check whether MO is a mask operand of MI.
186static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO,
187 const MachineRegisterInfo *MRI) {
188
189 if (!MO.isReg() || !isVectorRegClass(R: MO.getReg(), MRI))
190 return false;
191
192 const MCInstrDesc &Desc = MI.getDesc();
193 return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID;
194}
195
196static std::optional<unsigned>
197getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
198 const MachineInstr &MI = *MO.getParent();
199 const RISCVVPseudosTable::PseudoInfo *RVV =
200 RISCVVPseudosTable::getPseudoInfo(Pseudo: MI.getOpcode());
201 assert(RVV && "Could not find MI in PseudoTable");
202
203 // MI has a SEW associated with it. The RVV specification defines
204 // the EEW of each operand and definition in relation to MI.SEW.
205 unsigned MILog2SEW =
206 MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MI.getDesc())).getImm();
207
208 const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc: MI.getDesc());
209 const bool IsTied = RISCVII::isTiedPseudo(TSFlags: MI.getDesc().TSFlags);
210
211 bool IsMODef = MO.getOperandNo() == 0 ||
212 (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs());
213
214 // All mask operands have EEW=1
215 if (isMaskOperand(MI, MO, MRI))
216 return 0;
217
218 // switch against BaseInstr to reduce number of cases that need to be
219 // considered.
220 switch (RVV->BaseInstr) {
221
222 // 6. Configuration-Setting Instructions
223 // Configuration setting instructions do not read or write vector registers
224 case RISCV::VSETIVLI:
225 case RISCV::VSETVL:
226 case RISCV::VSETVLI:
227 llvm_unreachable("Configuration setting instructions do not read or write "
228 "vector registers");
229
230 // Vector Loads and Stores
231 // Vector Unit-Stride Instructions
232 // Vector Strided Instructions
233 /// Dest EEW encoded in the instruction
234 case RISCV::VLM_V:
235 case RISCV::VSM_V:
236 return 0;
237 case RISCV::VLE8_V:
238 case RISCV::VSE8_V:
239 case RISCV::VLSE8_V:
240 case RISCV::VSSE8_V:
241 return 3;
242 case RISCV::VLE16_V:
243 case RISCV::VSE16_V:
244 case RISCV::VLSE16_V:
245 case RISCV::VSSE16_V:
246 return 4;
247 case RISCV::VLE32_V:
248 case RISCV::VSE32_V:
249 case RISCV::VLSE32_V:
250 case RISCV::VSSE32_V:
251 return 5;
252 case RISCV::VLE64_V:
253 case RISCV::VSE64_V:
254 case RISCV::VLSE64_V:
255 case RISCV::VSSE64_V:
256 return 6;
257
258 // Vector Indexed Instructions
259 // vs(o|u)xei<eew>.v
260 // Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>.
261 case RISCV::VLUXEI8_V:
262 case RISCV::VLOXEI8_V:
263 case RISCV::VSUXEI8_V:
264 case RISCV::VSOXEI8_V: {
265 if (MO.getOperandNo() == 0)
266 return MILog2SEW;
267 return 3;
268 }
269 case RISCV::VLUXEI16_V:
270 case RISCV::VLOXEI16_V:
271 case RISCV::VSUXEI16_V:
272 case RISCV::VSOXEI16_V: {
273 if (MO.getOperandNo() == 0)
274 return MILog2SEW;
275 return 4;
276 }
277 case RISCV::VLUXEI32_V:
278 case RISCV::VLOXEI32_V:
279 case RISCV::VSUXEI32_V:
280 case RISCV::VSOXEI32_V: {
281 if (MO.getOperandNo() == 0)
282 return MILog2SEW;
283 return 5;
284 }
285 case RISCV::VLUXEI64_V:
286 case RISCV::VLOXEI64_V:
287 case RISCV::VSUXEI64_V:
288 case RISCV::VSOXEI64_V: {
289 if (MO.getOperandNo() == 0)
290 return MILog2SEW;
291 return 6;
292 }
293
294 // Vector Integer Arithmetic Instructions
295 // Vector Single-Width Integer Add and Subtract
296 case RISCV::VADD_VI:
297 case RISCV::VADD_VV:
298 case RISCV::VADD_VX:
299 case RISCV::VSUB_VV:
300 case RISCV::VSUB_VX:
301 case RISCV::VRSUB_VI:
302 case RISCV::VRSUB_VX:
303 // Vector Bitwise Logical Instructions
304 // Vector Single-Width Shift Instructions
305 // EEW=SEW.
306 case RISCV::VAND_VI:
307 case RISCV::VAND_VV:
308 case RISCV::VAND_VX:
309 case RISCV::VOR_VI:
310 case RISCV::VOR_VV:
311 case RISCV::VOR_VX:
312 case RISCV::VXOR_VI:
313 case RISCV::VXOR_VV:
314 case RISCV::VXOR_VX:
315 case RISCV::VSLL_VI:
316 case RISCV::VSLL_VV:
317 case RISCV::VSLL_VX:
318 case RISCV::VSRL_VI:
319 case RISCV::VSRL_VV:
320 case RISCV::VSRL_VX:
321 case RISCV::VSRA_VI:
322 case RISCV::VSRA_VV:
323 case RISCV::VSRA_VX:
324 // Vector Integer Min/Max Instructions
325 // EEW=SEW.
326 case RISCV::VMINU_VV:
327 case RISCV::VMINU_VX:
328 case RISCV::VMIN_VV:
329 case RISCV::VMIN_VX:
330 case RISCV::VMAXU_VV:
331 case RISCV::VMAXU_VX:
332 case RISCV::VMAX_VV:
333 case RISCV::VMAX_VX:
334 // Vector Single-Width Integer Multiply Instructions
335 // Source and Dest EEW=SEW.
336 case RISCV::VMUL_VV:
337 case RISCV::VMUL_VX:
338 case RISCV::VMULH_VV:
339 case RISCV::VMULH_VX:
340 case RISCV::VMULHU_VV:
341 case RISCV::VMULHU_VX:
342 case RISCV::VMULHSU_VV:
343 case RISCV::VMULHSU_VX:
344 // Vector Integer Divide Instructions
345 // EEW=SEW.
346 case RISCV::VDIVU_VV:
347 case RISCV::VDIVU_VX:
348 case RISCV::VDIV_VV:
349 case RISCV::VDIV_VX:
350 case RISCV::VREMU_VV:
351 case RISCV::VREMU_VX:
352 case RISCV::VREM_VV:
353 case RISCV::VREM_VX:
354 // Vector Single-Width Integer Multiply-Add Instructions
355 // EEW=SEW.
356 case RISCV::VMACC_VV:
357 case RISCV::VMACC_VX:
358 case RISCV::VNMSAC_VV:
359 case RISCV::VNMSAC_VX:
360 case RISCV::VMADD_VV:
361 case RISCV::VMADD_VX:
362 case RISCV::VNMSUB_VV:
363 case RISCV::VNMSUB_VX:
364 // Vector Integer Merge Instructions
365 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
366 // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled
367 // before this switch.
368 case RISCV::VMERGE_VIM:
369 case RISCV::VMERGE_VVM:
370 case RISCV::VMERGE_VXM:
371 case RISCV::VADC_VIM:
372 case RISCV::VADC_VVM:
373 case RISCV::VADC_VXM:
374 case RISCV::VSBC_VVM:
375 case RISCV::VSBC_VXM:
376 // Vector Integer Move Instructions
377 // Vector Fixed-Point Arithmetic Instructions
378 // Vector Single-Width Saturating Add and Subtract
379 // Vector Single-Width Averaging Add and Subtract
380 // EEW=SEW.
381 case RISCV::VMV_V_I:
382 case RISCV::VMV_V_V:
383 case RISCV::VMV_V_X:
384 case RISCV::VSADDU_VI:
385 case RISCV::VSADDU_VV:
386 case RISCV::VSADDU_VX:
387 case RISCV::VSADD_VI:
388 case RISCV::VSADD_VV:
389 case RISCV::VSADD_VX:
390 case RISCV::VSSUBU_VV:
391 case RISCV::VSSUBU_VX:
392 case RISCV::VSSUB_VV:
393 case RISCV::VSSUB_VX:
394 case RISCV::VAADDU_VV:
395 case RISCV::VAADDU_VX:
396 case RISCV::VAADD_VV:
397 case RISCV::VAADD_VX:
398 case RISCV::VASUBU_VV:
399 case RISCV::VASUBU_VX:
400 case RISCV::VASUB_VV:
401 case RISCV::VASUB_VX:
402 // Vector Single-Width Fractional Multiply with Rounding and Saturation
403 // EEW=SEW. The instruction produces 2*SEW product internally but
404 // saturates to fit into SEW bits.
405 case RISCV::VSMUL_VV:
406 case RISCV::VSMUL_VX:
407 // Vector Single-Width Scaling Shift Instructions
408 // EEW=SEW.
409 case RISCV::VSSRL_VI:
410 case RISCV::VSSRL_VV:
411 case RISCV::VSSRL_VX:
412 case RISCV::VSSRA_VI:
413 case RISCV::VSSRA_VV:
414 case RISCV::VSSRA_VX:
415 // Vector Permutation Instructions
416 // Integer Scalar Move Instructions
417 // Floating-Point Scalar Move Instructions
418 // EEW=SEW.
419 case RISCV::VMV_X_S:
420 case RISCV::VMV_S_X:
421 case RISCV::VFMV_F_S:
422 case RISCV::VFMV_S_F:
423 // Vector Slide Instructions
424 // EEW=SEW.
425 case RISCV::VSLIDEUP_VI:
426 case RISCV::VSLIDEUP_VX:
427 case RISCV::VSLIDEDOWN_VI:
428 case RISCV::VSLIDEDOWN_VX:
429 case RISCV::VSLIDE1UP_VX:
430 case RISCV::VFSLIDE1UP_VF:
431 case RISCV::VSLIDE1DOWN_VX:
432 case RISCV::VFSLIDE1DOWN_VF:
433 // Vector Register Gather Instructions
434 // EEW=SEW. For mask operand, EEW=1.
435 case RISCV::VRGATHER_VI:
436 case RISCV::VRGATHER_VV:
437 case RISCV::VRGATHER_VX:
438 // Vector Compress Instruction
439 // EEW=SEW.
440 case RISCV::VCOMPRESS_VM:
441 // Vector Element Index Instruction
442 case RISCV::VID_V:
443 // Vector Single-Width Floating-Point Add/Subtract Instructions
444 case RISCV::VFADD_VF:
445 case RISCV::VFADD_VV:
446 case RISCV::VFSUB_VF:
447 case RISCV::VFSUB_VV:
448 case RISCV::VFRSUB_VF:
449 // Vector Single-Width Floating-Point Multiply/Divide Instructions
450 case RISCV::VFMUL_VF:
451 case RISCV::VFMUL_VV:
452 case RISCV::VFDIV_VF:
453 case RISCV::VFDIV_VV:
454 case RISCV::VFRDIV_VF:
455 // Vector Single-Width Floating-Point Fused Multiply-Add Instructions
456 case RISCV::VFMACC_VV:
457 case RISCV::VFMACC_VF:
458 case RISCV::VFNMACC_VV:
459 case RISCV::VFNMACC_VF:
460 case RISCV::VFMSAC_VV:
461 case RISCV::VFMSAC_VF:
462 case RISCV::VFNMSAC_VV:
463 case RISCV::VFNMSAC_VF:
464 case RISCV::VFMADD_VV:
465 case RISCV::VFMADD_VF:
466 case RISCV::VFNMADD_VV:
467 case RISCV::VFNMADD_VF:
468 case RISCV::VFMSUB_VV:
469 case RISCV::VFMSUB_VF:
470 case RISCV::VFNMSUB_VV:
471 case RISCV::VFNMSUB_VF:
472 // Vector Floating-Point Square-Root Instruction
473 case RISCV::VFSQRT_V:
474 // Vector Floating-Point Reciprocal Square-Root Estimate Instruction
475 case RISCV::VFRSQRT7_V:
476 // Vector Floating-Point Reciprocal Estimate Instruction
477 case RISCV::VFREC7_V:
478 // Vector Floating-Point MIN/MAX Instructions
479 case RISCV::VFMIN_VF:
480 case RISCV::VFMIN_VV:
481 case RISCV::VFMAX_VF:
482 case RISCV::VFMAX_VV:
483 // Vector Floating-Point Sign-Injection Instructions
484 case RISCV::VFSGNJ_VF:
485 case RISCV::VFSGNJ_VV:
486 case RISCV::VFSGNJN_VV:
487 case RISCV::VFSGNJN_VF:
488 case RISCV::VFSGNJX_VF:
489 case RISCV::VFSGNJX_VV:
490 // Vector Floating-Point Classify Instruction
491 case RISCV::VFCLASS_V:
492 // Vector Floating-Point Move Instruction
493 case RISCV::VFMV_V_F:
494 // Single-Width Floating-Point/Integer Type-Convert Instructions
495 case RISCV::VFCVT_XU_F_V:
496 case RISCV::VFCVT_X_F_V:
497 case RISCV::VFCVT_RTZ_XU_F_V:
498 case RISCV::VFCVT_RTZ_X_F_V:
499 case RISCV::VFCVT_F_XU_V:
500 case RISCV::VFCVT_F_X_V:
501 // Vector Floating-Point Merge Instruction
502 case RISCV::VFMERGE_VFM:
503 // Vector count population in mask vcpop.m
504 // vfirst find-first-set mask bit
505 case RISCV::VCPOP_M:
506 case RISCV::VFIRST_M:
507 return MILog2SEW;
508
509 // Vector Widening Integer Add/Subtract
510 // Def uses EEW=2*SEW . Operands use EEW=SEW.
511 case RISCV::VWADDU_VV:
512 case RISCV::VWADDU_VX:
513 case RISCV::VWSUBU_VV:
514 case RISCV::VWSUBU_VX:
515 case RISCV::VWADD_VV:
516 case RISCV::VWADD_VX:
517 case RISCV::VWSUB_VV:
518 case RISCV::VWSUB_VX:
519 case RISCV::VWSLL_VI:
520 case RISCV::VWSLL_VX:
521 case RISCV::VWSLL_VV:
522 // Vector Widening Integer Multiply Instructions
523 // Destination EEW=2*SEW. Source EEW=SEW.
524 case RISCV::VWMUL_VV:
525 case RISCV::VWMUL_VX:
526 case RISCV::VWMULSU_VV:
527 case RISCV::VWMULSU_VX:
528 case RISCV::VWMULU_VV:
529 case RISCV::VWMULU_VX:
530 // Vector Widening Integer Multiply-Add Instructions
531 // Destination EEW=2*SEW. Source EEW=SEW.
532 // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which
533 // is then added to the 2*SEW-bit Dest. These instructions never have a
534 // passthru operand.
535 case RISCV::VWMACCU_VV:
536 case RISCV::VWMACCU_VX:
537 case RISCV::VWMACC_VV:
538 case RISCV::VWMACC_VX:
539 case RISCV::VWMACCSU_VV:
540 case RISCV::VWMACCSU_VX:
541 case RISCV::VWMACCUS_VX:
542 // Vector Widening Floating-Point Fused Multiply-Add Instructions
543 case RISCV::VFWMACC_VF:
544 case RISCV::VFWMACC_VV:
545 case RISCV::VFWNMACC_VF:
546 case RISCV::VFWNMACC_VV:
547 case RISCV::VFWMSAC_VF:
548 case RISCV::VFWMSAC_VV:
549 case RISCV::VFWNMSAC_VF:
550 case RISCV::VFWNMSAC_VV:
551 case RISCV::VFWMACCBF16_VV:
552 case RISCV::VFWMACCBF16_VF:
553 // Vector Widening Floating-Point Add/Subtract Instructions
554 // Dest EEW=2*SEW. Source EEW=SEW.
555 case RISCV::VFWADD_VV:
556 case RISCV::VFWADD_VF:
557 case RISCV::VFWSUB_VV:
558 case RISCV::VFWSUB_VF:
559 // Vector Widening Floating-Point Multiply
560 case RISCV::VFWMUL_VF:
561 case RISCV::VFWMUL_VV:
562 // Widening Floating-Point/Integer Type-Convert Instructions
563 case RISCV::VFWCVT_XU_F_V:
564 case RISCV::VFWCVT_X_F_V:
565 case RISCV::VFWCVT_RTZ_XU_F_V:
566 case RISCV::VFWCVT_RTZ_X_F_V:
567 case RISCV::VFWCVT_F_XU_V:
568 case RISCV::VFWCVT_F_X_V:
569 case RISCV::VFWCVT_F_F_V:
570 case RISCV::VFWCVTBF16_F_F_V:
571 return IsMODef ? MILog2SEW + 1 : MILog2SEW;
572
573 // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW.
574 case RISCV::VWADDU_WV:
575 case RISCV::VWADDU_WX:
576 case RISCV::VWSUBU_WV:
577 case RISCV::VWSUBU_WX:
578 case RISCV::VWADD_WV:
579 case RISCV::VWADD_WX:
580 case RISCV::VWSUB_WV:
581 case RISCV::VWSUB_WX:
582 // Vector Widening Floating-Point Add/Subtract Instructions
583 case RISCV::VFWADD_WF:
584 case RISCV::VFWADD_WV:
585 case RISCV::VFWSUB_WF:
586 case RISCV::VFWSUB_WV: {
587 bool IsOp1 = (HasPassthru && !IsTied) ? MO.getOperandNo() == 2
588 : MO.getOperandNo() == 1;
589 bool TwoTimes = IsMODef || IsOp1;
590 return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
591 }
592
593 // Vector Integer Extension
594 case RISCV::VZEXT_VF2:
595 case RISCV::VSEXT_VF2:
596 return getIntegerExtensionOperandEEW(Factor: 2, MI, MO);
597 case RISCV::VZEXT_VF4:
598 case RISCV::VSEXT_VF4:
599 return getIntegerExtensionOperandEEW(Factor: 4, MI, MO);
600 case RISCV::VZEXT_VF8:
601 case RISCV::VSEXT_VF8:
602 return getIntegerExtensionOperandEEW(Factor: 8, MI, MO);
603
604 // Vector Narrowing Integer Right Shift Instructions
605 // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW
606 case RISCV::VNSRL_WX:
607 case RISCV::VNSRL_WI:
608 case RISCV::VNSRL_WV:
609 case RISCV::VNSRA_WI:
610 case RISCV::VNSRA_WV:
611 case RISCV::VNSRA_WX:
612 // Vector Narrowing Fixed-Point Clip Instructions
613 // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW.
614 case RISCV::VNCLIPU_WI:
615 case RISCV::VNCLIPU_WV:
616 case RISCV::VNCLIPU_WX:
617 case RISCV::VNCLIP_WI:
618 case RISCV::VNCLIP_WV:
619 case RISCV::VNCLIP_WX:
620 // Narrowing Floating-Point/Integer Type-Convert Instructions
621 case RISCV::VFNCVT_XU_F_W:
622 case RISCV::VFNCVT_X_F_W:
623 case RISCV::VFNCVT_RTZ_XU_F_W:
624 case RISCV::VFNCVT_RTZ_X_F_W:
625 case RISCV::VFNCVT_F_XU_W:
626 case RISCV::VFNCVT_F_X_W:
627 case RISCV::VFNCVT_F_F_W:
628 case RISCV::VFNCVT_ROD_F_F_W:
629 case RISCV::VFNCVTBF16_F_F_W: {
630 assert(!IsTied);
631 bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
632 bool TwoTimes = IsOp1;
633 return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
634 }
635
636 // Vector Mask Instructions
637 // Vector Mask-Register Logical Instructions
638 // vmsbf.m set-before-first mask bit
639 // vmsif.m set-including-first mask bit
640 // vmsof.m set-only-first mask bit
641 // EEW=1
642 // We handle the cases when operand is a v0 mask operand above the switch,
643 // but these instructions may use non-v0 mask operands and need to be handled
644 // specifically.
645 case RISCV::VMAND_MM:
646 case RISCV::VMNAND_MM:
647 case RISCV::VMANDN_MM:
648 case RISCV::VMXOR_MM:
649 case RISCV::VMOR_MM:
650 case RISCV::VMNOR_MM:
651 case RISCV::VMORN_MM:
652 case RISCV::VMXNOR_MM:
653 case RISCV::VMSBF_M:
654 case RISCV::VMSIF_M:
655 case RISCV::VMSOF_M: {
656 return MILog2SEW;
657 }
658
659 // Vector Iota Instruction
660 // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
661 // before this switch.
662 case RISCV::VIOTA_M: {
663 if (IsMODef || MO.getOperandNo() == 1)
664 return MILog2SEW;
665 return 0;
666 }
667
668 // Vector Integer Compare Instructions
669 // Dest EEW=1. Source EEW=SEW.
670 case RISCV::VMSEQ_VI:
671 case RISCV::VMSEQ_VV:
672 case RISCV::VMSEQ_VX:
673 case RISCV::VMSNE_VI:
674 case RISCV::VMSNE_VV:
675 case RISCV::VMSNE_VX:
676 case RISCV::VMSLTU_VV:
677 case RISCV::VMSLTU_VX:
678 case RISCV::VMSLT_VV:
679 case RISCV::VMSLT_VX:
680 case RISCV::VMSLEU_VV:
681 case RISCV::VMSLEU_VI:
682 case RISCV::VMSLEU_VX:
683 case RISCV::VMSLE_VV:
684 case RISCV::VMSLE_VI:
685 case RISCV::VMSLE_VX:
686 case RISCV::VMSGTU_VI:
687 case RISCV::VMSGTU_VX:
688 case RISCV::VMSGT_VI:
689 case RISCV::VMSGT_VX:
690 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
691 // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch.
692 case RISCV::VMADC_VIM:
693 case RISCV::VMADC_VVM:
694 case RISCV::VMADC_VXM:
695 case RISCV::VMSBC_VVM:
696 case RISCV::VMSBC_VXM:
697 // Dest EEW=1. Source EEW=SEW.
698 case RISCV::VMADC_VV:
699 case RISCV::VMADC_VI:
700 case RISCV::VMADC_VX:
701 case RISCV::VMSBC_VV:
702 case RISCV::VMSBC_VX:
703 // 13.13. Vector Floating-Point Compare Instructions
704 // Dest EEW=1. Source EEW=SEW
705 case RISCV::VMFEQ_VF:
706 case RISCV::VMFEQ_VV:
707 case RISCV::VMFNE_VF:
708 case RISCV::VMFNE_VV:
709 case RISCV::VMFLT_VF:
710 case RISCV::VMFLT_VV:
711 case RISCV::VMFLE_VF:
712 case RISCV::VMFLE_VV:
713 case RISCV::VMFGT_VF:
714 case RISCV::VMFGE_VF: {
715 if (IsMODef)
716 return 0;
717 return MILog2SEW;
718 }
719
720 // Vector Reduction Operations
721 // Vector Single-Width Integer Reduction Instructions
722 case RISCV::VREDAND_VS:
723 case RISCV::VREDMAX_VS:
724 case RISCV::VREDMAXU_VS:
725 case RISCV::VREDMIN_VS:
726 case RISCV::VREDMINU_VS:
727 case RISCV::VREDOR_VS:
728 case RISCV::VREDSUM_VS:
729 case RISCV::VREDXOR_VS:
730 // Vector Single-Width Floating-Point Reduction Instructions
731 case RISCV::VFREDMAX_VS:
732 case RISCV::VFREDMIN_VS:
733 case RISCV::VFREDOSUM_VS:
734 case RISCV::VFREDUSUM_VS: {
735 return MILog2SEW;
736 }
737
738 // Vector Widening Integer Reduction Instructions
739 // The Dest and VS1 read only element 0 for the vector register. Return
740 // 2*EEW for these. VS2 has EEW=SEW and EMUL=LMUL.
741 case RISCV::VWREDSUM_VS:
742 case RISCV::VWREDSUMU_VS:
743 // Vector Widening Floating-Point Reduction Instructions
744 case RISCV::VFWREDOSUM_VS:
745 case RISCV::VFWREDUSUM_VS: {
746 bool TwoTimes = IsMODef || MO.getOperandNo() == 3;
747 return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
748 }
749
750 default:
751 return std::nullopt;
752 }
753}
754
755static std::optional<OperandInfo>
756getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
757 const MachineInstr &MI = *MO.getParent();
758 const RISCVVPseudosTable::PseudoInfo *RVV =
759 RISCVVPseudosTable::getPseudoInfo(Pseudo: MI.getOpcode());
760 assert(RVV && "Could not find MI in PseudoTable");
761
762 std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI);
763 if (!Log2EEW)
764 return std::nullopt;
765
766 switch (RVV->BaseInstr) {
767 // Vector Reduction Operations
768 // Vector Single-Width Integer Reduction Instructions
769 // Vector Widening Integer Reduction Instructions
770 // Vector Widening Floating-Point Reduction Instructions
771 // The Dest and VS1 only read element 0 of the vector register. Return just
772 // the EEW for these.
773 case RISCV::VREDAND_VS:
774 case RISCV::VREDMAX_VS:
775 case RISCV::VREDMAXU_VS:
776 case RISCV::VREDMIN_VS:
777 case RISCV::VREDMINU_VS:
778 case RISCV::VREDOR_VS:
779 case RISCV::VREDSUM_VS:
780 case RISCV::VREDXOR_VS:
781 case RISCV::VWREDSUM_VS:
782 case RISCV::VWREDSUMU_VS:
783 case RISCV::VFWREDOSUM_VS:
784 case RISCV::VFWREDUSUM_VS:
785 if (MO.getOperandNo() != 2)
786 return OperandInfo(*Log2EEW);
787 break;
788 };
789
790 // All others have EMUL=EEW/SEW*LMUL
791 return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW: *Log2EEW, MI), *Log2EEW);
792}
793
794/// Return true if this optimization should consider MI for VL reduction. This
795/// white-list approach simplifies this optimization for instructions that may
796/// have more complex semantics with relation to how it uses VL.
797static bool isSupportedInstr(const MachineInstr &MI) {
798 const RISCVVPseudosTable::PseudoInfo *RVV =
799 RISCVVPseudosTable::getPseudoInfo(Pseudo: MI.getOpcode());
800
801 if (!RVV)
802 return false;
803
804 switch (RVV->BaseInstr) {
805 // Vector Unit-Stride Instructions
806 // Vector Strided Instructions
807 case RISCV::VLM_V:
808 case RISCV::VLE8_V:
809 case RISCV::VLSE8_V:
810 case RISCV::VLE16_V:
811 case RISCV::VLSE16_V:
812 case RISCV::VLE32_V:
813 case RISCV::VLSE32_V:
814 case RISCV::VLE64_V:
815 case RISCV::VLSE64_V:
816 // Vector Indexed Instructions
817 case RISCV::VLUXEI8_V:
818 case RISCV::VLOXEI8_V:
819 case RISCV::VLUXEI16_V:
820 case RISCV::VLOXEI16_V:
821 case RISCV::VLUXEI32_V:
822 case RISCV::VLOXEI32_V:
823 case RISCV::VLUXEI64_V:
824 case RISCV::VLOXEI64_V: {
825 for (const MachineMemOperand *MMO : MI.memoperands())
826 if (MMO->isVolatile())
827 return false;
828 return true;
829 }
830
831 // Vector Single-Width Integer Add and Subtract
832 case RISCV::VADD_VI:
833 case RISCV::VADD_VV:
834 case RISCV::VADD_VX:
835 case RISCV::VSUB_VV:
836 case RISCV::VSUB_VX:
837 case RISCV::VRSUB_VI:
838 case RISCV::VRSUB_VX:
839 // Vector Bitwise Logical Instructions
840 // Vector Single-Width Shift Instructions
841 case RISCV::VAND_VI:
842 case RISCV::VAND_VV:
843 case RISCV::VAND_VX:
844 case RISCV::VOR_VI:
845 case RISCV::VOR_VV:
846 case RISCV::VOR_VX:
847 case RISCV::VXOR_VI:
848 case RISCV::VXOR_VV:
849 case RISCV::VXOR_VX:
850 case RISCV::VSLL_VI:
851 case RISCV::VSLL_VV:
852 case RISCV::VSLL_VX:
853 case RISCV::VSRL_VI:
854 case RISCV::VSRL_VV:
855 case RISCV::VSRL_VX:
856 case RISCV::VSRA_VI:
857 case RISCV::VSRA_VV:
858 case RISCV::VSRA_VX:
859 // Vector Widening Integer Add/Subtract
860 case RISCV::VWADDU_VV:
861 case RISCV::VWADDU_VX:
862 case RISCV::VWSUBU_VV:
863 case RISCV::VWSUBU_VX:
864 case RISCV::VWADD_VV:
865 case RISCV::VWADD_VX:
866 case RISCV::VWSUB_VV:
867 case RISCV::VWSUB_VX:
868 case RISCV::VWADDU_WV:
869 case RISCV::VWADDU_WX:
870 case RISCV::VWSUBU_WV:
871 case RISCV::VWSUBU_WX:
872 case RISCV::VWADD_WV:
873 case RISCV::VWADD_WX:
874 case RISCV::VWSUB_WV:
875 case RISCV::VWSUB_WX:
876 // Vector Integer Extension
877 case RISCV::VZEXT_VF2:
878 case RISCV::VSEXT_VF2:
879 case RISCV::VZEXT_VF4:
880 case RISCV::VSEXT_VF4:
881 case RISCV::VZEXT_VF8:
882 case RISCV::VSEXT_VF8:
883 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
884 // FIXME: Add support
885 case RISCV::VMADC_VV:
886 case RISCV::VMADC_VI:
887 case RISCV::VMADC_VX:
888 case RISCV::VMSBC_VV:
889 case RISCV::VMSBC_VX:
890 // Vector Narrowing Integer Right Shift Instructions
891 case RISCV::VNSRL_WX:
892 case RISCV::VNSRL_WI:
893 case RISCV::VNSRL_WV:
894 case RISCV::VNSRA_WI:
895 case RISCV::VNSRA_WV:
896 case RISCV::VNSRA_WX:
897 // Vector Integer Compare Instructions
898 case RISCV::VMSEQ_VI:
899 case RISCV::VMSEQ_VV:
900 case RISCV::VMSEQ_VX:
901 case RISCV::VMSNE_VI:
902 case RISCV::VMSNE_VV:
903 case RISCV::VMSNE_VX:
904 case RISCV::VMSLTU_VV:
905 case RISCV::VMSLTU_VX:
906 case RISCV::VMSLT_VV:
907 case RISCV::VMSLT_VX:
908 case RISCV::VMSLEU_VV:
909 case RISCV::VMSLEU_VI:
910 case RISCV::VMSLEU_VX:
911 case RISCV::VMSLE_VV:
912 case RISCV::VMSLE_VI:
913 case RISCV::VMSLE_VX:
914 case RISCV::VMSGTU_VI:
915 case RISCV::VMSGTU_VX:
916 case RISCV::VMSGT_VI:
917 case RISCV::VMSGT_VX:
918 // Vector Integer Min/Max Instructions
919 case RISCV::VMINU_VV:
920 case RISCV::VMINU_VX:
921 case RISCV::VMIN_VV:
922 case RISCV::VMIN_VX:
923 case RISCV::VMAXU_VV:
924 case RISCV::VMAXU_VX:
925 case RISCV::VMAX_VV:
926 case RISCV::VMAX_VX:
927 // Vector Single-Width Integer Multiply Instructions
928 case RISCV::VMUL_VV:
929 case RISCV::VMUL_VX:
930 case RISCV::VMULH_VV:
931 case RISCV::VMULH_VX:
932 case RISCV::VMULHU_VV:
933 case RISCV::VMULHU_VX:
934 case RISCV::VMULHSU_VV:
935 case RISCV::VMULHSU_VX:
936 // Vector Integer Divide Instructions
937 case RISCV::VDIVU_VV:
938 case RISCV::VDIVU_VX:
939 case RISCV::VDIV_VV:
940 case RISCV::VDIV_VX:
941 case RISCV::VREMU_VV:
942 case RISCV::VREMU_VX:
943 case RISCV::VREM_VV:
944 case RISCV::VREM_VX:
945 // Vector Widening Integer Multiply Instructions
946 case RISCV::VWMUL_VV:
947 case RISCV::VWMUL_VX:
948 case RISCV::VWMULSU_VV:
949 case RISCV::VWMULSU_VX:
950 case RISCV::VWMULU_VV:
951 case RISCV::VWMULU_VX:
952 // Vector Single-Width Integer Multiply-Add Instructions
953 case RISCV::VMACC_VV:
954 case RISCV::VMACC_VX:
955 case RISCV::VNMSAC_VV:
956 case RISCV::VNMSAC_VX:
957 case RISCV::VMADD_VV:
958 case RISCV::VMADD_VX:
959 case RISCV::VNMSUB_VV:
960 case RISCV::VNMSUB_VX:
961 // Vector Integer Merge Instructions
962 case RISCV::VMERGE_VIM:
963 case RISCV::VMERGE_VVM:
964 case RISCV::VMERGE_VXM:
965 // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
966 case RISCV::VADC_VIM:
967 case RISCV::VADC_VVM:
968 case RISCV::VADC_VXM:
969 // Vector Widening Integer Multiply-Add Instructions
970 case RISCV::VWMACCU_VV:
971 case RISCV::VWMACCU_VX:
972 case RISCV::VWMACC_VV:
973 case RISCV::VWMACC_VX:
974 case RISCV::VWMACCSU_VV:
975 case RISCV::VWMACCSU_VX:
976 case RISCV::VWMACCUS_VX:
977 // Vector Integer Merge Instructions
978 // FIXME: Add support
979 // Vector Integer Move Instructions
980 // FIXME: Add support
981 case RISCV::VMV_V_I:
982 case RISCV::VMV_V_X:
983 case RISCV::VMV_V_V:
984 // Vector Single-Width Saturating Add and Subtract
985 case RISCV::VSADDU_VV:
986 case RISCV::VSADDU_VX:
987 case RISCV::VSADDU_VI:
988 case RISCV::VSADD_VV:
989 case RISCV::VSADD_VX:
990 case RISCV::VSADD_VI:
991 case RISCV::VSSUBU_VV:
992 case RISCV::VSSUBU_VX:
993 case RISCV::VSSUB_VV:
994 case RISCV::VSSUB_VX:
995 // Vector Single-Width Averaging Add and Subtract
996 case RISCV::VAADDU_VV:
997 case RISCV::VAADDU_VX:
998 case RISCV::VAADD_VV:
999 case RISCV::VAADD_VX:
1000 case RISCV::VASUBU_VV:
1001 case RISCV::VASUBU_VX:
1002 case RISCV::VASUB_VV:
1003 case RISCV::VASUB_VX:
1004 // Vector Single-Width Fractional Multiply with Rounding and Saturation
1005 case RISCV::VSMUL_VV:
1006 case RISCV::VSMUL_VX:
1007 // Vector Single-Width Scaling Shift Instructions
1008 case RISCV::VSSRL_VV:
1009 case RISCV::VSSRL_VX:
1010 case RISCV::VSSRL_VI:
1011 case RISCV::VSSRA_VV:
1012 case RISCV::VSSRA_VX:
1013 case RISCV::VSSRA_VI:
1014 // Vector Narrowing Fixed-Point Clip Instructions
1015 case RISCV::VNCLIPU_WV:
1016 case RISCV::VNCLIPU_WX:
1017 case RISCV::VNCLIPU_WI:
1018 case RISCV::VNCLIP_WV:
1019 case RISCV::VNCLIP_WX:
1020 case RISCV::VNCLIP_WI:
1021
1022 // Vector Crypto
1023 case RISCV::VWSLL_VI:
1024 case RISCV::VWSLL_VX:
1025 case RISCV::VWSLL_VV:
1026
1027 // Vector Mask Instructions
1028 // Vector Mask-Register Logical Instructions
1029 // vmsbf.m set-before-first mask bit
1030 // vmsif.m set-including-first mask bit
1031 // vmsof.m set-only-first mask bit
1032 // Vector Iota Instruction
1033 // Vector Element Index Instruction
1034 case RISCV::VMAND_MM:
1035 case RISCV::VMNAND_MM:
1036 case RISCV::VMANDN_MM:
1037 case RISCV::VMXOR_MM:
1038 case RISCV::VMOR_MM:
1039 case RISCV::VMNOR_MM:
1040 case RISCV::VMORN_MM:
1041 case RISCV::VMXNOR_MM:
1042 case RISCV::VMSBF_M:
1043 case RISCV::VMSIF_M:
1044 case RISCV::VMSOF_M:
1045 case RISCV::VIOTA_M:
1046 case RISCV::VID_V:
1047 // Vector Slide Instructions
1048 case RISCV::VSLIDEUP_VX:
1049 case RISCV::VSLIDEUP_VI:
1050 case RISCV::VSLIDEDOWN_VX:
1051 case RISCV::VSLIDEDOWN_VI:
1052 case RISCV::VSLIDE1UP_VX:
1053 case RISCV::VFSLIDE1UP_VF:
1054 // Vector Single-Width Floating-Point Add/Subtract Instructions
1055 case RISCV::VFADD_VF:
1056 case RISCV::VFADD_VV:
1057 case RISCV::VFSUB_VF:
1058 case RISCV::VFSUB_VV:
1059 case RISCV::VFRSUB_VF:
1060 // Vector Widening Floating-Point Add/Subtract Instructions
1061 case RISCV::VFWADD_VV:
1062 case RISCV::VFWADD_VF:
1063 case RISCV::VFWSUB_VV:
1064 case RISCV::VFWSUB_VF:
1065 case RISCV::VFWADD_WF:
1066 case RISCV::VFWADD_WV:
1067 case RISCV::VFWSUB_WF:
1068 case RISCV::VFWSUB_WV:
1069 // Vector Single-Width Floating-Point Multiply/Divide Instructions
1070 case RISCV::VFMUL_VF:
1071 case RISCV::VFMUL_VV:
1072 case RISCV::VFDIV_VF:
1073 case RISCV::VFDIV_VV:
1074 case RISCV::VFRDIV_VF:
1075 // Vector Widening Floating-Point Multiply
1076 case RISCV::VFWMUL_VF:
1077 case RISCV::VFWMUL_VV:
1078 // Vector Single-Width Floating-Point Fused Multiply-Add Instructions
1079 case RISCV::VFMACC_VV:
1080 case RISCV::VFMACC_VF:
1081 case RISCV::VFNMACC_VV:
1082 case RISCV::VFNMACC_VF:
1083 case RISCV::VFMSAC_VV:
1084 case RISCV::VFMSAC_VF:
1085 case RISCV::VFNMSAC_VV:
1086 case RISCV::VFNMSAC_VF:
1087 case RISCV::VFMADD_VV:
1088 case RISCV::VFMADD_VF:
1089 case RISCV::VFNMADD_VV:
1090 case RISCV::VFNMADD_VF:
1091 case RISCV::VFMSUB_VV:
1092 case RISCV::VFMSUB_VF:
1093 case RISCV::VFNMSUB_VV:
1094 case RISCV::VFNMSUB_VF:
1095 // Vector Widening Floating-Point Fused Multiply-Add Instructions
1096 case RISCV::VFWMACC_VV:
1097 case RISCV::VFWMACC_VF:
1098 case RISCV::VFWNMACC_VV:
1099 case RISCV::VFWNMACC_VF:
1100 case RISCV::VFWMSAC_VV:
1101 case RISCV::VFWMSAC_VF:
1102 case RISCV::VFWNMSAC_VV:
1103 case RISCV::VFWNMSAC_VF:
1104 case RISCV::VFWMACCBF16_VV:
1105 case RISCV::VFWMACCBF16_VF:
1106 // Vector Floating-Point Square-Root Instruction
1107 case RISCV::VFSQRT_V:
1108 // Vector Floating-Point Reciprocal Square-Root Estimate Instruction
1109 case RISCV::VFRSQRT7_V:
1110 // Vector Floating-Point MIN/MAX Instructions
1111 case RISCV::VFMIN_VF:
1112 case RISCV::VFMIN_VV:
1113 case RISCV::VFMAX_VF:
1114 case RISCV::VFMAX_VV:
1115 // Vector Floating-Point Sign-Injection Instructions
1116 case RISCV::VFSGNJ_VF:
1117 case RISCV::VFSGNJ_VV:
1118 case RISCV::VFSGNJN_VV:
1119 case RISCV::VFSGNJN_VF:
1120 case RISCV::VFSGNJX_VF:
1121 case RISCV::VFSGNJX_VV:
1122 // Vector Floating-Point Compare Instructions
1123 case RISCV::VMFEQ_VF:
1124 case RISCV::VMFEQ_VV:
1125 case RISCV::VMFNE_VF:
1126 case RISCV::VMFNE_VV:
1127 case RISCV::VMFLT_VF:
1128 case RISCV::VMFLT_VV:
1129 case RISCV::VMFLE_VF:
1130 case RISCV::VMFLE_VV:
1131 case RISCV::VMFGT_VF:
1132 case RISCV::VMFGE_VF:
1133 // Vector Floating-Point Merge Instruction
1134 case RISCV::VFMERGE_VFM:
1135 // Vector Floating-Point Move Instruction
1136 case RISCV::VFMV_V_F:
1137 // Single-Width Floating-Point/Integer Type-Convert Instructions
1138 case RISCV::VFCVT_XU_F_V:
1139 case RISCV::VFCVT_X_F_V:
1140 case RISCV::VFCVT_RTZ_XU_F_V:
1141 case RISCV::VFCVT_RTZ_X_F_V:
1142 case RISCV::VFCVT_F_XU_V:
1143 case RISCV::VFCVT_F_X_V:
1144 // Widening Floating-Point/Integer Type-Convert Instructions
1145 case RISCV::VFWCVT_XU_F_V:
1146 case RISCV::VFWCVT_X_F_V:
1147 case RISCV::VFWCVT_RTZ_XU_F_V:
1148 case RISCV::VFWCVT_RTZ_X_F_V:
1149 case RISCV::VFWCVT_F_XU_V:
1150 case RISCV::VFWCVT_F_X_V:
1151 case RISCV::VFWCVT_F_F_V:
1152 case RISCV::VFWCVTBF16_F_F_V:
1153 // Narrowing Floating-Point/Integer Type-Convert Instructions
1154 case RISCV::VFNCVT_XU_F_W:
1155 case RISCV::VFNCVT_X_F_W:
1156 case RISCV::VFNCVT_RTZ_XU_F_W:
1157 case RISCV::VFNCVT_RTZ_X_F_W:
1158 case RISCV::VFNCVT_F_XU_W:
1159 case RISCV::VFNCVT_F_X_W:
1160 case RISCV::VFNCVT_F_F_W:
1161 case RISCV::VFNCVT_ROD_F_F_W:
1162 case RISCV::VFNCVTBF16_F_F_W:
1163 return true;
1164 }
1165
1166 return false;
1167}
1168
1169/// Return true if MO is a vector operand but is used as a scalar operand.
1170static bool isVectorOpUsedAsScalarOp(const MachineOperand &MO) {
1171 const MachineInstr *MI = MO.getParent();
1172 const RISCVVPseudosTable::PseudoInfo *RVV =
1173 RISCVVPseudosTable::getPseudoInfo(Pseudo: MI->getOpcode());
1174
1175 if (!RVV)
1176 return false;
1177
1178 switch (RVV->BaseInstr) {
1179 // Reductions only use vs1[0] of vs1
1180 case RISCV::VREDAND_VS:
1181 case RISCV::VREDMAX_VS:
1182 case RISCV::VREDMAXU_VS:
1183 case RISCV::VREDMIN_VS:
1184 case RISCV::VREDMINU_VS:
1185 case RISCV::VREDOR_VS:
1186 case RISCV::VREDSUM_VS:
1187 case RISCV::VREDXOR_VS:
1188 case RISCV::VWREDSUM_VS:
1189 case RISCV::VWREDSUMU_VS:
1190 case RISCV::VFREDMAX_VS:
1191 case RISCV::VFREDMIN_VS:
1192 case RISCV::VFREDOSUM_VS:
1193 case RISCV::VFREDUSUM_VS:
1194 case RISCV::VFWREDOSUM_VS:
1195 case RISCV::VFWREDUSUM_VS:
1196 return MO.getOperandNo() == 3;
1197 case RISCV::VMV_X_S:
1198 case RISCV::VFMV_F_S:
1199 return MO.getOperandNo() == 1;
1200 default:
1201 return false;
1202 }
1203}
1204
1205/// Return true if MI may read elements past VL.
1206static bool mayReadPastVL(const MachineInstr &MI) {
1207 const RISCVVPseudosTable::PseudoInfo *RVV =
1208 RISCVVPseudosTable::getPseudoInfo(Pseudo: MI.getOpcode());
1209 if (!RVV)
1210 return true;
1211
1212 switch (RVV->BaseInstr) {
1213 // vslidedown instructions may read elements past VL. They are handled
1214 // according to current tail policy.
1215 case RISCV::VSLIDEDOWN_VI:
1216 case RISCV::VSLIDEDOWN_VX:
1217 case RISCV::VSLIDE1DOWN_VX:
1218 case RISCV::VFSLIDE1DOWN_VF:
1219
1220 // vrgather instructions may read the source vector at any index < VLMAX,
1221 // regardless of VL.
1222 case RISCV::VRGATHER_VI:
1223 case RISCV::VRGATHER_VV:
1224 case RISCV::VRGATHER_VX:
1225 case RISCV::VRGATHEREI16_VV:
1226 return true;
1227
1228 default:
1229 return false;
1230 }
1231}
1232
1233bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
1234 const MCInstrDesc &Desc = MI.getDesc();
1235 if (!RISCVII::hasVLOp(TSFlags: Desc.TSFlags) || !RISCVII::hasSEWOp(TSFlags: Desc.TSFlags))
1236 return false;
1237
1238 if (MI.getNumExplicitDefs() != 1)
1239 return false;
1240
1241 // Some instructions have implicit defs e.g. $vxsat. If they might be read
1242 // later then we can't reduce VL.
1243 if (!MI.allImplicitDefsAreDead()) {
1244 LLVM_DEBUG(dbgs() << "Not a candidate because has non-dead implicit def\n");
1245 return false;
1246 }
1247
1248 if (MI.mayRaiseFPException()) {
1249 LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n");
1250 return false;
1251 }
1252
1253 // Some instructions that produce vectors have semantics that make it more
1254 // difficult to determine whether the VL can be reduced. For example, some
1255 // instructions, such as reductions, may write lanes past VL to a scalar
1256 // register. Other instructions, such as some loads or stores, may write
1257 // lower lanes using data from higher lanes. There may be other complex
1258 // semantics not mentioned here that make it hard to determine whether
1259 // the VL can be optimized. As a result, a white-list of supported
1260 // instructions is used. Over time, more instructions can be supported
1261 // upon careful examination of their semantics under the logic in this
1262 // optimization.
1263 // TODO: Use a better approach than a white-list, such as adding
1264 // properties to instructions using something like TSFlags.
1265 if (!isSupportedInstr(MI)) {
1266 LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n");
1267 return false;
1268 }
1269
1270 assert(!RISCVII::elementsDependOnVL(RISCV::getRVVMCOpcode(MI.getOpcode())) &&
1271 "Instruction shouldn't be supported if elements depend on VL");
1272
1273 assert(MI.getOperand(0).isReg() &&
1274 isVectorRegClass(MI.getOperand(0).getReg(), MRI) &&
1275 "All supported instructions produce a vector register result");
1276
1277 LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n");
1278 return true;
1279}
1280
1281std::optional<MachineOperand>
1282RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
1283 const MachineInstr &UserMI = *UserOp.getParent();
1284 const MCInstrDesc &Desc = UserMI.getDesc();
1285
1286 if (!RISCVII::hasVLOp(TSFlags: Desc.TSFlags) || !RISCVII::hasSEWOp(TSFlags: Desc.TSFlags)) {
1287 LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
1288 " use VLMAX\n");
1289 return std::nullopt;
1290 }
1291
1292 if (mayReadPastVL(MI: UserMI)) {
1293 LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
1294 return std::nullopt;
1295 }
1296
1297 unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
1298 const MachineOperand &VLOp = UserMI.getOperand(i: VLOpNum);
1299 // Looking for an immediate or a register VL that isn't X0.
1300 assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
1301 "Did not expect X0 VL");
1302
1303 // If the user is a passthru it will read the elements past VL, so
1304 // abort if any of the elements past VL are demanded.
1305 if (UserOp.isTied()) {
1306 assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() &&
1307 RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc()));
1308 auto DemandedVL = DemandedVLs.lookup(Val: &UserMI);
1309 if (!DemandedVL || !RISCV::isVLKnownLE(LHS: *DemandedVL, RHS: VLOp)) {
1310 LLVM_DEBUG(dbgs() << " Abort because user is passthru in "
1311 "instruction with demanded tail\n");
1312 return std::nullopt;
1313 }
1314 }
1315
1316 // Instructions like reductions may use a vector register as a scalar
1317 // register. In this case, we should treat it as only reading the first lane.
1318 if (isVectorOpUsedAsScalarOp(MO: UserOp)) {
1319 LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n");
1320 return MachineOperand::CreateImm(Val: 1);
1321 }
1322
1323 // If we know the demanded VL of UserMI, then we can reduce the VL it
1324 // requires.
1325 if (auto DemandedVL = DemandedVLs.lookup(Val: &UserMI)) {
1326 assert(isCandidate(UserMI));
1327 if (RISCV::isVLKnownLE(LHS: *DemandedVL, RHS: VLOp))
1328 return DemandedVL;
1329 }
1330
1331 return VLOp;
1332}
1333
1334std::optional<MachineOperand>
1335RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
1336 std::optional<MachineOperand> CommonVL;
1337 SmallSetVector<MachineOperand *, 8> Worklist;
1338 SmallPtrSet<const MachineInstr *, 4> PHISeen;
1339 for (auto &UserOp : MRI->use_operands(Reg: MI.getOperand(i: 0).getReg()))
1340 Worklist.insert(X: &UserOp);
1341
1342 while (!Worklist.empty()) {
1343 MachineOperand &UserOp = *Worklist.pop_back_val();
1344 const MachineInstr &UserMI = *UserOp.getParent();
1345 LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
1346
1347 if (UserMI.isCopy() && UserMI.getOperand(i: 0).getReg().isVirtual() &&
1348 UserMI.getOperand(i: 0).getSubReg() == RISCV::NoSubRegister &&
1349 UserMI.getOperand(i: 1).getSubReg() == RISCV::NoSubRegister) {
1350 LLVM_DEBUG(dbgs() << " Peeking through uses of COPY\n");
1351 Worklist.insert_range(R: llvm::make_pointer_range(
1352 Range: MRI->use_operands(Reg: UserMI.getOperand(i: 0).getReg())));
1353 continue;
1354 }
1355
1356 if (UserMI.isPHI()) {
1357 // Don't follow PHI cycles
1358 if (!PHISeen.insert(Ptr: &UserMI).second)
1359 continue;
1360 LLVM_DEBUG(dbgs() << " Peeking through uses of PHI\n");
1361 Worklist.insert_range(R: llvm::make_pointer_range(
1362 Range: MRI->use_operands(Reg: UserMI.getOperand(i: 0).getReg())));
1363 continue;
1364 }
1365
1366 auto VLOp = getMinimumVLForUser(UserOp);
1367 if (!VLOp)
1368 return std::nullopt;
1369
1370 // Use the largest VL among all the users. If we cannot determine this
1371 // statically, then we cannot optimize the VL.
1372 if (!CommonVL || RISCV::isVLKnownLE(LHS: *CommonVL, RHS: *VLOp)) {
1373 CommonVL = *VLOp;
1374 LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
1375 } else if (!RISCV::isVLKnownLE(LHS: *VLOp, RHS: *CommonVL)) {
1376 LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n");
1377 return std::nullopt;
1378 }
1379
1380 if (!RISCVII::hasSEWOp(TSFlags: UserMI.getDesc().TSFlags)) {
1381 LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
1382 return std::nullopt;
1383 }
1384
1385 std::optional<OperandInfo> ConsumerInfo = getOperandInfo(MO: UserOp, MRI);
1386 std::optional<OperandInfo> ProducerInfo =
1387 getOperandInfo(MO: MI.getOperand(i: 0), MRI);
1388 if (!ConsumerInfo || !ProducerInfo) {
1389 LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
1390 LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
1391 LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1392 return std::nullopt;
1393 }
1394
1395 // If the operand is used as a scalar operand, then the EEW must be
1396 // compatible. Otherwise, the EMUL *and* EEW must be compatible.
1397 bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(MO: UserOp);
1398 if ((IsVectorOpUsedAsScalarOp &&
1399 !OperandInfo::EEWAreEqual(A: *ConsumerInfo, B: *ProducerInfo)) ||
1400 (!IsVectorOpUsedAsScalarOp &&
1401 !OperandInfo::EMULAndEEWAreEqual(A: *ConsumerInfo, B: *ProducerInfo))) {
1402 LLVM_DEBUG(
1403 dbgs()
1404 << " Abort due to incompatible information for EMUL or EEW.\n");
1405 LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
1406 LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1407 return std::nullopt;
1408 }
1409 }
1410
1411 return CommonVL;
1412}
1413
1414bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const {
1415 LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
1416
1417 unsigned VLOpNum = RISCVII::getVLOpNum(Desc: MI.getDesc());
1418 MachineOperand &VLOp = MI.getOperand(i: VLOpNum);
1419
1420 // If the VL is 1, then there is no need to reduce it. This is an
1421 // optimization, not needed to preserve correctness.
1422 if (VLOp.isImm() && VLOp.getImm() == 1) {
1423 LLVM_DEBUG(dbgs() << " Abort due to VL == 1, no point in reducing.\n");
1424 return false;
1425 }
1426
1427 auto CommonVL = DemandedVLs.lookup(Val: &MI);
1428 if (!CommonVL)
1429 return false;
1430
1431 assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
1432 "Expected VL to be an Imm or virtual Reg");
1433
1434 if (!RISCV::isVLKnownLE(LHS: *CommonVL, RHS: VLOp)) {
1435 LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
1436 return false;
1437 }
1438
1439 if (CommonVL->isIdenticalTo(Other: VLOp)) {
1440 LLVM_DEBUG(
1441 dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
1442 return false;
1443 }
1444
1445 if (CommonVL->isImm()) {
1446 LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1447 << CommonVL->getImm() << " for " << MI << "\n");
1448 VLOp.ChangeToImmediate(ImmVal: CommonVL->getImm());
1449 return true;
1450 }
1451 const MachineInstr *VLMI = MRI->getVRegDef(Reg: CommonVL->getReg());
1452 if (!MDT->dominates(A: VLMI, B: &MI))
1453 return false;
1454 LLVM_DEBUG(
1455 dbgs() << " Reduce VL from " << VLOp << " to "
1456 << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
1457 << " for " << MI << "\n");
1458
1459 // All our checks passed. We can reduce VL.
1460 VLOp.ChangeToRegister(Reg: CommonVL->getReg(), isDef: false);
1461 return true;
1462}
1463
1464bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
1465 assert(DemandedVLs.size() == 0);
1466 if (skipFunction(F: MF.getFunction()))
1467 return false;
1468
1469 MRI = &MF.getRegInfo();
1470 MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
1471
1472 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
1473 if (!ST.hasVInstructions())
1474 return false;
1475
1476 // For each instruction that defines a vector, compute what VL its
1477 // downstream users demand.
1478 for (MachineBasicBlock *MBB : post_order(G: &MF)) {
1479 assert(MDT->isReachableFromEntry(MBB));
1480 for (MachineInstr &MI : reverse(C&: *MBB)) {
1481 if (!isCandidate(MI))
1482 continue;
1483 DemandedVLs.insert(KV: {&MI, checkUsers(MI)});
1484 }
1485 }
1486
1487 // Then go through and see if we can reduce the VL of any instructions to
1488 // only what's demanded.
1489 bool MadeChange = false;
1490 for (MachineBasicBlock &MBB : MF) {
1491 // Avoid unreachable blocks as they have degenerate dominance
1492 if (!MDT->isReachableFromEntry(A: &MBB))
1493 continue;
1494
1495 for (auto &MI : reverse(C&: MBB)) {
1496 if (!isCandidate(MI))
1497 continue;
1498 if (!tryReduceVL(MI))
1499 continue;
1500 MadeChange = true;
1501 }
1502 }
1503
1504 DemandedVLs.clear();
1505 return MadeChange;
1506}
1507