1//===- RISCVFoldMemOffset.cpp - Fold ADDI into memory offsets ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// Look for ADDIs that can be removed by folding their immediate into later
10// load/store addresses. There may be other arithmetic instructions between the
11// addi and load/store that we need to reassociate through. If the final result
12// of the arithmetic is only used by load/store addresses, we can fold the
13// offset into the all the load/store as long as it doesn't create an offset
14// that is too large.
15//
16//===---------------------------------------------------------------------===//
17
18#include "RISCV.h"
19#include "RISCVSubtarget.h"
20#include "llvm/CodeGen/MachineFunctionPass.h"
21#include <queue>
22
23using namespace llvm;
24
25#define DEBUG_TYPE "riscv-fold-mem-offset"
26#define RISCV_FOLD_MEM_OFFSET_NAME "RISC-V Fold Memory Offset"
27
28namespace {
29
30class RISCVFoldMemOffset : public MachineFunctionPass {
31public:
32 static char ID;
33
34 RISCVFoldMemOffset() : MachineFunctionPass(ID) {}
35
36 bool runOnMachineFunction(MachineFunction &MF) override;
37
38 bool foldOffset(Register OrigReg, int64_t InitialOffset,
39 const MachineRegisterInfo &MRI,
40 DenseMap<MachineInstr *, int64_t> &FoldableInstrs);
41
42 void getAnalysisUsage(AnalysisUsage &AU) const override {
43 AU.setPreservesCFG();
44 MachineFunctionPass::getAnalysisUsage(AU);
45 }
46
47 StringRef getPassName() const override { return RISCV_FOLD_MEM_OFFSET_NAME; }
48};
49
50// Wrapper class around a std::optional to allow accumulation.
51class FoldableOffset {
52 std::optional<int64_t> Offset;
53
54public:
55 bool hasValue() const { return Offset.has_value(); }
56 int64_t getValue() const { return *Offset; }
57
58 FoldableOffset &operator=(int64_t RHS) {
59 Offset = RHS;
60 return *this;
61 }
62
63 FoldableOffset &operator+=(int64_t RHS) {
64 if (!Offset)
65 Offset = 0;
66 Offset = (uint64_t)*Offset + (uint64_t)RHS;
67 return *this;
68 }
69
70 int64_t operator*() { return *Offset; }
71};
72
73} // end anonymous namespace
74
75char RISCVFoldMemOffset::ID = 0;
76INITIALIZE_PASS(RISCVFoldMemOffset, DEBUG_TYPE, RISCV_FOLD_MEM_OFFSET_NAME,
77 false, false)
78
79FunctionPass *llvm::createRISCVFoldMemOffsetPass() {
80 return new RISCVFoldMemOffset();
81}
82
83// Walk forward from the ADDI looking for arithmetic instructions we can
84// analyze or memory instructions that use it as part of their address
85// calculation. For each arithmetic instruction we lookup how the offset
86// contributes to the value in that register use that information to
87// calculate the contribution to the output of this instruction.
88// Only addition and left shift are supported.
89// FIXME: Add multiplication by constant. The constant will be in a register.
90bool RISCVFoldMemOffset::foldOffset(
91 Register OrigReg, int64_t InitialOffset, const MachineRegisterInfo &MRI,
92 DenseMap<MachineInstr *, int64_t> &FoldableInstrs) {
93 // Map to hold how much the offset contributes to the value of this register.
94 DenseMap<Register, int64_t> RegToOffsetMap;
95
96 // Insert root offset into the map.
97 RegToOffsetMap[OrigReg] = InitialOffset;
98
99 std::queue<Register> Worklist;
100 Worklist.push(x: OrigReg);
101
102 while (!Worklist.empty()) {
103 Register Reg = Worklist.front();
104 Worklist.pop();
105
106 if (!Reg.isVirtual())
107 return false;
108
109 for (auto &User : MRI.use_nodbg_instructions(Reg)) {
110 FoldableOffset Offset;
111
112 switch (User.getOpcode()) {
113 default:
114 return false;
115 case RISCV::ADD:
116 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
117 I != RegToOffsetMap.end())
118 Offset = I->second;
119 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 2).getReg());
120 I != RegToOffsetMap.end())
121 Offset += I->second;
122 break;
123 case RISCV::SH1ADD:
124 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
125 I != RegToOffsetMap.end())
126 Offset = (uint64_t)I->second << 1;
127 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 2).getReg());
128 I != RegToOffsetMap.end())
129 Offset += I->second;
130 break;
131 case RISCV::SH2ADD:
132 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
133 I != RegToOffsetMap.end())
134 Offset = (uint64_t)I->second << 2;
135 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 2).getReg());
136 I != RegToOffsetMap.end())
137 Offset += I->second;
138 break;
139 case RISCV::SH3ADD:
140 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
141 I != RegToOffsetMap.end())
142 Offset = (uint64_t)I->second << 3;
143 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 2).getReg());
144 I != RegToOffsetMap.end())
145 Offset += I->second;
146 break;
147 case RISCV::ADD_UW:
148 case RISCV::SH1ADD_UW:
149 case RISCV::SH2ADD_UW:
150 case RISCV::SH3ADD_UW:
151 // Don't fold through the zero extended input.
152 if (User.getOperand(i: 1).getReg() == Reg)
153 return false;
154 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 2).getReg());
155 I != RegToOffsetMap.end())
156 Offset = I->second;
157 break;
158 case RISCV::SLLI: {
159 unsigned ShAmt = User.getOperand(i: 2).getImm();
160 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
161 I != RegToOffsetMap.end())
162 Offset = (uint64_t)I->second << ShAmt;
163 break;
164 }
165 case RISCV::LB:
166 case RISCV::LBU:
167 case RISCV::SB:
168 case RISCV::LH:
169 case RISCV::LH_INX:
170 case RISCV::LHU:
171 case RISCV::FLH:
172 case RISCV::SH:
173 case RISCV::SH_INX:
174 case RISCV::FSH:
175 case RISCV::LW:
176 case RISCV::LW_INX:
177 case RISCV::LWU:
178 case RISCV::FLW:
179 case RISCV::SW:
180 case RISCV::SW_INX:
181 case RISCV::FSW:
182 case RISCV::LD:
183 case RISCV::LD_RV32:
184 case RISCV::FLD:
185 case RISCV::SD:
186 case RISCV::SD_RV32:
187 case RISCV::FSD: {
188 // Can't fold into store value.
189 if (User.getOperand(i: 0).getReg() == Reg)
190 return false;
191
192 // Existing offset must be immediate.
193 if (!User.getOperand(i: 2).isImm())
194 return false;
195
196 // Require at least one operation between the ADDI and the load/store.
197 // We have other optimizations that should handle the simple case.
198 if (User.getOperand(i: 1).getReg() == OrigReg)
199 return false;
200
201 auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
202 if (I == RegToOffsetMap.end())
203 return false;
204
205 int64_t LocalOffset = User.getOperand(i: 2).getImm();
206 assert(isInt<12>(LocalOffset));
207 int64_t CombinedOffset = (uint64_t)LocalOffset + (uint64_t)I->second;
208 if (!isInt<12>(x: CombinedOffset))
209 return false;
210
211 FoldableInstrs[&User] = CombinedOffset;
212 continue;
213 }
214 }
215
216 // If we reach here we should have an accumulated offset.
217 assert(Offset.hasValue() && "Expected an offset");
218
219 // If the offset is new or changed, add the destination register to the
220 // work list.
221 int64_t OffsetVal = Offset.getValue();
222 auto P =
223 RegToOffsetMap.try_emplace(Key: User.getOperand(i: 0).getReg(), Args&: OffsetVal);
224 if (P.second) {
225 Worklist.push(x: User.getOperand(i: 0).getReg());
226 } else if (P.first->second != OffsetVal) {
227 P.first->second = OffsetVal;
228 Worklist.push(x: User.getOperand(i: 0).getReg());
229 }
230 }
231 }
232
233 return true;
234}
235
236bool RISCVFoldMemOffset::runOnMachineFunction(MachineFunction &MF) {
237 if (skipFunction(F: MF.getFunction()))
238 return false;
239
240 // This optimization may increase size by preventing compression.
241 if (MF.getFunction().hasOptSize())
242 return false;
243
244 MachineRegisterInfo &MRI = MF.getRegInfo();
245
246 bool MadeChange = false;
247 for (MachineBasicBlock &MBB : MF) {
248 for (MachineInstr &MI : llvm::make_early_inc_range(Range&: MBB)) {
249 // FIXME: We can support ADDIW from an LUI+ADDIW pair if the result is
250 // equivalent to LUI+ADDI.
251 if (MI.getOpcode() != RISCV::ADDI)
252 continue;
253
254 // We only want to optimize register ADDIs.
255 if (!MI.getOperand(i: 1).isReg() || !MI.getOperand(i: 2).isImm())
256 continue;
257
258 // Ignore 'li'.
259 if (MI.getOperand(i: 1).getReg() == RISCV::X0)
260 continue;
261
262 int64_t Offset = MI.getOperand(i: 2).getImm();
263 assert(isInt<12>(Offset));
264
265 DenseMap<MachineInstr *, int64_t> FoldableInstrs;
266
267 if (!foldOffset(OrigReg: MI.getOperand(i: 0).getReg(), InitialOffset: Offset, MRI, FoldableInstrs))
268 continue;
269
270 if (FoldableInstrs.empty())
271 continue;
272
273 // We can fold this ADDI.
274 // Rewrite all the instructions.
275 for (auto [MemMI, NewOffset] : FoldableInstrs)
276 MemMI->getOperand(i: 2).setImm(NewOffset);
277
278 MRI.replaceRegWith(FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 1).getReg());
279 MRI.clearKillFlags(Reg: MI.getOperand(i: 1).getReg());
280 MI.eraseFromParent();
281 }
282 }
283
284 return MadeChange;
285}
286