1 | //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===---------------------------------------------------------------------===// |
8 | // |
9 | // This pass does some optimizations for *W instructions at the MI level. |
10 | // |
11 | // First it removes unneeded sext.w instructions. Either because the sign |
12 | // extended bits aren't consumed or because the input was already sign extended |
13 | // by an earlier instruction. |
14 | // |
15 | // Then: |
16 | // 1. Unless explicit disabled or the target prefers instructions with W suffix, |
17 | // it removes the -w suffix from opw instructions whenever all users are |
18 | // dependent only on the lower word of the result of the instruction. |
19 | // The cases handled are: |
20 | // * addw because c.add has a larger register encoding than c.addw. |
21 | // * addiw because it helps reduce test differences between RV32 and RV64 |
22 | // w/o being a pessimization. |
23 | // * mulw because c.mulw doesn't exist but c.mul does (w/ zcb) |
24 | // * slliw because c.slliw doesn't exist and c.slli does |
25 | // |
26 | // 2. Or if explicit enabled or the target prefers instructions with W suffix, |
27 | // it adds the W suffix to the instruction whenever all users are dependent |
28 | // only on the lower word of the result of the instruction. |
29 | // The cases handled are: |
30 | // * add/addi/sub/mul. |
31 | // * slli with imm < 32. |
32 | // * ld/lwu. |
33 | //===---------------------------------------------------------------------===// |
34 | |
35 | #include "RISCV.h" |
36 | #include "RISCVMachineFunctionInfo.h" |
37 | #include "RISCVSubtarget.h" |
38 | #include "llvm/ADT/SmallSet.h" |
39 | #include "llvm/ADT/Statistic.h" |
40 | #include "llvm/CodeGen/MachineFunctionPass.h" |
41 | #include "llvm/CodeGen/TargetInstrInfo.h" |
42 | |
43 | using namespace llvm; |
44 | |
45 | #define DEBUG_TYPE "riscv-opt-w-instrs" |
46 | #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions" |
47 | |
48 | STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions" ); |
49 | STATISTIC(NumTransformedToWInstrs, |
50 | "Number of instructions transformed to W-ops" ); |
51 | |
52 | static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal" , |
53 | cl::desc("Disable removal of sext.w" ), |
54 | cl::init(Val: false), cl::Hidden); |
55 | static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix" , |
56 | cl::desc("Disable strip W suffix" ), |
57 | cl::init(Val: false), cl::Hidden); |
58 | |
59 | namespace { |
60 | |
61 | class RISCVOptWInstrs : public MachineFunctionPass { |
62 | public: |
63 | static char ID; |
64 | |
65 | RISCVOptWInstrs() : MachineFunctionPass(ID) {} |
66 | |
67 | bool runOnMachineFunction(MachineFunction &MF) override; |
68 | bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, |
69 | const RISCVSubtarget &ST, MachineRegisterInfo &MRI); |
70 | bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, |
71 | const RISCVSubtarget &ST, MachineRegisterInfo &MRI); |
72 | bool appendWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, |
73 | const RISCVSubtarget &ST, MachineRegisterInfo &MRI); |
74 | |
75 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
76 | AU.setPreservesCFG(); |
77 | MachineFunctionPass::getAnalysisUsage(AU); |
78 | } |
79 | |
80 | StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; } |
81 | }; |
82 | |
83 | } // end anonymous namespace |
84 | |
85 | char RISCVOptWInstrs::ID = 0; |
86 | INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false, |
87 | false) |
88 | |
89 | FunctionPass *llvm::createRISCVOptWInstrsPass() { |
90 | return new RISCVOptWInstrs(); |
91 | } |
92 | |
93 | static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp, |
94 | unsigned Bits) { |
95 | const MachineInstr &MI = *UserOp.getParent(); |
96 | unsigned MCOpcode = RISCV::getRVVMCOpcode(RVVPseudoOpcode: MI.getOpcode()); |
97 | |
98 | if (!MCOpcode) |
99 | return false; |
100 | |
101 | const MCInstrDesc &MCID = MI.getDesc(); |
102 | const uint64_t TSFlags = MCID.TSFlags; |
103 | if (!RISCVII::hasSEWOp(TSFlags)) |
104 | return false; |
105 | assert(RISCVII::hasVLOp(TSFlags)); |
106 | const unsigned Log2SEW = MI.getOperand(i: RISCVII::getSEWOpNum(Desc: MCID)).getImm(); |
107 | |
108 | if (UserOp.getOperandNo() == RISCVII::getVLOpNum(Desc: MCID)) |
109 | return false; |
110 | |
111 | auto NumDemandedBits = |
112 | RISCV::getVectorLowDemandedScalarBits(Opcode: MCOpcode, Log2SEW); |
113 | return NumDemandedBits && Bits >= *NumDemandedBits; |
114 | } |
115 | |
116 | // Checks if all users only demand the lower \p OrigBits of the original |
117 | // instruction's result. |
118 | // TODO: handle multiple interdependent transformations |
119 | static bool hasAllNBitUsers(const MachineInstr &OrigMI, |
120 | const RISCVSubtarget &ST, |
121 | const MachineRegisterInfo &MRI, unsigned OrigBits) { |
122 | |
123 | SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited; |
124 | SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist; |
125 | |
126 | Worklist.push_back(Elt: std::make_pair(x: &OrigMI, y&: OrigBits)); |
127 | |
128 | while (!Worklist.empty()) { |
129 | auto P = Worklist.pop_back_val(); |
130 | const MachineInstr *MI = P.first; |
131 | unsigned Bits = P.second; |
132 | |
133 | if (!Visited.insert(V: P).second) |
134 | continue; |
135 | |
136 | // Only handle instructions with one def. |
137 | if (MI->getNumExplicitDefs() != 1) |
138 | return false; |
139 | |
140 | Register DestReg = MI->getOperand(i: 0).getReg(); |
141 | if (!DestReg.isVirtual()) |
142 | return false; |
143 | |
144 | for (auto &UserOp : MRI.use_nodbg_operands(Reg: DestReg)) { |
145 | const MachineInstr *UserMI = UserOp.getParent(); |
146 | unsigned OpIdx = UserOp.getOperandNo(); |
147 | |
148 | switch (UserMI->getOpcode()) { |
149 | default: |
150 | if (vectorPseudoHasAllNBitUsers(UserOp, Bits)) |
151 | break; |
152 | return false; |
153 | |
154 | case RISCV::ADDIW: |
155 | case RISCV::ADDW: |
156 | case RISCV::DIVUW: |
157 | case RISCV::DIVW: |
158 | case RISCV::MULW: |
159 | case RISCV::REMUW: |
160 | case RISCV::REMW: |
161 | case RISCV::SLLIW: |
162 | case RISCV::SLLW: |
163 | case RISCV::SRAIW: |
164 | case RISCV::SRAW: |
165 | case RISCV::SRLIW: |
166 | case RISCV::SRLW: |
167 | case RISCV::SUBW: |
168 | case RISCV::ROLW: |
169 | case RISCV::RORW: |
170 | case RISCV::RORIW: |
171 | case RISCV::CLZW: |
172 | case RISCV::CTZW: |
173 | case RISCV::CPOPW: |
174 | case RISCV::SLLI_UW: |
175 | case RISCV::FMV_W_X: |
176 | case RISCV::FCVT_H_W: |
177 | case RISCV::FCVT_H_W_INX: |
178 | case RISCV::FCVT_H_WU: |
179 | case RISCV::FCVT_H_WU_INX: |
180 | case RISCV::FCVT_S_W: |
181 | case RISCV::FCVT_S_W_INX: |
182 | case RISCV::FCVT_S_WU: |
183 | case RISCV::FCVT_S_WU_INX: |
184 | case RISCV::FCVT_D_W: |
185 | case RISCV::FCVT_D_W_INX: |
186 | case RISCV::FCVT_D_WU: |
187 | case RISCV::FCVT_D_WU_INX: |
188 | if (Bits >= 32) |
189 | break; |
190 | return false; |
191 | case RISCV::SEXT_B: |
192 | case RISCV::PACKH: |
193 | if (Bits >= 8) |
194 | break; |
195 | return false; |
196 | case RISCV::SEXT_H: |
197 | case RISCV::FMV_H_X: |
198 | case RISCV::ZEXT_H_RV32: |
199 | case RISCV::ZEXT_H_RV64: |
200 | case RISCV::PACKW: |
201 | if (Bits >= 16) |
202 | break; |
203 | return false; |
204 | |
205 | case RISCV::PACK: |
206 | if (Bits >= (ST.getXLen() / 2)) |
207 | break; |
208 | return false; |
209 | |
210 | case RISCV::SRLI: { |
211 | // If we are shifting right by less than Bits, and users don't demand |
212 | // any bits that were shifted into [Bits-1:0], then we can consider this |
213 | // as an N-Bit user. |
214 | unsigned ShAmt = UserMI->getOperand(i: 2).getImm(); |
215 | if (Bits > ShAmt) { |
216 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y: Bits - ShAmt)); |
217 | break; |
218 | } |
219 | return false; |
220 | } |
221 | |
222 | // these overwrite higher input bits, otherwise the lower word of output |
223 | // depends only on the lower word of input. So check their uses read W. |
224 | case RISCV::SLLI: { |
225 | unsigned ShAmt = UserMI->getOperand(i: 2).getImm(); |
226 | if (Bits >= (ST.getXLen() - ShAmt)) |
227 | break; |
228 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y: Bits + ShAmt)); |
229 | break; |
230 | } |
231 | case RISCV::ANDI: { |
232 | uint64_t Imm = UserMI->getOperand(i: 2).getImm(); |
233 | if (Bits >= (unsigned)llvm::bit_width(Value: Imm)) |
234 | break; |
235 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
236 | break; |
237 | } |
238 | case RISCV::ORI: { |
239 | uint64_t Imm = UserMI->getOperand(i: 2).getImm(); |
240 | if (Bits >= (unsigned)llvm::bit_width<uint64_t>(Value: ~Imm)) |
241 | break; |
242 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
243 | break; |
244 | } |
245 | |
246 | case RISCV::SLL: |
247 | case RISCV::BSET: |
248 | case RISCV::BCLR: |
249 | case RISCV::BINV: |
250 | // Operand 2 is the shift amount which uses log2(xlen) bits. |
251 | if (OpIdx == 2) { |
252 | if (Bits >= Log2_32(Value: ST.getXLen())) |
253 | break; |
254 | return false; |
255 | } |
256 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
257 | break; |
258 | |
259 | case RISCV::SRA: |
260 | case RISCV::SRL: |
261 | case RISCV::ROL: |
262 | case RISCV::ROR: |
263 | // Operand 2 is the shift amount which uses 6 bits. |
264 | if (OpIdx == 2 && Bits >= Log2_32(Value: ST.getXLen())) |
265 | break; |
266 | return false; |
267 | |
268 | case RISCV::ADD_UW: |
269 | case RISCV::SH1ADD_UW: |
270 | case RISCV::SH2ADD_UW: |
271 | case RISCV::SH3ADD_UW: |
272 | // Operand 1 is implicitly zero extended. |
273 | if (OpIdx == 1 && Bits >= 32) |
274 | break; |
275 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
276 | break; |
277 | |
278 | case RISCV::BEXTI: |
279 | if (UserMI->getOperand(i: 2).getImm() >= Bits) |
280 | return false; |
281 | break; |
282 | |
283 | case RISCV::SB: |
284 | // The first argument is the value to store. |
285 | if (OpIdx == 0 && Bits >= 8) |
286 | break; |
287 | return false; |
288 | case RISCV::SH: |
289 | // The first argument is the value to store. |
290 | if (OpIdx == 0 && Bits >= 16) |
291 | break; |
292 | return false; |
293 | case RISCV::SW: |
294 | // The first argument is the value to store. |
295 | if (OpIdx == 0 && Bits >= 32) |
296 | break; |
297 | return false; |
298 | |
299 | // For these, lower word of output in these operations, depends only on |
300 | // the lower word of input. So, we check all uses only read lower word. |
301 | case RISCV::COPY: |
302 | case RISCV::PHI: |
303 | |
304 | case RISCV::ADD: |
305 | case RISCV::ADDI: |
306 | case RISCV::AND: |
307 | case RISCV::MUL: |
308 | case RISCV::OR: |
309 | case RISCV::SUB: |
310 | case RISCV::XOR: |
311 | case RISCV::XORI: |
312 | |
313 | case RISCV::ANDN: |
314 | case RISCV::BREV8: |
315 | case RISCV::CLMUL: |
316 | case RISCV::ORC_B: |
317 | case RISCV::ORN: |
318 | case RISCV::SH1ADD: |
319 | case RISCV::SH2ADD: |
320 | case RISCV::SH3ADD: |
321 | case RISCV::XNOR: |
322 | case RISCV::BSETI: |
323 | case RISCV::BCLRI: |
324 | case RISCV::BINVI: |
325 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
326 | break; |
327 | |
328 | case RISCV::PseudoCCMOVGPR: |
329 | case RISCV::PseudoCCMOVGPRNoX0: |
330 | // Either operand 4 or operand 5 is returned by this instruction. If |
331 | // only the lower word of the result is used, then only the lower word |
332 | // of operand 4 and 5 is used. |
333 | if (OpIdx != 4 && OpIdx != 5) |
334 | return false; |
335 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
336 | break; |
337 | |
338 | case RISCV::CZERO_EQZ: |
339 | case RISCV::CZERO_NEZ: |
340 | case RISCV::VT_MASKC: |
341 | case RISCV::VT_MASKCN: |
342 | if (OpIdx != 1) |
343 | return false; |
344 | Worklist.push_back(Elt: std::make_pair(x&: UserMI, y&: Bits)); |
345 | break; |
346 | } |
347 | } |
348 | } |
349 | |
350 | return true; |
351 | } |
352 | |
353 | static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, |
354 | const MachineRegisterInfo &MRI) { |
355 | return hasAllNBitUsers(OrigMI, ST, MRI, OrigBits: 32); |
356 | } |
357 | |
358 | // This function returns true if the machine instruction always outputs a value |
359 | // where bits 63:32 match bit 31. |
360 | static bool isSignExtendingOpW(const MachineInstr &MI, unsigned OpNo) { |
361 | uint64_t TSFlags = MI.getDesc().TSFlags; |
362 | |
363 | // Instructions that can be determined from opcode are marked in tablegen. |
364 | if (TSFlags & RISCVII::IsSignExtendingOpWMask) |
365 | return true; |
366 | |
367 | // Special cases that require checking operands. |
368 | switch (MI.getOpcode()) { |
369 | // shifting right sufficiently makes the value 32-bit sign-extended |
370 | case RISCV::SRAI: |
371 | return MI.getOperand(i: 2).getImm() >= 32; |
372 | case RISCV::SRLI: |
373 | return MI.getOperand(i: 2).getImm() > 32; |
374 | // The LI pattern ADDI rd, X0, imm is sign extended. |
375 | case RISCV::ADDI: |
376 | return MI.getOperand(i: 1).isReg() && MI.getOperand(i: 1).getReg() == RISCV::X0; |
377 | // An ANDI with an 11 bit immediate will zero bits 63:11. |
378 | case RISCV::ANDI: |
379 | return isUInt<11>(x: MI.getOperand(i: 2).getImm()); |
380 | // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. |
381 | case RISCV::ORI: |
382 | return !isUInt<11>(x: MI.getOperand(i: 2).getImm()); |
383 | // A bseti with X0 is sign extended if the immediate is less than 31. |
384 | case RISCV::BSETI: |
385 | return MI.getOperand(i: 2).getImm() < 31 && |
386 | MI.getOperand(i: 1).getReg() == RISCV::X0; |
387 | // Copying from X0 produces zero. |
388 | case RISCV::COPY: |
389 | return MI.getOperand(i: 1).getReg() == RISCV::X0; |
390 | // Ignore the scratch register destination. |
391 | case RISCV::PseudoAtomicLoadNand32: |
392 | return OpNo == 0; |
393 | case RISCV::PseudoVMV_X_S: { |
394 | // vmv.x.s has at least 33 sign bits if log2(sew) <= 5. |
395 | int64_t Log2SEW = MI.getOperand(i: 2).getImm(); |
396 | assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW" ); |
397 | return Log2SEW <= 5; |
398 | } |
399 | } |
400 | |
401 | return false; |
402 | } |
403 | |
404 | static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, |
405 | const MachineRegisterInfo &MRI, |
406 | SmallPtrSetImpl<MachineInstr *> &FixableDef) { |
407 | SmallSet<Register, 4> Visited; |
408 | SmallVector<Register, 4> Worklist; |
409 | |
410 | auto AddRegToWorkList = [&](Register SrcReg) { |
411 | if (!SrcReg.isVirtual()) |
412 | return false; |
413 | Worklist.push_back(Elt: SrcReg); |
414 | return true; |
415 | }; |
416 | |
417 | if (!AddRegToWorkList(SrcReg)) |
418 | return false; |
419 | |
420 | while (!Worklist.empty()) { |
421 | Register Reg = Worklist.pop_back_val(); |
422 | |
423 | // If we already visited this register, we don't need to check it again. |
424 | if (!Visited.insert(V: Reg).second) |
425 | continue; |
426 | |
427 | MachineInstr *MI = MRI.getVRegDef(Reg); |
428 | if (!MI) |
429 | continue; |
430 | |
431 | int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr); |
432 | assert(OpNo != -1 && "Couldn't find register" ); |
433 | |
434 | // If this is a sign extending operation we don't need to look any further. |
435 | if (isSignExtendingOpW(MI: *MI, OpNo)) |
436 | continue; |
437 | |
438 | // Is this an instruction that propagates sign extend? |
439 | switch (MI->getOpcode()) { |
440 | default: |
441 | // Unknown opcode, give up. |
442 | return false; |
443 | case RISCV::COPY: { |
444 | const MachineFunction *MF = MI->getMF(); |
445 | const RISCVMachineFunctionInfo *RVFI = |
446 | MF->getInfo<RISCVMachineFunctionInfo>(); |
447 | |
448 | // If this is the entry block and the register is livein, see if we know |
449 | // it is sign extended. |
450 | if (MI->getParent() == &MF->front()) { |
451 | Register VReg = MI->getOperand(i: 0).getReg(); |
452 | if (MF->getRegInfo().isLiveIn(Reg: VReg) && RVFI->isSExt32Register(Reg: VReg)) |
453 | continue; |
454 | } |
455 | |
456 | Register CopySrcReg = MI->getOperand(i: 1).getReg(); |
457 | if (CopySrcReg == RISCV::X10) { |
458 | // For a method return value, we check the ZExt/SExt flags in attribute. |
459 | // We assume the following code sequence for method call. |
460 | // PseudoCALL @bar, ... |
461 | // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2 |
462 | // %0:gpr = COPY $x10 |
463 | // |
464 | // We use the PseudoCall to look up the IR function being called to find |
465 | // its return attributes. |
466 | const MachineBasicBlock *MBB = MI->getParent(); |
467 | auto II = MI->getIterator(); |
468 | if (II == MBB->instr_begin() || |
469 | (--II)->getOpcode() != RISCV::ADJCALLSTACKUP) |
470 | return false; |
471 | |
472 | const MachineInstr &CallMI = *(--II); |
473 | if (!CallMI.isCall() || !CallMI.getOperand(i: 0).isGlobal()) |
474 | return false; |
475 | |
476 | auto *CalleeFn = |
477 | dyn_cast_if_present<Function>(Val: CallMI.getOperand(i: 0).getGlobal()); |
478 | if (!CalleeFn) |
479 | return false; |
480 | |
481 | auto *IntTy = dyn_cast<IntegerType>(Val: CalleeFn->getReturnType()); |
482 | if (!IntTy) |
483 | return false; |
484 | |
485 | const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs(); |
486 | unsigned BitWidth = IntTy->getBitWidth(); |
487 | if ((BitWidth <= 32 && Attrs.hasAttribute(Kind: Attribute::SExt)) || |
488 | (BitWidth < 32 && Attrs.hasAttribute(Kind: Attribute::ZExt))) |
489 | continue; |
490 | } |
491 | |
492 | if (!AddRegToWorkList(CopySrcReg)) |
493 | return false; |
494 | |
495 | break; |
496 | } |
497 | |
498 | // For these, we just need to check if the 1st operand is sign extended. |
499 | case RISCV::BCLRI: |
500 | case RISCV::BINVI: |
501 | case RISCV::BSETI: |
502 | if (MI->getOperand(i: 2).getImm() >= 31) |
503 | return false; |
504 | [[fallthrough]]; |
505 | case RISCV::REM: |
506 | case RISCV::ANDI: |
507 | case RISCV::ORI: |
508 | case RISCV::XORI: |
509 | // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R. |
510 | // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1 |
511 | // Logical operations use a sign extended 12-bit immediate. |
512 | if (!AddRegToWorkList(MI->getOperand(i: 1).getReg())) |
513 | return false; |
514 | |
515 | break; |
516 | case RISCV::PseudoCCADDW: |
517 | case RISCV::PseudoCCADDIW: |
518 | case RISCV::PseudoCCSUBW: |
519 | case RISCV::PseudoCCSLLW: |
520 | case RISCV::PseudoCCSRLW: |
521 | case RISCV::PseudoCCSRAW: |
522 | case RISCV::PseudoCCSLLIW: |
523 | case RISCV::PseudoCCSRLIW: |
524 | case RISCV::PseudoCCSRAIW: |
525 | // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only |
526 | // need to check if operand 4 is sign extended. |
527 | if (!AddRegToWorkList(MI->getOperand(i: 4).getReg())) |
528 | return false; |
529 | break; |
530 | case RISCV::REMU: |
531 | case RISCV::AND: |
532 | case RISCV::OR: |
533 | case RISCV::XOR: |
534 | case RISCV::ANDN: |
535 | case RISCV::ORN: |
536 | case RISCV::XNOR: |
537 | case RISCV::MAX: |
538 | case RISCV::MAXU: |
539 | case RISCV::MIN: |
540 | case RISCV::MINU: |
541 | case RISCV::PseudoCCMOVGPR: |
542 | case RISCV::PseudoCCMOVGPRNoX0: |
543 | case RISCV::PseudoCCAND: |
544 | case RISCV::PseudoCCOR: |
545 | case RISCV::PseudoCCXOR: |
546 | case RISCV::PHI: { |
547 | // If all incoming values are sign-extended, the output of AND, OR, XOR, |
548 | // MIN, MAX, or PHI is also sign-extended. |
549 | |
550 | // The input registers for PHI are operand 1, 3, ... |
551 | // The input registers for PseudoCCMOVGPR(NoX0) are 4 and 5. |
552 | // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6. |
553 | // The input registers for others are operand 1 and 2. |
554 | unsigned B = 1, E = 3, D = 1; |
555 | switch (MI->getOpcode()) { |
556 | case RISCV::PHI: |
557 | E = MI->getNumOperands(); |
558 | D = 2; |
559 | break; |
560 | case RISCV::PseudoCCMOVGPR: |
561 | case RISCV::PseudoCCMOVGPRNoX0: |
562 | B = 4; |
563 | E = 6; |
564 | break; |
565 | case RISCV::PseudoCCAND: |
566 | case RISCV::PseudoCCOR: |
567 | case RISCV::PseudoCCXOR: |
568 | B = 4; |
569 | E = 7; |
570 | break; |
571 | } |
572 | |
573 | for (unsigned I = B; I != E; I += D) { |
574 | if (!MI->getOperand(i: I).isReg()) |
575 | return false; |
576 | |
577 | if (!AddRegToWorkList(MI->getOperand(i: I).getReg())) |
578 | return false; |
579 | } |
580 | |
581 | break; |
582 | } |
583 | |
584 | case RISCV::CZERO_EQZ: |
585 | case RISCV::CZERO_NEZ: |
586 | case RISCV::VT_MASKC: |
587 | case RISCV::VT_MASKCN: |
588 | // Instructions return zero or operand 1. Result is sign extended if |
589 | // operand 1 is sign extended. |
590 | if (!AddRegToWorkList(MI->getOperand(i: 1).getReg())) |
591 | return false; |
592 | break; |
593 | |
594 | case RISCV::ADDI: { |
595 | if (MI->getOperand(i: 1).isReg() && MI->getOperand(i: 1).getReg().isVirtual()) { |
596 | if (MachineInstr *SrcMI = MRI.getVRegDef(Reg: MI->getOperand(i: 1).getReg())) { |
597 | if (SrcMI->getOpcode() == RISCV::LUI && |
598 | SrcMI->getOperand(i: 1).isImm()) { |
599 | uint64_t Imm = SrcMI->getOperand(i: 1).getImm(); |
600 | Imm = SignExtend64<32>(x: Imm << 12); |
601 | Imm += (uint64_t)MI->getOperand(i: 2).getImm(); |
602 | if (isInt<32>(x: Imm)) |
603 | continue; |
604 | } |
605 | } |
606 | } |
607 | |
608 | if (hasAllWUsers(OrigMI: *MI, ST, MRI)) { |
609 | FixableDef.insert(Ptr: MI); |
610 | break; |
611 | } |
612 | return false; |
613 | } |
614 | |
615 | // With these opcode, we can "fix" them with the W-version |
616 | // if we know all users of the result only rely on bits 31:0 |
617 | case RISCV::SLLI: |
618 | // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits |
619 | if (MI->getOperand(i: 2).getImm() >= 32) |
620 | return false; |
621 | [[fallthrough]]; |
622 | case RISCV::ADD: |
623 | case RISCV::LD: |
624 | case RISCV::LWU: |
625 | case RISCV::MUL: |
626 | case RISCV::SUB: |
627 | if (hasAllWUsers(OrigMI: *MI, ST, MRI)) { |
628 | FixableDef.insert(Ptr: MI); |
629 | break; |
630 | } |
631 | return false; |
632 | } |
633 | } |
634 | |
635 | // If we get here, then every node we visited produces a sign extended value |
636 | // or propagated sign extended values. So the result must be sign extended. |
637 | return true; |
638 | } |
639 | |
640 | static unsigned getWOp(unsigned Opcode) { |
641 | switch (Opcode) { |
642 | case RISCV::ADDI: |
643 | return RISCV::ADDIW; |
644 | case RISCV::ADD: |
645 | return RISCV::ADDW; |
646 | case RISCV::LD: |
647 | case RISCV::LWU: |
648 | return RISCV::LW; |
649 | case RISCV::MUL: |
650 | return RISCV::MULW; |
651 | case RISCV::SLLI: |
652 | return RISCV::SLLIW; |
653 | case RISCV::SUB: |
654 | return RISCV::SUBW; |
655 | default: |
656 | llvm_unreachable("Unexpected opcode for replacement with W variant" ); |
657 | } |
658 | } |
659 | |
660 | bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, |
661 | const RISCVInstrInfo &TII, |
662 | const RISCVSubtarget &ST, |
663 | MachineRegisterInfo &MRI) { |
664 | if (DisableSExtWRemoval) |
665 | return false; |
666 | |
667 | bool MadeChange = false; |
668 | for (MachineBasicBlock &MBB : MF) { |
669 | for (MachineInstr &MI : llvm::make_early_inc_range(Range&: MBB)) { |
670 | // We're looking for the sext.w pattern ADDIW rd, rs1, 0. |
671 | if (!RISCVInstrInfo::isSEXT_W(MI)) |
672 | continue; |
673 | |
674 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
675 | |
676 | SmallPtrSet<MachineInstr *, 4> FixableDefs; |
677 | |
678 | // If all users only use the lower bits, this sext.w is redundant. |
679 | // Or if all definitions reaching MI sign-extend their output, |
680 | // then sext.w is redundant. |
681 | if (!hasAllWUsers(OrigMI: MI, ST, MRI) && |
682 | !isSignExtendedW(SrcReg, ST, MRI, FixableDef&: FixableDefs)) |
683 | continue; |
684 | |
685 | Register DstReg = MI.getOperand(i: 0).getReg(); |
686 | if (!MRI.constrainRegClass(Reg: SrcReg, RC: MRI.getRegClass(Reg: DstReg))) |
687 | continue; |
688 | |
689 | // Convert Fixable instructions to their W versions. |
690 | for (MachineInstr *Fixable : FixableDefs) { |
691 | LLVM_DEBUG(dbgs() << "Replacing " << *Fixable); |
692 | Fixable->setDesc(TII.get(Opcode: getWOp(Opcode: Fixable->getOpcode()))); |
693 | Fixable->clearFlag(Flag: MachineInstr::MIFlag::NoSWrap); |
694 | Fixable->clearFlag(Flag: MachineInstr::MIFlag::NoUWrap); |
695 | Fixable->clearFlag(Flag: MachineInstr::MIFlag::IsExact); |
696 | LLVM_DEBUG(dbgs() << " with " << *Fixable); |
697 | ++NumTransformedToWInstrs; |
698 | } |
699 | |
700 | LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n" ); |
701 | MRI.replaceRegWith(FromReg: DstReg, ToReg: SrcReg); |
702 | MRI.clearKillFlags(Reg: SrcReg); |
703 | MI.eraseFromParent(); |
704 | ++NumRemovedSExtW; |
705 | MadeChange = true; |
706 | } |
707 | } |
708 | |
709 | return MadeChange; |
710 | } |
711 | |
712 | bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF, |
713 | const RISCVInstrInfo &TII, |
714 | const RISCVSubtarget &ST, |
715 | MachineRegisterInfo &MRI) { |
716 | bool MadeChange = false; |
717 | for (MachineBasicBlock &MBB : MF) { |
718 | for (MachineInstr &MI : MBB) { |
719 | unsigned Opc; |
720 | switch (MI.getOpcode()) { |
721 | default: |
722 | continue; |
723 | case RISCV::ADDW: Opc = RISCV::ADD; break; |
724 | case RISCV::ADDIW: Opc = RISCV::ADDI; break; |
725 | case RISCV::MULW: Opc = RISCV::MUL; break; |
726 | case RISCV::SLLIW: Opc = RISCV::SLLI; break; |
727 | } |
728 | |
729 | if (hasAllWUsers(OrigMI: MI, ST, MRI)) { |
730 | MI.setDesc(TII.get(Opcode: Opc)); |
731 | MadeChange = true; |
732 | } |
733 | } |
734 | } |
735 | |
736 | return MadeChange; |
737 | } |
738 | |
739 | bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF, |
740 | const RISCVInstrInfo &TII, |
741 | const RISCVSubtarget &ST, |
742 | MachineRegisterInfo &MRI) { |
743 | bool MadeChange = false; |
744 | for (MachineBasicBlock &MBB : MF) { |
745 | for (MachineInstr &MI : MBB) { |
746 | unsigned WOpc; |
747 | // TODO: Add more? |
748 | switch (MI.getOpcode()) { |
749 | default: |
750 | continue; |
751 | case RISCV::ADD: |
752 | WOpc = RISCV::ADDW; |
753 | break; |
754 | case RISCV::ADDI: |
755 | WOpc = RISCV::ADDIW; |
756 | break; |
757 | case RISCV::SUB: |
758 | WOpc = RISCV::SUBW; |
759 | break; |
760 | case RISCV::MUL: |
761 | WOpc = RISCV::MULW; |
762 | break; |
763 | case RISCV::SLLI: |
764 | // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits |
765 | if (MI.getOperand(i: 2).getImm() >= 32) |
766 | continue; |
767 | WOpc = RISCV::SLLIW; |
768 | break; |
769 | case RISCV::LD: |
770 | case RISCV::LWU: |
771 | WOpc = RISCV::LW; |
772 | break; |
773 | } |
774 | |
775 | if (hasAllWUsers(OrigMI: MI, ST, MRI)) { |
776 | LLVM_DEBUG(dbgs() << "Replacing " << MI); |
777 | MI.setDesc(TII.get(Opcode: WOpc)); |
778 | MI.clearFlag(Flag: MachineInstr::MIFlag::NoSWrap); |
779 | MI.clearFlag(Flag: MachineInstr::MIFlag::NoUWrap); |
780 | MI.clearFlag(Flag: MachineInstr::MIFlag::IsExact); |
781 | LLVM_DEBUG(dbgs() << " with " << MI); |
782 | ++NumTransformedToWInstrs; |
783 | MadeChange = true; |
784 | } |
785 | } |
786 | } |
787 | |
788 | return MadeChange; |
789 | } |
790 | |
791 | bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) { |
792 | if (skipFunction(F: MF.getFunction())) |
793 | return false; |
794 | |
795 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
796 | const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); |
797 | const RISCVInstrInfo &TII = *ST.getInstrInfo(); |
798 | |
799 | if (!ST.is64Bit()) |
800 | return false; |
801 | |
802 | bool MadeChange = false; |
803 | MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI); |
804 | |
805 | if (!(DisableStripWSuffix || ST.preferWInst())) |
806 | MadeChange |= stripWSuffixes(MF, TII, ST, MRI); |
807 | |
808 | if (ST.preferWInst()) |
809 | MadeChange |= appendWSuffixes(MF, TII, ST, MRI); |
810 | |
811 | return MadeChange; |
812 | } |
813 | |