1 | //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | /// |
9 | /// \file |
10 | /// Post-legalization combines on generic MachineInstrs. |
11 | /// |
12 | /// The combines here must preserve instruction legality. |
13 | /// |
14 | /// Lowering combines (e.g. pseudo matching) should be handled by |
15 | /// AArch64PostLegalizerLowering. |
16 | /// |
17 | /// Combines which don't rely on instruction legality should go in the |
18 | /// AArch64PreLegalizerCombiner. |
19 | /// |
20 | //===----------------------------------------------------------------------===// |
21 | |
22 | #include "AArch64TargetMachine.h" |
23 | #include "llvm/ADT/STLExtras.h" |
24 | #include "llvm/CodeGen/GlobalISel/CSEInfo.h" |
25 | #include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h" |
26 | #include "llvm/CodeGen/GlobalISel/Combiner.h" |
27 | #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" |
28 | #include "llvm/CodeGen/GlobalISel/CombinerInfo.h" |
29 | #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h" |
30 | #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" |
31 | #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" |
32 | #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" |
33 | #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" |
34 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
35 | #include "llvm/CodeGen/GlobalISel/Utils.h" |
36 | #include "llvm/CodeGen/MachineDominators.h" |
37 | #include "llvm/CodeGen/MachineFunctionPass.h" |
38 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
39 | #include "llvm/CodeGen/TargetOpcodes.h" |
40 | #include "llvm/CodeGen/TargetPassConfig.h" |
41 | #include "llvm/Support/Debug.h" |
42 | |
43 | #define GET_GICOMBINER_DEPS |
44 | #include "AArch64GenPostLegalizeGICombiner.inc" |
45 | #undef GET_GICOMBINER_DEPS |
46 | |
47 | #define DEBUG_TYPE "aarch64-postlegalizer-combiner" |
48 | |
49 | using namespace llvm; |
50 | using namespace MIPatternMatch; |
51 | |
52 | namespace { |
53 | |
54 | #define GET_GICOMBINER_TYPES |
55 | #include "AArch64GenPostLegalizeGICombiner.inc" |
56 | #undef GET_GICOMBINER_TYPES |
57 | |
58 | /// This combine tries do what performExtractVectorEltCombine does in SDAG. |
59 | /// Rewrite for pairwise fadd pattern |
60 | /// (s32 (g_extract_vector_elt |
61 | /// (g_fadd (vXs32 Other) |
62 | /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0)) |
63 | /// -> |
64 | /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0) |
65 | /// (g_extract_vector_elt (vXs32 Other) 1)) |
66 | bool ( |
67 | MachineInstr &MI, MachineRegisterInfo &MRI, |
68 | std::tuple<unsigned, LLT, Register> &MatchInfo) { |
69 | Register Src1 = MI.getOperand(i: 1).getReg(); |
70 | Register Src2 = MI.getOperand(i: 2).getReg(); |
71 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
72 | |
73 | auto Cst = getIConstantVRegValWithLookThrough(VReg: Src2, MRI); |
74 | if (!Cst || Cst->Value != 0) |
75 | return false; |
76 | // SDAG also checks for FullFP16, but this looks to be beneficial anyway. |
77 | |
78 | // Now check for an fadd operation. TODO: expand this for integer add? |
79 | auto *FAddMI = getOpcodeDef(Opcode: TargetOpcode::G_FADD, Reg: Src1, MRI); |
80 | if (!FAddMI) |
81 | return false; |
82 | |
83 | // If we add support for integer add, must restrict these types to just s64. |
84 | unsigned DstSize = DstTy.getSizeInBits(); |
85 | if (DstSize != 16 && DstSize != 32 && DstSize != 64) |
86 | return false; |
87 | |
88 | Register Src1Op1 = FAddMI->getOperand(i: 1).getReg(); |
89 | Register Src1Op2 = FAddMI->getOperand(i: 2).getReg(); |
90 | MachineInstr *Shuffle = |
91 | getOpcodeDef(Opcode: TargetOpcode::G_SHUFFLE_VECTOR, Reg: Src1Op2, MRI); |
92 | MachineInstr *Other = MRI.getVRegDef(Reg: Src1Op1); |
93 | if (!Shuffle) { |
94 | Shuffle = getOpcodeDef(Opcode: TargetOpcode::G_SHUFFLE_VECTOR, Reg: Src1Op1, MRI); |
95 | Other = MRI.getVRegDef(Reg: Src1Op2); |
96 | } |
97 | |
98 | // We're looking for a shuffle that moves the second element to index 0. |
99 | if (Shuffle && Shuffle->getOperand(i: 3).getShuffleMask()[0] == 1 && |
100 | Other == MRI.getVRegDef(Reg: Shuffle->getOperand(i: 1).getReg())) { |
101 | std::get<0>(t&: MatchInfo) = TargetOpcode::G_FADD; |
102 | std::get<1>(t&: MatchInfo) = DstTy; |
103 | std::get<2>(t&: MatchInfo) = Other->getOperand(i: 0).getReg(); |
104 | return true; |
105 | } |
106 | return false; |
107 | } |
108 | |
109 | void ( |
110 | MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, |
111 | std::tuple<unsigned, LLT, Register> &MatchInfo) { |
112 | unsigned Opc = std::get<0>(t&: MatchInfo); |
113 | assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!" ); |
114 | // We want to generate two extracts of elements 0 and 1, and add them. |
115 | LLT Ty = std::get<1>(t&: MatchInfo); |
116 | Register Src = std::get<2>(t&: MatchInfo); |
117 | LLT s64 = LLT::scalar(SizeInBits: 64); |
118 | B.setInstrAndDebugLoc(MI); |
119 | auto Elt0 = B.buildExtractVectorElement(Res: Ty, Val: Src, Idx: B.buildConstant(Res: s64, Val: 0)); |
120 | auto Elt1 = B.buildExtractVectorElement(Res: Ty, Val: Src, Idx: B.buildConstant(Res: s64, Val: 1)); |
121 | B.buildInstr(Opc, DstOps: {MI.getOperand(i: 0).getReg()}, SrcOps: {Elt0, Elt1}); |
122 | MI.eraseFromParent(); |
123 | } |
124 | |
125 | bool isSignExtended(Register R, MachineRegisterInfo &MRI) { |
126 | // TODO: check if extended build vector as well. |
127 | unsigned Opc = MRI.getVRegDef(Reg: R)->getOpcode(); |
128 | return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG; |
129 | } |
130 | |
131 | bool isZeroExtended(Register R, MachineRegisterInfo &MRI) { |
132 | // TODO: check if extended build vector as well. |
133 | return MRI.getVRegDef(Reg: R)->getOpcode() == TargetOpcode::G_ZEXT; |
134 | } |
135 | |
136 | bool matchAArch64MulConstCombine( |
137 | MachineInstr &MI, MachineRegisterInfo &MRI, |
138 | std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { |
139 | assert(MI.getOpcode() == TargetOpcode::G_MUL); |
140 | Register LHS = MI.getOperand(i: 1).getReg(); |
141 | Register RHS = MI.getOperand(i: 2).getReg(); |
142 | Register Dst = MI.getOperand(i: 0).getReg(); |
143 | const LLT Ty = MRI.getType(Reg: LHS); |
144 | |
145 | // The below optimizations require a constant RHS. |
146 | auto Const = getIConstantVRegValWithLookThrough(VReg: RHS, MRI); |
147 | if (!Const) |
148 | return false; |
149 | |
150 | APInt ConstValue = Const->Value.sext(width: Ty.getSizeInBits()); |
151 | // The following code is ported from AArch64ISelLowering. |
152 | // Multiplication of a power of two plus/minus one can be done more |
153 | // cheaply as shift+add/sub. For now, this is true unilaterally. If |
154 | // future CPUs have a cheaper MADD instruction, this may need to be |
155 | // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and |
156 | // 64-bit is 5 cycles, so this is always a win. |
157 | // More aggressively, some multiplications N0 * C can be lowered to |
158 | // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M, |
159 | // e.g. 6=3*2=(2+1)*2. |
160 | // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 |
161 | // which equals to (1+2)*16-(1+2). |
162 | // TrailingZeroes is used to test if the mul can be lowered to |
163 | // shift+add+shift. |
164 | unsigned TrailingZeroes = ConstValue.countr_zero(); |
165 | if (TrailingZeroes) { |
166 | // Conservatively do not lower to shift+add+shift if the mul might be |
167 | // folded into smul or umul. |
168 | if (MRI.hasOneNonDBGUse(RegNo: LHS) && |
169 | (isSignExtended(R: LHS, MRI) || isZeroExtended(R: LHS, MRI))) |
170 | return false; |
171 | // Conservatively do not lower to shift+add+shift if the mul might be |
172 | // folded into madd or msub. |
173 | if (MRI.hasOneNonDBGUse(RegNo: Dst)) { |
174 | MachineInstr &UseMI = *MRI.use_instr_begin(RegNo: Dst); |
175 | unsigned UseOpc = UseMI.getOpcode(); |
176 | if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD || |
177 | UseOpc == TargetOpcode::G_SUB) |
178 | return false; |
179 | } |
180 | } |
181 | // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub |
182 | // and shift+add+shift. |
183 | APInt ShiftedConstValue = ConstValue.ashr(ShiftAmt: TrailingZeroes); |
184 | |
185 | unsigned ShiftAmt, AddSubOpc; |
186 | // Is the shifted value the LHS operand of the add/sub? |
187 | bool ShiftValUseIsLHS = true; |
188 | // Do we need to negate the result? |
189 | bool NegateResult = false; |
190 | |
191 | if (ConstValue.isNonNegative()) { |
192 | // (mul x, 2^N + 1) => (add (shl x, N), x) |
193 | // (mul x, 2^N - 1) => (sub (shl x, N), x) |
194 | // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) |
195 | APInt SCVMinus1 = ShiftedConstValue - 1; |
196 | APInt CVPlus1 = ConstValue + 1; |
197 | if (SCVMinus1.isPowerOf2()) { |
198 | ShiftAmt = SCVMinus1.logBase2(); |
199 | AddSubOpc = TargetOpcode::G_ADD; |
200 | } else if (CVPlus1.isPowerOf2()) { |
201 | ShiftAmt = CVPlus1.logBase2(); |
202 | AddSubOpc = TargetOpcode::G_SUB; |
203 | } else |
204 | return false; |
205 | } else { |
206 | // (mul x, -(2^N - 1)) => (sub x, (shl x, N)) |
207 | // (mul x, -(2^N + 1)) => - (add (shl x, N), x) |
208 | APInt CVNegPlus1 = -ConstValue + 1; |
209 | APInt CVNegMinus1 = -ConstValue - 1; |
210 | if (CVNegPlus1.isPowerOf2()) { |
211 | ShiftAmt = CVNegPlus1.logBase2(); |
212 | AddSubOpc = TargetOpcode::G_SUB; |
213 | ShiftValUseIsLHS = false; |
214 | } else if (CVNegMinus1.isPowerOf2()) { |
215 | ShiftAmt = CVNegMinus1.logBase2(); |
216 | AddSubOpc = TargetOpcode::G_ADD; |
217 | NegateResult = true; |
218 | } else |
219 | return false; |
220 | } |
221 | |
222 | if (NegateResult && TrailingZeroes) |
223 | return false; |
224 | |
225 | ApplyFn = [=](MachineIRBuilder &B, Register DstReg) { |
226 | auto Shift = B.buildConstant(Res: LLT::scalar(SizeInBits: 64), Val: ShiftAmt); |
227 | auto ShiftedVal = B.buildShl(Dst: Ty, Src0: LHS, Src1: Shift); |
228 | |
229 | Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(Idx: 0) : LHS; |
230 | Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(Idx: 0); |
231 | auto Res = B.buildInstr(Opc: AddSubOpc, DstOps: {Ty}, SrcOps: {AddSubLHS, AddSubRHS}); |
232 | assert(!(NegateResult && TrailingZeroes) && |
233 | "NegateResult and TrailingZeroes cannot both be true for now." ); |
234 | // Negate the result. |
235 | if (NegateResult) { |
236 | B.buildSub(Dst: DstReg, Src0: B.buildConstant(Res: Ty, Val: 0), Src1: Res); |
237 | return; |
238 | } |
239 | // Shift the result. |
240 | if (TrailingZeroes) { |
241 | B.buildShl(Dst: DstReg, Src0: Res, Src1: B.buildConstant(Res: LLT::scalar(SizeInBits: 64), Val: TrailingZeroes)); |
242 | return; |
243 | } |
244 | B.buildCopy(Res: DstReg, Op: Res.getReg(Idx: 0)); |
245 | }; |
246 | return true; |
247 | } |
248 | |
249 | void applyAArch64MulConstCombine( |
250 | MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, |
251 | std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { |
252 | B.setInstrAndDebugLoc(MI); |
253 | ApplyFn(B, MI.getOperand(i: 0).getReg()); |
254 | MI.eraseFromParent(); |
255 | } |
256 | |
257 | /// Try to fold a G_MERGE_VALUES of 2 s32 sources, where the second source |
258 | /// is a zero, into a G_ZEXT of the first. |
259 | bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) { |
260 | auto &Merge = cast<GMerge>(Val&: MI); |
261 | LLT SrcTy = MRI.getType(Reg: Merge.getSourceReg(I: 0)); |
262 | if (SrcTy != LLT::scalar(SizeInBits: 32) || Merge.getNumSources() != 2) |
263 | return false; |
264 | return mi_match(R: Merge.getSourceReg(I: 1), MRI, P: m_SpecificICst(RequestedValue: 0)); |
265 | } |
266 | |
267 | void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI, |
268 | MachineIRBuilder &B, GISelChangeObserver &Observer) { |
269 | // Mutate %d(s64) = G_MERGE_VALUES %a(s32), 0(s32) |
270 | // -> |
271 | // %d(s64) = G_ZEXT %a(s32) |
272 | Observer.changingInstr(MI); |
273 | MI.setDesc(B.getTII().get(Opcode: TargetOpcode::G_ZEXT)); |
274 | MI.removeOperand(OpNo: 2); |
275 | Observer.changedInstr(MI); |
276 | } |
277 | |
278 | /// \returns True if a G_ANYEXT instruction \p MI should be mutated to a G_ZEXT |
279 | /// instruction. |
280 | bool matchMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI) { |
281 | // If this is coming from a scalar compare then we can use a G_ZEXT instead of |
282 | // a G_ANYEXT: |
283 | // |
284 | // %cmp:_(s32) = G_[I|F]CMP ... <-- produces 0/1. |
285 | // %ext:_(s64) = G_ANYEXT %cmp(s32) |
286 | // |
287 | // By doing this, we can leverage more KnownBits combines. |
288 | assert(MI.getOpcode() == TargetOpcode::G_ANYEXT); |
289 | Register Dst = MI.getOperand(i: 0).getReg(); |
290 | Register Src = MI.getOperand(i: 1).getReg(); |
291 | return MRI.getType(Reg: Dst).isScalar() && |
292 | mi_match(R: Src, MRI, |
293 | P: m_any_of(preds: m_GICmp(P: m_Pred(), L: m_Reg(), R: m_Reg()), |
294 | preds: m_GFCmp(P: m_Pred(), L: m_Reg(), R: m_Reg()))); |
295 | } |
296 | |
297 | void applyMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI, |
298 | MachineIRBuilder &B, |
299 | GISelChangeObserver &Observer) { |
300 | Observer.changingInstr(MI); |
301 | MI.setDesc(B.getTII().get(Opcode: TargetOpcode::G_ZEXT)); |
302 | Observer.changedInstr(MI); |
303 | } |
304 | |
305 | /// Match a 128b store of zero and split it into two 64 bit stores, for |
306 | /// size/performance reasons. |
307 | bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) { |
308 | GStore &Store = cast<GStore>(Val&: MI); |
309 | if (!Store.isSimple()) |
310 | return false; |
311 | LLT ValTy = MRI.getType(Reg: Store.getValueReg()); |
312 | if (ValTy.isScalableVector()) |
313 | return false; |
314 | if (!ValTy.isVector() || ValTy.getSizeInBits() != 128) |
315 | return false; |
316 | if (Store.getMemSizeInBits() != ValTy.getSizeInBits()) |
317 | return false; // Don't split truncating stores. |
318 | if (!MRI.hasOneNonDBGUse(RegNo: Store.getValueReg())) |
319 | return false; |
320 | auto MaybeCst = isConstantOrConstantSplatVector( |
321 | MI&: *MRI.getVRegDef(Reg: Store.getValueReg()), MRI); |
322 | return MaybeCst && MaybeCst->isZero(); |
323 | } |
324 | |
325 | void applySplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI, |
326 | MachineIRBuilder &B, |
327 | GISelChangeObserver &Observer) { |
328 | B.setInstrAndDebugLoc(MI); |
329 | GStore &Store = cast<GStore>(Val&: MI); |
330 | assert(MRI.getType(Store.getValueReg()).isVector() && |
331 | "Expected a vector store value" ); |
332 | LLT NewTy = LLT::scalar(SizeInBits: 64); |
333 | Register PtrReg = Store.getPointerReg(); |
334 | auto Zero = B.buildConstant(Res: NewTy, Val: 0); |
335 | auto HighPtr = B.buildPtrAdd(Res: MRI.getType(Reg: PtrReg), Op0: PtrReg, |
336 | Op1: B.buildConstant(Res: LLT::scalar(SizeInBits: 64), Val: 8)); |
337 | auto &MF = *MI.getMF(); |
338 | auto *LowMMO = MF.getMachineMemOperand(MMO: &Store.getMMO(), Offset: 0, Ty: NewTy); |
339 | auto *HighMMO = MF.getMachineMemOperand(MMO: &Store.getMMO(), Offset: 8, Ty: NewTy); |
340 | B.buildStore(Val: Zero, Addr: PtrReg, MMO&: *LowMMO); |
341 | B.buildStore(Val: Zero, Addr: HighPtr, MMO&: *HighMMO); |
342 | Store.eraseFromParent(); |
343 | } |
344 | |
345 | bool matchOrToBSP(MachineInstr &MI, MachineRegisterInfo &MRI, |
346 | std::tuple<Register, Register, Register> &MatchInfo) { |
347 | const LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
348 | if (!DstTy.isVector()) |
349 | return false; |
350 | |
351 | Register AO1, AO2, BVO1, BVO2; |
352 | if (!mi_match(MI, MRI, |
353 | P: m_GOr(L: m_GAnd(L: m_Reg(R&: AO1), R: m_Reg(R&: BVO1)), |
354 | R: m_GAnd(L: m_Reg(R&: AO2), R: m_Reg(R&: BVO2))))) |
355 | return false; |
356 | |
357 | auto *BV1 = getOpcodeDef<GBuildVector>(Reg: BVO1, MRI); |
358 | auto *BV2 = getOpcodeDef<GBuildVector>(Reg: BVO2, MRI); |
359 | if (!BV1 || !BV2) |
360 | return false; |
361 | |
362 | for (int I = 0, E = DstTy.getNumElements(); I < E; I++) { |
363 | auto ValAndVReg1 = |
364 | getIConstantVRegValWithLookThrough(VReg: BV1->getSourceReg(I), MRI); |
365 | auto ValAndVReg2 = |
366 | getIConstantVRegValWithLookThrough(VReg: BV2->getSourceReg(I), MRI); |
367 | if (!ValAndVReg1 || !ValAndVReg2 || |
368 | ValAndVReg1->Value != ~ValAndVReg2->Value) |
369 | return false; |
370 | } |
371 | |
372 | MatchInfo = {AO1, AO2, BVO1}; |
373 | return true; |
374 | } |
375 | |
376 | void applyOrToBSP(MachineInstr &MI, MachineRegisterInfo &MRI, |
377 | MachineIRBuilder &B, |
378 | std::tuple<Register, Register, Register> &MatchInfo) { |
379 | B.setInstrAndDebugLoc(MI); |
380 | B.buildInstr( |
381 | Opc: AArch64::G_BSP, DstOps: {MI.getOperand(i: 0).getReg()}, |
382 | SrcOps: {std::get<2>(t&: MatchInfo), std::get<0>(t&: MatchInfo), std::get<1>(t&: MatchInfo)}); |
383 | MI.eraseFromParent(); |
384 | } |
385 | |
386 | // Combines Mul(And(Srl(X, 15), 0x10001), 0xffff) into CMLTz |
387 | bool matchCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI, |
388 | Register &SrcReg) { |
389 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
390 | |
391 | if (DstTy != LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 64) && DstTy != LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 32) && |
392 | DstTy != LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32) && DstTy != LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 16) && |
393 | DstTy != LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 16)) |
394 | return false; |
395 | |
396 | auto AndMI = getDefIgnoringCopies(Reg: MI.getOperand(i: 1).getReg(), MRI); |
397 | if (AndMI->getOpcode() != TargetOpcode::G_AND) |
398 | return false; |
399 | auto LShrMI = getDefIgnoringCopies(Reg: AndMI->getOperand(i: 1).getReg(), MRI); |
400 | if (LShrMI->getOpcode() != TargetOpcode::G_LSHR) |
401 | return false; |
402 | |
403 | // Check the constant splat values |
404 | auto V1 = isConstantOrConstantSplatVector( |
405 | MI&: *MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg()), MRI); |
406 | auto V2 = isConstantOrConstantSplatVector( |
407 | MI&: *MRI.getVRegDef(Reg: AndMI->getOperand(i: 2).getReg()), MRI); |
408 | auto V3 = isConstantOrConstantSplatVector( |
409 | MI&: *MRI.getVRegDef(Reg: LShrMI->getOperand(i: 2).getReg()), MRI); |
410 | if (!V1.has_value() || !V2.has_value() || !V3.has_value()) |
411 | return false; |
412 | unsigned HalfSize = DstTy.getScalarSizeInBits() / 2; |
413 | if (!V1.value().isMask(numBits: HalfSize) || V2.value() != (1ULL | 1ULL << HalfSize) || |
414 | V3 != (HalfSize - 1)) |
415 | return false; |
416 | |
417 | SrcReg = LShrMI->getOperand(i: 1).getReg(); |
418 | |
419 | return true; |
420 | } |
421 | |
422 | void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI, |
423 | MachineIRBuilder &B, Register &SrcReg) { |
424 | Register DstReg = MI.getOperand(i: 0).getReg(); |
425 | LLT DstTy = MRI.getType(Reg: DstReg); |
426 | LLT HalfTy = |
427 | DstTy.changeElementCount(EC: DstTy.getElementCount().multiplyCoefficientBy(RHS: 2)) |
428 | .changeElementSize(NewEltSize: DstTy.getScalarSizeInBits() / 2); |
429 | |
430 | Register ZeroVec = B.buildConstant(Res: HalfTy, Val: 0).getReg(Idx: 0); |
431 | Register CastReg = |
432 | B.buildInstr(Opc: TargetOpcode::G_BITCAST, DstOps: {HalfTy}, SrcOps: {SrcReg}).getReg(Idx: 0); |
433 | Register CMLTReg = |
434 | B.buildICmp(Pred: CmpInst::Predicate::ICMP_SLT, Res: HalfTy, Op0: CastReg, Op1: ZeroVec) |
435 | .getReg(Idx: 0); |
436 | |
437 | B.buildInstr(Opc: TargetOpcode::G_BITCAST, DstOps: {DstReg}, SrcOps: {CMLTReg}).getReg(Idx: 0); |
438 | MI.eraseFromParent(); |
439 | } |
440 | |
441 | class AArch64PostLegalizerCombinerImpl : public Combiner { |
442 | protected: |
443 | // TODO: Make CombinerHelper methods const. |
444 | mutable CombinerHelper Helper; |
445 | const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig; |
446 | const AArch64Subtarget &STI; |
447 | |
448 | public: |
449 | AArch64PostLegalizerCombinerImpl( |
450 | MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, |
451 | GISelKnownBits &KB, GISelCSEInfo *CSEInfo, |
452 | const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig, |
453 | const AArch64Subtarget &STI, MachineDominatorTree *MDT, |
454 | const LegalizerInfo *LI); |
455 | |
456 | static const char *getName() { return "AArch64PostLegalizerCombiner" ; } |
457 | |
458 | bool tryCombineAll(MachineInstr &I) const override; |
459 | |
460 | private: |
461 | #define GET_GICOMBINER_CLASS_MEMBERS |
462 | #include "AArch64GenPostLegalizeGICombiner.inc" |
463 | #undef GET_GICOMBINER_CLASS_MEMBERS |
464 | }; |
465 | |
466 | #define GET_GICOMBINER_IMPL |
467 | #include "AArch64GenPostLegalizeGICombiner.inc" |
468 | #undef GET_GICOMBINER_IMPL |
469 | |
470 | AArch64PostLegalizerCombinerImpl::AArch64PostLegalizerCombinerImpl( |
471 | MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, |
472 | GISelKnownBits &KB, GISelCSEInfo *CSEInfo, |
473 | const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig, |
474 | const AArch64Subtarget &STI, MachineDominatorTree *MDT, |
475 | const LegalizerInfo *LI) |
476 | : Combiner(MF, CInfo, TPC, &KB, CSEInfo), |
477 | Helper(Observer, B, /*IsPreLegalize*/ false, &KB, MDT, LI), |
478 | RuleConfig(RuleConfig), STI(STI), |
479 | #define GET_GICOMBINER_CONSTRUCTOR_INITS |
480 | #include "AArch64GenPostLegalizeGICombiner.inc" |
481 | #undef GET_GICOMBINER_CONSTRUCTOR_INITS |
482 | { |
483 | } |
484 | |
485 | class AArch64PostLegalizerCombiner : public MachineFunctionPass { |
486 | public: |
487 | static char ID; |
488 | |
489 | AArch64PostLegalizerCombiner(bool IsOptNone = false); |
490 | |
491 | StringRef getPassName() const override { |
492 | return "AArch64PostLegalizerCombiner" ; |
493 | } |
494 | |
495 | bool runOnMachineFunction(MachineFunction &MF) override; |
496 | void getAnalysisUsage(AnalysisUsage &AU) const override; |
497 | |
498 | private: |
499 | bool IsOptNone; |
500 | AArch64PostLegalizerCombinerImplRuleConfig RuleConfig; |
501 | |
502 | |
503 | struct StoreInfo { |
504 | GStore *St = nullptr; |
505 | // The G_PTR_ADD that's used by the store. We keep this to cache the |
506 | // MachineInstr def. |
507 | GPtrAdd *Ptr = nullptr; |
508 | // The signed offset to the Ptr instruction. |
509 | int64_t Offset = 0; |
510 | LLT StoredType; |
511 | }; |
512 | bool tryOptimizeConsecStores(SmallVectorImpl<StoreInfo> &Stores, |
513 | CSEMIRBuilder &MIB); |
514 | |
515 | bool optimizeConsecutiveMemOpAddressing(MachineFunction &MF, |
516 | CSEMIRBuilder &MIB); |
517 | }; |
518 | } // end anonymous namespace |
519 | |
520 | void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { |
521 | AU.addRequired<TargetPassConfig>(); |
522 | AU.setPreservesCFG(); |
523 | getSelectionDAGFallbackAnalysisUsage(AU); |
524 | AU.addRequired<GISelKnownBitsAnalysis>(); |
525 | AU.addPreserved<GISelKnownBitsAnalysis>(); |
526 | if (!IsOptNone) { |
527 | AU.addRequired<MachineDominatorTreeWrapperPass>(); |
528 | AU.addPreserved<MachineDominatorTreeWrapperPass>(); |
529 | AU.addRequired<GISelCSEAnalysisWrapperPass>(); |
530 | AU.addPreserved<GISelCSEAnalysisWrapperPass>(); |
531 | } |
532 | MachineFunctionPass::getAnalysisUsage(AU); |
533 | } |
534 | |
535 | AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone) |
536 | : MachineFunctionPass(ID), IsOptNone(IsOptNone) { |
537 | initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry()); |
538 | |
539 | if (!RuleConfig.parseCommandLineOption()) |
540 | report_fatal_error(reason: "Invalid rule identifier" ); |
541 | } |
542 | |
543 | bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { |
544 | if (MF.getProperties().hasProperty( |
545 | P: MachineFunctionProperties::Property::FailedISel)) |
546 | return false; |
547 | assert(MF.getProperties().hasProperty( |
548 | MachineFunctionProperties::Property::Legalized) && |
549 | "Expected a legalized function?" ); |
550 | auto *TPC = &getAnalysis<TargetPassConfig>(); |
551 | const Function &F = MF.getFunction(); |
552 | bool EnableOpt = |
553 | MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F); |
554 | |
555 | const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>(); |
556 | const auto *LI = ST.getLegalizerInfo(); |
557 | |
558 | GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF); |
559 | MachineDominatorTree *MDT = |
560 | IsOptNone ? nullptr |
561 | : &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); |
562 | GISelCSEAnalysisWrapper &Wrapper = |
563 | getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper(); |
564 | auto *CSEInfo = &Wrapper.get(CSEOpt: TPC->getCSEConfig()); |
565 | |
566 | CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, |
567 | /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(), |
568 | F.hasMinSize()); |
569 | AArch64PostLegalizerCombinerImpl Impl(MF, CInfo, TPC, *KB, CSEInfo, |
570 | RuleConfig, ST, MDT, LI); |
571 | bool Changed = Impl.combineMachineInstrs(); |
572 | |
573 | auto MIB = CSEMIRBuilder(MF); |
574 | MIB.setCSEInfo(CSEInfo); |
575 | Changed |= optimizeConsecutiveMemOpAddressing(MF, MIB); |
576 | return Changed; |
577 | } |
578 | |
579 | bool AArch64PostLegalizerCombiner::tryOptimizeConsecStores( |
580 | SmallVectorImpl<StoreInfo> &Stores, CSEMIRBuilder &MIB) { |
581 | if (Stores.size() <= 2) |
582 | return false; |
583 | |
584 | // Profitabity checks: |
585 | int64_t BaseOffset = Stores[0].Offset; |
586 | unsigned NumPairsExpected = Stores.size() / 2; |
587 | unsigned TotalInstsExpected = NumPairsExpected + (Stores.size() % 2); |
588 | // Size savings will depend on whether we can fold the offset, as an |
589 | // immediate of an ADD. |
590 | auto &TLI = *MIB.getMF().getSubtarget().getTargetLowering(); |
591 | if (!TLI.isLegalAddImmediate(BaseOffset)) |
592 | TotalInstsExpected++; |
593 | int SavingsExpected = Stores.size() - TotalInstsExpected; |
594 | if (SavingsExpected <= 0) |
595 | return false; |
596 | |
597 | auto &MRI = MIB.getMF().getRegInfo(); |
598 | |
599 | // We have a series of consecutive stores. Factor out the common base |
600 | // pointer and rewrite the offsets. |
601 | Register NewBase = Stores[0].Ptr->getReg(Idx: 0); |
602 | for (auto &SInfo : Stores) { |
603 | // Compute a new pointer with the new base ptr and adjusted offset. |
604 | MIB.setInstrAndDebugLoc(*SInfo.St); |
605 | auto NewOff = MIB.buildConstant(Res: LLT::scalar(SizeInBits: 64), Val: SInfo.Offset - BaseOffset); |
606 | auto NewPtr = MIB.buildPtrAdd(Res: MRI.getType(Reg: SInfo.St->getPointerReg()), |
607 | Op0: NewBase, Op1: NewOff); |
608 | if (MIB.getObserver()) |
609 | MIB.getObserver()->changingInstr(MI&: *SInfo.St); |
610 | SInfo.St->getOperand(i: 1).setReg(NewPtr.getReg(Idx: 0)); |
611 | if (MIB.getObserver()) |
612 | MIB.getObserver()->changedInstr(MI&: *SInfo.St); |
613 | } |
614 | LLVM_DEBUG(dbgs() << "Split a series of " << Stores.size() |
615 | << " stores into a base pointer and offsets.\n" ); |
616 | return true; |
617 | } |
618 | |
619 | static cl::opt<bool> |
620 | EnableConsecutiveMemOpOpt("aarch64-postlegalizer-consecutive-memops" , |
621 | cl::init(Val: true), cl::Hidden, |
622 | cl::desc("Enable consecutive memop optimization " |
623 | "in AArch64PostLegalizerCombiner" )); |
624 | |
625 | bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing( |
626 | MachineFunction &MF, CSEMIRBuilder &MIB) { |
627 | // This combine needs to run after all reassociations/folds on pointer |
628 | // addressing have been done, specifically those that combine two G_PTR_ADDs |
629 | // with constant offsets into a single G_PTR_ADD with a combined offset. |
630 | // The goal of this optimization is to undo that combine in the case where |
631 | // doing so has prevented the formation of pair stores due to illegal |
632 | // addressing modes of STP. The reason that we do it here is because |
633 | // it's much easier to undo the transformation of a series consecutive |
634 | // mem ops, than it is to detect when doing it would be a bad idea looking |
635 | // at a single G_PTR_ADD in the reassociation/ptradd_immed_chain combine. |
636 | // |
637 | // An example: |
638 | // G_STORE %11:_(<2 x s64>), %base:_(p0) :: (store (<2 x s64>), align 1) |
639 | // %off1:_(s64) = G_CONSTANT i64 4128 |
640 | // %p1:_(p0) = G_PTR_ADD %0:_, %off1:_(s64) |
641 | // G_STORE %11:_(<2 x s64>), %p1:_(p0) :: (store (<2 x s64>), align 1) |
642 | // %off2:_(s64) = G_CONSTANT i64 4144 |
643 | // %p2:_(p0) = G_PTR_ADD %0:_, %off2:_(s64) |
644 | // G_STORE %11:_(<2 x s64>), %p2:_(p0) :: (store (<2 x s64>), align 1) |
645 | // %off3:_(s64) = G_CONSTANT i64 4160 |
646 | // %p3:_(p0) = G_PTR_ADD %0:_, %off3:_(s64) |
647 | // G_STORE %11:_(<2 x s64>), %17:_(p0) :: (store (<2 x s64>), align 1) |
648 | bool Changed = false; |
649 | auto &MRI = MF.getRegInfo(); |
650 | |
651 | if (!EnableConsecutiveMemOpOpt) |
652 | return Changed; |
653 | |
654 | SmallVector<StoreInfo, 8> Stores; |
655 | // If we see a load, then we keep track of any values defined by it. |
656 | // In the following example, STP formation will fail anyway because |
657 | // the latter store is using a load result that appears after the |
658 | // the prior store. In this situation if we factor out the offset then |
659 | // we increase code size for no benefit. |
660 | // G_STORE %v1:_(s64), %base:_(p0) :: (store (s64)) |
661 | // %v2:_(s64) = G_LOAD %ldptr:_(p0) :: (load (s64)) |
662 | // G_STORE %v2:_(s64), %base:_(p0) :: (store (s64)) |
663 | SmallVector<Register> LoadValsSinceLastStore; |
664 | |
665 | auto storeIsValid = [&](StoreInfo &Last, StoreInfo New) { |
666 | // Check if this store is consecutive to the last one. |
667 | if (Last.Ptr->getBaseReg() != New.Ptr->getBaseReg() || |
668 | (Last.Offset + static_cast<int64_t>(Last.StoredType.getSizeInBytes()) != |
669 | New.Offset) || |
670 | Last.StoredType != New.StoredType) |
671 | return false; |
672 | |
673 | // Check if this store is using a load result that appears after the |
674 | // last store. If so, bail out. |
675 | if (any_of(Range&: LoadValsSinceLastStore, P: [&](Register LoadVal) { |
676 | return New.St->getValueReg() == LoadVal; |
677 | })) |
678 | return false; |
679 | |
680 | // Check if the current offset would be too large for STP. |
681 | // If not, then STP formation should be able to handle it, so we don't |
682 | // need to do anything. |
683 | int64_t MaxLegalOffset; |
684 | switch (New.StoredType.getSizeInBits()) { |
685 | case 32: |
686 | MaxLegalOffset = 252; |
687 | break; |
688 | case 64: |
689 | MaxLegalOffset = 504; |
690 | break; |
691 | case 128: |
692 | MaxLegalOffset = 1008; |
693 | break; |
694 | default: |
695 | llvm_unreachable("Unexpected stored type size" ); |
696 | } |
697 | if (New.Offset < MaxLegalOffset) |
698 | return false; |
699 | |
700 | // If factoring it out still wouldn't help then don't bother. |
701 | return New.Offset - Stores[0].Offset <= MaxLegalOffset; |
702 | }; |
703 | |
704 | auto resetState = [&]() { |
705 | Stores.clear(); |
706 | LoadValsSinceLastStore.clear(); |
707 | }; |
708 | |
709 | for (auto &MBB : MF) { |
710 | // We're looking inside a single BB at a time since the memset pattern |
711 | // should only be in a single block. |
712 | resetState(); |
713 | for (auto &MI : MBB) { |
714 | // Skip for scalable vectors |
715 | if (auto *LdSt = dyn_cast<GLoadStore>(Val: &MI); |
716 | LdSt && MRI.getType(Reg: LdSt->getOperand(i: 0).getReg()).isScalableVector()) |
717 | continue; |
718 | |
719 | if (auto *St = dyn_cast<GStore>(Val: &MI)) { |
720 | Register PtrBaseReg; |
721 | APInt Offset; |
722 | LLT StoredValTy = MRI.getType(Reg: St->getValueReg()); |
723 | unsigned ValSize = StoredValTy.getSizeInBits(); |
724 | if (ValSize < 32 || St->getMMO().getSizeInBits() != ValSize) |
725 | continue; |
726 | |
727 | Register PtrReg = St->getPointerReg(); |
728 | if (mi_match( |
729 | R: PtrReg, MRI, |
730 | P: m_OneNonDBGUse(SP: m_GPtrAdd(L: m_Reg(R&: PtrBaseReg), R: m_ICst(Cst&: Offset))))) { |
731 | GPtrAdd *PtrAdd = cast<GPtrAdd>(Val: MRI.getVRegDef(Reg: PtrReg)); |
732 | StoreInfo New = {.St: St, .Ptr: PtrAdd, .Offset: Offset.getSExtValue(), .StoredType: StoredValTy}; |
733 | |
734 | if (Stores.empty()) { |
735 | Stores.push_back(Elt: New); |
736 | continue; |
737 | } |
738 | |
739 | // Check if this store is a valid continuation of the sequence. |
740 | auto &Last = Stores.back(); |
741 | if (storeIsValid(Last, New)) { |
742 | Stores.push_back(Elt: New); |
743 | LoadValsSinceLastStore.clear(); // Reset the load value tracking. |
744 | } else { |
745 | // The store isn't a valid to consider for the prior sequence, |
746 | // so try to optimize what we have so far and start a new sequence. |
747 | Changed |= tryOptimizeConsecStores(Stores, MIB); |
748 | resetState(); |
749 | Stores.push_back(Elt: New); |
750 | } |
751 | } |
752 | } else if (auto *Ld = dyn_cast<GLoad>(Val: &MI)) { |
753 | LoadValsSinceLastStore.push_back(Elt: Ld->getDstReg()); |
754 | } |
755 | } |
756 | Changed |= tryOptimizeConsecStores(Stores, MIB); |
757 | resetState(); |
758 | } |
759 | |
760 | return Changed; |
761 | } |
762 | |
763 | char AArch64PostLegalizerCombiner::ID = 0; |
764 | INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE, |
765 | "Combine AArch64 MachineInstrs after legalization" , false, |
766 | false) |
767 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) |
768 | INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis) |
769 | INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE, |
770 | "Combine AArch64 MachineInstrs after legalization" , false, |
771 | false) |
772 | |
773 | namespace llvm { |
774 | FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) { |
775 | return new AArch64PostLegalizerCombiner(IsOptNone); |
776 | } |
777 | } // end namespace llvm |
778 | |