1//===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// This pass does some optimizations for *W instructions at the MI level.
10//
11// First it removes unneeded sext.w instructions. Either because the sign
12// extended bits aren't consumed or because the input was already sign extended
13// by an earlier instruction.
14//
15// Then:
16// 1. Unless explicit disabled or the target prefers instructions with W suffix,
17// it removes the -w suffix from opw instructions whenever all users are
18// dependent only on the lower word of the result of the instruction.
19// The cases handled are:
20// * addw because c.add has a larger register encoding than c.addw.
21// * addiw because it helps reduce test differences between RV32 and RV64
22// w/o being a pessimization.
23// * mulw because c.mulw doesn't exist but c.mul does (w/ zcb)
24// * slliw because c.slliw doesn't exist and c.slli does
25//
26// 2. Or if explicit enabled or the target prefers instructions with W suffix,
27// it adds the W suffix to the instruction whenever all users are dependent
28// only on the lower word of the result of the instruction.
29// The cases handled are:
30// * add/addi/sub/mul.
31// * slli with imm < 32.
32// * ld/lwu.
33//===---------------------------------------------------------------------===//
34
35#include "RISCV.h"
36#include "RISCVMachineFunctionInfo.h"
37#include "RISCVSubtarget.h"
38#include "llvm/ADT/SmallSet.h"
39#include "llvm/ADT/Statistic.h"
40#include "llvm/CodeGen/MachineFunctionPass.h"
41#include "llvm/CodeGen/TargetInstrInfo.h"
42
43using namespace llvm;
44
45#define DEBUG_TYPE "riscv-opt-w-instrs"
46#define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
47
48STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
49STATISTIC(NumTransformedToWInstrs,
50 "Number of instructions transformed to W-ops");
51STATISTIC(NumTransformedToNonWInstrs,
52 "Number of instructions transformed to non-W-ops");
53
54static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
55 cl::desc("Disable removal of sext.w"),
56 cl::init(Val: false), cl::Hidden);
57static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
58 cl::desc("Disable strip W suffix"),
59 cl::init(Val: false), cl::Hidden);
60
61namespace {
62
63class RISCVOptWInstrs : public MachineFunctionPass {
64public:
65 static char ID;
66
67 RISCVOptWInstrs() : MachineFunctionPass(ID) {}
68
69 bool runOnMachineFunction(MachineFunction &MF) override;
70 bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
71 const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
72 bool canonicalizeWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
73 const RISCVSubtarget &ST,
74 MachineRegisterInfo &MRI);
75
76 void getAnalysisUsage(AnalysisUsage &AU) const override {
77 AU.setPreservesCFG();
78 MachineFunctionPass::getAnalysisUsage(AU);
79 }
80
81 StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; }
82};
83
84} // end anonymous namespace
85
86char RISCVOptWInstrs::ID = 0;
87INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false,
88 false)
89
90FunctionPass *llvm::createRISCVOptWInstrsPass() {
91 return new RISCVOptWInstrs();
92}
93
94static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
95 unsigned Bits) {
96 const MachineInstr &MI = *UserOp.getParent();
97 unsigned MCOpcode = RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode());
98
99 if (!MCOpcode)
100 return false;
101
102 const MCInstrDesc &MCID = MI.getDesc();
103 const uint64_t TSFlags = MCID.TSFlags;
104 if (!RISCVII::hasSEWOp(TSFlags))
105 return false;
106 assert(RISCVII::hasVLOp(TSFlags));
107 const unsigned Log2SEW = MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MCID)).getImm();
108
109 if (UserOp.getOperandNo() == RISCVII::getVLOpNum(Desc: MCID))
110 return false;
111
112 auto NumDemandedBits =
113 RISCV::getVectorLowDemandedScalarBits(Opcode: MCOpcode, Log2SEW);
114 return NumDemandedBits && Bits >= *NumDemandedBits;
115}
116
117// Checks if all users only demand the lower \p OrigBits of the original
118// instruction's result.
119// TODO: handle multiple interdependent transformations
120static bool hasAllNBitUsers(const MachineInstr &OrigMI,
121 const RISCVSubtarget &ST,
122 const MachineRegisterInfo &MRI, unsigned OrigBits) {
123
124 SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
125 SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
126
127 Worklist.emplace_back(Args: &OrigMI, Args&: OrigBits);
128
129 while (!Worklist.empty()) {
130 auto P = Worklist.pop_back_val();
131 const MachineInstr *MI = P.first;
132 unsigned Bits = P.second;
133
134 if (!Visited.insert(V: P).second)
135 continue;
136
137 // Only handle instructions with one def.
138 if (MI->getNumExplicitDefs() != 1)
139 return false;
140
141 Register DestReg = MI->getOperand(i: 0).getReg();
142 if (!DestReg.isVirtual())
143 return false;
144
145 for (auto &UserOp : MRI.use_nodbg_operands(Reg: DestReg)) {
146 const MachineInstr *UserMI = UserOp.getParent();
147 unsigned OpIdx = UserOp.getOperandNo();
148
149 switch (UserMI->getOpcode()) {
150 default:
151 if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
152 break;
153 return false;
154
155 case RISCV::ADDIW:
156 case RISCV::ADDW:
157 case RISCV::DIVUW:
158 case RISCV::DIVW:
159 case RISCV::MULW:
160 case RISCV::REMUW:
161 case RISCV::REMW:
162 case RISCV::SLLW:
163 case RISCV::SRAIW:
164 case RISCV::SRAW:
165 case RISCV::SRLIW:
166 case RISCV::SRLW:
167 case RISCV::SUBW:
168 case RISCV::ROLW:
169 case RISCV::RORW:
170 case RISCV::RORIW:
171 case RISCV::CLSW:
172 case RISCV::CLZW:
173 case RISCV::CTZW:
174 case RISCV::CPOPW:
175 case RISCV::SLLI_UW:
176 case RISCV::ABSW:
177 case RISCV::FMV_W_X:
178 case RISCV::FCVT_H_W:
179 case RISCV::FCVT_H_W_INX:
180 case RISCV::FCVT_H_WU:
181 case RISCV::FCVT_H_WU_INX:
182 case RISCV::FCVT_S_W:
183 case RISCV::FCVT_S_W_INX:
184 case RISCV::FCVT_S_WU:
185 case RISCV::FCVT_S_WU_INX:
186 case RISCV::FCVT_D_W:
187 case RISCV::FCVT_D_W_INX:
188 case RISCV::FCVT_D_WU:
189 case RISCV::FCVT_D_WU_INX:
190 if (Bits >= 32)
191 break;
192 return false;
193
194 case RISCV::SEXT_B:
195 case RISCV::PACKH:
196 if (Bits >= 8)
197 break;
198 return false;
199 case RISCV::SEXT_H:
200 case RISCV::FMV_H_X:
201 case RISCV::ZEXT_H_RV32:
202 case RISCV::ZEXT_H_RV64:
203 case RISCV::PACKW:
204 if (Bits >= 16)
205 break;
206 return false;
207
208 case RISCV::PACK:
209 if (Bits >= (ST.getXLen() / 2))
210 break;
211 return false;
212
213 case RISCV::SRLI: {
214 // If we are shifting right by less than Bits, and users don't demand
215 // any bits that were shifted into [Bits-1:0], then we can consider this
216 // as an N-Bit user.
217 unsigned ShAmt = UserMI->getOperand(i: 2).getImm();
218 if (Bits > ShAmt) {
219 Worklist.emplace_back(Args&: UserMI, Args: Bits - ShAmt);
220 break;
221 }
222 return false;
223 }
224
225 // these overwrite higher input bits, otherwise the lower word of output
226 // depends only on the lower word of input. So check their uses read W.
227 case RISCV::SLLI: {
228 unsigned ShAmt = UserMI->getOperand(i: 2).getImm();
229 if (Bits >= (ST.getXLen() - ShAmt))
230 break;
231 Worklist.emplace_back(Args&: UserMI, Args: Bits + ShAmt);
232 break;
233 }
234 case RISCV::SLLIW: {
235 unsigned ShAmt = UserMI->getOperand(i: 2).getImm();
236 if (Bits >= 32 - ShAmt)
237 break;
238 Worklist.emplace_back(Args&: UserMI, Args: Bits + ShAmt);
239 break;
240 }
241
242 case RISCV::ANDI: {
243 uint64_t Imm = UserMI->getOperand(i: 2).getImm();
244 if (Bits >= (unsigned)llvm::bit_width(Value: Imm))
245 break;
246 Worklist.emplace_back(Args&: UserMI, Args&: Bits);
247 break;
248 }
249 case RISCV::ORI: {
250 uint64_t Imm = UserMI->getOperand(i: 2).getImm();
251 if (Bits >= (unsigned)llvm::bit_width<uint64_t>(Value: ~Imm))
252 break;
253 Worklist.emplace_back(Args&: UserMI, Args&: Bits);
254 break;
255 }
256
257 case RISCV::SLL:
258 case RISCV::BSET:
259 case RISCV::BCLR:
260 case RISCV::BINV:
261 // Operand 2 is the shift amount which uses log2(xlen) bits.
262 if (OpIdx == 2) {
263 if (Bits >= Log2_32(Value: ST.getXLen()))
264 break;
265 return false;
266 }
267 Worklist.emplace_back(Args&: UserMI, Args&: Bits);
268 break;
269
270 case RISCV::SRA:
271 case RISCV::SRL:
272 case RISCV::ROL:
273 case RISCV::ROR:
274 // Operand 2 is the shift amount which uses 6 bits.
275 if (OpIdx == 2 && Bits >= Log2_32(Value: ST.getXLen()))
276 break;
277 return false;
278
279 case RISCV::ADD_UW:
280 case RISCV::SH1ADD_UW:
281 case RISCV::SH2ADD_UW:
282 case RISCV::SH3ADD_UW:
283 // Operand 1 is implicitly zero extended.
284 if (OpIdx == 1 && Bits >= 32)
285 break;
286 Worklist.emplace_back(Args&: UserMI, Args&: Bits);
287 break;
288
289 case RISCV::BEXTI:
290 if (UserMI->getOperand(i: 2).getImm() >= Bits)
291 return false;
292 break;
293
294 case RISCV::SB:
295 // The first argument is the value to store.
296 if (OpIdx == 0 && Bits >= 8)
297 break;
298 return false;
299 case RISCV::SH:
300 // The first argument is the value to store.
301 if (OpIdx == 0 && Bits >= 16)
302 break;
303 return false;
304 case RISCV::SW:
305 // The first argument is the value to store.
306 if (OpIdx == 0 && Bits >= 32)
307 break;
308 return false;
309
310 // For these, lower word of output in these operations, depends only on
311 // the lower word of input. So, we check all uses only read lower word.
312 case RISCV::COPY:
313 case RISCV::PHI:
314
315 case RISCV::ADD:
316 case RISCV::ADDI:
317 case RISCV::AND:
318 case RISCV::MUL:
319 case RISCV::OR:
320 case RISCV::SUB:
321 case RISCV::XOR:
322 case RISCV::XORI:
323
324 case RISCV::ANDN:
325 case RISCV::CLMUL:
326 case RISCV::ORN:
327 case RISCV::SH1ADD:
328 case RISCV::SH2ADD:
329 case RISCV::SH3ADD:
330 case RISCV::XNOR:
331 case RISCV::BSETI:
332 case RISCV::BCLRI:
333 case RISCV::BINVI:
334 Worklist.emplace_back(Args&: UserMI, Args&: Bits);
335 break;
336
337 case RISCV::BREV8:
338 case RISCV::ORC_B:
339 // BREV8 and ORC_B work on bytes. Round Bits down to the nearest byte.
340 Worklist.emplace_back(Args&: UserMI, Args: alignDown(Value: Bits, Align: 8));
341 break;
342
343 case RISCV::PseudoCCMOVGPR:
344 case RISCV::PseudoCCMOVGPRNoX0:
345 // Either operand 1 or operand 2 is returned by this instruction. If
346 // only the lower word of the result is used, then only the lower word
347 // of operand 1 and 2 is used.
348 if (OpIdx != 1 && OpIdx != 2)
349 return false;
350 Worklist.emplace_back(Args&: UserMI, Args&: Bits);
351 break;
352
353 case RISCV::CZERO_EQZ:
354 case RISCV::CZERO_NEZ:
355 case RISCV::VT_MASKC:
356 case RISCV::VT_MASKCN:
357 if (OpIdx != 1)
358 return false;
359 Worklist.emplace_back(Args&: UserMI, Args&: Bits);
360 break;
361 case RISCV::TH_EXT:
362 case RISCV::TH_EXTU:
363 unsigned Msb = UserMI->getOperand(i: 2).getImm();
364 unsigned Lsb = UserMI->getOperand(i: 3).getImm();
365 // Behavior of Msb < Lsb is not well documented.
366 if (Msb >= Lsb && Bits > Msb)
367 break;
368 return false;
369 }
370 }
371 }
372
373 return true;
374}
375
376static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
377 const MachineRegisterInfo &MRI) {
378 return hasAllNBitUsers(OrigMI, ST, MRI, OrigBits: 32);
379}
380
381// This function returns true if the machine instruction always outputs a value
382// where bits 63:32 match bit 31.
383static bool isSignExtendingOpW(const MachineInstr &MI, unsigned OpNo) {
384 uint64_t TSFlags = MI.getDesc().TSFlags;
385
386 // Instructions that can be determined from opcode are marked in tablegen.
387 if (TSFlags & RISCVII::IsSignExtendingOpWMask)
388 return true;
389
390 // Special cases that require checking operands.
391 switch (MI.getOpcode()) {
392 // shifting right sufficiently makes the value 32-bit sign-extended
393 case RISCV::SRAI:
394 return MI.getOperand(i: 2).getImm() >= 32;
395 case RISCV::SRLI:
396 return MI.getOperand(i: 2).getImm() > 32;
397 // The LI pattern ADDI rd, X0, imm is sign extended.
398 case RISCV::ADDI:
399 return MI.getOperand(i: 1).isReg() && MI.getOperand(i: 1).getReg() == RISCV::X0;
400 // An ANDI with an 11 bit immediate will zero bits 63:11.
401 case RISCV::ANDI:
402 return isUInt<11>(x: MI.getOperand(i: 2).getImm());
403 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
404 case RISCV::ORI:
405 return !isUInt<11>(x: MI.getOperand(i: 2).getImm());
406 // A bseti with X0 is sign extended if the immediate is less than 31.
407 case RISCV::BSETI:
408 return MI.getOperand(i: 2).getImm() < 31 &&
409 MI.getOperand(i: 1).getReg() == RISCV::X0;
410 // Copying from X0 produces zero.
411 case RISCV::COPY:
412 return MI.getOperand(i: 1).getReg() == RISCV::X0;
413 // Ignore the scratch register destination.
414 case RISCV::PseudoAtomicLoadNand32:
415 return OpNo == 0;
416 case RISCV::PseudoVMV_X_S: {
417 // vmv.x.s has at least 33 sign bits if log2(sew) <= 5.
418 int64_t Log2SEW = MI.getOperand(i: 2).getImm();
419 assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW");
420 return Log2SEW <= 5;
421 }
422 case RISCV::TH_EXT: {
423 unsigned Msb = MI.getOperand(i: 2).getImm();
424 unsigned Lsb = MI.getOperand(i: 3).getImm();
425 return Msb >= Lsb && (Msb - Lsb + 1) <= 32;
426 }
427 case RISCV::TH_EXTU: {
428 unsigned Msb = MI.getOperand(i: 2).getImm();
429 unsigned Lsb = MI.getOperand(i: 3).getImm();
430 return Msb >= Lsb && (Msb - Lsb + 1) < 32;
431 }
432 }
433
434 return false;
435}
436
437static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
438 const MachineRegisterInfo &MRI,
439 SmallPtrSetImpl<MachineInstr *> &FixableDef) {
440 SmallSet<Register, 4> Visited;
441 SmallVector<Register, 4> Worklist;
442
443 auto AddRegToWorkList = [&](Register SrcReg) {
444 if (!SrcReg.isVirtual())
445 return false;
446 Worklist.push_back(Elt: SrcReg);
447 return true;
448 };
449
450 if (!AddRegToWorkList(SrcReg))
451 return false;
452
453 while (!Worklist.empty()) {
454 Register Reg = Worklist.pop_back_val();
455
456 // If we already visited this register, we don't need to check it again.
457 if (!Visited.insert(V: Reg).second)
458 continue;
459
460 MachineInstr *MI = MRI.getVRegDef(Reg);
461 if (!MI)
462 continue;
463
464 int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr);
465 assert(OpNo != -1 && "Couldn't find register");
466
467 // If this is a sign extending operation we don't need to look any further.
468 if (isSignExtendingOpW(MI: *MI, OpNo))
469 continue;
470
471 // Is this an instruction that propagates sign extend?
472 switch (MI->getOpcode()) {
473 default:
474 // Unknown opcode, give up.
475 return false;
476 case RISCV::COPY: {
477 const MachineFunction *MF = MI->getMF();
478 const RISCVMachineFunctionInfo *RVFI =
479 MF->getInfo<RISCVMachineFunctionInfo>();
480
481 // If this is the entry block and the register is livein, see if we know
482 // it is sign extended.
483 if (MI->getParent() == &MF->front()) {
484 Register VReg = MI->getOperand(i: 0).getReg();
485 if (MF->getRegInfo().isLiveIn(Reg: VReg) && RVFI->isSExt32Register(Reg: VReg))
486 continue;
487 }
488
489 Register CopySrcReg = MI->getOperand(i: 1).getReg();
490 if (CopySrcReg == RISCV::X10) {
491 // For a method return value, we check the ZExt/SExt flags in attribute.
492 // We assume the following code sequence for method call.
493 // PseudoCALL @bar, ...
494 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
495 // %0:gpr = COPY $x10
496 //
497 // We use the PseudoCall to look up the IR function being called to find
498 // its return attributes.
499 const MachineBasicBlock *MBB = MI->getParent();
500 auto II = MI->getIterator();
501 if (II == MBB->instr_begin() ||
502 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP)
503 return false;
504
505 const MachineInstr &CallMI = *(--II);
506 if (!CallMI.isCall() || !CallMI.getOperand(i: 0).isGlobal())
507 return false;
508
509 auto *CalleeFn =
510 dyn_cast_if_present<Function>(Val: CallMI.getOperand(i: 0).getGlobal());
511 if (!CalleeFn)
512 return false;
513
514 auto *IntTy = dyn_cast<IntegerType>(Val: CalleeFn->getReturnType());
515 if (!IntTy)
516 return false;
517
518 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs();
519 unsigned BitWidth = IntTy->getBitWidth();
520 if ((BitWidth <= 32 && Attrs.hasAttribute(Kind: Attribute::SExt)) ||
521 (BitWidth < 32 && Attrs.hasAttribute(Kind: Attribute::ZExt)))
522 continue;
523 }
524
525 if (!AddRegToWorkList(CopySrcReg))
526 return false;
527
528 break;
529 }
530
531 // For these, we just need to check if the 1st operand is sign extended.
532 case RISCV::BCLRI:
533 case RISCV::BINVI:
534 case RISCV::BSETI:
535 if (MI->getOperand(i: 2).getImm() >= 31)
536 return false;
537 [[fallthrough]];
538 case RISCV::REM:
539 case RISCV::ANDI:
540 case RISCV::ORI:
541 case RISCV::XORI:
542 case RISCV::SRAI:
543 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
544 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
545 // Logical operations use a sign extended 12-bit immediate.
546 // Arithmetic shift right can only increase the number of sign bits.
547 if (!AddRegToWorkList(MI->getOperand(i: 1).getReg()))
548 return false;
549
550 break;
551 case RISCV::PseudoCCADDW:
552 case RISCV::PseudoCCADDIW:
553 case RISCV::PseudoCCSUBW:
554 case RISCV::PseudoCCSLLW:
555 case RISCV::PseudoCCSRLW:
556 case RISCV::PseudoCCSRAW:
557 case RISCV::PseudoCCSLLIW:
558 case RISCV::PseudoCCSRLIW:
559 case RISCV::PseudoCCSRAIW:
560 // Returns operand 1 or an ADDW/SUBW/etc. of operands 2 and 3. We only
561 // need to check if operand 1 is sign extended.
562 if (!AddRegToWorkList(MI->getOperand(i: 1).getReg()))
563 return false;
564 break;
565 case RISCV::REMU:
566 case RISCV::AND:
567 case RISCV::OR:
568 case RISCV::XOR:
569 case RISCV::ANDN:
570 case RISCV::ORN:
571 case RISCV::XNOR:
572 case RISCV::MAX:
573 case RISCV::MAXU:
574 case RISCV::MIN:
575 case RISCV::MINU:
576 case RISCV::PseudoCCMOVGPR:
577 case RISCV::PseudoCCMOVGPRNoX0:
578 case RISCV::PseudoCCAND:
579 case RISCV::PseudoCCOR:
580 case RISCV::PseudoCCXOR:
581 case RISCV::PseudoCCANDN:
582 case RISCV::PseudoCCORN:
583 case RISCV::PseudoCCXNOR:
584 case RISCV::PHI:
585 case RISCV::MERGE:
586 case RISCV::MVM:
587 case RISCV::MVMN: {
588 // If all incoming values are sign-extended, the output of AND, OR, XOR,
589 // MIN, MAX, PHI, or bitwise merge instructions is also sign-extended.
590
591 // The input registers for PHI are operand 1, 3, ...
592 // The input registers for PseudoCCMOVGPR(NoX0) are 1 and 2.
593 // The input registers for PseudoCCAND/OR/XOR are 1, 2, and 3.
594 // The input registers for MERGE/MVM/MVMN are 1, 2, and 3.
595 // The input registers for others are operand 1 and 2.
596 unsigned B = 1, E = 3, D = 1;
597 switch (MI->getOpcode()) {
598 case RISCV::PHI:
599 E = MI->getNumOperands();
600 D = 2;
601 break;
602 case RISCV::PseudoCCMOVGPR:
603 case RISCV::PseudoCCMOVGPRNoX0:
604 B = 1;
605 E = 3;
606 break;
607 case RISCV::PseudoCCAND:
608 case RISCV::PseudoCCOR:
609 case RISCV::PseudoCCXOR:
610 case RISCV::PseudoCCANDN:
611 case RISCV::PseudoCCORN:
612 case RISCV::PseudoCCXNOR:
613 B = 1;
614 E = 4;
615 break;
616 case RISCV::MERGE:
617 case RISCV::MVM:
618 case RISCV::MVMN:
619 B = 1;
620 E = 4;
621 break;
622 }
623
624 for (unsigned I = B; I != E; I += D) {
625 if (!MI->getOperand(i: I).isReg())
626 return false;
627
628 if (!AddRegToWorkList(MI->getOperand(i: I).getReg()))
629 return false;
630 }
631
632 break;
633 }
634
635 case RISCV::CZERO_EQZ:
636 case RISCV::CZERO_NEZ:
637 case RISCV::VT_MASKC:
638 case RISCV::VT_MASKCN:
639 // Instructions return zero or operand 1. Result is sign extended if
640 // operand 1 is sign extended.
641 if (!AddRegToWorkList(MI->getOperand(i: 1).getReg()))
642 return false;
643 break;
644
645 case RISCV::ADDI: {
646 if (MI->getOperand(i: 1).isReg() && MI->getOperand(i: 1).getReg().isVirtual()) {
647 if (MachineInstr *SrcMI = MRI.getVRegDef(Reg: MI->getOperand(i: 1).getReg())) {
648 if (SrcMI->getOpcode() == RISCV::LUI &&
649 SrcMI->getOperand(i: 1).isImm()) {
650 uint64_t Imm = SrcMI->getOperand(i: 1).getImm();
651 Imm = SignExtend64<32>(x: Imm << 12);
652 Imm += (uint64_t)MI->getOperand(i: 2).getImm();
653 if (isInt<32>(x: Imm))
654 continue;
655 }
656 }
657 }
658
659 if (hasAllWUsers(OrigMI: *MI, ST, MRI)) {
660 FixableDef.insert(Ptr: MI);
661 break;
662 }
663 return false;
664 }
665
666 // With these opcode, we can "fix" them with the W-version
667 // if we know all users of the result only rely on bits 31:0
668 case RISCV::SLLI:
669 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
670 if (MI->getOperand(i: 2).getImm() >= 32)
671 return false;
672 [[fallthrough]];
673 case RISCV::ADD:
674 case RISCV::LD:
675 case RISCV::LWU:
676 case RISCV::MUL:
677 case RISCV::SUB:
678 if (hasAllWUsers(OrigMI: *MI, ST, MRI)) {
679 FixableDef.insert(Ptr: MI);
680 break;
681 }
682 return false;
683 }
684 }
685
686 // If we get here, then every node we visited produces a sign extended value
687 // or propagated sign extended values. So the result must be sign extended.
688 return true;
689}
690
691static unsigned getWOp(unsigned Opcode) {
692 switch (Opcode) {
693 case RISCV::ADDI:
694 return RISCV::ADDIW;
695 case RISCV::ADD:
696 return RISCV::ADDW;
697 case RISCV::LD:
698 case RISCV::LWU:
699 return RISCV::LW;
700 case RISCV::MUL:
701 return RISCV::MULW;
702 case RISCV::SLLI:
703 return RISCV::SLLIW;
704 case RISCV::SUB:
705 return RISCV::SUBW;
706 default:
707 llvm_unreachable("Unexpected opcode for replacement with W variant");
708 }
709}
710
711bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
712 const RISCVInstrInfo &TII,
713 const RISCVSubtarget &ST,
714 MachineRegisterInfo &MRI) {
715 if (DisableSExtWRemoval)
716 return false;
717
718 bool MadeChange = false;
719 for (MachineBasicBlock &MBB : MF) {
720 for (MachineInstr &MI : llvm::make_early_inc_range(Range&: MBB)) {
721 // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
722 if (!RISCVInstrInfo::isSEXT_W(MI))
723 continue;
724
725 Register SrcReg = MI.getOperand(i: 1).getReg();
726
727 SmallPtrSet<MachineInstr *, 4> FixableDefs;
728
729 // If all users only use the lower bits, this sext.w is redundant.
730 // Or if all definitions reaching MI sign-extend their output,
731 // then sext.w is redundant.
732 if (!hasAllWUsers(OrigMI: MI, ST, MRI) &&
733 !isSignExtendedW(SrcReg, ST, MRI, FixableDef&: FixableDefs))
734 continue;
735
736 Register DstReg = MI.getOperand(i: 0).getReg();
737 if (!MRI.constrainRegClass(Reg: SrcReg, RC: MRI.getRegClass(Reg: DstReg)))
738 continue;
739
740 // Convert Fixable instructions to their W versions.
741 for (MachineInstr *Fixable : FixableDefs) {
742 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
743 Fixable->setDesc(TII.get(Opcode: getWOp(Opcode: Fixable->getOpcode())));
744 Fixable->clearFlag(Flag: MachineInstr::MIFlag::NoSWrap);
745 Fixable->clearFlag(Flag: MachineInstr::MIFlag::NoUWrap);
746 Fixable->clearFlag(Flag: MachineInstr::MIFlag::IsExact);
747 LLVM_DEBUG(dbgs() << " with " << *Fixable);
748 ++NumTransformedToWInstrs;
749 }
750
751 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
752 MRI.replaceRegWith(FromReg: DstReg, ToReg: SrcReg);
753 MRI.clearKillFlags(Reg: SrcReg);
754 MI.eraseFromParent();
755 ++NumRemovedSExtW;
756 MadeChange = true;
757 }
758 }
759
760 return MadeChange;
761}
762
763// Strips or adds W suffixes to eligible instructions depending on the
764// subtarget preferences.
765bool RISCVOptWInstrs::canonicalizeWSuffixes(MachineFunction &MF,
766 const RISCVInstrInfo &TII,
767 const RISCVSubtarget &ST,
768 MachineRegisterInfo &MRI) {
769 bool ShouldStripW = !(DisableStripWSuffix || ST.preferWInst());
770 bool ShouldPreferW = ST.preferWInst();
771 bool MadeChange = false;
772
773 for (MachineBasicBlock &MBB : MF) {
774 for (MachineInstr &MI : MBB) {
775 std::optional<unsigned> WOpc;
776 std::optional<unsigned> NonWOpc;
777 unsigned OrigOpc = MI.getOpcode();
778 switch (OrigOpc) {
779 default:
780 continue;
781 case RISCV::ADDW:
782 NonWOpc = RISCV::ADD;
783 break;
784 case RISCV::ADDIW:
785 NonWOpc = RISCV::ADDI;
786 break;
787 case RISCV::MULW:
788 NonWOpc = RISCV::MUL;
789 break;
790 case RISCV::SLLIW:
791 NonWOpc = RISCV::SLLI;
792 break;
793 case RISCV::SUBW:
794 NonWOpc = RISCV::SUB;
795 break;
796 case RISCV::ADD:
797 WOpc = RISCV::ADDW;
798 break;
799 case RISCV::ADDI:
800 WOpc = RISCV::ADDIW;
801 break;
802 case RISCV::SUB:
803 WOpc = RISCV::SUBW;
804 break;
805 case RISCV::MUL:
806 WOpc = RISCV::MULW;
807 break;
808 case RISCV::SLLI:
809 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits.
810 if (MI.getOperand(i: 2).getImm() >= 32)
811 continue;
812 WOpc = RISCV::SLLIW;
813 break;
814 case RISCV::LD:
815 case RISCV::LWU:
816 WOpc = RISCV::LW;
817 break;
818 }
819
820 if (ShouldStripW && NonWOpc.has_value() && hasAllWUsers(OrigMI: MI, ST, MRI)) {
821 LLVM_DEBUG(dbgs() << "Replacing " << MI);
822 MI.setDesc(TII.get(Opcode: NonWOpc.value()));
823 LLVM_DEBUG(dbgs() << " with " << MI);
824 ++NumTransformedToNonWInstrs;
825 MadeChange = true;
826 continue;
827 }
828 // LWU is always converted to LW when possible as 1) LW is compressible
829 // and 2) it helps minimise differences vs RV32.
830 if ((ShouldPreferW || OrigOpc == RISCV::LWU) && WOpc.has_value() &&
831 hasAllWUsers(OrigMI: MI, ST, MRI)) {
832 LLVM_DEBUG(dbgs() << "Replacing " << MI);
833 MI.setDesc(TII.get(Opcode: WOpc.value()));
834 MI.clearFlag(Flag: MachineInstr::MIFlag::NoSWrap);
835 MI.clearFlag(Flag: MachineInstr::MIFlag::NoUWrap);
836 MI.clearFlag(Flag: MachineInstr::MIFlag::IsExact);
837 LLVM_DEBUG(dbgs() << " with " << MI);
838 ++NumTransformedToWInstrs;
839 MadeChange = true;
840 continue;
841 }
842 }
843 }
844 return MadeChange;
845}
846
847bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
848 if (skipFunction(F: MF.getFunction()))
849 return false;
850
851 MachineRegisterInfo &MRI = MF.getRegInfo();
852 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
853 const RISCVInstrInfo &TII = *ST.getInstrInfo();
854
855 if (!ST.is64Bit())
856 return false;
857
858 bool MadeChange = false;
859 MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
860 MadeChange |= canonicalizeWSuffixes(MF, TII, ST, MRI);
861 return MadeChange;
862}
863