1//===- RISCVFoldMemOffset.cpp - Fold ADDI into memory offsets ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// Look for ADDIs that can be removed by folding their immediate into later
10// load/store addresses. There may be other arithmetic instructions between the
11// addi and load/store that we need to reassociate through. If the final result
12// of the arithmetic is only used by load/store addresses, we can fold the
13// offset into the all the load/store as long as it doesn't create an offset
14// that is too large.
15//
16//===---------------------------------------------------------------------===//
17
18#include "RISCV.h"
19#include "RISCVSubtarget.h"
20#include "llvm/CodeGen/MachineFunctionPass.h"
21#include <queue>
22
23using namespace llvm;
24
25#define DEBUG_TYPE "riscv-fold-mem-offset"
26#define RISCV_FOLD_MEM_OFFSET_NAME "RISC-V Fold Memory Offset"
27
28namespace {
29
30class RISCVFoldMemOffset : public MachineFunctionPass {
31public:
32 static char ID;
33
34 RISCVFoldMemOffset() : MachineFunctionPass(ID) {}
35
36 bool runOnMachineFunction(MachineFunction &MF) override;
37
38 bool foldOffset(Register OrigReg, int64_t InitialOffset,
39 const MachineRegisterInfo &MRI,
40 DenseMap<MachineInstr *, int64_t> &FoldableInstrs);
41
42 void getAnalysisUsage(AnalysisUsage &AU) const override {
43 AU.setPreservesCFG();
44 MachineFunctionPass::getAnalysisUsage(AU);
45 }
46
47 StringRef getPassName() const override { return RISCV_FOLD_MEM_OFFSET_NAME; }
48};
49
50// Wrapper class around a std::optional to allow accumulation.
51class FoldableOffset {
52 std::optional<int64_t> Offset;
53
54public:
55 bool hasValue() const { return Offset.has_value(); }
56 int64_t getValue() const { return *Offset; }
57
58 FoldableOffset &operator=(int64_t RHS) {
59 Offset = RHS;
60 return *this;
61 }
62
63 FoldableOffset &operator+=(int64_t RHS) {
64 if (!Offset)
65 Offset = 0;
66 Offset = (uint64_t)*Offset + (uint64_t)RHS;
67 return *this;
68 }
69
70 int64_t operator*() { return *Offset; }
71};
72
73} // end anonymous namespace
74
75char RISCVFoldMemOffset::ID = 0;
76INITIALIZE_PASS(RISCVFoldMemOffset, DEBUG_TYPE, RISCV_FOLD_MEM_OFFSET_NAME,
77 false, false)
78
79FunctionPass *llvm::createRISCVFoldMemOffsetPass() {
80 return new RISCVFoldMemOffset();
81}
82
83// Walk forward from the ADDI looking for arithmetic instructions we can
84// analyze or memory instructions that use it as part of their address
85// calculation. For each arithmetic instruction we lookup how the offset
86// contributes to the value in that register use that information to
87// calculate the contribution to the output of this instruction.
88// Only addition and left shift are supported.
89// FIXME: Add multiplication by constant. The constant will be in a register.
90bool RISCVFoldMemOffset::foldOffset(
91 Register OrigReg, int64_t InitialOffset, const MachineRegisterInfo &MRI,
92 DenseMap<MachineInstr *, int64_t> &FoldableInstrs) {
93 // Map to hold how much the offset contributes to the value of this register.
94 DenseMap<Register, int64_t> RegToOffsetMap;
95
96 // Insert root offset into the map.
97 RegToOffsetMap[OrigReg] = InitialOffset;
98
99 std::queue<Register> Worklist;
100 Worklist.push(x: OrigReg);
101
102 while (!Worklist.empty()) {
103 Register Reg = Worklist.front();
104 Worklist.pop();
105
106 if (!Reg.isVirtual())
107 return false;
108
109 for (auto &User : MRI.use_nodbg_instructions(Reg)) {
110 FoldableOffset Offset;
111
112 switch (User.getOpcode()) {
113 default:
114 return false;
115 case RISCV::ADD:
116 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
117 I != RegToOffsetMap.end())
118 Offset = I->second;
119 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 2).getReg());
120 I != RegToOffsetMap.end())
121 Offset += I->second;
122 break;
123 case RISCV::SH1ADD:
124 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
125 I != RegToOffsetMap.end())
126 Offset = (uint64_t)I->second << 1;
127 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 2).getReg());
128 I != RegToOffsetMap.end())
129 Offset += I->second;
130 break;
131 case RISCV::SH2ADD:
132 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
133 I != RegToOffsetMap.end())
134 Offset = (uint64_t)I->second << 2;
135 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 2).getReg());
136 I != RegToOffsetMap.end())
137 Offset += I->second;
138 break;
139 case RISCV::SH3ADD:
140 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
141 I != RegToOffsetMap.end())
142 Offset = (uint64_t)I->second << 3;
143 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 2).getReg());
144 I != RegToOffsetMap.end())
145 Offset += I->second;
146 break;
147 case RISCV::ADD_UW:
148 case RISCV::SH1ADD_UW:
149 case RISCV::SH2ADD_UW:
150 case RISCV::SH3ADD_UW:
151 // Don't fold through the zero extended input.
152 if (User.getOperand(i: 1).getReg() == Reg)
153 return false;
154 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 2).getReg());
155 I != RegToOffsetMap.end())
156 Offset = I->second;
157 break;
158 case RISCV::SLLI: {
159 unsigned ShAmt = User.getOperand(i: 2).getImm();
160 if (auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
161 I != RegToOffsetMap.end())
162 Offset = (uint64_t)I->second << ShAmt;
163 break;
164 }
165 case RISCV::LB:
166 case RISCV::LBU:
167 case RISCV::SB:
168 case RISCV::LH:
169 case RISCV::LH_INX:
170 case RISCV::LHU:
171 case RISCV::FLH:
172 case RISCV::SH:
173 case RISCV::SH_INX:
174 case RISCV::FSH:
175 case RISCV::LW:
176 case RISCV::LW_INX:
177 case RISCV::LWU:
178 case RISCV::FLW:
179 case RISCV::SW:
180 case RISCV::SW_INX:
181 case RISCV::FSW:
182 case RISCV::LD:
183 case RISCV::FLD:
184 case RISCV::SD:
185 case RISCV::FSD: {
186 // Can't fold into store value.
187 if (User.getOperand(i: 0).getReg() == Reg)
188 return false;
189
190 // Existing offset must be immediate.
191 if (!User.getOperand(i: 2).isImm())
192 return false;
193
194 // Require at least one operation between the ADDI and the load/store.
195 // We have other optimizations that should handle the simple case.
196 if (User.getOperand(i: 1).getReg() == OrigReg)
197 return false;
198
199 auto I = RegToOffsetMap.find(Val: User.getOperand(i: 1).getReg());
200 if (I == RegToOffsetMap.end())
201 return false;
202
203 int64_t LocalOffset = User.getOperand(i: 2).getImm();
204 assert(isInt<12>(LocalOffset));
205 int64_t CombinedOffset = (uint64_t)LocalOffset + (uint64_t)I->second;
206 if (!isInt<12>(x: CombinedOffset))
207 return false;
208
209 FoldableInstrs[&User] = CombinedOffset;
210 continue;
211 }
212 }
213
214 // If we reach here we should have an accumulated offset.
215 assert(Offset.hasValue() && "Expected an offset");
216
217 // If the offset is new or changed, add the destination register to the
218 // work list.
219 int64_t OffsetVal = Offset.getValue();
220 auto P =
221 RegToOffsetMap.try_emplace(Key: User.getOperand(i: 0).getReg(), Args&: OffsetVal);
222 if (P.second) {
223 Worklist.push(x: User.getOperand(i: 0).getReg());
224 } else if (P.first->second != OffsetVal) {
225 P.first->second = OffsetVal;
226 Worklist.push(x: User.getOperand(i: 0).getReg());
227 }
228 }
229 }
230
231 return true;
232}
233
234bool RISCVFoldMemOffset::runOnMachineFunction(MachineFunction &MF) {
235 if (skipFunction(F: MF.getFunction()))
236 return false;
237
238 // This optimization may increase size by preventing compression.
239 if (MF.getFunction().hasOptSize())
240 return false;
241
242 MachineRegisterInfo &MRI = MF.getRegInfo();
243
244 bool MadeChange = false;
245 for (MachineBasicBlock &MBB : MF) {
246 for (MachineInstr &MI : llvm::make_early_inc_range(Range&: MBB)) {
247 // FIXME: We can support ADDIW from an LUI+ADDIW pair if the result is
248 // equivalent to LUI+ADDI.
249 if (MI.getOpcode() != RISCV::ADDI)
250 continue;
251
252 // We only want to optimize register ADDIs.
253 if (!MI.getOperand(i: 1).isReg() || !MI.getOperand(i: 2).isImm())
254 continue;
255
256 // Ignore 'li'.
257 if (MI.getOperand(i: 1).getReg() == RISCV::X0)
258 continue;
259
260 int64_t Offset = MI.getOperand(i: 2).getImm();
261 assert(isInt<12>(Offset));
262
263 DenseMap<MachineInstr *, int64_t> FoldableInstrs;
264
265 if (!foldOffset(OrigReg: MI.getOperand(i: 0).getReg(), InitialOffset: Offset, MRI, FoldableInstrs))
266 continue;
267
268 if (FoldableInstrs.empty())
269 continue;
270
271 // We can fold this ADDI.
272 // Rewrite all the instructions.
273 for (auto [MemMI, NewOffset] : FoldableInstrs)
274 MemMI->getOperand(i: 2).setImm(NewOffset);
275
276 MRI.replaceRegWith(FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 1).getReg());
277 MRI.clearKillFlags(Reg: MI.getOperand(i: 1).getReg());
278 MI.eraseFromParent();
279 }
280 }
281
282 return MadeChange;
283}
284