1//=== lib/CodeGen/GlobalISel/AArch64PreLegalizerCombiner.cpp --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass does combining of machine instructions at the generic MI level,
10// before the legalizer.
11//
12//===----------------------------------------------------------------------===//
13
14#include "AArch64GlobalISelUtils.h"
15#include "AArch64TargetMachine.h"
16#include "llvm/CodeGen/GlobalISel/CSEInfo.h"
17#include "llvm/CodeGen/GlobalISel/Combiner.h"
18#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
19#include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
20#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
21#include "llvm/CodeGen/GlobalISel/GISelValueTracking.h"
22#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
23#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
24#include "llvm/CodeGen/GlobalISel/Utils.h"
25#include "llvm/CodeGen/MachineDominators.h"
26#include "llvm/CodeGen/MachineFunction.h"
27#include "llvm/CodeGen/MachineFunctionPass.h"
28#include "llvm/CodeGen/MachineRegisterInfo.h"
29#include "llvm/CodeGen/TargetPassConfig.h"
30#include "llvm/IR/Instructions.h"
31
32#define GET_GICOMBINER_DEPS
33#include "AArch64GenPreLegalizeGICombiner.inc"
34#undef GET_GICOMBINER_DEPS
35
36#define DEBUG_TYPE "aarch64-prelegalizer-combiner"
37
38using namespace llvm;
39using namespace MIPatternMatch;
40
41namespace {
42
43#define GET_GICOMBINER_TYPES
44#include "AArch64GenPreLegalizeGICombiner.inc"
45#undef GET_GICOMBINER_TYPES
46
47/// Return true if a G_FCONSTANT instruction is known to be better-represented
48/// as a G_CONSTANT.
49bool matchFConstantToConstant(MachineInstr &MI, MachineRegisterInfo &MRI) {
50 assert(MI.getOpcode() == TargetOpcode::G_FCONSTANT);
51 Register DstReg = MI.getOperand(i: 0).getReg();
52 const unsigned DstSize = MRI.getType(Reg: DstReg).getSizeInBits();
53 if (DstSize != 32 && DstSize != 64)
54 return false;
55
56 // When we're storing a value, it doesn't matter what register bank it's on.
57 // Since not all floating point constants can be materialized using a fmov,
58 // it makes more sense to just use a GPR.
59 return all_of(Range: MRI.use_nodbg_instructions(Reg: DstReg),
60 P: [](const MachineInstr &Use) { return Use.mayStore(); });
61}
62
63/// Change a G_FCONSTANT into a G_CONSTANT.
64void applyFConstantToConstant(MachineInstr &MI) {
65 assert(MI.getOpcode() == TargetOpcode::G_FCONSTANT);
66 MachineIRBuilder MIB(MI);
67 const APFloat &ImmValAPF = MI.getOperand(i: 1).getFPImm()->getValueAPF();
68 MIB.buildConstant(Res: MI.getOperand(i: 0).getReg(), Val: ImmValAPF.bitcastToAPInt());
69 MI.eraseFromParent();
70}
71
72/// Try to match a G_ICMP of a G_TRUNC with zero, in which the truncated bits
73/// are sign bits. In this case, we can transform the G_ICMP to directly compare
74/// the wide value with a zero.
75bool matchICmpRedundantTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
76 GISelValueTracking *VT, Register &MatchInfo) {
77 assert(MI.getOpcode() == TargetOpcode::G_ICMP && VT);
78
79 auto Pred = (CmpInst::Predicate)MI.getOperand(i: 1).getPredicate();
80 if (!ICmpInst::isEquality(P: Pred))
81 return false;
82
83 Register LHS = MI.getOperand(i: 2).getReg();
84 LLT LHSTy = MRI.getType(Reg: LHS);
85 if (!LHSTy.isScalar())
86 return false;
87
88 Register RHS = MI.getOperand(i: 3).getReg();
89 Register WideReg;
90
91 if (!mi_match(R: LHS, MRI, P: m_GTrunc(Src: m_Reg(R&: WideReg))) ||
92 !mi_match(R: RHS, MRI, P: m_SpecificICst(RequestedValue: 0)))
93 return false;
94
95 LLT WideTy = MRI.getType(Reg: WideReg);
96 if (VT->computeNumSignBits(R: WideReg) <=
97 WideTy.getSizeInBits() - LHSTy.getSizeInBits())
98 return false;
99
100 MatchInfo = WideReg;
101 return true;
102}
103
104void applyICmpRedundantTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
105 MachineIRBuilder &Builder,
106 GISelChangeObserver &Observer, Register &WideReg) {
107 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
108
109 LLT WideTy = MRI.getType(Reg: WideReg);
110 // We're going to directly use the wide register as the LHS, and then use an
111 // equivalent size zero for RHS.
112 Builder.setInstrAndDebugLoc(MI);
113 auto WideZero = Builder.buildConstant(Res: WideTy, Val: 0);
114 Observer.changingInstr(MI);
115 MI.getOperand(i: 2).setReg(WideReg);
116 MI.getOperand(i: 3).setReg(WideZero.getReg(Idx: 0));
117 Observer.changedInstr(MI);
118}
119
120/// \returns true if it is possible to fold a constant into a G_GLOBAL_VALUE.
121///
122/// e.g.
123///
124/// %g = G_GLOBAL_VALUE @x -> %g = G_GLOBAL_VALUE @x + cst
125bool matchFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
126 std::pair<uint64_t, uint64_t> &MatchInfo) {
127 assert(MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
128 MachineFunction &MF = *MI.getMF();
129 auto &GlobalOp = MI.getOperand(i: 1);
130 auto *GV = GlobalOp.getGlobal();
131 if (GV->isThreadLocal())
132 return false;
133
134 // Don't allow anything that could represent offsets etc.
135 if (MF.getSubtarget<AArch64Subtarget>().ClassifyGlobalReference(
136 GV, TM: MF.getTarget()) != AArch64II::MO_NO_FLAG)
137 return false;
138
139 // Look for a G_GLOBAL_VALUE only used by G_PTR_ADDs against constants:
140 //
141 // %g = G_GLOBAL_VALUE @x
142 // %ptr1 = G_PTR_ADD %g, cst1
143 // %ptr2 = G_PTR_ADD %g, cst2
144 // ...
145 // %ptrN = G_PTR_ADD %g, cstN
146 //
147 // Identify the *smallest* constant. We want to be able to form this:
148 //
149 // %offset_g = G_GLOBAL_VALUE @x + min_cst
150 // %g = G_PTR_ADD %offset_g, -min_cst
151 // %ptr1 = G_PTR_ADD %g, cst1
152 // ...
153 Register Dst = MI.getOperand(i: 0).getReg();
154 uint64_t MinOffset = -1ull;
155 for (auto &UseInstr : MRI.use_nodbg_instructions(Reg: Dst)) {
156 if (UseInstr.getOpcode() != TargetOpcode::G_PTR_ADD)
157 return false;
158 auto Cst = getIConstantVRegValWithLookThrough(
159 VReg: UseInstr.getOperand(i: 2).getReg(), MRI);
160 if (!Cst)
161 return false;
162 MinOffset = std::min(a: MinOffset, b: Cst->Value.getZExtValue());
163 }
164
165 // Require that the new offset is larger than the existing one to avoid
166 // infinite loops.
167 uint64_t CurrOffset = GlobalOp.getOffset();
168 uint64_t NewOffset = MinOffset + CurrOffset;
169 if (NewOffset <= CurrOffset)
170 return false;
171
172 // Check whether folding this offset is legal. It must not go out of bounds of
173 // the referenced object to avoid violating the code model, and must be
174 // smaller than 2^20 because this is the largest offset expressible in all
175 // object formats. (The IMAGE_REL_ARM64_PAGEBASE_REL21 relocation in COFF
176 // stores an immediate signed 21 bit offset.)
177 //
178 // This check also prevents us from folding negative offsets, which will end
179 // up being treated in the same way as large positive ones. They could also
180 // cause code model violations, and aren't really common enough to matter.
181 if (NewOffset >= (1 << 20))
182 return false;
183
184 Type *T = GV->getValueType();
185 if (!T->isSized() ||
186 NewOffset > GV->getDataLayout().getTypeAllocSize(Ty: T))
187 return false;
188 MatchInfo = std::make_pair(x&: NewOffset, y&: MinOffset);
189 return true;
190}
191
192void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
193 MachineIRBuilder &B, GISelChangeObserver &Observer,
194 std::pair<uint64_t, uint64_t> &MatchInfo) {
195 // Change:
196 //
197 // %g = G_GLOBAL_VALUE @x
198 // %ptr1 = G_PTR_ADD %g, cst1
199 // %ptr2 = G_PTR_ADD %g, cst2
200 // ...
201 // %ptrN = G_PTR_ADD %g, cstN
202 //
203 // To:
204 //
205 // %offset_g = G_GLOBAL_VALUE @x + min_cst
206 // %g = G_PTR_ADD %offset_g, -min_cst
207 // %ptr1 = G_PTR_ADD %g, cst1
208 // ...
209 // %ptrN = G_PTR_ADD %g, cstN
210 //
211 // Then, the original G_PTR_ADDs should be folded later on so that they look
212 // like this:
213 //
214 // %ptrN = G_PTR_ADD %offset_g, cstN - min_cst
215 uint64_t Offset, MinOffset;
216 std::tie(args&: Offset, args&: MinOffset) = MatchInfo;
217 B.setInstrAndDebugLoc(*std::next(x: MI.getIterator()));
218 Observer.changingInstr(MI);
219 auto &GlobalOp = MI.getOperand(i: 1);
220 auto *GV = GlobalOp.getGlobal();
221 GlobalOp.ChangeToGA(GV, Offset, TargetFlags: GlobalOp.getTargetFlags());
222 Register Dst = MI.getOperand(i: 0).getReg();
223 Register NewGVDst = MRI.cloneVirtualRegister(VReg: Dst);
224 MI.getOperand(i: 0).setReg(NewGVDst);
225 Observer.changedInstr(MI);
226 B.buildPtrAdd(
227 Res: Dst, Op0: NewGVDst,
228 Op1: B.buildConstant(Res: LLT::scalar(SizeInBits: 64), Val: -static_cast<int64_t>(MinOffset)));
229}
230
231// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot(x, y))
232// Or vecreduce_add(ext(x)) -> vecreduce_add(udot(x, 1))
233// Similar to performVecReduceAddCombine in SelectionDAG
234bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
235 const AArch64Subtarget &STI,
236 std::tuple<Register, Register, bool> &MatchInfo) {
237 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
238 "Expected a G_VECREDUCE_ADD instruction");
239 assert(STI.hasDotProd() && "Target should have Dot Product feature");
240
241 MachineInstr *I1 = getDefIgnoringCopies(Reg: MI.getOperand(i: 1).getReg(), MRI);
242 Register DstReg = MI.getOperand(i: 0).getReg();
243 Register MidReg = I1->getOperand(i: 0).getReg();
244 LLT DstTy = MRI.getType(Reg: DstReg);
245 LLT MidTy = MRI.getType(Reg: MidReg);
246 if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
247 return false;
248
249 LLT SrcTy;
250 auto I1Opc = I1->getOpcode();
251 if (I1Opc == TargetOpcode::G_MUL) {
252 // If result of this has more than 1 use, then there is no point in creating
253 // udot instruction
254 if (!MRI.hasOneNonDBGUse(RegNo: MidReg))
255 return false;
256
257 MachineInstr *ExtMI1 =
258 getDefIgnoringCopies(Reg: I1->getOperand(i: 1).getReg(), MRI);
259 MachineInstr *ExtMI2 =
260 getDefIgnoringCopies(Reg: I1->getOperand(i: 2).getReg(), MRI);
261 LLT Ext1DstTy = MRI.getType(Reg: ExtMI1->getOperand(i: 0).getReg());
262 LLT Ext2DstTy = MRI.getType(Reg: ExtMI2->getOperand(i: 0).getReg());
263
264 if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
265 return false;
266 I1Opc = ExtMI1->getOpcode();
267 SrcTy = MRI.getType(Reg: ExtMI1->getOperand(i: 1).getReg());
268 std::get<0>(t&: MatchInfo) = ExtMI1->getOperand(i: 1).getReg();
269 std::get<1>(t&: MatchInfo) = ExtMI2->getOperand(i: 1).getReg();
270 } else if (I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) {
271 SrcTy = MRI.getType(Reg: I1->getOperand(i: 1).getReg());
272 std::get<0>(t&: MatchInfo) = I1->getOperand(i: 1).getReg();
273 std::get<1>(t&: MatchInfo) = 0;
274 } else {
275 return false;
276 }
277
278 if (I1Opc == TargetOpcode::G_ZEXT)
279 std::get<2>(t&: MatchInfo) = 0;
280 else if (I1Opc == TargetOpcode::G_SEXT)
281 std::get<2>(t&: MatchInfo) = 1;
282 else
283 return false;
284
285 if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0)
286 return false;
287
288 return true;
289}
290
291void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
292 MachineIRBuilder &Builder,
293 GISelChangeObserver &Observer,
294 const AArch64Subtarget &STI,
295 std::tuple<Register, Register, bool> &MatchInfo) {
296 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
297 "Expected a G_VECREDUCE_ADD instruction");
298 assert(STI.hasDotProd() && "Target should have Dot Product feature");
299
300 // Initialise the variables
301 unsigned DotOpcode =
302 std::get<2>(t&: MatchInfo) ? AArch64::G_SDOT : AArch64::G_UDOT;
303 Register Ext1SrcReg = std::get<0>(t&: MatchInfo);
304
305 // If there is one source register, create a vector of 0s as the second
306 // source register
307 Register Ext2SrcReg;
308 if (std::get<1>(t&: MatchInfo) == 0)
309 Ext2SrcReg = Builder.buildConstant(Res: MRI.getType(Reg: Ext1SrcReg), Val: 1)
310 ->getOperand(i: 0)
311 .getReg();
312 else
313 Ext2SrcReg = std::get<1>(t&: MatchInfo);
314
315 // Find out how many DOT instructions are needed
316 LLT SrcTy = MRI.getType(Reg: Ext1SrcReg);
317 LLT MidTy;
318 unsigned NumOfDotMI;
319 if (SrcTy.getNumElements() % 16 == 0) {
320 NumOfDotMI = SrcTy.getNumElements() / 16;
321 MidTy = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32);
322 } else if (SrcTy.getNumElements() % 8 == 0) {
323 NumOfDotMI = SrcTy.getNumElements() / 8;
324 MidTy = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 32);
325 } else {
326 llvm_unreachable("Source type number of elements is not multiple of 8");
327 }
328
329 // Handle case where one DOT instruction is needed
330 if (NumOfDotMI == 1) {
331 auto Zeroes = Builder.buildConstant(Res: MidTy, Val: 0)->getOperand(i: 0).getReg();
332 auto Dot = Builder.buildInstr(Opc: DotOpcode, DstOps: {MidTy},
333 SrcOps: {Zeroes, Ext1SrcReg, Ext2SrcReg});
334 Builder.buildVecReduceAdd(Dst: MI.getOperand(i: 0), Src: Dot->getOperand(i: 0));
335 } else {
336 // If not pad the last v8 element with 0s to a v16
337 SmallVector<Register, 4> Ext1UnmergeReg;
338 SmallVector<Register, 4> Ext2UnmergeReg;
339 if (SrcTy.getNumElements() % 16 != 0) {
340 SmallVector<Register> Leftover1;
341 SmallVector<Register> Leftover2;
342
343 // Split the elements into v16i8 and v8i8
344 LLT MainTy = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8);
345 LLT LeftoverTy1, LeftoverTy2;
346 if ((!extractParts(Reg: Ext1SrcReg, RegTy: MRI.getType(Reg: Ext1SrcReg), MainTy,
347 LeftoverTy&: LeftoverTy1, VRegs&: Ext1UnmergeReg, LeftoverVRegs&: Leftover1, MIRBuilder&: Builder,
348 MRI)) ||
349 (!extractParts(Reg: Ext2SrcReg, RegTy: MRI.getType(Reg: Ext2SrcReg), MainTy,
350 LeftoverTy&: LeftoverTy2, VRegs&: Ext2UnmergeReg, LeftoverVRegs&: Leftover2, MIRBuilder&: Builder,
351 MRI))) {
352 llvm_unreachable("Unable to split this vector properly");
353 }
354
355 // Pad the leftover v8i8 vector with register of 0s of type v8i8
356 Register v8Zeroes = Builder.buildConstant(Res: LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 8), Val: 0)
357 ->getOperand(i: 0)
358 .getReg();
359
360 Ext1UnmergeReg.push_back(
361 Elt: Builder
362 .buildMergeLikeInstr(Res: LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8),
363 Ops: {Leftover1[0], v8Zeroes})
364 .getReg(Idx: 0));
365 Ext2UnmergeReg.push_back(
366 Elt: Builder
367 .buildMergeLikeInstr(Res: LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8),
368 Ops: {Leftover2[0], v8Zeroes})
369 .getReg(Idx: 0));
370
371 } else {
372 // Unmerge the source vectors to v16i8
373 unsigned SrcNumElts = SrcTy.getNumElements();
374 extractParts(Reg: Ext1SrcReg, Ty: LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8), NumParts: SrcNumElts / 16,
375 VRegs&: Ext1UnmergeReg, MIRBuilder&: Builder, MRI);
376 extractParts(Reg: Ext2SrcReg, Ty: LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8), NumParts: SrcNumElts / 16,
377 VRegs&: Ext2UnmergeReg, MIRBuilder&: Builder, MRI);
378 }
379
380 // Build the UDOT instructions
381 SmallVector<Register, 2> DotReg;
382 unsigned NumElements = 0;
383 for (unsigned i = 0; i < Ext1UnmergeReg.size(); i++) {
384 LLT ZeroesLLT;
385 // Check if it is 16 or 8 elements. Set Zeroes to the according size
386 if (MRI.getType(Reg: Ext1UnmergeReg[i]).getNumElements() == 16) {
387 ZeroesLLT = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32);
388 NumElements += 4;
389 } else {
390 ZeroesLLT = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 32);
391 NumElements += 2;
392 }
393 auto Zeroes = Builder.buildConstant(Res: ZeroesLLT, Val: 0)->getOperand(i: 0).getReg();
394 DotReg.push_back(
395 Elt: Builder
396 .buildInstr(Opc: DotOpcode, DstOps: {MRI.getType(Reg: Zeroes)},
397 SrcOps: {Zeroes, Ext1UnmergeReg[i], Ext2UnmergeReg[i]})
398 .getReg(Idx: 0));
399 }
400
401 // Merge the output
402 auto ConcatMI =
403 Builder.buildConcatVectors(Res: LLT::fixed_vector(NumElements, ScalarSizeInBits: 32), Ops: DotReg);
404
405 // Put it through a vector reduction
406 Builder.buildVecReduceAdd(Dst: MI.getOperand(i: 0).getReg(),
407 Src: ConcatMI->getOperand(i: 0).getReg());
408 }
409
410 // Erase the dead instructions
411 MI.eraseFromParent();
412}
413
414// Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
415// Ensure that the type coming from the extend instruction is the right size
416bool matchExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
417 std::pair<Register, bool> &MatchInfo) {
418 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
419 "Expected G_VECREDUCE_ADD Opcode");
420
421 // Check if the last instruction is an extend
422 MachineInstr *ExtMI = getDefIgnoringCopies(Reg: MI.getOperand(i: 1).getReg(), MRI);
423 auto ExtOpc = ExtMI->getOpcode();
424
425 if (ExtOpc == TargetOpcode::G_ZEXT)
426 std::get<1>(in&: MatchInfo) = 0;
427 else if (ExtOpc == TargetOpcode::G_SEXT)
428 std::get<1>(in&: MatchInfo) = 1;
429 else
430 return false;
431
432 // Check if the source register is a valid type
433 Register ExtSrcReg = ExtMI->getOperand(i: 1).getReg();
434 LLT ExtSrcTy = MRI.getType(Reg: ExtSrcReg);
435 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
436 if ((DstTy.getScalarSizeInBits() == 16 &&
437 ExtSrcTy.getNumElements() % 8 == 0 && ExtSrcTy.getNumElements() < 256) ||
438 (DstTy.getScalarSizeInBits() == 32 &&
439 ExtSrcTy.getNumElements() % 4 == 0) ||
440 (DstTy.getScalarSizeInBits() == 64 &&
441 ExtSrcTy.getNumElements() % 4 == 0)) {
442 std::get<0>(in&: MatchInfo) = ExtSrcReg;
443 return true;
444 }
445 return false;
446}
447
448void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
449 MachineIRBuilder &B, GISelChangeObserver &Observer,
450 std::pair<Register, bool> &MatchInfo) {
451 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
452 "Expected G_VECREDUCE_ADD Opcode");
453
454 unsigned Opc = std::get<1>(in&: MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
455 Register SrcReg = std::get<0>(in&: MatchInfo);
456 Register DstReg = MI.getOperand(i: 0).getReg();
457 LLT SrcTy = MRI.getType(Reg: SrcReg);
458 LLT DstTy = MRI.getType(Reg: DstReg);
459
460 // If SrcTy has more elements than expected, split them into multiple
461 // insructions and sum the results
462 LLT MainTy;
463 SmallVector<Register, 1> WorkingRegisters;
464 unsigned SrcScalSize = SrcTy.getScalarSizeInBits();
465 unsigned SrcNumElem = SrcTy.getNumElements();
466 if ((SrcScalSize == 8 && SrcNumElem > 16) ||
467 (SrcScalSize == 16 && SrcNumElem > 8) ||
468 (SrcScalSize == 32 && SrcNumElem > 4)) {
469
470 LLT LeftoverTy;
471 SmallVector<Register, 4> LeftoverRegs;
472 if (SrcScalSize == 8)
473 MainTy = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8);
474 else if (SrcScalSize == 16)
475 MainTy = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 16);
476 else if (SrcScalSize == 32)
477 MainTy = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32);
478 else
479 llvm_unreachable("Source's Scalar Size not supported");
480
481 // Extract the parts and put each extracted sources through U/SADDLV and put
482 // the values inside a small vec
483 extractParts(Reg: SrcReg, RegTy: SrcTy, MainTy, LeftoverTy, VRegs&: WorkingRegisters,
484 LeftoverVRegs&: LeftoverRegs, MIRBuilder&: B, MRI);
485 llvm::append_range(C&: WorkingRegisters, R&: LeftoverRegs);
486 } else {
487 WorkingRegisters.push_back(Elt: SrcReg);
488 MainTy = SrcTy;
489 }
490
491 unsigned MidScalarSize = MainTy.getScalarSizeInBits() * 2;
492 LLT MidScalarLLT = LLT::scalar(SizeInBits: MidScalarSize);
493 Register zeroReg = B.buildConstant(Res: LLT::scalar(SizeInBits: 64), Val: 0).getReg(Idx: 0);
494 for (unsigned I = 0; I < WorkingRegisters.size(); I++) {
495 // If the number of elements is too small to build an instruction, extend
496 // its size before applying addlv
497 LLT WorkingRegTy = MRI.getType(Reg: WorkingRegisters[I]);
498 if ((WorkingRegTy.getScalarSizeInBits() == 8) &&
499 (WorkingRegTy.getNumElements() == 4)) {
500 WorkingRegisters[I] =
501 B.buildInstr(Opc: std::get<1>(in&: MatchInfo) ? TargetOpcode::G_SEXT
502 : TargetOpcode::G_ZEXT,
503 DstOps: {LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 16)}, SrcOps: {WorkingRegisters[I]})
504 .getReg(Idx: 0);
505 }
506
507 // Generate the {U/S}ADDLV instruction, whose output is always double of the
508 // Src's Scalar size
509 LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32)
510 : LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 64);
511 Register addlvReg =
512 B.buildInstr(Opc, DstOps: {addlvTy}, SrcOps: {WorkingRegisters[I]}).getReg(Idx: 0);
513
514 // The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or
515 // v2i64 register.
516 // i16, i32 results uses v4i32 registers
517 // i64 results uses v2i64 registers
518 // Therefore we have to extract/truncate the the value to the right type
519 if (MidScalarSize == 32 || MidScalarSize == 64) {
520 WorkingRegisters[I] = B.buildInstr(Opc: AArch64::G_EXTRACT_VECTOR_ELT,
521 DstOps: {MidScalarLLT}, SrcOps: {addlvReg, zeroReg})
522 .getReg(Idx: 0);
523 } else {
524 Register extractReg = B.buildInstr(Opc: AArch64::G_EXTRACT_VECTOR_ELT,
525 DstOps: {LLT::scalar(SizeInBits: 32)}, SrcOps: {addlvReg, zeroReg})
526 .getReg(Idx: 0);
527 WorkingRegisters[I] =
528 B.buildTrunc(Res: {MidScalarLLT}, Op: {extractReg}).getReg(Idx: 0);
529 }
530 }
531
532 Register outReg;
533 if (WorkingRegisters.size() > 1) {
534 outReg = B.buildAdd(Dst: MidScalarLLT, Src0: WorkingRegisters[0], Src1: WorkingRegisters[1])
535 .getReg(Idx: 0);
536 for (unsigned I = 2; I < WorkingRegisters.size(); I++) {
537 outReg = B.buildAdd(Dst: MidScalarLLT, Src0: outReg, Src1: WorkingRegisters[I]).getReg(Idx: 0);
538 }
539 } else {
540 outReg = WorkingRegisters[0];
541 }
542
543 if (DstTy.getScalarSizeInBits() > MidScalarSize) {
544 // Handle the scalar value if the DstTy's Scalar Size is more than double
545 // Src's ScalarType
546 B.buildInstr(Opc: std::get<1>(in&: MatchInfo) ? TargetOpcode::G_SEXT
547 : TargetOpcode::G_ZEXT,
548 DstOps: {DstReg}, SrcOps: {outReg});
549 } else {
550 B.buildCopy(Res: DstReg, Op: outReg);
551 }
552
553 MI.eraseFromParent();
554}
555
556// Pushes ADD/SUB through extend instructions to decrease the number of extend
557// instruction at the end by allowing selection of {s|u}addl sooner
558
559// i32 add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8))
560bool matchPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
561 Register DstReg, Register SrcReg1, Register SrcReg2) {
562 assert((MI.getOpcode() == TargetOpcode::G_ADD ||
563 MI.getOpcode() == TargetOpcode::G_SUB) &&
564 "Expected a G_ADD or G_SUB instruction\n");
565
566 // Deal with vector types only
567 LLT DstTy = MRI.getType(Reg: DstReg);
568 if (!DstTy.isVector())
569 return false;
570
571 // Return true if G_{S|Z}EXT instruction is more than 2* source
572 Register ExtDstReg = MI.getOperand(i: 1).getReg();
573 LLT Ext1SrcTy = MRI.getType(Reg: SrcReg1);
574 LLT Ext2SrcTy = MRI.getType(Reg: SrcReg2);
575 unsigned ExtDstScal = MRI.getType(Reg: ExtDstReg).getScalarSizeInBits();
576 unsigned Ext1SrcScal = Ext1SrcTy.getScalarSizeInBits();
577 if (((Ext1SrcScal == 8 && ExtDstScal == 32) ||
578 ((Ext1SrcScal == 8 || Ext1SrcScal == 16) && ExtDstScal == 64)) &&
579 Ext1SrcTy == Ext2SrcTy)
580 return true;
581
582 return false;
583}
584
585void applyPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
586 MachineIRBuilder &B, bool isSExt, Register DstReg,
587 Register SrcReg1, Register SrcReg2) {
588 LLT SrcTy = MRI.getType(Reg: SrcReg1);
589 LLT MidTy = SrcTy.changeElementSize(NewEltSize: SrcTy.getScalarSizeInBits() * 2);
590 unsigned Opc = isSExt ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
591 Register Ext1Reg = B.buildInstr(Opc, DstOps: {MidTy}, SrcOps: {SrcReg1}).getReg(Idx: 0);
592 Register Ext2Reg = B.buildInstr(Opc, DstOps: {MidTy}, SrcOps: {SrcReg2}).getReg(Idx: 0);
593 Register AddReg =
594 B.buildInstr(Opc: MI.getOpcode(), DstOps: {MidTy}, SrcOps: {Ext1Reg, Ext2Reg}).getReg(Idx: 0);
595
596 // G_SUB has to sign-extend the result.
597 // G_ADD needs to sext from sext and can sext or zext from zext, so the
598 // original opcode is used.
599 if (MI.getOpcode() == TargetOpcode::G_ADD)
600 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {AddReg});
601 else
602 B.buildSExt(Res: DstReg, Op: AddReg);
603
604 MI.eraseFromParent();
605}
606
607bool tryToSimplifyUADDO(MachineInstr &MI, MachineIRBuilder &B,
608 const CombinerHelper &Helper,
609 GISelChangeObserver &Observer) {
610 // Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
611 // result is only used in the no-overflow case. It is restricted to cases
612 // where we know that the high-bits of the operands are 0. If there's an
613 // overflow, then the 9th or 17th bit must be set, which can be checked
614 // using TBNZ.
615 //
616 // Change (for UADDOs on 8 and 16 bits):
617 //
618 // %z0 = G_ASSERT_ZEXT _
619 // %op0 = G_TRUNC %z0
620 // %z1 = G_ASSERT_ZEXT _
621 // %op1 = G_TRUNC %z1
622 // %val, %cond = G_UADDO %op0, %op1
623 // G_BRCOND %cond, %error.bb
624 //
625 // error.bb:
626 // (no successors and no uses of %val)
627 //
628 // To:
629 //
630 // %z0 = G_ASSERT_ZEXT _
631 // %z1 = G_ASSERT_ZEXT _
632 // %add = G_ADD %z0, %z1
633 // %val = G_TRUNC %add
634 // %bit = G_AND %add, 1 << scalar-size-in-bits(%op1)
635 // %cond = G_ICMP NE, %bit, 0
636 // G_BRCOND %cond, %error.bb
637
638 auto &MRI = *B.getMRI();
639
640 MachineOperand *DefOp0 = MRI.getOneDef(Reg: MI.getOperand(i: 2).getReg());
641 MachineOperand *DefOp1 = MRI.getOneDef(Reg: MI.getOperand(i: 3).getReg());
642 Register Op0Wide;
643 Register Op1Wide;
644 if (!mi_match(R: DefOp0->getParent(), MRI, P: m_GTrunc(Src: m_Reg(R&: Op0Wide))) ||
645 !mi_match(R: DefOp1->getParent(), MRI, P: m_GTrunc(Src: m_Reg(R&: Op1Wide))))
646 return false;
647 LLT WideTy0 = MRI.getType(Reg: Op0Wide);
648 LLT WideTy1 = MRI.getType(Reg: Op1Wide);
649 Register ResVal = MI.getOperand(i: 0).getReg();
650 LLT OpTy = MRI.getType(Reg: ResVal);
651 MachineInstr *Op0WideDef = MRI.getVRegDef(Reg: Op0Wide);
652 MachineInstr *Op1WideDef = MRI.getVRegDef(Reg: Op1Wide);
653
654 unsigned OpTySize = OpTy.getScalarSizeInBits();
655 // First check that the G_TRUNC feeding the G_UADDO are no-ops, because the
656 // inputs have been zero-extended.
657 if (Op0WideDef->getOpcode() != TargetOpcode::G_ASSERT_ZEXT ||
658 Op1WideDef->getOpcode() != TargetOpcode::G_ASSERT_ZEXT ||
659 OpTySize != Op0WideDef->getOperand(i: 2).getImm() ||
660 OpTySize != Op1WideDef->getOperand(i: 2).getImm())
661 return false;
662
663 // Only scalar UADDO with either 8 or 16 bit operands are handled.
664 if (!WideTy0.isScalar() || !WideTy1.isScalar() || WideTy0 != WideTy1 ||
665 OpTySize >= WideTy0.getScalarSizeInBits() ||
666 (OpTySize != 8 && OpTySize != 16))
667 return false;
668
669 // The overflow-status result must be used by a branch only.
670 Register ResStatus = MI.getOperand(i: 1).getReg();
671 if (!MRI.hasOneNonDBGUse(RegNo: ResStatus))
672 return false;
673 MachineInstr *CondUser = &*MRI.use_instr_nodbg_begin(RegNo: ResStatus);
674 if (CondUser->getOpcode() != TargetOpcode::G_BRCOND)
675 return false;
676
677 // Make sure the computed result is only used in the no-overflow blocks.
678 MachineBasicBlock *CurrentMBB = MI.getParent();
679 MachineBasicBlock *FailMBB = CondUser->getOperand(i: 1).getMBB();
680 if (!FailMBB->succ_empty() || CondUser->getParent() != CurrentMBB)
681 return false;
682 if (any_of(Range: MRI.use_nodbg_instructions(Reg: ResVal),
683 P: [&MI, FailMBB, CurrentMBB](MachineInstr &I) {
684 return &MI != &I &&
685 (I.getParent() == FailMBB || I.getParent() == CurrentMBB);
686 }))
687 return false;
688
689 // Remove G_ADDO.
690 B.setInstrAndDebugLoc(*MI.getNextNode());
691 MI.eraseFromParent();
692
693 // Emit wide add.
694 Register AddDst = MRI.cloneVirtualRegister(VReg: Op0Wide);
695 B.buildInstr(Opc: TargetOpcode::G_ADD, DstOps: {AddDst}, SrcOps: {Op0Wide, Op1Wide});
696
697 // Emit check of the 9th or 17th bit and update users (the branch). This will
698 // later be folded to TBNZ.
699 Register CondBit = MRI.cloneVirtualRegister(VReg: Op0Wide);
700 B.buildAnd(
701 Dst: CondBit, Src0: AddDst,
702 Src1: B.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: OpTySize == 8 ? 1 << 8 : 1 << 16));
703 B.buildICmp(Pred: CmpInst::ICMP_NE, Res: ResStatus, Op0: CondBit,
704 Op1: B.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0));
705
706 // Update ZEXts users of the result value. Because all uses are in the
707 // no-overflow case, we know that the top bits are 0 and we can ignore ZExts.
708 B.buildZExtOrTrunc(Res: ResVal, Op: AddDst);
709 for (MachineOperand &U : make_early_inc_range(Range: MRI.use_operands(Reg: ResVal))) {
710 Register WideReg;
711 if (mi_match(R: U.getParent(), MRI, P: m_GZExt(Src: m_Reg(R&: WideReg)))) {
712 auto OldR = U.getParent()->getOperand(i: 0).getReg();
713 Observer.erasingInstr(MI&: *U.getParent());
714 U.getParent()->eraseFromParent();
715 Helper.replaceRegWith(MRI, FromReg: OldR, ToReg: AddDst);
716 }
717 }
718
719 return true;
720}
721
722class AArch64PreLegalizerCombinerImpl : public Combiner {
723protected:
724 const CombinerHelper Helper;
725 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig;
726 const AArch64Subtarget &STI;
727
728public:
729 AArch64PreLegalizerCombinerImpl(
730 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
731 GISelValueTracking &VT, GISelCSEInfo *CSEInfo,
732 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig,
733 const AArch64Subtarget &STI, MachineDominatorTree *MDT,
734 const LegalizerInfo *LI);
735
736 static const char *getName() { return "AArch6400PreLegalizerCombiner"; }
737
738 bool tryCombineAll(MachineInstr &I) const override;
739
740 bool tryCombineAllImpl(MachineInstr &I) const;
741
742private:
743#define GET_GICOMBINER_CLASS_MEMBERS
744#include "AArch64GenPreLegalizeGICombiner.inc"
745#undef GET_GICOMBINER_CLASS_MEMBERS
746};
747
748#define GET_GICOMBINER_IMPL
749#include "AArch64GenPreLegalizeGICombiner.inc"
750#undef GET_GICOMBINER_IMPL
751
752AArch64PreLegalizerCombinerImpl::AArch64PreLegalizerCombinerImpl(
753 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
754 GISelValueTracking &VT, GISelCSEInfo *CSEInfo,
755 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig,
756 const AArch64Subtarget &STI, MachineDominatorTree *MDT,
757 const LegalizerInfo *LI)
758 : Combiner(MF, CInfo, TPC, &VT, CSEInfo),
759 Helper(Observer, B, /*IsPreLegalize*/ true, &VT, MDT, LI),
760 RuleConfig(RuleConfig), STI(STI),
761#define GET_GICOMBINER_CONSTRUCTOR_INITS
762#include "AArch64GenPreLegalizeGICombiner.inc"
763#undef GET_GICOMBINER_CONSTRUCTOR_INITS
764{
765}
766
767bool AArch64PreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const {
768 if (tryCombineAllImpl(I&: MI))
769 return true;
770
771 unsigned Opc = MI.getOpcode();
772 switch (Opc) {
773 case TargetOpcode::G_SHUFFLE_VECTOR:
774 return Helper.tryCombineShuffleVector(MI);
775 case TargetOpcode::G_UADDO:
776 return tryToSimplifyUADDO(MI, B, Helper, Observer);
777 case TargetOpcode::G_MEMCPY_INLINE:
778 return Helper.tryEmitMemcpyInline(MI);
779 case TargetOpcode::G_MEMCPY:
780 case TargetOpcode::G_MEMMOVE:
781 case TargetOpcode::G_MEMSET: {
782 // If we're at -O0 set a maxlen of 32 to inline, otherwise let the other
783 // heuristics decide.
784 unsigned MaxLen = CInfo.EnableOpt ? 0 : 32;
785 // Try to inline memcpy type calls if optimizations are enabled.
786 if (Helper.tryCombineMemCpyFamily(MI, MaxLen))
787 return true;
788 if (Opc == TargetOpcode::G_MEMSET)
789 return llvm::AArch64GISelUtils::tryEmitBZero(MI, MIRBuilder&: B, MinSize: CInfo.EnableMinSize);
790 return false;
791 }
792 }
793
794 return false;
795}
796
797// Pass boilerplate
798// ================
799
800class AArch64PreLegalizerCombiner : public MachineFunctionPass {
801public:
802 static char ID;
803
804 AArch64PreLegalizerCombiner();
805
806 StringRef getPassName() const override {
807 return "AArch64PreLegalizerCombiner";
808 }
809
810 bool runOnMachineFunction(MachineFunction &MF) override;
811
812 void getAnalysisUsage(AnalysisUsage &AU) const override;
813
814private:
815 AArch64PreLegalizerCombinerImplRuleConfig RuleConfig;
816};
817} // end anonymous namespace
818
819void AArch64PreLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
820 AU.addRequired<TargetPassConfig>();
821 AU.setPreservesCFG();
822 getSelectionDAGFallbackAnalysisUsage(AU);
823 AU.addRequired<GISelValueTrackingAnalysisLegacy>();
824 AU.addPreserved<GISelValueTrackingAnalysisLegacy>();
825 AU.addRequired<MachineDominatorTreeWrapperPass>();
826 AU.addPreserved<MachineDominatorTreeWrapperPass>();
827 AU.addRequired<GISelCSEAnalysisWrapperPass>();
828 AU.addPreserved<GISelCSEAnalysisWrapperPass>();
829 MachineFunctionPass::getAnalysisUsage(AU);
830}
831
832AArch64PreLegalizerCombiner::AArch64PreLegalizerCombiner()
833 : MachineFunctionPass(ID) {
834 if (!RuleConfig.parseCommandLineOption())
835 report_fatal_error(reason: "Invalid rule identifier");
836}
837
838bool AArch64PreLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
839 if (MF.getProperties().hasFailedISel())
840 return false;
841 auto &TPC = getAnalysis<TargetPassConfig>();
842
843 // Enable CSE.
844 GISelCSEAnalysisWrapper &Wrapper =
845 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
846 auto *CSEInfo = &Wrapper.get(CSEOpt: TPC.getCSEConfig());
847
848 const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>();
849 const auto *LI = ST.getLegalizerInfo();
850
851 const Function &F = MF.getFunction();
852 bool EnableOpt =
853 MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F);
854 GISelValueTracking *VT =
855 &getAnalysis<GISelValueTrackingAnalysisLegacy>().get(MF);
856 MachineDominatorTree *MDT =
857 &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
858 CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
859 /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(),
860 F.hasMinSize());
861 // Disable fixed-point iteration to reduce compile-time
862 CInfo.MaxIterations = 1;
863 CInfo.ObserverLvl = CombinerInfo::ObserverLevel::SinglePass;
864 // This is the first Combiner, so the input IR might contain dead
865 // instructions.
866 CInfo.EnableFullDCE = true;
867 AArch64PreLegalizerCombinerImpl Impl(MF, CInfo, &TPC, *VT, CSEInfo,
868 RuleConfig, ST, MDT, LI);
869 return Impl.combineMachineInstrs();
870}
871
872char AArch64PreLegalizerCombiner::ID = 0;
873INITIALIZE_PASS_BEGIN(AArch64PreLegalizerCombiner, DEBUG_TYPE,
874 "Combine AArch64 machine instrs before legalization",
875 false, false)
876INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
877INITIALIZE_PASS_DEPENDENCY(GISelValueTrackingAnalysisLegacy)
878INITIALIZE_PASS_DEPENDENCY(GISelCSEAnalysisWrapperPass)
879INITIALIZE_PASS_END(AArch64PreLegalizerCombiner, DEBUG_TYPE,
880 "Combine AArch64 machine instrs before legalization", false,
881 false)
882
883namespace llvm {
884FunctionPass *createAArch64PreLegalizerCombiner() {
885 return new AArch64PreLegalizerCombiner();
886}
887} // end namespace llvm
888