1 | //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===---------------------------------------------------------------------===// |
8 | // |
9 | // This pass does some optimizations for *W instructions at the MI level. |
10 | // |
11 | // First it removes unneeded sext.w instructions. Either because the sign |
12 | // extended bits aren't consumed or because the input was already sign extended |
13 | // by an earlier instruction. |
14 | // |
15 | // Then: |
16 | // 1. Unless explicit disabled or the target prefers instructions with W suffix, |
17 | // it removes the -w suffix from opw instructions whenever all users are |
18 | // dependent only on the lower word of the result of the instruction. |
19 | // The cases handled are: |
20 | // * addw because c.add has a larger register encoding than c.addw. |
21 | // * addiw because it helps reduce test differences between RV32 and RV64 |
22 | // w/o being a pessimization. |
23 | // * mulw because c.mulw doesn't exist but c.mul does (w/ zcb) |
24 | // * slliw because c.slliw doesn't exist and c.slli does |
25 | // |
26 | // 2. Or if explicit enabled or the target prefers instructions with W suffix, |
27 | // it adds the W suffix to the instruction whenever all users are dependent |
28 | // only on the lower word of the result of the instruction. |
29 | // The cases handled are: |
30 | // * add/addi/sub/mul. |
31 | // * slli with imm < 32. |
32 | // * ld/lwu. |
33 | //===---------------------------------------------------------------------===// |
34 | |
35 | #include "RISCV.h" |
36 | #include "RISCVMachineFunctionInfo.h" |
37 | #include "RISCVSubtarget.h" |
38 | #include "llvm/ADT/SmallSet.h" |
39 | #include "llvm/ADT/Statistic.h" |
40 | #include "llvm/CodeGen/MachineFunctionPass.h" |
41 | #include "llvm/CodeGen/TargetInstrInfo.h" |
42 | |
43 | using namespace llvm; |
44 | |
45 | #define DEBUG_TYPE "riscv-opt-w-instrs" |
46 | #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions" |
47 | |
48 | STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions" ); |
49 | STATISTIC(NumTransformedToWInstrs, |
50 | "Number of instructions transformed to W-ops" ); |
51 | |
52 | static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal" , |
53 | cl::desc("Disable removal of sext.w" ), |
54 | cl::init(Val: false), cl::Hidden); |
55 | static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix" , |
56 | cl::desc("Disable strip W suffix" ), |
57 | cl::init(Val: false), cl::Hidden); |
58 | |
59 | namespace { |
60 | |
61 | class RISCVOptWInstrs : public MachineFunctionPass { |
62 | public: |
63 | static char ID; |
64 | |
65 | RISCVOptWInstrs() : MachineFunctionPass(ID) {} |
66 | |
67 | bool runOnMachineFunction(MachineFunction &MF) override; |
68 | bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, |
69 | const RISCVSubtarget &ST, MachineRegisterInfo &MRI); |
70 | bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, |
71 | const RISCVSubtarget &ST, MachineRegisterInfo &MRI); |
72 | bool appendWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, |
73 | const RISCVSubtarget &ST, MachineRegisterInfo &MRI); |
74 | |
75 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
76 | AU.setPreservesCFG(); |
77 | MachineFunctionPass::getAnalysisUsage(AU); |
78 | } |
79 | |
80 | StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; } |
81 | }; |
82 | |
83 | } // end anonymous namespace |
84 | |
85 | char RISCVOptWInstrs::ID = 0; |
86 | INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false, |
87 | false) |
88 | |
89 | FunctionPass *llvm::createRISCVOptWInstrsPass() { |
90 | return new RISCVOptWInstrs(); |
91 | } |
92 | |
93 | static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp, |
94 | unsigned Bits) { |
95 | const MachineInstr &MI = *UserOp.getParent(); |
96 | unsigned MCOpcode = RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode()); |
97 | |
98 | if (!MCOpcode) |
99 | return false; |
100 | |
101 | const MCInstrDesc &MCID = MI.getDesc(); |
102 | const uint64_t TSFlags = MCID.TSFlags; |
103 | if (!RISCVII::hasSEWOp(TSFlags)) |
104 | return false; |
105 | assert(RISCVII::hasVLOp(TSFlags)); |
106 | const unsigned Log2SEW = MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MCID)).getImm(); |
107 | |
108 | if (UserOp.getOperandNo() == RISCVII::getVLOpNum(Desc: MCID)) |
109 | return false; |
110 | |
111 | auto NumDemandedBits = |
112 | RISCV::getVectorLowDemandedScalarBits(Opcode: MCOpcode, Log2SEW); |
113 | return NumDemandedBits && Bits >= *NumDemandedBits; |
114 | } |
115 | |
116 | // Checks if all users only demand the lower \p OrigBits of the original |
117 | // instruction's result. |
118 | // TODO: handle multiple interdependent transformations |
119 | static bool hasAllNBitUsers(const MachineInstr &OrigMI, |
120 | const RISCVSubtarget &ST, |
121 | const MachineRegisterInfo &MRI, unsigned OrigBits) { |
122 | |
123 | SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited; |
124 | SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist; |
125 | |
126 | Worklist.push_back(Elt: std::make_pair(x: &OrigMI, y&: OrigBits)); |
127 | |
128 | while (!Worklist.empty()) { |
129 | auto P = Worklist.pop_back_val(); |
130 | const MachineInstr *MI = P.first; |
131 | unsigned Bits = P.second; |
132 | |
133 | if (!Visited.insert(V: P).second) |
134 | continue; |
135 | |
136 | // Only handle instructions with one def. |
137 | if (MI->getNumExplicitDefs() != 1) |
138 | return false; |
139 | |
140 | Register DestReg = MI->getOperand(i: 0).getReg(); |
141 | if (!DestReg.isVirtual()) |
142 | return false; |
143 | |
144 | for (auto &UserOp : MRI.use_nodbg_operands(Reg: DestReg)) { |
145 | const MachineInstr *UserMI = UserOp.getParent(); |
146 | unsigned OpIdx = UserOp.getOperandNo(); |
147 | |
148 | switch (UserMI->getOpcode()) { |
149 | default: |
150 | if (vectorPseudoHasAllNBitUsers(UserOp, Bits)) |
151 | break; |
152 | return false; |
153 | |
154 | case RISCV::ADDIW: |
155 | case RISCV::ADDW: |
156 | case RISCV::DIVUW: |
157 | case RISCV::DIVW: |
158 | case RISCV::MULW: |
159 | case RISCV::REMUW: |
160 | case RISCV::REMW: |
161 | case RISCV::SLLIW: |
162 | case RISCV::SLLW: |
163 | case RISCV::SRAIW: |
164 | case RISCV::SRAW: |
165 | case RISCV::SRLIW: |
166 | case RISCV::SRLW: |
167 | case RISCV::SUBW: |
168 | case RISCV::ROLW: |
169 | case RISCV::RORW: |
170 | case RISCV::RORIW: |
171 | case RISCV::CLZW: |
172 | case RISCV::CTZW: |
173 | case RISCV::CPOPW: |
174 | case RISCV::SLLI_UW: |
175 | case RISCV::FMV_W_X: |
176 | case RISCV::FCVT_H_W: |
177 | case RISCV::FCVT_H_WU: |
178 | case RISCV::FCVT_S_W: |
179 | case RISCV::FCVT_S_WU: |
180 | case RISCV::FCVT_D_W: |
181 | case RISCV::FCVT_D_WU: |
182 | if (Bits >= 32) |
183 | break; |
184 | return false; |
185 | case RISCV::SEXT_B: |
186 | case RISCV::PACKH: |
187 | if (Bits >= 8) |
188 | break; |
189 | return false; |
190 | case RISCV::SEXT_H: |
191 | case RISCV::FMV_H_X: |
192 | case RISCV::ZEXT_H_RV32: |
193 | case RISCV::ZEXT_H_RV64: |
194 | case RISCV::PACKW: |
195 | if (Bits >= 16) |
196 | break; |
197 | return false; |
198 | |
199 | case RISCV::PACK: |
200 | if (Bits >= (ST.getXLen() / 2)) |
201 | break; |
202 | return false; |
203 | |
204 | case RISCV::SRLI: { |
205 | // If we are shifting right by less than Bits, and users don't demand |
206 | // any bits that were shifted into [Bits-1:0], then we can consider this |
207 | // as an N-Bit user. |
208 | unsigned ShAmt = UserMI->getOperand(i: 2).getImm(); |
209 | if (Bits > ShAmt) { |
210 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y: Bits - ShAmt)); |
211 | break; |
212 | } |
213 | return false; |
214 | } |
215 | |
216 | // these overwrite higher input bits, otherwise the lower word of output |
217 | // depends only on the lower word of input. So check their uses read W. |
218 | case RISCV::SLLI: |
219 | if (Bits >= (ST.getXLen() - UserMI->getOperand(i: 2).getImm())) |
220 | break; |
221 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
222 | break; |
223 | case RISCV::ANDI: { |
224 | uint64_t Imm = UserMI->getOperand(i: 2).getImm(); |
225 | if (Bits >= (unsigned)llvm::bit_width(Value: Imm)) |
226 | break; |
227 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
228 | break; |
229 | } |
230 | case RISCV::ORI: { |
231 | uint64_t Imm = UserMI->getOperand(i: 2).getImm(); |
232 | if (Bits >= (unsigned)llvm::bit_width<uint64_t>(Value: ~Imm)) |
233 | break; |
234 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
235 | break; |
236 | } |
237 | |
238 | case RISCV::SLL: |
239 | case RISCV::BSET: |
240 | case RISCV::BCLR: |
241 | case RISCV::BINV: |
242 | // Operand 2 is the shift amount which uses log2(xlen) bits. |
243 | if (OpIdx == 2) { |
244 | if (Bits >= Log2_32(Value: ST.getXLen())) |
245 | break; |
246 | return false; |
247 | } |
248 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
249 | break; |
250 | |
251 | case RISCV::SRA: |
252 | case RISCV::SRL: |
253 | case RISCV::ROL: |
254 | case RISCV::ROR: |
255 | // Operand 2 is the shift amount which uses 6 bits. |
256 | if (OpIdx == 2 && Bits >= Log2_32(Value: ST.getXLen())) |
257 | break; |
258 | return false; |
259 | |
260 | case RISCV::ADD_UW: |
261 | case RISCV::SH1ADD_UW: |
262 | case RISCV::SH2ADD_UW: |
263 | case RISCV::SH3ADD_UW: |
264 | // Operand 1 is implicitly zero extended. |
265 | if (OpIdx == 1 && Bits >= 32) |
266 | break; |
267 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
268 | break; |
269 | |
270 | case RISCV::BEXTI: |
271 | if (UserMI->getOperand(i: 2).getImm() >= Bits) |
272 | return false; |
273 | break; |
274 | |
275 | case RISCV::SB: |
276 | // The first argument is the value to store. |
277 | if (OpIdx == 0 && Bits >= 8) |
278 | break; |
279 | return false; |
280 | case RISCV::SH: |
281 | // The first argument is the value to store. |
282 | if (OpIdx == 0 && Bits >= 16) |
283 | break; |
284 | return false; |
285 | case RISCV::SW: |
286 | // The first argument is the value to store. |
287 | if (OpIdx == 0 && Bits >= 32) |
288 | break; |
289 | return false; |
290 | |
291 | // For these, lower word of output in these operations, depends only on |
292 | // the lower word of input. So, we check all uses only read lower word. |
293 | case RISCV::COPY: |
294 | case RISCV::PHI: |
295 | |
296 | case RISCV::ADD: |
297 | case RISCV::ADDI: |
298 | case RISCV::AND: |
299 | case RISCV::MUL: |
300 | case RISCV::OR: |
301 | case RISCV::SUB: |
302 | case RISCV::XOR: |
303 | case RISCV::XORI: |
304 | |
305 | case RISCV::ANDN: |
306 | case RISCV::BREV8: |
307 | case RISCV::CLMUL: |
308 | case RISCV::ORC_B: |
309 | case RISCV::ORN: |
310 | case RISCV::SH1ADD: |
311 | case RISCV::SH2ADD: |
312 | case RISCV::SH3ADD: |
313 | case RISCV::XNOR: |
314 | case RISCV::BSETI: |
315 | case RISCV::BCLRI: |
316 | case RISCV::BINVI: |
317 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
318 | break; |
319 | |
320 | case RISCV::PseudoCCMOVGPR: |
321 | // Either operand 4 or operand 5 is returned by this instruction. If |
322 | // only the lower word of the result is used, then only the lower word |
323 | // of operand 4 and 5 is used. |
324 | if (OpIdx != 4 && OpIdx != 5) |
325 | return false; |
326 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
327 | break; |
328 | |
329 | case RISCV::CZERO_EQZ: |
330 | case RISCV::CZERO_NEZ: |
331 | case RISCV::VT_MASKC: |
332 | case RISCV::VT_MASKCN: |
333 | if (OpIdx != 1) |
334 | return false; |
335 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
336 | break; |
337 | } |
338 | } |
339 | } |
340 | |
341 | return true; |
342 | } |
343 | |
344 | static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, |
345 | const MachineRegisterInfo &MRI) { |
346 | return hasAllNBitUsers(OrigMI, ST, MRI, OrigBits: 32); |
347 | } |
348 | |
349 | // This function returns true if the machine instruction always outputs a value |
350 | // where bits 63:32 match bit 31. |
351 | static bool isSignExtendingOpW(const MachineInstr &MI, |
352 | const MachineRegisterInfo &MRI, unsigned OpNo) { |
353 | uint64_t TSFlags = MI.getDesc().TSFlags; |
354 | |
355 | // Instructions that can be determined from opcode are marked in tablegen. |
356 | if (TSFlags & RISCVII::IsSignExtendingOpWMask) |
357 | return true; |
358 | |
359 | // Special cases that require checking operands. |
360 | switch (MI.getOpcode()) { |
361 | // shifting right sufficiently makes the value 32-bit sign-extended |
362 | case RISCV::SRAI: |
363 | return MI.getOperand(i: 2).getImm() >= 32; |
364 | case RISCV::SRLI: |
365 | return MI.getOperand(i: 2).getImm() > 32; |
366 | // The LI pattern ADDI rd, X0, imm is sign extended. |
367 | case RISCV::ADDI: |
368 | return MI.getOperand(i: 1).isReg() && MI.getOperand(i: 1).getReg() == RISCV::X0; |
369 | // An ANDI with an 11 bit immediate will zero bits 63:11. |
370 | case RISCV::ANDI: |
371 | return isUInt<11>(x: MI.getOperand(i: 2).getImm()); |
372 | // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. |
373 | case RISCV::ORI: |
374 | return !isUInt<11>(x: MI.getOperand(i: 2).getImm()); |
375 | // A bseti with X0 is sign extended if the immediate is less than 31. |
376 | case RISCV::BSETI: |
377 | return MI.getOperand(i: 2).getImm() < 31 && |
378 | MI.getOperand(i: 1).getReg() == RISCV::X0; |
379 | // Copying from X0 produces zero. |
380 | case RISCV::COPY: |
381 | return MI.getOperand(i: 1).getReg() == RISCV::X0; |
382 | // Ignore the scratch register destination. |
383 | case RISCV::PseudoAtomicLoadNand32: |
384 | return OpNo == 0; |
385 | case RISCV::PseudoVMV_X_S: { |
386 | // vmv.x.s has at least 33 sign bits if log2(sew) <= 5. |
387 | int64_t Log2SEW = MI.getOperand(i: 2).getImm(); |
388 | assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW" ); |
389 | return Log2SEW <= 5; |
390 | } |
391 | } |
392 | |
393 | return false; |
394 | } |
395 | |
396 | static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, |
397 | const MachineRegisterInfo &MRI, |
398 | SmallPtrSetImpl<MachineInstr *> &FixableDef) { |
399 | SmallSet<Register, 4> Visited; |
400 | SmallVector<Register, 4> Worklist; |
401 | |
402 | auto AddRegToWorkList = [&](Register SrcReg) { |
403 | if (!SrcReg.isVirtual()) |
404 | return false; |
405 | Worklist.push_back(Elt: SrcReg); |
406 | return true; |
407 | }; |
408 | |
409 | if (!AddRegToWorkList(SrcReg)) |
410 | return false; |
411 | |
412 | while (!Worklist.empty()) { |
413 | Register Reg = Worklist.pop_back_val(); |
414 | |
415 | // If we already visited this register, we don't need to check it again. |
416 | if (!Visited.insert(V: Reg).second) |
417 | continue; |
418 | |
419 | MachineInstr *MI = MRI.getVRegDef(Reg); |
420 | if (!MI) |
421 | continue; |
422 | |
423 | int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr); |
424 | assert(OpNo != -1 && "Couldn't find register" ); |
425 | |
426 | // If this is a sign extending operation we don't need to look any further. |
427 | if (isSignExtendingOpW(MI: *MI, MRI, OpNo)) |
428 | continue; |
429 | |
430 | // Is this an instruction that propagates sign extend? |
431 | switch (MI->getOpcode()) { |
432 | default: |
433 | // Unknown opcode, give up. |
434 | return false; |
435 | case RISCV::COPY: { |
436 | const MachineFunction *MF = MI->getMF(); |
437 | const RISCVMachineFunctionInfo *RVFI = |
438 | MF->getInfo<RISCVMachineFunctionInfo>(); |
439 | |
440 | // If this is the entry block and the register is livein, see if we know |
441 | // it is sign extended. |
442 | if (MI->getParent() == &MF->front()) { |
443 | Register VReg = MI->getOperand(i: 0).getReg(); |
444 | if (MF->getRegInfo().isLiveIn(Reg: VReg) && RVFI->isSExt32Register(Reg: VReg)) |
445 | continue; |
446 | } |
447 | |
448 | Register CopySrcReg = MI->getOperand(i: 1).getReg(); |
449 | if (CopySrcReg == RISCV::X10) { |
450 | // For a method return value, we check the ZExt/SExt flags in attribute. |
451 | // We assume the following code sequence for method call. |
452 | // PseudoCALL @bar, ... |
453 | // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2 |
454 | // %0:gpr = COPY $x10 |
455 | // |
456 | // We use the PseudoCall to look up the IR function being called to find |
457 | // its return attributes. |
458 | const MachineBasicBlock *MBB = MI->getParent(); |
459 | auto II = MI->getIterator(); |
460 | if (II == MBB->instr_begin() || |
461 | (--II)->getOpcode() != RISCV::ADJCALLSTACKUP) |
462 | return false; |
463 | |
464 | const MachineInstr &CallMI = *(--II); |
465 | if (!CallMI.isCall() || !CallMI.getOperand(i: 0).isGlobal()) |
466 | return false; |
467 | |
468 | auto *CalleeFn = |
469 | dyn_cast_if_present<Function>(Val: CallMI.getOperand(i: 0).getGlobal()); |
470 | if (!CalleeFn) |
471 | return false; |
472 | |
473 | auto *IntTy = dyn_cast<IntegerType>(Val: CalleeFn->getReturnType()); |
474 | if (!IntTy) |
475 | return false; |
476 | |
477 | const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs(); |
478 | unsigned BitWidth = IntTy->getBitWidth(); |
479 | if ((BitWidth <= 32 && Attrs.hasAttribute(Kind: Attribute::SExt)) || |
480 | (BitWidth < 32 && Attrs.hasAttribute(Kind: Attribute::ZExt))) |
481 | continue; |
482 | } |
483 | |
484 | if (!AddRegToWorkList(CopySrcReg)) |
485 | return false; |
486 | |
487 | break; |
488 | } |
489 | |
490 | // For these, we just need to check if the 1st operand is sign extended. |
491 | case RISCV::BCLRI: |
492 | case RISCV::BINVI: |
493 | case RISCV::BSETI: |
494 | if (MI->getOperand(i: 2).getImm() >= 31) |
495 | return false; |
496 | [[fallthrough]]; |
497 | case RISCV::REM: |
498 | case RISCV::ANDI: |
499 | case RISCV::ORI: |
500 | case RISCV::XORI: |
501 | // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R. |
502 | // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1 |
503 | // Logical operations use a sign extended 12-bit immediate. |
504 | if (!AddRegToWorkList(MI->getOperand(i: 1).getReg())) |
505 | return false; |
506 | |
507 | break; |
508 | case RISCV::PseudoCCADDW: |
509 | case RISCV::PseudoCCADDIW: |
510 | case RISCV::PseudoCCSUBW: |
511 | case RISCV::PseudoCCSLLW: |
512 | case RISCV::PseudoCCSRLW: |
513 | case RISCV::PseudoCCSRAW: |
514 | case RISCV::PseudoCCSLLIW: |
515 | case RISCV::PseudoCCSRLIW: |
516 | case RISCV::PseudoCCSRAIW: |
517 | // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only |
518 | // need to check if operand 4 is sign extended. |
519 | if (!AddRegToWorkList(MI->getOperand(i: 4).getReg())) |
520 | return false; |
521 | break; |
522 | case RISCV::REMU: |
523 | case RISCV::AND: |
524 | case RISCV::OR: |
525 | case RISCV::XOR: |
526 | case RISCV::ANDN: |
527 | case RISCV::ORN: |
528 | case RISCV::XNOR: |
529 | case RISCV::MAX: |
530 | case RISCV::MAXU: |
531 | case RISCV::MIN: |
532 | case RISCV::MINU: |
533 | case RISCV::PseudoCCMOVGPR: |
534 | case RISCV::PseudoCCAND: |
535 | case RISCV::PseudoCCOR: |
536 | case RISCV::PseudoCCXOR: |
537 | case RISCV::PHI: { |
538 | // If all incoming values are sign-extended, the output of AND, OR, XOR, |
539 | // MIN, MAX, or PHI is also sign-extended. |
540 | |
541 | // The input registers for PHI are operand 1, 3, ... |
542 | // The input registers for PseudoCCMOVGPR are 4 and 5. |
543 | // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6. |
544 | // The input registers for others are operand 1 and 2. |
545 | unsigned B = 1, E = 3, D = 1; |
546 | switch (MI->getOpcode()) { |
547 | case RISCV::PHI: |
548 | E = MI->getNumOperands(); |
549 | D = 2; |
550 | break; |
551 | case RISCV::PseudoCCMOVGPR: |
552 | B = 4; |
553 | E = 6; |
554 | break; |
555 | case RISCV::PseudoCCAND: |
556 | case RISCV::PseudoCCOR: |
557 | case RISCV::PseudoCCXOR: |
558 | B = 4; |
559 | E = 7; |
560 | break; |
561 | } |
562 | |
563 | for (unsigned I = B; I != E; I += D) { |
564 | if (!MI->getOperand(i: I).isReg()) |
565 | return false; |
566 | |
567 | if (!AddRegToWorkList(MI->getOperand(i: I).getReg())) |
568 | return false; |
569 | } |
570 | |
571 | break; |
572 | } |
573 | |
574 | case RISCV::CZERO_EQZ: |
575 | case RISCV::CZERO_NEZ: |
576 | case RISCV::VT_MASKC: |
577 | case RISCV::VT_MASKCN: |
578 | // Instructions return zero or operand 1. Result is sign extended if |
579 | // operand 1 is sign extended. |
580 | if (!AddRegToWorkList(MI->getOperand(i: 1).getReg())) |
581 | return false; |
582 | break; |
583 | |
584 | // With these opcode, we can "fix" them with the W-version |
585 | // if we know all users of the result only rely on bits 31:0 |
586 | case RISCV::SLLI: |
587 | // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits |
588 | if (MI->getOperand(i: 2).getImm() >= 32) |
589 | return false; |
590 | [[fallthrough]]; |
591 | case RISCV::ADDI: |
592 | case RISCV::ADD: |
593 | case RISCV::LD: |
594 | case RISCV::LWU: |
595 | case RISCV::MUL: |
596 | case RISCV::SUB: |
597 | if (hasAllWUsers(OrigMI: *MI, ST, MRI)) { |
598 | FixableDef.insert(Ptr: MI); |
599 | break; |
600 | } |
601 | return false; |
602 | } |
603 | } |
604 | |
605 | // If we get here, then every node we visited produces a sign extended value |
606 | // or propagated sign extended values. So the result must be sign extended. |
607 | return true; |
608 | } |
609 | |
610 | static unsigned getWOp(unsigned Opcode) { |
611 | switch (Opcode) { |
612 | case RISCV::ADDI: |
613 | return RISCV::ADDIW; |
614 | case RISCV::ADD: |
615 | return RISCV::ADDW; |
616 | case RISCV::LD: |
617 | case RISCV::LWU: |
618 | return RISCV::LW; |
619 | case RISCV::MUL: |
620 | return RISCV::MULW; |
621 | case RISCV::SLLI: |
622 | return RISCV::SLLIW; |
623 | case RISCV::SUB: |
624 | return RISCV::SUBW; |
625 | default: |
626 | llvm_unreachable("Unexpected opcode for replacement with W variant" ); |
627 | } |
628 | } |
629 | |
630 | bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, |
631 | const RISCVInstrInfo &TII, |
632 | const RISCVSubtarget &ST, |
633 | MachineRegisterInfo &MRI) { |
634 | if (DisableSExtWRemoval) |
635 | return false; |
636 | |
637 | bool MadeChange = false; |
638 | for (MachineBasicBlock &MBB : MF) { |
639 | for (MachineInstr &MI : llvm::make_early_inc_range(Range&: MBB)) { |
640 | // We're looking for the sext.w pattern ADDIW rd, rs1, 0. |
641 | if (!RISCV::isSEXT_W(MI)) |
642 | continue; |
643 | |
644 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
645 | |
646 | SmallPtrSet<MachineInstr *, 4> FixableDefs; |
647 | |
648 | // If all users only use the lower bits, this sext.w is redundant. |
649 | // Or if all definitions reaching MI sign-extend their output, |
650 | // then sext.w is redundant. |
651 | if (!hasAllWUsers(OrigMI: MI, ST, MRI) && |
652 | !isSignExtendedW(SrcReg, ST, MRI, FixableDef&: FixableDefs)) |
653 | continue; |
654 | |
655 | Register DstReg = MI.getOperand(i: 0).getReg(); |
656 | if (!MRI.constrainRegClass(Reg: SrcReg, RC: MRI.getRegClass(Reg: DstReg))) |
657 | continue; |
658 | |
659 | // Convert Fixable instructions to their W versions. |
660 | for (MachineInstr *Fixable : FixableDefs) { |
661 | LLVM_DEBUG(dbgs() << "Replacing " << *Fixable); |
662 | Fixable->setDesc(TII.get(Opcode: getWOp(Opcode: Fixable->getOpcode()))); |
663 | Fixable->clearFlag(Flag: MachineInstr::MIFlag::NoSWrap); |
664 | Fixable->clearFlag(Flag: MachineInstr::MIFlag::NoUWrap); |
665 | Fixable->clearFlag(Flag: MachineInstr::MIFlag::IsExact); |
666 | LLVM_DEBUG(dbgs() << " with " << *Fixable); |
667 | ++NumTransformedToWInstrs; |
668 | } |
669 | |
670 | LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n" ); |
671 | MRI.replaceRegWith(FromReg: DstReg, ToReg: SrcReg); |
672 | MRI.clearKillFlags(Reg: SrcReg); |
673 | MI.eraseFromParent(); |
674 | ++NumRemovedSExtW; |
675 | MadeChange = true; |
676 | } |
677 | } |
678 | |
679 | return MadeChange; |
680 | } |
681 | |
682 | bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF, |
683 | const RISCVInstrInfo &TII, |
684 | const RISCVSubtarget &ST, |
685 | MachineRegisterInfo &MRI) { |
686 | bool MadeChange = false; |
687 | for (MachineBasicBlock &MBB : MF) { |
688 | for (MachineInstr &MI : MBB) { |
689 | unsigned Opc; |
690 | switch (MI.getOpcode()) { |
691 | default: |
692 | continue; |
693 | case RISCV::ADDW: Opc = RISCV::ADD; break; |
694 | case RISCV::ADDIW: Opc = RISCV::ADDI; break; |
695 | case RISCV::MULW: Opc = RISCV::MUL; break; |
696 | case RISCV::SLLIW: Opc = RISCV::SLLI; break; |
697 | } |
698 | |
699 | if (hasAllWUsers(OrigMI: MI, ST, MRI)) { |
700 | MI.setDesc(TII.get(Opcode: Opc)); |
701 | MadeChange = true; |
702 | } |
703 | } |
704 | } |
705 | |
706 | return MadeChange; |
707 | } |
708 | |
709 | bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF, |
710 | const RISCVInstrInfo &TII, |
711 | const RISCVSubtarget &ST, |
712 | MachineRegisterInfo &MRI) { |
713 | bool MadeChange = false; |
714 | for (MachineBasicBlock &MBB : MF) { |
715 | for (MachineInstr &MI : MBB) { |
716 | unsigned WOpc; |
717 | // TODO: Add more? |
718 | switch (MI.getOpcode()) { |
719 | default: |
720 | continue; |
721 | case RISCV::ADD: |
722 | WOpc = RISCV::ADDW; |
723 | break; |
724 | case RISCV::ADDI: |
725 | WOpc = RISCV::ADDIW; |
726 | break; |
727 | case RISCV::SUB: |
728 | WOpc = RISCV::SUBW; |
729 | break; |
730 | case RISCV::MUL: |
731 | WOpc = RISCV::MULW; |
732 | break; |
733 | case RISCV::SLLI: |
734 | // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits |
735 | if (MI.getOperand(i: 2).getImm() >= 32) |
736 | continue; |
737 | WOpc = RISCV::SLLIW; |
738 | break; |
739 | case RISCV::LD: |
740 | case RISCV::LWU: |
741 | WOpc = RISCV::LW; |
742 | break; |
743 | } |
744 | |
745 | if (hasAllWUsers(OrigMI: MI, ST, MRI)) { |
746 | LLVM_DEBUG(dbgs() << "Replacing " << MI); |
747 | MI.setDesc(TII.get(Opcode: WOpc)); |
748 | MI.clearFlag(Flag: MachineInstr::MIFlag::NoSWrap); |
749 | MI.clearFlag(Flag: MachineInstr::MIFlag::NoUWrap); |
750 | MI.clearFlag(Flag: MachineInstr::MIFlag::IsExact); |
751 | LLVM_DEBUG(dbgs() << " with " << MI); |
752 | ++NumTransformedToWInstrs; |
753 | MadeChange = true; |
754 | } |
755 | } |
756 | } |
757 | |
758 | return MadeChange; |
759 | } |
760 | |
761 | bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) { |
762 | if (skipFunction(F: MF.getFunction())) |
763 | return false; |
764 | |
765 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
766 | const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); |
767 | const RISCVInstrInfo &TII = *ST.getInstrInfo(); |
768 | |
769 | if (!ST.is64Bit()) |
770 | return false; |
771 | |
772 | bool MadeChange = false; |
773 | MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI); |
774 | |
775 | if (!(DisableStripWSuffix || ST.preferWInst())) |
776 | MadeChange |= stripWSuffixes(MF, TII, ST, MRI); |
777 | |
778 | if (ST.preferWInst()) |
779 | MadeChange |= appendWSuffixes(MF, TII, ST, MRI); |
780 | |
781 | return MadeChange; |
782 | } |
783 | |