1//=== lib/CodeGen/GlobalISel/AArch64PreLegalizerCombiner.cpp --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass does combining of machine instructions at the generic MI level,
10// before the legalizer.
11//
12//===----------------------------------------------------------------------===//
13
14#include "AArch64GlobalISelUtils.h"
15#include "AArch64TargetMachine.h"
16#include "llvm/CodeGen/GlobalISel/CSEInfo.h"
17#include "llvm/CodeGen/GlobalISel/Combiner.h"
18#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
19#include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
20#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
21#include "llvm/CodeGen/GlobalISel/GISelValueTracking.h"
22#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
23#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
24#include "llvm/CodeGen/GlobalISel/Utils.h"
25#include "llvm/CodeGen/MachineDominators.h"
26#include "llvm/CodeGen/MachineFunction.h"
27#include "llvm/CodeGen/MachineFunctionPass.h"
28#include "llvm/CodeGen/MachineRegisterInfo.h"
29#include "llvm/IR/Instructions.h"
30
31#define GET_GICOMBINER_DEPS
32#include "AArch64GenPreLegalizeGICombiner.inc"
33#undef GET_GICOMBINER_DEPS
34
35#define DEBUG_TYPE "aarch64-prelegalizer-combiner"
36
37using namespace llvm;
38using namespace MIPatternMatch;
39
40namespace {
41
42#define GET_GICOMBINER_TYPES
43#include "AArch64GenPreLegalizeGICombiner.inc"
44#undef GET_GICOMBINER_TYPES
45
46/// Try to match a G_ICMP of a G_TRUNC with zero, in which the truncated bits
47/// are sign bits. In this case, we can transform the G_ICMP to directly compare
48/// the wide value with a zero.
49bool matchICmpRedundantTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
50 GISelValueTracking *VT, Register &MatchInfo) {
51 assert(MI.getOpcode() == TargetOpcode::G_ICMP && VT);
52
53 auto Pred = (CmpInst::Predicate)MI.getOperand(i: 1).getPredicate();
54 if (!ICmpInst::isEquality(P: Pred))
55 return false;
56
57 Register LHS = MI.getOperand(i: 2).getReg();
58 LLT LHSTy = MRI.getType(Reg: LHS);
59 if (!LHSTy.isScalar())
60 return false;
61
62 Register RHS = MI.getOperand(i: 3).getReg();
63 Register WideReg;
64
65 if (!mi_match(R: LHS, MRI, P: m_GTrunc(Src: m_Reg(R&: WideReg))) ||
66 !mi_match(R: RHS, MRI, P: m_SpecificICst(RequestedValue: 0)))
67 return false;
68
69 LLT WideTy = MRI.getType(Reg: WideReg);
70 if (VT->computeNumSignBits(R: WideReg) <=
71 WideTy.getSizeInBits() - LHSTy.getSizeInBits())
72 return false;
73
74 MatchInfo = WideReg;
75 return true;
76}
77
78void applyICmpRedundantTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
79 MachineIRBuilder &Builder,
80 GISelChangeObserver &Observer, Register &WideReg) {
81 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
82
83 LLT WideTy = MRI.getType(Reg: WideReg);
84 // We're going to directly use the wide register as the LHS, and then use an
85 // equivalent size zero for RHS.
86 Builder.setInstrAndDebugLoc(MI);
87 auto WideZero = Builder.buildConstant(Res: WideTy, Val: 0);
88 Observer.changingInstr(MI);
89 MI.getOperand(i: 2).setReg(WideReg);
90 MI.getOperand(i: 3).setReg(WideZero.getReg(Idx: 0));
91 Observer.changedInstr(MI);
92}
93
94/// \returns true if it is possible to fold a constant into a G_GLOBAL_VALUE.
95///
96/// e.g.
97///
98/// %g = G_GLOBAL_VALUE @x -> %g = G_GLOBAL_VALUE @x + cst
99bool matchFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
100 std::pair<uint64_t, uint64_t> &MatchInfo) {
101 assert(MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
102 MachineFunction &MF = *MI.getMF();
103 auto &GlobalOp = MI.getOperand(i: 1);
104 auto *GV = GlobalOp.getGlobal();
105 if (GV->isThreadLocal())
106 return false;
107
108 // Don't allow anything that could represent offsets etc.
109 if (MF.getSubtarget<AArch64Subtarget>().ClassifyGlobalReference(
110 GV, TM: MF.getTarget()) != AArch64II::MO_NO_FLAG)
111 return false;
112
113 // Look for a G_GLOBAL_VALUE only used by G_PTR_ADDs against constants:
114 //
115 // %g = G_GLOBAL_VALUE @x
116 // %ptr1 = G_PTR_ADD %g, cst1
117 // %ptr2 = G_PTR_ADD %g, cst2
118 // ...
119 // %ptrN = G_PTR_ADD %g, cstN
120 //
121 // Identify the *smallest* constant. We want to be able to form this:
122 //
123 // %offset_g = G_GLOBAL_VALUE @x + min_cst
124 // %g = G_PTR_ADD %offset_g, -min_cst
125 // %ptr1 = G_PTR_ADD %g, cst1
126 // ...
127 Register Dst = MI.getOperand(i: 0).getReg();
128 uint64_t MinOffset = -1ull;
129 for (auto &UseInstr : MRI.use_nodbg_instructions(Reg: Dst)) {
130 if (UseInstr.getOpcode() != TargetOpcode::G_PTR_ADD)
131 return false;
132 auto Cst = getIConstantVRegValWithLookThrough(
133 VReg: UseInstr.getOperand(i: 2).getReg(), MRI);
134 if (!Cst)
135 return false;
136 MinOffset = std::min(a: MinOffset, b: Cst->Value.getZExtValue());
137 }
138
139 // Require that the new offset is larger than the existing one to avoid
140 // infinite loops.
141 uint64_t CurrOffset = GlobalOp.getOffset();
142 uint64_t NewOffset = MinOffset + CurrOffset;
143 if (NewOffset <= CurrOffset)
144 return false;
145
146 // Check whether folding this offset is legal. It must not go out of bounds of
147 // the referenced object to avoid violating the code model, and must be
148 // smaller than 2^20 because this is the largest offset expressible in all
149 // object formats. (The IMAGE_REL_ARM64_PAGEBASE_REL21 relocation in COFF
150 // stores an immediate signed 21 bit offset.)
151 //
152 // This check also prevents us from folding negative offsets, which will end
153 // up being treated in the same way as large positive ones. They could also
154 // cause code model violations, and aren't really common enough to matter.
155 if (NewOffset >= (1 << 20))
156 return false;
157
158 Type *T = GV->getValueType();
159 if (!T->isSized() ||
160 NewOffset > GV->getDataLayout().getTypeAllocSize(Ty: T))
161 return false;
162 MatchInfo = std::make_pair(x&: NewOffset, y&: MinOffset);
163 return true;
164}
165
166void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
167 MachineIRBuilder &B, GISelChangeObserver &Observer,
168 std::pair<uint64_t, uint64_t> &MatchInfo) {
169 // Change:
170 //
171 // %g = G_GLOBAL_VALUE @x
172 // %ptr1 = G_PTR_ADD %g, cst1
173 // %ptr2 = G_PTR_ADD %g, cst2
174 // ...
175 // %ptrN = G_PTR_ADD %g, cstN
176 //
177 // To:
178 //
179 // %offset_g = G_GLOBAL_VALUE @x + min_cst
180 // %g = G_PTR_ADD %offset_g, -min_cst
181 // %ptr1 = G_PTR_ADD %g, cst1
182 // ...
183 // %ptrN = G_PTR_ADD %g, cstN
184 //
185 // Then, the original G_PTR_ADDs should be folded later on so that they look
186 // like this:
187 //
188 // %ptrN = G_PTR_ADD %offset_g, cstN - min_cst
189 uint64_t Offset, MinOffset;
190 std::tie(args&: Offset, args&: MinOffset) = MatchInfo;
191 B.setInstrAndDebugLoc(*std::next(x: MI.getIterator()));
192 Observer.changingInstr(MI);
193 auto &GlobalOp = MI.getOperand(i: 1);
194 auto *GV = GlobalOp.getGlobal();
195 GlobalOp.ChangeToGA(GV, Offset, TargetFlags: GlobalOp.getTargetFlags());
196 Register Dst = MI.getOperand(i: 0).getReg();
197 Register NewGVDst = MRI.cloneVirtualRegister(VReg: Dst);
198 MI.getOperand(i: 0).setReg(NewGVDst);
199 Observer.changedInstr(MI);
200 B.buildPtrAdd(
201 Res: Dst, Op0: NewGVDst,
202 Op1: B.buildConstant(Res: LLT::scalar(SizeInBits: 64), Val: -static_cast<int64_t>(MinOffset)));
203}
204
205// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add([us]dot(x, y))
206// Or vecreduce_add(ext(mul(ext(x), ext(y)))) -> vecreduce_add([us]dot(x, y))
207// Or vecreduce_add(ext(x)) -> vecreduce_add([us]dot(x, 1))
208// Similar to performVecReduceAddCombine in SelectionDAG
209bool matchExtAddvToDotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
210 const AArch64Subtarget &STI,
211 std::tuple<Register, Register, bool> &MatchInfo) {
212 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
213 "Expected a G_VECREDUCE_ADD instruction");
214 assert(STI.hasDotProd() && "Target should have Dot Product feature");
215
216 MachineInstr *I1 = getDefIgnoringCopies(Reg: MI.getOperand(i: 1).getReg(), MRI);
217 Register DstReg = MI.getOperand(i: 0).getReg();
218 Register MidReg = I1->getOperand(i: 0).getReg();
219 LLT DstTy = MRI.getType(Reg: DstReg);
220 LLT MidTy = MRI.getType(Reg: MidReg);
221 if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
222 return false;
223
224 // Detect mul(ext, ext) with symmetric ext's. If I1Opc is G_ZEXT or G_SEXT
225 // then the ext's must match the same opcode. It is set to the ext opcode on
226 // output.
227 auto tryMatchingMulOfExt = [&MRI](MachineInstr *MI, Register &Out1,
228 Register &Out2, unsigned &I1Opc) {
229 // If result of this has more than 1 use, then there is no point in creating
230 // a dot instruction
231 if (!MRI.hasOneNonDBGUse(RegNo: MI->getOperand(i: 0).getReg()))
232 return false;
233
234 MachineInstr *ExtMI1 =
235 getDefIgnoringCopies(Reg: MI->getOperand(i: 1).getReg(), MRI);
236 MachineInstr *ExtMI2 =
237 getDefIgnoringCopies(Reg: MI->getOperand(i: 2).getReg(), MRI);
238 LLT Ext1DstTy = MRI.getType(Reg: ExtMI1->getOperand(i: 0).getReg());
239 LLT Ext2DstTy = MRI.getType(Reg: ExtMI2->getOperand(i: 0).getReg());
240
241 if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
242 return false;
243 if ((I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) &&
244 I1Opc != ExtMI1->getOpcode())
245 return false;
246 Out1 = ExtMI1->getOperand(i: 1).getReg();
247 Out2 = ExtMI2->getOperand(i: 1).getReg();
248 I1Opc = ExtMI1->getOpcode();
249 return true;
250 };
251
252 LLT SrcTy;
253 unsigned I1Opc = I1->getOpcode();
254 if (I1Opc == TargetOpcode::G_MUL) {
255 Register Out1, Out2;
256 if (!tryMatchingMulOfExt(I1, Out1, Out2, I1Opc))
257 return false;
258 SrcTy = MRI.getType(Reg: Out1);
259 std::get<0>(t&: MatchInfo) = Out1;
260 std::get<1>(t&: MatchInfo) = Out2;
261 } else if (I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) {
262 Register I1Op = I1->getOperand(i: 1).getReg();
263 MachineInstr *M = getDefIgnoringCopies(Reg: I1Op, MRI);
264 Register Out1, Out2;
265 if (M->getOpcode() == TargetOpcode::G_MUL &&
266 tryMatchingMulOfExt(M, Out1, Out2, I1Opc)) {
267 SrcTy = MRI.getType(Reg: Out1);
268 std::get<0>(t&: MatchInfo) = Out1;
269 std::get<1>(t&: MatchInfo) = Out2;
270 } else {
271 SrcTy = MRI.getType(Reg: I1Op);
272 std::get<0>(t&: MatchInfo) = I1Op;
273 std::get<1>(t&: MatchInfo) = 0;
274 }
275 } else {
276 return false;
277 }
278
279 if (I1Opc == TargetOpcode::G_ZEXT)
280 std::get<2>(t&: MatchInfo) = 0;
281 else if (I1Opc == TargetOpcode::G_SEXT)
282 std::get<2>(t&: MatchInfo) = 1;
283 else
284 return false;
285
286 if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0)
287 return false;
288
289 return true;
290}
291
292void applyExtAddvToDotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
293 MachineIRBuilder &Builder,
294 GISelChangeObserver &Observer,
295 const AArch64Subtarget &STI,
296 std::tuple<Register, Register, bool> &MatchInfo) {
297 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
298 "Expected a G_VECREDUCE_ADD instruction");
299 assert(STI.hasDotProd() && "Target should have Dot Product feature");
300
301 // Initialise the variables
302 unsigned DotOpcode =
303 std::get<2>(t&: MatchInfo) ? AArch64::G_SDOT : AArch64::G_UDOT;
304 Register Ext1SrcReg = std::get<0>(t&: MatchInfo);
305
306 // If there is one source register, create a vector of 0s as the second
307 // source register
308 Register Ext2SrcReg;
309 if (std::get<1>(t&: MatchInfo) == 0)
310 Ext2SrcReg = Builder.buildConstant(Res: MRI.getType(Reg: Ext1SrcReg), Val: 1)
311 ->getOperand(i: 0)
312 .getReg();
313 else
314 Ext2SrcReg = std::get<1>(t&: MatchInfo);
315
316 // Find out how many DOT instructions are needed
317 LLT SrcTy = MRI.getType(Reg: Ext1SrcReg);
318 LLT MidTy;
319 unsigned NumOfDotMI;
320 if (SrcTy.getNumElements() % 16 == 0) {
321 NumOfDotMI = SrcTy.getNumElements() / 16;
322 MidTy = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32);
323 } else if (SrcTy.getNumElements() % 8 == 0) {
324 NumOfDotMI = SrcTy.getNumElements() / 8;
325 MidTy = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 32);
326 } else {
327 llvm_unreachable("Source type number of elements is not multiple of 8");
328 }
329
330 // Handle case where one DOT instruction is needed
331 if (NumOfDotMI == 1) {
332 auto Zeroes = Builder.buildConstant(Res: MidTy, Val: 0)->getOperand(i: 0).getReg();
333 auto Dot = Builder.buildInstr(Opc: DotOpcode, DstOps: {MidTy},
334 SrcOps: {Zeroes, Ext1SrcReg, Ext2SrcReg});
335 Builder.buildVecReduceAdd(Dst: MI.getOperand(i: 0), Src: Dot->getOperand(i: 0));
336 } else {
337 // If not pad the last v8 element with 0s to a v16
338 SmallVector<Register, 4> Ext1UnmergeReg;
339 SmallVector<Register, 4> Ext2UnmergeReg;
340 if (SrcTy.getNumElements() % 16 != 0) {
341 SmallVector<Register> Leftover1;
342 SmallVector<Register> Leftover2;
343
344 // Split the elements into v16i8 and v8i8
345 LLT MainTy = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8);
346 LLT LeftoverTy1, LeftoverTy2;
347 if ((!extractParts(Reg: Ext1SrcReg, RegTy: MRI.getType(Reg: Ext1SrcReg), MainTy,
348 LeftoverTy&: LeftoverTy1, VRegs&: Ext1UnmergeReg, LeftoverVRegs&: Leftover1, MIRBuilder&: Builder,
349 MRI)) ||
350 (!extractParts(Reg: Ext2SrcReg, RegTy: MRI.getType(Reg: Ext2SrcReg), MainTy,
351 LeftoverTy&: LeftoverTy2, VRegs&: Ext2UnmergeReg, LeftoverVRegs&: Leftover2, MIRBuilder&: Builder,
352 MRI))) {
353 llvm_unreachable("Unable to split this vector properly");
354 }
355
356 // Pad the leftover v8i8 vector with register of 0s of type v8i8
357 Register v8Zeroes = Builder.buildConstant(Res: LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 8), Val: 0)
358 ->getOperand(i: 0)
359 .getReg();
360
361 Ext1UnmergeReg.push_back(
362 Elt: Builder
363 .buildMergeLikeInstr(Res: LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8),
364 Ops: {Leftover1[0], v8Zeroes})
365 .getReg(Idx: 0));
366 Ext2UnmergeReg.push_back(
367 Elt: Builder
368 .buildMergeLikeInstr(Res: LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8),
369 Ops: {Leftover2[0], v8Zeroes})
370 .getReg(Idx: 0));
371
372 } else {
373 // Unmerge the source vectors to v16i8
374 unsigned SrcNumElts = SrcTy.getNumElements();
375 extractParts(Reg: Ext1SrcReg, Ty: LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8), NumParts: SrcNumElts / 16,
376 VRegs&: Ext1UnmergeReg, MIRBuilder&: Builder, MRI);
377 extractParts(Reg: Ext2SrcReg, Ty: LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8), NumParts: SrcNumElts / 16,
378 VRegs&: Ext2UnmergeReg, MIRBuilder&: Builder, MRI);
379 }
380
381 // Build the UDOT instructions
382 SmallVector<Register, 2> DotReg;
383 unsigned NumElements = 0;
384 for (unsigned i = 0; i < Ext1UnmergeReg.size(); i++) {
385 LLT ZeroesLLT;
386 // Check if it is 16 or 8 elements. Set Zeroes to the according size
387 if (MRI.getType(Reg: Ext1UnmergeReg[i]).getNumElements() == 16) {
388 ZeroesLLT = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32);
389 NumElements += 4;
390 } else {
391 ZeroesLLT = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 32);
392 NumElements += 2;
393 }
394 auto Zeroes = Builder.buildConstant(Res: ZeroesLLT, Val: 0)->getOperand(i: 0).getReg();
395 DotReg.push_back(
396 Elt: Builder
397 .buildInstr(Opc: DotOpcode, DstOps: {MRI.getType(Reg: Zeroes)},
398 SrcOps: {Zeroes, Ext1UnmergeReg[i], Ext2UnmergeReg[i]})
399 .getReg(Idx: 0));
400 }
401
402 // Merge the output
403 auto ConcatMI =
404 Builder.buildConcatVectors(Res: LLT::fixed_vector(NumElements, ScalarSizeInBits: 32), Ops: DotReg);
405
406 // Put it through a vector reduction
407 Builder.buildVecReduceAdd(Dst: MI.getOperand(i: 0).getReg(),
408 Src: ConcatMI->getOperand(i: 0).getReg());
409 }
410
411 // Erase the dead instructions
412 MI.eraseFromParent();
413}
414
415// Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
416// Ensure that the type coming from the extend instruction is the right size
417bool matchExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
418 std::pair<Register, bool> &MatchInfo) {
419 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
420 "Expected G_VECREDUCE_ADD Opcode");
421
422 // Check if the last instruction is an extend
423 MachineInstr *ExtMI = getDefIgnoringCopies(Reg: MI.getOperand(i: 1).getReg(), MRI);
424 auto ExtOpc = ExtMI->getOpcode();
425
426 if (ExtOpc == TargetOpcode::G_ZEXT)
427 std::get<1>(in&: MatchInfo) = 0;
428 else if (ExtOpc == TargetOpcode::G_SEXT)
429 std::get<1>(in&: MatchInfo) = 1;
430 else
431 return false;
432
433 // Check if the source register is a valid type
434 Register ExtSrcReg = ExtMI->getOperand(i: 1).getReg();
435 LLT ExtSrcTy = MRI.getType(Reg: ExtSrcReg);
436 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
437 if (ExtSrcTy.getScalarSizeInBits() * 2 > DstTy.getScalarSizeInBits())
438 return false;
439 if ((DstTy.getScalarSizeInBits() == 16 &&
440 ExtSrcTy.getNumElements() % 8 == 0 && ExtSrcTy.getNumElements() < 256) ||
441 (DstTy.getScalarSizeInBits() == 32 &&
442 ExtSrcTy.getNumElements() % 4 == 0) ||
443 (DstTy.getScalarSizeInBits() == 64 &&
444 ExtSrcTy.getNumElements() % 4 == 0)) {
445 std::get<0>(in&: MatchInfo) = ExtSrcReg;
446 return true;
447 }
448 return false;
449}
450
451void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
452 MachineIRBuilder &B, GISelChangeObserver &Observer,
453 std::pair<Register, bool> &MatchInfo) {
454 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
455 "Expected G_VECREDUCE_ADD Opcode");
456
457 unsigned Opc = std::get<1>(in&: MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
458 Register SrcReg = std::get<0>(in&: MatchInfo);
459 Register DstReg = MI.getOperand(i: 0).getReg();
460 LLT SrcTy = MRI.getType(Reg: SrcReg);
461 LLT DstTy = MRI.getType(Reg: DstReg);
462
463 // If SrcTy has more elements than expected, split them into multiple
464 // instructions and sum the results
465 LLT MainTy;
466 SmallVector<Register, 1> WorkingRegisters;
467 unsigned SrcScalSize = SrcTy.getScalarSizeInBits();
468 unsigned SrcNumElem = SrcTy.getNumElements();
469 if ((SrcScalSize == 8 && SrcNumElem > 16) ||
470 (SrcScalSize == 16 && SrcNumElem > 8) ||
471 (SrcScalSize == 32 && SrcNumElem > 4)) {
472
473 LLT LeftoverTy;
474 SmallVector<Register, 4> LeftoverRegs;
475 if (SrcScalSize == 8)
476 MainTy = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8);
477 else if (SrcScalSize == 16)
478 MainTy = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 16);
479 else if (SrcScalSize == 32)
480 MainTy = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32);
481 else
482 llvm_unreachable("Source's Scalar Size not supported");
483
484 // Extract the parts and put each extracted sources through U/SADDLV and put
485 // the values inside a small vec
486 extractParts(Reg: SrcReg, RegTy: SrcTy, MainTy, LeftoverTy, VRegs&: WorkingRegisters,
487 LeftoverVRegs&: LeftoverRegs, MIRBuilder&: B, MRI);
488 llvm::append_range(C&: WorkingRegisters, R&: LeftoverRegs);
489 } else {
490 WorkingRegisters.push_back(Elt: SrcReg);
491 MainTy = SrcTy;
492 }
493
494 unsigned MidScalarSize = MainTy.getScalarSizeInBits() * 2;
495 LLT MidScalarLLT = LLT::scalar(SizeInBits: MidScalarSize);
496 Register ZeroReg = B.buildConstant(Res: LLT::scalar(SizeInBits: 64), Val: 0).getReg(Idx: 0);
497 for (unsigned I = 0; I < WorkingRegisters.size(); I++) {
498 // If the number of elements is too small to build an instruction, extend
499 // its size before applying addlv
500 LLT WorkingRegTy = MRI.getType(Reg: WorkingRegisters[I]);
501 if ((WorkingRegTy.getScalarSizeInBits() == 8) &&
502 (WorkingRegTy.getNumElements() == 4)) {
503 WorkingRegisters[I] =
504 B.buildInstr(Opc: std::get<1>(in&: MatchInfo) ? TargetOpcode::G_SEXT
505 : TargetOpcode::G_ZEXT,
506 DstOps: {LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 16)}, SrcOps: {WorkingRegisters[I]})
507 .getReg(Idx: 0);
508 }
509
510 // Generate the {U/S}ADDLV instruction, whose output is always double of the
511 // Src's Scalar size
512 LLT AddlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32)
513 : LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 64);
514 Register AddlvReg =
515 B.buildInstr(Opc, DstOps: {AddlvTy}, SrcOps: {WorkingRegisters[I]}).getReg(Idx: 0);
516
517 // The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or
518 // v2i64 register.
519 // i16, i32 results uses v4i32 registers
520 // i64 results uses v2i64 registers
521 // Therefore we have to extract/truncate the the value to the right type
522 if (MidScalarSize == 32 || MidScalarSize == 64) {
523 WorkingRegisters[I] = B.buildInstr(Opc: AArch64::G_EXTRACT_VECTOR_ELT,
524 DstOps: {MidScalarLLT}, SrcOps: {AddlvReg, ZeroReg})
525 .getReg(Idx: 0);
526 } else {
527 Register ExtractReg = B.buildInstr(Opc: AArch64::G_EXTRACT_VECTOR_ELT,
528 DstOps: {LLT::scalar(SizeInBits: 32)}, SrcOps: {AddlvReg, ZeroReg})
529 .getReg(Idx: 0);
530 WorkingRegisters[I] =
531 B.buildTrunc(Res: {MidScalarLLT}, Op: {ExtractReg}).getReg(Idx: 0);
532 }
533 }
534
535 Register OutReg;
536 if (WorkingRegisters.size() > 1) {
537 OutReg = B.buildAdd(Dst: MidScalarLLT, Src0: WorkingRegisters[0], Src1: WorkingRegisters[1])
538 .getReg(Idx: 0);
539 for (unsigned I = 2; I < WorkingRegisters.size(); I++) {
540 OutReg = B.buildAdd(Dst: MidScalarLLT, Src0: OutReg, Src1: WorkingRegisters[I]).getReg(Idx: 0);
541 }
542 } else {
543 OutReg = WorkingRegisters[0];
544 }
545
546 if (DstTy.getScalarSizeInBits() > MidScalarSize) {
547 // Handle the scalar value if the DstTy's Scalar Size is more than double
548 // Src's ScalarType
549 B.buildInstr(Opc: std::get<1>(in&: MatchInfo) ? TargetOpcode::G_SEXT
550 : TargetOpcode::G_ZEXT,
551 DstOps: {DstReg}, SrcOps: {OutReg});
552 } else {
553 B.buildCopy(Res: DstReg, Op: OutReg);
554 }
555
556 MI.eraseFromParent();
557}
558
559// Pushes ADD/SUB/MUL through extend instructions to decrease the number of
560// extend instruction at the end by allowing selection of {s|u}addl sooner
561// i32 add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8))
562bool matchPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
563 Register DstReg, Register SrcReg1, Register SrcReg2) {
564 assert((MI.getOpcode() == TargetOpcode::G_ADD ||
565 MI.getOpcode() == TargetOpcode::G_SUB ||
566 MI.getOpcode() == TargetOpcode::G_MUL) &&
567 "Expected a G_ADD, G_SUB or G_MUL instruction\n");
568
569 // Deal with vector types only
570 LLT DstTy = MRI.getType(Reg: DstReg);
571 if (!DstTy.isVector())
572 return false;
573
574 // Return true if G_{S|Z}EXT instruction is more than 2* source
575 Register ExtDstReg = MI.getOperand(i: 1).getReg();
576 LLT Ext1SrcTy = MRI.getType(Reg: SrcReg1);
577 LLT Ext2SrcTy = MRI.getType(Reg: SrcReg2);
578 unsigned ExtDstScal = MRI.getType(Reg: ExtDstReg).getScalarSizeInBits();
579 unsigned Ext1SrcScal = Ext1SrcTy.getScalarSizeInBits();
580 if (((Ext1SrcScal == 8 && ExtDstScal == 32) ||
581 ((Ext1SrcScal == 8 || Ext1SrcScal == 16) && ExtDstScal == 64)) &&
582 Ext1SrcTy == Ext2SrcTy)
583 return true;
584
585 return false;
586}
587
588void applyPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
589 MachineIRBuilder &B, bool isSExt, Register DstReg,
590 Register SrcReg1, Register SrcReg2) {
591 LLT SrcTy = MRI.getType(Reg: SrcReg1);
592 LLT MidTy = SrcTy.changeElementSize(NewEltSize: SrcTy.getScalarSizeInBits() * 2);
593 unsigned Opc = isSExt ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
594 Register Ext1Reg = B.buildInstr(Opc, DstOps: {MidTy}, SrcOps: {SrcReg1}).getReg(Idx: 0);
595 Register Ext2Reg = B.buildInstr(Opc, DstOps: {MidTy}, SrcOps: {SrcReg2}).getReg(Idx: 0);
596 Register AddReg =
597 B.buildInstr(Opc: MI.getOpcode(), DstOps: {MidTy}, SrcOps: {Ext1Reg, Ext2Reg}).getReg(Idx: 0);
598
599 // G_SUB has to sign-extend the result.
600 // G_ADD needs to sext from sext and can sext or zext from zext, and G_MUL
601 // needs to use the original opcode so the original opcode is used for both.
602 if (MI.getOpcode() == TargetOpcode::G_ADD ||
603 MI.getOpcode() == TargetOpcode::G_MUL)
604 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {AddReg});
605 else
606 B.buildSExt(Res: DstReg, Op: AddReg);
607
608 MI.eraseFromParent();
609}
610
611bool tryToSimplifyUADDO(MachineInstr &MI, MachineIRBuilder &B,
612 const CombinerHelper &Helper,
613 GISelChangeObserver &Observer) {
614 // Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
615 // result is only used in the no-overflow case. It is restricted to cases
616 // where we know that the high-bits of the operands are 0. If there's an
617 // overflow, then the 9th or 17th bit must be set, which can be checked
618 // using TBNZ.
619 //
620 // Change (for UADDOs on 8 and 16 bits):
621 //
622 // %z0 = G_ASSERT_ZEXT _
623 // %op0 = G_TRUNC %z0
624 // %z1 = G_ASSERT_ZEXT _
625 // %op1 = G_TRUNC %z1
626 // %val, %cond = G_UADDO %op0, %op1
627 // G_BRCOND %cond, %error.bb
628 //
629 // error.bb:
630 // (no successors and no uses of %val)
631 //
632 // To:
633 //
634 // %z0 = G_ASSERT_ZEXT _
635 // %z1 = G_ASSERT_ZEXT _
636 // %add = G_ADD %z0, %z1
637 // %val = G_TRUNC %add
638 // %bit = G_AND %add, 1 << scalar-size-in-bits(%op1)
639 // %cond = G_ICMP NE, %bit, 0
640 // G_BRCOND %cond, %error.bb
641
642 auto &MRI = *B.getMRI();
643
644 MachineOperand *DefOp0 = MRI.getOneDef(Reg: MI.getOperand(i: 2).getReg());
645 MachineOperand *DefOp1 = MRI.getOneDef(Reg: MI.getOperand(i: 3).getReg());
646 Register Op0Wide;
647 Register Op1Wide;
648 if (!mi_match(R: DefOp0->getParent(), MRI, P: m_GTrunc(Src: m_Reg(R&: Op0Wide))) ||
649 !mi_match(R: DefOp1->getParent(), MRI, P: m_GTrunc(Src: m_Reg(R&: Op1Wide))))
650 return false;
651 LLT WideTy0 = MRI.getType(Reg: Op0Wide);
652 LLT WideTy1 = MRI.getType(Reg: Op1Wide);
653 Register ResVal = MI.getOperand(i: 0).getReg();
654 LLT OpTy = MRI.getType(Reg: ResVal);
655 MachineInstr *Op0WideDef = MRI.getVRegDef(Reg: Op0Wide);
656 MachineInstr *Op1WideDef = MRI.getVRegDef(Reg: Op1Wide);
657
658 unsigned OpTySize = OpTy.getScalarSizeInBits();
659 // First check that the G_TRUNC feeding the G_UADDO are no-ops, because the
660 // inputs have been zero-extended.
661 if (Op0WideDef->getOpcode() != TargetOpcode::G_ASSERT_ZEXT ||
662 Op1WideDef->getOpcode() != TargetOpcode::G_ASSERT_ZEXT ||
663 OpTySize != Op0WideDef->getOperand(i: 2).getImm() ||
664 OpTySize != Op1WideDef->getOperand(i: 2).getImm())
665 return false;
666
667 // Only scalar UADDO with either 8 or 16 bit operands are handled.
668 if (!WideTy0.isScalar() || !WideTy1.isScalar() || WideTy0 != WideTy1 ||
669 OpTySize >= WideTy0.getScalarSizeInBits() ||
670 (OpTySize != 8 && OpTySize != 16))
671 return false;
672
673 // The overflow-status result must be used by a branch only.
674 Register ResStatus = MI.getOperand(i: 1).getReg();
675 if (!MRI.hasOneNonDBGUse(RegNo: ResStatus))
676 return false;
677 MachineInstr *CondUser = &*MRI.use_instr_nodbg_begin(RegNo: ResStatus);
678 if (CondUser->getOpcode() != TargetOpcode::G_BRCOND)
679 return false;
680
681 // Make sure the computed result is only used in the no-overflow blocks.
682 MachineBasicBlock *CurrentMBB = MI.getParent();
683 MachineBasicBlock *FailMBB = CondUser->getOperand(i: 1).getMBB();
684 if (!FailMBB->succ_empty() || CondUser->getParent() != CurrentMBB)
685 return false;
686 if (any_of(Range: MRI.use_nodbg_instructions(Reg: ResVal),
687 P: [&MI, FailMBB, CurrentMBB](MachineInstr &I) {
688 return &MI != &I &&
689 (I.getParent() == FailMBB || I.getParent() == CurrentMBB);
690 }))
691 return false;
692
693 // Remove G_ADDO.
694 B.setInstrAndDebugLoc(*MI.getNextNode());
695 MI.eraseFromParent();
696
697 // Emit wide add.
698 Register AddDst = MRI.cloneVirtualRegister(VReg: Op0Wide);
699 B.buildInstr(Opc: TargetOpcode::G_ADD, DstOps: {AddDst}, SrcOps: {Op0Wide, Op1Wide});
700
701 // Emit check of the 9th or 17th bit and update users (the branch). This will
702 // later be folded to TBNZ.
703 Register CondBit = MRI.cloneVirtualRegister(VReg: Op0Wide);
704 B.buildAnd(
705 Dst: CondBit, Src0: AddDst,
706 Src1: B.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: OpTySize == 8 ? 1 << 8 : 1 << 16));
707 B.buildICmp(Pred: CmpInst::ICMP_NE, Res: ResStatus, Op0: CondBit,
708 Op1: B.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0));
709
710 // Update ZEXts users of the result value. Because all uses are in the
711 // no-overflow case, we know that the top bits are 0 and we can ignore ZExts.
712 B.buildZExtOrTrunc(Res: ResVal, Op: AddDst);
713 for (MachineOperand &U : make_early_inc_range(Range: MRI.use_operands(Reg: ResVal))) {
714 Register WideReg;
715 if (mi_match(R: U.getParent(), MRI, P: m_GZExt(Src: m_Reg(R&: WideReg)))) {
716 auto OldR = U.getParent()->getOperand(i: 0).getReg();
717 Observer.erasingInstr(MI&: *U.getParent());
718 U.getParent()->eraseFromParent();
719 Helper.replaceRegWith(MRI, FromReg: OldR, ToReg: AddDst);
720 }
721 }
722
723 return true;
724}
725
726class AArch64PreLegalizerCombinerImpl : public Combiner {
727protected:
728 const CombinerHelper Helper;
729 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig;
730 const AArch64Subtarget &STI;
731 const LibcallLoweringInfo &Libcalls;
732
733public:
734 AArch64PreLegalizerCombinerImpl(
735 MachineFunction &MF, CombinerInfo &CInfo, GISelValueTracking &VT,
736 GISelCSEInfo *CSEInfo,
737 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig,
738 const AArch64Subtarget &STI, const LibcallLoweringInfo &Libcalls,
739 MachineDominatorTree *MDT, const LegalizerInfo *LI);
740
741 static const char *getName() { return "AArch6400PreLegalizerCombiner"; }
742
743 bool tryCombineAll(MachineInstr &I) const override;
744
745 bool tryCombineAllImpl(MachineInstr &I) const;
746
747private:
748#define GET_GICOMBINER_CLASS_MEMBERS
749#include "AArch64GenPreLegalizeGICombiner.inc"
750#undef GET_GICOMBINER_CLASS_MEMBERS
751};
752
753#define GET_GICOMBINER_IMPL
754#include "AArch64GenPreLegalizeGICombiner.inc"
755#undef GET_GICOMBINER_IMPL
756
757AArch64PreLegalizerCombinerImpl::AArch64PreLegalizerCombinerImpl(
758 MachineFunction &MF, CombinerInfo &CInfo, GISelValueTracking &VT,
759 GISelCSEInfo *CSEInfo,
760 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig,
761 const AArch64Subtarget &STI, const LibcallLoweringInfo &Libcalls,
762 MachineDominatorTree *MDT, const LegalizerInfo *LI)
763 : Combiner(MF, CInfo, &VT, CSEInfo),
764 Helper(Observer, B, /*IsPreLegalize*/ true, &VT, MDT, LI),
765 RuleConfig(RuleConfig), STI(STI), Libcalls(Libcalls),
766#define GET_GICOMBINER_CONSTRUCTOR_INITS
767#include "AArch64GenPreLegalizeGICombiner.inc"
768#undef GET_GICOMBINER_CONSTRUCTOR_INITS
769{
770}
771
772bool AArch64PreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const {
773 if (tryCombineAllImpl(I&: MI))
774 return true;
775
776 unsigned Opc = MI.getOpcode();
777 switch (Opc) {
778 case TargetOpcode::G_SHUFFLE_VECTOR:
779 return Helper.tryCombineShuffleVector(MI);
780 case TargetOpcode::G_UADDO:
781 return tryToSimplifyUADDO(MI, B, Helper, Observer);
782 case TargetOpcode::G_MEMCPY_INLINE:
783 return Helper.tryEmitMemcpyInline(MI);
784 case TargetOpcode::G_MEMCPY:
785 case TargetOpcode::G_MEMMOVE:
786 case TargetOpcode::G_MEMSET: {
787 // If we're at -O0 set a maxlen of 32 to inline, otherwise let the other
788 // heuristics decide.
789 unsigned MaxLen = CInfo.EnableOpt ? 0 : 32;
790 // Try to inline memcpy type calls if optimizations are enabled.
791 if (Helper.tryCombineMemCpyFamily(MI, MaxLen))
792 return true;
793 if (Opc == TargetOpcode::G_MEMSET)
794 return llvm::AArch64GISelUtils::tryEmitBZero(MI, MIRBuilder&: B, Libcalls,
795 MinSize: CInfo.EnableMinSize);
796 return false;
797 }
798 }
799
800 return false;
801}
802
803// Pass boilerplate
804// ================
805
806class AArch64PreLegalizerCombiner : public MachineFunctionPass {
807public:
808 static char ID;
809
810 AArch64PreLegalizerCombiner();
811
812 StringRef getPassName() const override {
813 return "AArch64PreLegalizerCombiner";
814 }
815
816 bool runOnMachineFunction(MachineFunction &MF) override;
817
818 void getAnalysisUsage(AnalysisUsage &AU) const override;
819
820private:
821 AArch64PreLegalizerCombinerImplRuleConfig RuleConfig;
822};
823} // end anonymous namespace
824
825void AArch64PreLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
826 AU.setPreservesCFG();
827 getSelectionDAGFallbackAnalysisUsage(AU);
828 AU.addRequired<GISelValueTrackingAnalysisLegacy>();
829 AU.addPreserved<GISelValueTrackingAnalysisLegacy>();
830 AU.addRequired<MachineDominatorTreeWrapperPass>();
831 AU.addPreserved<MachineDominatorTreeWrapperPass>();
832 AU.addRequired<GISelCSEAnalysisWrapperPass>();
833 AU.addPreserved<GISelCSEAnalysisWrapperPass>();
834 AU.addRequired<LibcallLoweringInfoWrapper>();
835 MachineFunctionPass::getAnalysisUsage(AU);
836}
837
838AArch64PreLegalizerCombiner::AArch64PreLegalizerCombiner()
839 : MachineFunctionPass(ID) {
840 if (!RuleConfig.parseCommandLineOption())
841 report_fatal_error(reason: "Invalid rule identifier");
842}
843
844bool AArch64PreLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
845 if (MF.getProperties().hasFailedISel())
846 return false;
847 // Enable CSE.
848 GISelCSEAnalysisWrapper &Wrapper =
849 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
850 auto *CSEInfo =
851 &Wrapper.get(CSEOpt: getStandardCSEConfigForOpt(Level: MF.getTarget().getOptLevel()));
852
853 const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>();
854 const auto *LI = ST.getLegalizerInfo();
855
856 const Function &F = MF.getFunction();
857
858 const LibcallLoweringInfo &Libcalls =
859 getAnalysis<LibcallLoweringInfoWrapper>().getLibcallLowering(
860 M: *F.getParent(), Subtarget: ST);
861
862 bool EnableOpt =
863 MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F);
864 GISelValueTracking *VT =
865 &getAnalysis<GISelValueTrackingAnalysisLegacy>().get(MF);
866 MachineDominatorTree *MDT =
867 &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
868 CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
869 /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(),
870 F.hasMinSize());
871 // Disable fixed-point iteration to reduce compile-time
872 CInfo.MaxIterations = 1;
873 CInfo.ObserverLvl = CombinerInfo::ObserverLevel::SinglePass;
874 // This is the first Combiner, so the input IR might contain dead
875 // instructions.
876 CInfo.EnableFullDCE = true;
877 AArch64PreLegalizerCombinerImpl Impl(MF, CInfo, *VT, CSEInfo, RuleConfig, ST,
878 Libcalls, MDT, LI);
879 return Impl.combineMachineInstrs();
880}
881
882char AArch64PreLegalizerCombiner::ID = 0;
883INITIALIZE_PASS_BEGIN(AArch64PreLegalizerCombiner, DEBUG_TYPE,
884 "Combine AArch64 machine instrs before legalization",
885 false, false)
886INITIALIZE_PASS_DEPENDENCY(GISelValueTrackingAnalysisLegacy)
887INITIALIZE_PASS_DEPENDENCY(GISelCSEAnalysisWrapperPass)
888INITIALIZE_PASS_DEPENDENCY(LibcallLoweringInfoWrapper)
889INITIALIZE_PASS_END(AArch64PreLegalizerCombiner, DEBUG_TYPE,
890 "Combine AArch64 machine instrs before legalization", false,
891 false)
892
893namespace llvm {
894FunctionPass *createAArch64PreLegalizerCombiner() {
895 return new AArch64PreLegalizerCombiner();
896}
897} // end namespace llvm
898