1//===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// This pass does some optimizations for *W instructions at the MI level.
10//
11// First it removes unneeded sext.w instructions. Either because the sign
12// extended bits aren't consumed or because the input was already sign extended
13// by an earlier instruction.
14//
15// Then:
16// 1. Unless explicit disabled or the target prefers instructions with W suffix,
17// it removes the -w suffix from opw instructions whenever all users are
18// dependent only on the lower word of the result of the instruction.
19// The cases handled are:
20// * addw because c.add has a larger register encoding than c.addw.
21// * addiw because it helps reduce test differences between RV32 and RV64
22// w/o being a pessimization.
23// * mulw because c.mulw doesn't exist but c.mul does (w/ zcb)
24// * slliw because c.slliw doesn't exist and c.slli does
25//
26// 2. Or if explicit enabled or the target prefers instructions with W suffix,
27// it adds the W suffix to the instruction whenever all users are dependent
28// only on the lower word of the result of the instruction.
29// The cases handled are:
30// * add/addi/sub/mul.
31// * slli with imm < 32.
32// * ld/lwu.
33//===---------------------------------------------------------------------===//
34
35#include "RISCV.h"
36#include "RISCVMachineFunctionInfo.h"
37#include "RISCVSubtarget.h"
38#include "llvm/ADT/SmallSet.h"
39#include "llvm/ADT/Statistic.h"
40#include "llvm/CodeGen/MachineFunctionPass.h"
41#include "llvm/CodeGen/TargetInstrInfo.h"
42
43using namespace llvm;
44
45#define DEBUG_TYPE "riscv-opt-w-instrs"
46#define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
47
48STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
49STATISTIC(NumTransformedToWInstrs,
50 "Number of instructions transformed to W-ops");
51
52static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
53 cl::desc("Disable removal of sext.w"),
54 cl::init(Val: false), cl::Hidden);
55static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
56 cl::desc("Disable strip W suffix"),
57 cl::init(Val: false), cl::Hidden);
58
59namespace {
60
61class RISCVOptWInstrs : public MachineFunctionPass {
62public:
63 static char ID;
64
65 RISCVOptWInstrs() : MachineFunctionPass(ID) {}
66
67 bool runOnMachineFunction(MachineFunction &MF) override;
68 bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
69 const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
70 bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
71 const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
72 bool appendWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
73 const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
74
75 void getAnalysisUsage(AnalysisUsage &AU) const override {
76 AU.setPreservesCFG();
77 MachineFunctionPass::getAnalysisUsage(AU);
78 }
79
80 StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; }
81};
82
83} // end anonymous namespace
84
85char RISCVOptWInstrs::ID = 0;
86INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false,
87 false)
88
89FunctionPass *llvm::createRISCVOptWInstrsPass() {
90 return new RISCVOptWInstrs();
91}
92
93static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
94 unsigned Bits) {
95 const MachineInstr &MI = *UserOp.getParent();
96 unsigned MCOpcode = RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode());
97
98 if (!MCOpcode)
99 return false;
100
101 const MCInstrDesc &MCID = MI.getDesc();
102 const uint64_t TSFlags = MCID.TSFlags;
103 if (!RISCVII::hasSEWOp(TSFlags))
104 return false;
105 assert(RISCVII::hasVLOp(TSFlags));
106 const unsigned Log2SEW = MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MCID)).getImm();
107
108 if (UserOp.getOperandNo() == RISCVII::getVLOpNum(Desc: MCID))
109 return false;
110
111 auto NumDemandedBits =
112 RISCV::getVectorLowDemandedScalarBits(Opcode: MCOpcode, Log2SEW);
113 return NumDemandedBits && Bits >= *NumDemandedBits;
114}
115
116// Checks if all users only demand the lower \p OrigBits of the original
117// instruction's result.
118// TODO: handle multiple interdependent transformations
119static bool hasAllNBitUsers(const MachineInstr &OrigMI,
120 const RISCVSubtarget &ST,
121 const MachineRegisterInfo &MRI, unsigned OrigBits) {
122
123 SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
124 SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
125
126 Worklist.push_back(Elt: std::make_pair(x: &OrigMI, y&: OrigBits));
127
128 while (!Worklist.empty()) {
129 auto P = Worklist.pop_back_val();
130 const MachineInstr *MI = P.first;
131 unsigned Bits = P.second;
132
133 if (!Visited.insert(V: P).second)
134 continue;
135
136 // Only handle instructions with one def.
137 if (MI->getNumExplicitDefs() != 1)
138 return false;
139
140 Register DestReg = MI->getOperand(i: 0).getReg();
141 if (!DestReg.isVirtual())
142 return false;
143
144 for (auto &UserOp : MRI.use_nodbg_operands(Reg: DestReg)) {
145 const MachineInstr *UserMI = UserOp.getParent();
146 unsigned OpIdx = UserOp.getOperandNo();
147
148 switch (UserMI->getOpcode()) {
149 default:
150 if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
151 break;
152 return false;
153
154 case RISCV::ADDIW:
155 case RISCV::ADDW:
156 case RISCV::DIVUW:
157 case RISCV::DIVW:
158 case RISCV::MULW:
159 case RISCV::REMUW:
160 case RISCV::REMW:
161 case RISCV::SLLIW:
162 case RISCV::SLLW:
163 case RISCV::SRAIW:
164 case RISCV::SRAW:
165 case RISCV::SRLIW:
166 case RISCV::SRLW:
167 case RISCV::SUBW:
168 case RISCV::ROLW:
169 case RISCV::RORW:
170 case RISCV::RORIW:
171 case RISCV::CLZW:
172 case RISCV::CTZW:
173 case RISCV::CPOPW:
174 case RISCV::SLLI_UW:
175 case RISCV::FMV_W_X:
176 case RISCV::FCVT_H_W:
177 case RISCV::FCVT_H_WU:
178 case RISCV::FCVT_S_W:
179 case RISCV::FCVT_S_WU:
180 case RISCV::FCVT_D_W:
181 case RISCV::FCVT_D_WU:
182 if (Bits >= 32)
183 break;
184 return false;
185 case RISCV::SEXT_B:
186 case RISCV::PACKH:
187 if (Bits >= 8)
188 break;
189 return false;
190 case RISCV::SEXT_H:
191 case RISCV::FMV_H_X:
192 case RISCV::ZEXT_H_RV32:
193 case RISCV::ZEXT_H_RV64:
194 case RISCV::PACKW:
195 if (Bits >= 16)
196 break;
197 return false;
198
199 case RISCV::PACK:
200 if (Bits >= (ST.getXLen() / 2))
201 break;
202 return false;
203
204 case RISCV::SRLI: {
205 // If we are shifting right by less than Bits, and users don't demand
206 // any bits that were shifted into [Bits-1:0], then we can consider this
207 // as an N-Bit user.
208 unsigned ShAmt = UserMI->getOperand(i: 2).getImm();
209 if (Bits > ShAmt) {
210 Worklist.push_back(Elt: std::make_pair(x&: UserMI, y: Bits - ShAmt));
211 break;
212 }
213 return false;
214 }
215
216 // these overwrite higher input bits, otherwise the lower word of output
217 // depends only on the lower word of input. So check their uses read W.
218 case RISCV::SLLI:
219 if (Bits >= (ST.getXLen() - UserMI->getOperand(i: 2).getImm()))
220 break;
221 Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits));
222 break;
223 case RISCV::ANDI: {
224 uint64_t Imm = UserMI->getOperand(i: 2).getImm();
225 if (Bits >= (unsigned)llvm::bit_width(Value: Imm))
226 break;
227 Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits));
228 break;
229 }
230 case RISCV::ORI: {
231 uint64_t Imm = UserMI->getOperand(i: 2).getImm();
232 if (Bits >= (unsigned)llvm::bit_width<uint64_t>(Value: ~Imm))
233 break;
234 Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits));
235 break;
236 }
237
238 case RISCV::SLL:
239 case RISCV::BSET:
240 case RISCV::BCLR:
241 case RISCV::BINV:
242 // Operand 2 is the shift amount which uses log2(xlen) bits.
243 if (OpIdx == 2) {
244 if (Bits >= Log2_32(Value: ST.getXLen()))
245 break;
246 return false;
247 }
248 Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits));
249 break;
250
251 case RISCV::SRA:
252 case RISCV::SRL:
253 case RISCV::ROL:
254 case RISCV::ROR:
255 // Operand 2 is the shift amount which uses 6 bits.
256 if (OpIdx == 2 && Bits >= Log2_32(Value: ST.getXLen()))
257 break;
258 return false;
259
260 case RISCV::ADD_UW:
261 case RISCV::SH1ADD_UW:
262 case RISCV::SH2ADD_UW:
263 case RISCV::SH3ADD_UW:
264 // Operand 1 is implicitly zero extended.
265 if (OpIdx == 1 && Bits >= 32)
266 break;
267 Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits));
268 break;
269
270 case RISCV::BEXTI:
271 if (UserMI->getOperand(i: 2).getImm() >= Bits)
272 return false;
273 break;
274
275 case RISCV::SB:
276 // The first argument is the value to store.
277 if (OpIdx == 0 && Bits >= 8)
278 break;
279 return false;
280 case RISCV::SH:
281 // The first argument is the value to store.
282 if (OpIdx == 0 && Bits >= 16)
283 break;
284 return false;
285 case RISCV::SW:
286 // The first argument is the value to store.
287 if (OpIdx == 0 && Bits >= 32)
288 break;
289 return false;
290
291 // For these, lower word of output in these operations, depends only on
292 // the lower word of input. So, we check all uses only read lower word.
293 case RISCV::COPY:
294 case RISCV::PHI:
295
296 case RISCV::ADD:
297 case RISCV::ADDI:
298 case RISCV::AND:
299 case RISCV::MUL:
300 case RISCV::OR:
301 case RISCV::SUB:
302 case RISCV::XOR:
303 case RISCV::XORI:
304
305 case RISCV::ANDN:
306 case RISCV::BREV8:
307 case RISCV::CLMUL:
308 case RISCV::ORC_B:
309 case RISCV::ORN:
310 case RISCV::SH1ADD:
311 case RISCV::SH2ADD:
312 case RISCV::SH3ADD:
313 case RISCV::XNOR:
314 case RISCV::BSETI:
315 case RISCV::BCLRI:
316 case RISCV::BINVI:
317 Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits));
318 break;
319
320 case RISCV::PseudoCCMOVGPR:
321 // Either operand 4 or operand 5 is returned by this instruction. If
322 // only the lower word of the result is used, then only the lower word
323 // of operand 4 and 5 is used.
324 if (OpIdx != 4 && OpIdx != 5)
325 return false;
326 Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits));
327 break;
328
329 case RISCV::CZERO_EQZ:
330 case RISCV::CZERO_NEZ:
331 case RISCV::VT_MASKC:
332 case RISCV::VT_MASKCN:
333 if (OpIdx != 1)
334 return false;
335 Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits));
336 break;
337 }
338 }
339 }
340
341 return true;
342}
343
344static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
345 const MachineRegisterInfo &MRI) {
346 return hasAllNBitUsers(OrigMI, ST, MRI, OrigBits: 32);
347}
348
349// This function returns true if the machine instruction always outputs a value
350// where bits 63:32 match bit 31.
351static bool isSignExtendingOpW(const MachineInstr &MI,
352 const MachineRegisterInfo &MRI, unsigned OpNo) {
353 uint64_t TSFlags = MI.getDesc().TSFlags;
354
355 // Instructions that can be determined from opcode are marked in tablegen.
356 if (TSFlags & RISCVII::IsSignExtendingOpWMask)
357 return true;
358
359 // Special cases that require checking operands.
360 switch (MI.getOpcode()) {
361 // shifting right sufficiently makes the value 32-bit sign-extended
362 case RISCV::SRAI:
363 return MI.getOperand(i: 2).getImm() >= 32;
364 case RISCV::SRLI:
365 return MI.getOperand(i: 2).getImm() > 32;
366 // The LI pattern ADDI rd, X0, imm is sign extended.
367 case RISCV::ADDI:
368 return MI.getOperand(i: 1).isReg() && MI.getOperand(i: 1).getReg() == RISCV::X0;
369 // An ANDI with an 11 bit immediate will zero bits 63:11.
370 case RISCV::ANDI:
371 return isUInt<11>(x: MI.getOperand(i: 2).getImm());
372 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
373 case RISCV::ORI:
374 return !isUInt<11>(x: MI.getOperand(i: 2).getImm());
375 // A bseti with X0 is sign extended if the immediate is less than 31.
376 case RISCV::BSETI:
377 return MI.getOperand(i: 2).getImm() < 31 &&
378 MI.getOperand(i: 1).getReg() == RISCV::X0;
379 // Copying from X0 produces zero.
380 case RISCV::COPY:
381 return MI.getOperand(i: 1).getReg() == RISCV::X0;
382 // Ignore the scratch register destination.
383 case RISCV::PseudoAtomicLoadNand32:
384 return OpNo == 0;
385 case RISCV::PseudoVMV_X_S: {
386 // vmv.x.s has at least 33 sign bits if log2(sew) <= 5.
387 int64_t Log2SEW = MI.getOperand(i: 2).getImm();
388 assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW");
389 return Log2SEW <= 5;
390 }
391 }
392
393 return false;
394}
395
396static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
397 const MachineRegisterInfo &MRI,
398 SmallPtrSetImpl<MachineInstr *> &FixableDef) {
399 SmallSet<Register, 4> Visited;
400 SmallVector<Register, 4> Worklist;
401
402 auto AddRegToWorkList = [&](Register SrcReg) {
403 if (!SrcReg.isVirtual())
404 return false;
405 Worklist.push_back(Elt: SrcReg);
406 return true;
407 };
408
409 if (!AddRegToWorkList(SrcReg))
410 return false;
411
412 while (!Worklist.empty()) {
413 Register Reg = Worklist.pop_back_val();
414
415 // If we already visited this register, we don't need to check it again.
416 if (!Visited.insert(V: Reg).second)
417 continue;
418
419 MachineInstr *MI = MRI.getVRegDef(Reg);
420 if (!MI)
421 continue;
422
423 int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr);
424 assert(OpNo != -1 && "Couldn't find register");
425
426 // If this is a sign extending operation we don't need to look any further.
427 if (isSignExtendingOpW(MI: *MI, MRI, OpNo))
428 continue;
429
430 // Is this an instruction that propagates sign extend?
431 switch (MI->getOpcode()) {
432 default:
433 // Unknown opcode, give up.
434 return false;
435 case RISCV::COPY: {
436 const MachineFunction *MF = MI->getMF();
437 const RISCVMachineFunctionInfo *RVFI =
438 MF->getInfo<RISCVMachineFunctionInfo>();
439
440 // If this is the entry block and the register is livein, see if we know
441 // it is sign extended.
442 if (MI->getParent() == &MF->front()) {
443 Register VReg = MI->getOperand(i: 0).getReg();
444 if (MF->getRegInfo().isLiveIn(Reg: VReg) && RVFI->isSExt32Register(Reg: VReg))
445 continue;
446 }
447
448 Register CopySrcReg = MI->getOperand(i: 1).getReg();
449 if (CopySrcReg == RISCV::X10) {
450 // For a method return value, we check the ZExt/SExt flags in attribute.
451 // We assume the following code sequence for method call.
452 // PseudoCALL @bar, ...
453 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
454 // %0:gpr = COPY $x10
455 //
456 // We use the PseudoCall to look up the IR function being called to find
457 // its return attributes.
458 const MachineBasicBlock *MBB = MI->getParent();
459 auto II = MI->getIterator();
460 if (II == MBB->instr_begin() ||
461 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP)
462 return false;
463
464 const MachineInstr &CallMI = *(--II);
465 if (!CallMI.isCall() || !CallMI.getOperand(i: 0).isGlobal())
466 return false;
467
468 auto *CalleeFn =
469 dyn_cast_if_present<Function>(Val: CallMI.getOperand(i: 0).getGlobal());
470 if (!CalleeFn)
471 return false;
472
473 auto *IntTy = dyn_cast<IntegerType>(Val: CalleeFn->getReturnType());
474 if (!IntTy)
475 return false;
476
477 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs();
478 unsigned BitWidth = IntTy->getBitWidth();
479 if ((BitWidth <= 32 && Attrs.hasAttribute(Kind: Attribute::SExt)) ||
480 (BitWidth < 32 && Attrs.hasAttribute(Kind: Attribute::ZExt)))
481 continue;
482 }
483
484 if (!AddRegToWorkList(CopySrcReg))
485 return false;
486
487 break;
488 }
489
490 // For these, we just need to check if the 1st operand is sign extended.
491 case RISCV::BCLRI:
492 case RISCV::BINVI:
493 case RISCV::BSETI:
494 if (MI->getOperand(i: 2).getImm() >= 31)
495 return false;
496 [[fallthrough]];
497 case RISCV::REM:
498 case RISCV::ANDI:
499 case RISCV::ORI:
500 case RISCV::XORI:
501 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
502 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
503 // Logical operations use a sign extended 12-bit immediate.
504 if (!AddRegToWorkList(MI->getOperand(i: 1).getReg()))
505 return false;
506
507 break;
508 case RISCV::PseudoCCADDW:
509 case RISCV::PseudoCCADDIW:
510 case RISCV::PseudoCCSUBW:
511 case RISCV::PseudoCCSLLW:
512 case RISCV::PseudoCCSRLW:
513 case RISCV::PseudoCCSRAW:
514 case RISCV::PseudoCCSLLIW:
515 case RISCV::PseudoCCSRLIW:
516 case RISCV::PseudoCCSRAIW:
517 // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only
518 // need to check if operand 4 is sign extended.
519 if (!AddRegToWorkList(MI->getOperand(i: 4).getReg()))
520 return false;
521 break;
522 case RISCV::REMU:
523 case RISCV::AND:
524 case RISCV::OR:
525 case RISCV::XOR:
526 case RISCV::ANDN:
527 case RISCV::ORN:
528 case RISCV::XNOR:
529 case RISCV::MAX:
530 case RISCV::MAXU:
531 case RISCV::MIN:
532 case RISCV::MINU:
533 case RISCV::PseudoCCMOVGPR:
534 case RISCV::PseudoCCAND:
535 case RISCV::PseudoCCOR:
536 case RISCV::PseudoCCXOR:
537 case RISCV::PHI: {
538 // If all incoming values are sign-extended, the output of AND, OR, XOR,
539 // MIN, MAX, or PHI is also sign-extended.
540
541 // The input registers for PHI are operand 1, 3, ...
542 // The input registers for PseudoCCMOVGPR are 4 and 5.
543 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6.
544 // The input registers for others are operand 1 and 2.
545 unsigned B = 1, E = 3, D = 1;
546 switch (MI->getOpcode()) {
547 case RISCV::PHI:
548 E = MI->getNumOperands();
549 D = 2;
550 break;
551 case RISCV::PseudoCCMOVGPR:
552 B = 4;
553 E = 6;
554 break;
555 case RISCV::PseudoCCAND:
556 case RISCV::PseudoCCOR:
557 case RISCV::PseudoCCXOR:
558 B = 4;
559 E = 7;
560 break;
561 }
562
563 for (unsigned I = B; I != E; I += D) {
564 if (!MI->getOperand(i: I).isReg())
565 return false;
566
567 if (!AddRegToWorkList(MI->getOperand(i: I).getReg()))
568 return false;
569 }
570
571 break;
572 }
573
574 case RISCV::CZERO_EQZ:
575 case RISCV::CZERO_NEZ:
576 case RISCV::VT_MASKC:
577 case RISCV::VT_MASKCN:
578 // Instructions return zero or operand 1. Result is sign extended if
579 // operand 1 is sign extended.
580 if (!AddRegToWorkList(MI->getOperand(i: 1).getReg()))
581 return false;
582 break;
583
584 // With these opcode, we can "fix" them with the W-version
585 // if we know all users of the result only rely on bits 31:0
586 case RISCV::SLLI:
587 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
588 if (MI->getOperand(i: 2).getImm() >= 32)
589 return false;
590 [[fallthrough]];
591 case RISCV::ADDI:
592 case RISCV::ADD:
593 case RISCV::LD:
594 case RISCV::LWU:
595 case RISCV::MUL:
596 case RISCV::SUB:
597 if (hasAllWUsers(OrigMI: *MI, ST, MRI)) {
598 FixableDef.insert(Ptr: MI);
599 break;
600 }
601 return false;
602 }
603 }
604
605 // If we get here, then every node we visited produces a sign extended value
606 // or propagated sign extended values. So the result must be sign extended.
607 return true;
608}
609
610static unsigned getWOp(unsigned Opcode) {
611 switch (Opcode) {
612 case RISCV::ADDI:
613 return RISCV::ADDIW;
614 case RISCV::ADD:
615 return RISCV::ADDW;
616 case RISCV::LD:
617 case RISCV::LWU:
618 return RISCV::LW;
619 case RISCV::MUL:
620 return RISCV::MULW;
621 case RISCV::SLLI:
622 return RISCV::SLLIW;
623 case RISCV::SUB:
624 return RISCV::SUBW;
625 default:
626 llvm_unreachable("Unexpected opcode for replacement with W variant");
627 }
628}
629
630bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
631 const RISCVInstrInfo &TII,
632 const RISCVSubtarget &ST,
633 MachineRegisterInfo &MRI) {
634 if (DisableSExtWRemoval)
635 return false;
636
637 bool MadeChange = false;
638 for (MachineBasicBlock &MBB : MF) {
639 for (MachineInstr &MI : llvm::make_early_inc_range(Range&: MBB)) {
640 // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
641 if (!RISCV::isSEXT_W(MI))
642 continue;
643
644 Register SrcReg = MI.getOperand(i: 1).getReg();
645
646 SmallPtrSet<MachineInstr *, 4> FixableDefs;
647
648 // If all users only use the lower bits, this sext.w is redundant.
649 // Or if all definitions reaching MI sign-extend their output,
650 // then sext.w is redundant.
651 if (!hasAllWUsers(OrigMI: MI, ST, MRI) &&
652 !isSignExtendedW(SrcReg, ST, MRI, FixableDef&: FixableDefs))
653 continue;
654
655 Register DstReg = MI.getOperand(i: 0).getReg();
656 if (!MRI.constrainRegClass(Reg: SrcReg, RC: MRI.getRegClass(Reg: DstReg)))
657 continue;
658
659 // Convert Fixable instructions to their W versions.
660 for (MachineInstr *Fixable : FixableDefs) {
661 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
662 Fixable->setDesc(TII.get(Opcode: getWOp(Opcode: Fixable->getOpcode())));
663 Fixable->clearFlag(Flag: MachineInstr::MIFlag::NoSWrap);
664 Fixable->clearFlag(Flag: MachineInstr::MIFlag::NoUWrap);
665 Fixable->clearFlag(Flag: MachineInstr::MIFlag::IsExact);
666 LLVM_DEBUG(dbgs() << " with " << *Fixable);
667 ++NumTransformedToWInstrs;
668 }
669
670 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
671 MRI.replaceRegWith(FromReg: DstReg, ToReg: SrcReg);
672 MRI.clearKillFlags(Reg: SrcReg);
673 MI.eraseFromParent();
674 ++NumRemovedSExtW;
675 MadeChange = true;
676 }
677 }
678
679 return MadeChange;
680}
681
682bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
683 const RISCVInstrInfo &TII,
684 const RISCVSubtarget &ST,
685 MachineRegisterInfo &MRI) {
686 bool MadeChange = false;
687 for (MachineBasicBlock &MBB : MF) {
688 for (MachineInstr &MI : MBB) {
689 unsigned Opc;
690 switch (MI.getOpcode()) {
691 default:
692 continue;
693 case RISCV::ADDW: Opc = RISCV::ADD; break;
694 case RISCV::ADDIW: Opc = RISCV::ADDI; break;
695 case RISCV::MULW: Opc = RISCV::MUL; break;
696 case RISCV::SLLIW: Opc = RISCV::SLLI; break;
697 }
698
699 if (hasAllWUsers(OrigMI: MI, ST, MRI)) {
700 MI.setDesc(TII.get(Opcode: Opc));
701 MadeChange = true;
702 }
703 }
704 }
705
706 return MadeChange;
707}
708
709bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF,
710 const RISCVInstrInfo &TII,
711 const RISCVSubtarget &ST,
712 MachineRegisterInfo &MRI) {
713 bool MadeChange = false;
714 for (MachineBasicBlock &MBB : MF) {
715 for (MachineInstr &MI : MBB) {
716 unsigned WOpc;
717 // TODO: Add more?
718 switch (MI.getOpcode()) {
719 default:
720 continue;
721 case RISCV::ADD:
722 WOpc = RISCV::ADDW;
723 break;
724 case RISCV::ADDI:
725 WOpc = RISCV::ADDIW;
726 break;
727 case RISCV::SUB:
728 WOpc = RISCV::SUBW;
729 break;
730 case RISCV::MUL:
731 WOpc = RISCV::MULW;
732 break;
733 case RISCV::SLLI:
734 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
735 if (MI.getOperand(i: 2).getImm() >= 32)
736 continue;
737 WOpc = RISCV::SLLIW;
738 break;
739 case RISCV::LD:
740 case RISCV::LWU:
741 WOpc = RISCV::LW;
742 break;
743 }
744
745 if (hasAllWUsers(OrigMI: MI, ST, MRI)) {
746 LLVM_DEBUG(dbgs() << "Replacing " << MI);
747 MI.setDesc(TII.get(Opcode: WOpc));
748 MI.clearFlag(Flag: MachineInstr::MIFlag::NoSWrap);
749 MI.clearFlag(Flag: MachineInstr::MIFlag::NoUWrap);
750 MI.clearFlag(Flag: MachineInstr::MIFlag::IsExact);
751 LLVM_DEBUG(dbgs() << " with " << MI);
752 ++NumTransformedToWInstrs;
753 MadeChange = true;
754 }
755 }
756 }
757
758 return MadeChange;
759}
760
761bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
762 if (skipFunction(F: MF.getFunction()))
763 return false;
764
765 MachineRegisterInfo &MRI = MF.getRegInfo();
766 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
767 const RISCVInstrInfo &TII = *ST.getInstrInfo();
768
769 if (!ST.is64Bit())
770 return false;
771
772 bool MadeChange = false;
773 MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
774
775 if (!(DisableStripWSuffix || ST.preferWInst()))
776 MadeChange |= stripWSuffixes(MF, TII, ST, MRI);
777
778 if (ST.preferWInst())
779 MadeChange |= appendWSuffixes(MF, TII, ST, MRI);
780
781 return MadeChange;
782}
783