1//=== AArch64PostLegalizerLowering.cpp --------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// Post-legalization lowering for instructions.
11///
12/// This is used to offload pattern matching from the selector.
13///
14/// For example, this combiner will notice that a G_SHUFFLE_VECTOR is actually
15/// a G_ZIP, G_UZP, etc.
16///
17/// General optimization combines should be handled by either the
18/// AArch64PostLegalizerCombiner or the AArch64PreLegalizerCombiner.
19///
20//===----------------------------------------------------------------------===//
21
22#include "AArch64.h"
23#include "AArch64ExpandImm.h"
24#include "AArch64GlobalISelUtils.h"
25#include "AArch64PerfectShuffle.h"
26#include "AArch64Subtarget.h"
27#include "GISel/AArch64LegalizerInfo.h"
28#include "MCTargetDesc/AArch64MCTargetDesc.h"
29#include "Utils/AArch64BaseInfo.h"
30#include "llvm/CodeGen/GlobalISel/Combiner.h"
31#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
32#include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
33#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
34#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
35#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
36#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
37#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
38#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
39#include "llvm/CodeGen/GlobalISel/Utils.h"
40#include "llvm/CodeGen/MachineFrameInfo.h"
41#include "llvm/CodeGen/MachineFunctionAnalysisManager.h"
42#include "llvm/CodeGen/MachineFunctionPass.h"
43#include "llvm/CodeGen/MachineInstrBuilder.h"
44#include "llvm/CodeGen/MachinePassManager.h"
45#include "llvm/CodeGen/MachineRegisterInfo.h"
46#include "llvm/CodeGen/TargetOpcodes.h"
47#include "llvm/IR/InstrTypes.h"
48#include "llvm/Support/ErrorHandling.h"
49#include <optional>
50
51#define GET_GICOMBINER_DEPS
52#include "AArch64GenPostLegalizeGILowering.inc"
53#undef GET_GICOMBINER_DEPS
54
55#define DEBUG_TYPE "aarch64-postlegalizer-lowering"
56
57using namespace llvm;
58using namespace MIPatternMatch;
59using namespace AArch64GISelUtils;
60
61#define GET_GICOMBINER_TYPES
62#include "AArch64GenPostLegalizeGILowering.inc"
63#undef GET_GICOMBINER_TYPES
64
65namespace {
66
67/// Represents a pseudo instruction which replaces a G_SHUFFLE_VECTOR.
68///
69/// Used for matching target-supported shuffles before codegen.
70struct ShuffleVectorPseudo {
71 unsigned Opc; ///< Opcode for the instruction. (E.g. G_ZIP1)
72 Register Dst; ///< Destination register.
73 SmallVector<SrcOp, 2> SrcOps; ///< Source registers.
74 ShuffleVectorPseudo(unsigned Opc, Register Dst,
75 std::initializer_list<SrcOp> SrcOps)
76 : Opc(Opc), Dst(Dst), SrcOps(SrcOps){};
77 ShuffleVectorPseudo() = default;
78};
79
80/// Check if a G_EXT instruction can handle a shuffle mask \p M when the vector
81/// sources of the shuffle are different.
82std::optional<std::pair<bool, uint64_t>> getExtMask(ArrayRef<int> M,
83 unsigned NumElts) {
84 // Look for the first non-undef element.
85 auto FirstRealElt = find_if(Range&: M, P: [](int Elt) { return Elt >= 0; });
86 if (FirstRealElt == M.end())
87 return std::nullopt;
88
89 // Use APInt to handle overflow when calculating expected element.
90 unsigned MaskBits = APInt(32, NumElts * 2).logBase2();
91 APInt ExpectedElt = APInt(MaskBits, *FirstRealElt + 1, false, true);
92
93 // The following shuffle indices must be the successive elements after the
94 // first real element.
95 if (any_of(
96 Range: make_range(x: std::next(x: FirstRealElt), y: M.end()),
97 P: [&ExpectedElt](int Elt) { return Elt != ExpectedElt++ && Elt >= 0; }))
98 return std::nullopt;
99
100 // The index of an EXT is the first element if it is not UNDEF.
101 // Watch out for the beginning UNDEFs. The EXT index should be the expected
102 // value of the first element. E.g.
103 // <-1, -1, 3, ...> is treated as <1, 2, 3, ...>.
104 // <-1, -1, 0, 1, ...> is treated as <2*NumElts-2, 2*NumElts-1, 0, 1, ...>.
105 // ExpectedElt is the last mask index plus 1.
106 uint64_t Imm = ExpectedElt.getZExtValue();
107 bool ReverseExt = false;
108
109 // There are two difference cases requiring to reverse input vectors.
110 // For example, for vector <4 x i32> we have the following cases,
111 // Case 1: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, -1, 0>)
112 // Case 2: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, 7, 0>)
113 // For both cases, we finally use mask <5, 6, 7, 0>, which requires
114 // to reverse two input vectors.
115 if (Imm < NumElts)
116 ReverseExt = true;
117 else
118 Imm -= NumElts;
119 return std::make_pair(x&: ReverseExt, y&: Imm);
120}
121
122/// Helper function for matchINS.
123///
124/// \returns a value when \p M is an ins mask for \p NumInputElements.
125///
126/// First element of the returned pair is true when the produced
127/// G_INSERT_VECTOR_ELT destination should be the LHS of the G_SHUFFLE_VECTOR.
128///
129/// Second element is the destination lane for the G_INSERT_VECTOR_ELT.
130std::optional<std::pair<bool, int>> isINSMask(ArrayRef<int> M,
131 int NumInputElements) {
132 if (M.size() != static_cast<size_t>(NumInputElements))
133 return std::nullopt;
134 int NumLHSMatch = 0, NumRHSMatch = 0;
135 int LastLHSMismatch = -1, LastRHSMismatch = -1;
136 for (int Idx = 0; Idx < NumInputElements; ++Idx) {
137 if (M[Idx] == -1) {
138 ++NumLHSMatch;
139 ++NumRHSMatch;
140 continue;
141 }
142 M[Idx] == Idx ? ++NumLHSMatch : LastLHSMismatch = Idx;
143 M[Idx] == Idx + NumInputElements ? ++NumRHSMatch : LastRHSMismatch = Idx;
144 }
145 const int NumNeededToMatch = NumInputElements - 1;
146 if (NumLHSMatch == NumNeededToMatch)
147 return std::make_pair(x: true, y&: LastLHSMismatch);
148 if (NumRHSMatch == NumNeededToMatch)
149 return std::make_pair(x: false, y&: LastRHSMismatch);
150 return std::nullopt;
151}
152
153/// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with a
154/// G_REV instruction. Returns the appropriate G_REV opcode in \p Opc.
155bool matchREV(MachineInstr &MI, MachineRegisterInfo &MRI,
156 ShuffleVectorPseudo &MatchInfo) {
157 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
158 ArrayRef<int> ShuffleMask = MI.getOperand(i: 3).getShuffleMask();
159 Register Dst = MI.getOperand(i: 0).getReg();
160 Register Src = MI.getOperand(i: 1).getReg();
161 LLT Ty = MRI.getType(Reg: Dst);
162 unsigned EltSize = Ty.getScalarSizeInBits();
163
164 // Element size for a rev cannot be 64.
165 if (EltSize == 64)
166 return false;
167
168 unsigned NumElts = Ty.getNumElements();
169
170 // Try to produce a G_REV instruction
171 for (unsigned LaneSize : {64U, 32U, 16U}) {
172 if (isREVMask(M: ShuffleMask, EltSize, NumElts, BlockSize: LaneSize)) {
173 unsigned Opcode;
174 if (LaneSize == 64U)
175 Opcode = AArch64::G_REV64;
176 else if (LaneSize == 32U)
177 Opcode = AArch64::G_REV32;
178 else
179 Opcode = AArch64::G_BSWAP;
180
181 MatchInfo = ShuffleVectorPseudo(Opcode, Dst, {Src});
182 return true;
183 }
184 }
185
186 return false;
187}
188
189/// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with
190/// a G_TRN1 or G_TRN2 instruction.
191bool matchTRN(MachineInstr &MI, MachineRegisterInfo &MRI,
192 ShuffleVectorPseudo &MatchInfo) {
193 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
194 unsigned WhichResult;
195 unsigned OperandOrder;
196 ArrayRef<int> ShuffleMask = MI.getOperand(i: 3).getShuffleMask();
197 Register Dst = MI.getOperand(i: 0).getReg();
198 unsigned NumElts = MRI.getType(Reg: Dst).getNumElements();
199 if (!isTRNMask(M: ShuffleMask, NumElts, WhichResultOut&: WhichResult, OperandOrderOut&: OperandOrder))
200 return false;
201 unsigned Opc = (WhichResult == 0) ? AArch64::G_TRN1 : AArch64::G_TRN2;
202 Register V1 = MI.getOperand(i: OperandOrder == 0 ? 1 : 2).getReg();
203 Register V2 = MI.getOperand(i: OperandOrder == 0 ? 2 : 1).getReg();
204 MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
205 return true;
206}
207
208/// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with
209/// a G_UZP1 or G_UZP2 instruction.
210///
211/// \param [in] MI - The shuffle vector instruction.
212/// \param [out] MatchInfo - Either G_UZP1 or G_UZP2 on success.
213bool matchUZP(MachineInstr &MI, MachineRegisterInfo &MRI,
214 ShuffleVectorPseudo &MatchInfo) {
215 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
216 unsigned WhichResult;
217 ArrayRef<int> ShuffleMask = MI.getOperand(i: 3).getShuffleMask();
218 Register Dst = MI.getOperand(i: 0).getReg();
219 unsigned NumElts = MRI.getType(Reg: Dst).getNumElements();
220 if (!isUZPMask(M: ShuffleMask, NumElts, WhichResultOut&: WhichResult))
221 return false;
222 unsigned Opc = (WhichResult == 0) ? AArch64::G_UZP1 : AArch64::G_UZP2;
223 Register V1 = MI.getOperand(i: 1).getReg();
224 Register V2 = MI.getOperand(i: 2).getReg();
225 MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
226 return true;
227}
228
229bool matchZip(MachineInstr &MI, MachineRegisterInfo &MRI,
230 ShuffleVectorPseudo &MatchInfo) {
231 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
232 unsigned WhichResult;
233 unsigned OperandOrder;
234 ArrayRef<int> ShuffleMask = MI.getOperand(i: 3).getShuffleMask();
235 Register Dst = MI.getOperand(i: 0).getReg();
236 unsigned NumElts = MRI.getType(Reg: Dst).getNumElements();
237 if (!isZIPMask(M: ShuffleMask, NumElts, WhichResultOut&: WhichResult, OperandOrderOut&: OperandOrder))
238 return false;
239 unsigned Opc = (WhichResult == 0) ? AArch64::G_ZIP1 : AArch64::G_ZIP2;
240 Register V1 = MI.getOperand(i: OperandOrder == 0 ? 1 : 2).getReg();
241 Register V2 = MI.getOperand(i: OperandOrder == 0 ? 2 : 1).getReg();
242 MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
243 return true;
244}
245
246/// Helper function for matchDup.
247bool matchDupFromInsertVectorElt(int Lane, MachineInstr &MI,
248 MachineRegisterInfo &MRI,
249 ShuffleVectorPseudo &MatchInfo) {
250 if (Lane != 0)
251 return false;
252
253 // Try to match a vector splat operation into a dup instruction.
254 // We're looking for this pattern:
255 //
256 // %scalar:gpr(s64) = COPY $x0
257 // %undef:fpr(<2 x s64>) = G_IMPLICIT_DEF
258 // %cst0:gpr(s32) = G_CONSTANT i32 0
259 // %zerovec:fpr(<2 x s32>) = G_BUILD_VECTOR %cst0(s32), %cst0(s32)
260 // %ins:fpr(<2 x s64>) = G_INSERT_VECTOR_ELT %undef, %scalar(s64), %cst0(s32)
261 // %splat:fpr(<2 x s64>) = G_SHUFFLE_VECTOR %ins(<2 x s64>), %undef,
262 // %zerovec(<2 x s32>)
263 //
264 // ...into:
265 // %splat = G_DUP %scalar
266
267 // Begin matching the insert.
268 auto *InsMI = getOpcodeDef(Opcode: TargetOpcode::G_INSERT_VECTOR_ELT,
269 Reg: MI.getOperand(i: 1).getReg(), MRI);
270 if (!InsMI)
271 return false;
272 // Match the undef vector operand.
273 if (!getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: InsMI->getOperand(i: 1).getReg(),
274 MRI))
275 return false;
276
277 // Match the index constant 0.
278 if (!mi_match(R: InsMI->getOperand(i: 3).getReg(), MRI, P: m_ZeroInt()))
279 return false;
280
281 MatchInfo = ShuffleVectorPseudo(AArch64::G_DUP, MI.getOperand(i: 0).getReg(),
282 {InsMI->getOperand(i: 2).getReg()});
283 return true;
284}
285
286/// Helper function for matchDup.
287bool matchDupFromBuildVector(int Lane, MachineInstr &MI,
288 MachineRegisterInfo &MRI,
289 ShuffleVectorPseudo &MatchInfo) {
290 assert(Lane >= 0 && "Expected positive lane?");
291 int NumElements = MRI.getType(Reg: MI.getOperand(i: 1).getReg()).getNumElements();
292 // Test if the LHS is a BUILD_VECTOR. If it is, then we can just reference the
293 // lane's definition directly.
294 auto *BuildVecMI =
295 getOpcodeDef(Opcode: TargetOpcode::G_BUILD_VECTOR,
296 Reg: MI.getOperand(i: Lane < NumElements ? 1 : 2).getReg(), MRI);
297 // If Lane >= NumElements then it is point to RHS, just check from RHS
298 if (NumElements <= Lane)
299 Lane -= NumElements;
300
301 if (!BuildVecMI)
302 return false;
303 Register Reg = BuildVecMI->getOperand(i: Lane + 1).getReg();
304 MatchInfo =
305 ShuffleVectorPseudo(AArch64::G_DUP, MI.getOperand(i: 0).getReg(), {Reg});
306 return true;
307}
308
309bool matchDup(MachineInstr &MI, MachineRegisterInfo &MRI,
310 ShuffleVectorPseudo &MatchInfo) {
311 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
312 auto MaybeLane = getSplatIndex(MI);
313 if (!MaybeLane)
314 return false;
315 int Lane = *MaybeLane;
316 // If this is undef splat, generate it via "just" vdup, if possible.
317 if (Lane < 0)
318 Lane = 0;
319 if (matchDupFromInsertVectorElt(Lane, MI, MRI, MatchInfo))
320 return true;
321 if (matchDupFromBuildVector(Lane, MI, MRI, MatchInfo))
322 return true;
323 return false;
324}
325
326// Check if an EXT instruction can handle the shuffle mask when the vector
327// sources of the shuffle are the same.
328bool isSingletonExtMask(ArrayRef<int> M, LLT Ty) {
329 unsigned NumElts = Ty.getNumElements();
330
331 // Assume that the first shuffle index is not UNDEF. Fail if it is.
332 if (M[0] < 0)
333 return false;
334
335 // If this is a VEXT shuffle, the immediate value is the index of the first
336 // element. The other shuffle indices must be the successive elements after
337 // the first one.
338 unsigned ExpectedElt = M[0];
339 for (unsigned I = 1; I < NumElts; ++I) {
340 // Increment the expected index. If it wraps around, just follow it
341 // back to index zero and keep going.
342 ++ExpectedElt;
343 if (ExpectedElt == NumElts)
344 ExpectedElt = 0;
345
346 if (M[I] < 0)
347 continue; // Ignore UNDEF indices.
348 if (ExpectedElt != static_cast<unsigned>(M[I]))
349 return false;
350 }
351
352 return true;
353}
354
355bool matchEXT(MachineInstr &MI, MachineRegisterInfo &MRI,
356 ShuffleVectorPseudo &MatchInfo) {
357 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
358 Register Dst = MI.getOperand(i: 0).getReg();
359 LLT DstTy = MRI.getType(Reg: Dst);
360 Register V1 = MI.getOperand(i: 1).getReg();
361 Register V2 = MI.getOperand(i: 2).getReg();
362 auto Mask = MI.getOperand(i: 3).getShuffleMask();
363 uint64_t Imm;
364 auto ExtInfo = getExtMask(M: Mask, NumElts: DstTy.getNumElements());
365 uint64_t ExtFactor = MRI.getType(Reg: V1).getScalarSizeInBits() / 8;
366
367 if (!ExtInfo) {
368 if (!getOpcodeDef<GImplicitDef>(Reg: V2, MRI) ||
369 !isSingletonExtMask(M: Mask, Ty: DstTy))
370 return false;
371
372 Imm = Mask[0] * ExtFactor;
373 MatchInfo = ShuffleVectorPseudo(AArch64::G_EXT, Dst, {V1, V1, Imm});
374 return true;
375 }
376 bool ReverseExt;
377 std::tie(args&: ReverseExt, args&: Imm) = *ExtInfo;
378 if (ReverseExt)
379 std::swap(a&: V1, b&: V2);
380 Imm *= ExtFactor;
381 MatchInfo = ShuffleVectorPseudo(AArch64::G_EXT, Dst, {V1, V2, Imm});
382 return true;
383}
384
385/// Replace a G_SHUFFLE_VECTOR instruction with a pseudo.
386/// \p Opc is the opcode to use. \p MI is the G_SHUFFLE_VECTOR.
387void applyShuffleVectorPseudo(MachineInstr &MI, MachineRegisterInfo &MRI,
388 ShuffleVectorPseudo &MatchInfo) {
389 MachineIRBuilder MIRBuilder(MI);
390 if (MatchInfo.Opc == TargetOpcode::G_BSWAP) {
391 assert(MatchInfo.SrcOps.size() == 1);
392 LLT DstTy = MRI.getType(Reg: MatchInfo.Dst);
393 assert(DstTy == LLT::fixed_vector(8, 8) ||
394 DstTy == LLT::fixed_vector(16, 8));
395 LLT BSTy = DstTy == LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 8)
396 ? LLT::fixed_vector(NumElements: 4, ScalarTy: LLT::integer(SizeInBits: 16))
397 : LLT::fixed_vector(NumElements: 8, ScalarTy: LLT::integer(SizeInBits: 16));
398 // FIXME: NVCAST
399 auto BS1 = MIRBuilder.buildInstr(Opc: TargetOpcode::G_BITCAST, DstOps: {BSTy},
400 SrcOps: MatchInfo.SrcOps[0]);
401 auto BS2 = MIRBuilder.buildInstr(Opc: MatchInfo.Opc, DstOps: {BSTy}, SrcOps: {BS1});
402 MIRBuilder.buildInstr(Opc: TargetOpcode::G_BITCAST, DstOps: {MatchInfo.Dst}, SrcOps: {BS2});
403 } else
404 MIRBuilder.buildInstr(Opc: MatchInfo.Opc, DstOps: {MatchInfo.Dst}, SrcOps: MatchInfo.SrcOps);
405 MI.eraseFromParent();
406}
407
408/// Replace a G_SHUFFLE_VECTOR instruction with G_EXT.
409/// Special-cased because the constant operand must be emitted as a G_CONSTANT
410/// for the imported tablegen patterns to work.
411void applyEXT(MachineInstr &MI, ShuffleVectorPseudo &MatchInfo) {
412 MachineIRBuilder MIRBuilder(MI);
413 if (MatchInfo.SrcOps[2].getImm() == 0)
414 MIRBuilder.buildCopy(Res: MatchInfo.Dst, Op: MatchInfo.SrcOps[0]);
415 else {
416 // Tablegen patterns expect an i32 G_CONSTANT as the final op.
417 auto Cst = MIRBuilder.buildConstant(Res: LLT::integer(SizeInBits: 32),
418 Val: MatchInfo.SrcOps[2].getImm());
419 MIRBuilder.buildInstr(Opc: MatchInfo.Opc, DstOps: {MatchInfo.Dst},
420 SrcOps: {MatchInfo.SrcOps[0], MatchInfo.SrcOps[1], Cst});
421 }
422 MI.eraseFromParent();
423}
424
425void applyFullRev(MachineInstr &MI, MachineRegisterInfo &MRI) {
426 Register Dst = MI.getOperand(i: 0).getReg();
427 Register Src = MI.getOperand(i: 1).getReg();
428 LLT DstTy = MRI.getType(Reg: Dst);
429 assert(DstTy.getSizeInBits() == 128 &&
430 "Expected 128bit vector in applyFullRev");
431 MachineIRBuilder MIRBuilder(MI);
432 auto Cst = MIRBuilder.buildConstant(Res: LLT::integer(SizeInBits: 32), Val: 8);
433 auto Rev = MIRBuilder.buildInstr(Opc: AArch64::G_REV64, DstOps: {DstTy}, SrcOps: {Src});
434 MIRBuilder.buildInstr(Opc: AArch64::G_EXT, DstOps: {Dst}, SrcOps: {Rev, Rev, Cst});
435 MI.eraseFromParent();
436}
437
438bool matchNonConstInsert(MachineInstr &MI, MachineRegisterInfo &MRI) {
439 assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT);
440
441 auto ValAndVReg =
442 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 3).getReg(), MRI);
443 return !ValAndVReg;
444}
445
446void applyNonConstInsert(MachineInstr &MI, MachineRegisterInfo &MRI,
447 MachineIRBuilder &Builder) {
448 auto &Insert = cast<GInsertVectorElement>(Val&: MI);
449 Builder.setInstrAndDebugLoc(Insert);
450
451 Register Offset = Insert.getIndexReg();
452 LLT VecTy = MRI.getType(Reg: Insert.getReg(Idx: 0));
453 LLT EltTy = MRI.getType(Reg: Insert.getElementReg());
454 LLT IdxTy = MRI.getType(Reg: Insert.getIndexReg());
455
456 if (VecTy.isScalableVector())
457 return;
458
459 // Create a stack slot and store the vector into it
460 MachineFunction &MF = Builder.getMF();
461 Align Alignment(
462 std::min<uint64_t>(a: VecTy.getSizeInBytes().getKnownMinValue(), b: 16));
463 int FrameIdx = MF.getFrameInfo().CreateStackObject(Size: VecTy.getSizeInBytes(),
464 Alignment, isSpillSlot: false);
465 LLT FramePtrTy = LLT::pointer(AddressSpace: 0, SizeInBits: 64);
466 MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(MF, FI: FrameIdx);
467 auto StackTemp = Builder.buildFrameIndex(Res: FramePtrTy, Idx: FrameIdx);
468
469 Builder.buildStore(Val: Insert.getOperand(i: 1), Addr: StackTemp, PtrInfo, Alignment: Align(8));
470
471 // Get the pointer to the element, and be sure not to hit undefined behavior
472 // if the index is out of bounds.
473 assert(isPowerOf2_64(VecTy.getNumElements()) &&
474 "Expected a power-2 vector size");
475 auto Mask = Builder.buildConstant(Res: IdxTy, Val: VecTy.getNumElements() - 1);
476 Register And = Builder.buildAnd(Dst: IdxTy, Src0: Offset, Src1: Mask).getReg(Idx: 0);
477 auto EltSize = Builder.buildConstant(Res: IdxTy, Val: EltTy.getSizeInBytes());
478 Register Mul = Builder.buildMul(Dst: IdxTy, Src0: And, Src1: EltSize).getReg(Idx: 0);
479 Register EltPtr =
480 Builder.buildPtrAdd(Res: MRI.getType(Reg: StackTemp.getReg(Idx: 0)), Op0: StackTemp, Op1: Mul)
481 .getReg(Idx: 0);
482
483 // Write the inserted element
484 Builder.buildStore(Val: Insert.getElementReg(), Addr: EltPtr, PtrInfo, Alignment: Align(1));
485 // Reload the whole vector.
486 Builder.buildLoad(Res: Insert.getReg(Idx: 0), Addr: StackTemp, PtrInfo, Alignment: Align(8));
487 Insert.eraseFromParent();
488}
489
490/// Match a G_SHUFFLE_VECTOR with a mask which corresponds to a
491/// G_INSERT_VECTOR_ELT and G_EXTRACT_VECTOR_ELT pair.
492///
493/// e.g.
494/// %shuf = G_SHUFFLE_VECTOR %left, %right, shufflemask(0, 0)
495///
496/// Can be represented as
497///
498/// %extract = G_EXTRACT_VECTOR_ELT %left, 0
499/// %ins = G_INSERT_VECTOR_ELT %left, %extract, 1
500///
501bool matchINS(MachineInstr &MI, MachineRegisterInfo &MRI,
502 std::tuple<Register, int, Register, int> &MatchInfo) {
503 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
504 ArrayRef<int> ShuffleMask = MI.getOperand(i: 3).getShuffleMask();
505 Register Dst = MI.getOperand(i: 0).getReg();
506 int NumElts = MRI.getType(Reg: Dst).getNumElements();
507 auto DstIsLeftAndDstLane = isINSMask(M: ShuffleMask, NumInputElements: NumElts);
508 if (!DstIsLeftAndDstLane)
509 return false;
510 bool DstIsLeft;
511 int DstLane;
512 std::tie(args&: DstIsLeft, args&: DstLane) = *DstIsLeftAndDstLane;
513 Register Left = MI.getOperand(i: 1).getReg();
514 Register Right = MI.getOperand(i: 2).getReg();
515 Register DstVec = DstIsLeft ? Left : Right;
516 Register SrcVec = Left;
517
518 int SrcLane = ShuffleMask[DstLane];
519 if (SrcLane >= NumElts) {
520 SrcVec = Right;
521 SrcLane -= NumElts;
522 }
523
524 MatchInfo = std::make_tuple(args&: DstVec, args&: DstLane, args&: SrcVec, args&: SrcLane);
525 return true;
526}
527
528void applyINS(MachineInstr &MI, MachineRegisterInfo &MRI,
529 MachineIRBuilder &Builder,
530 std::tuple<Register, int, Register, int> &MatchInfo) {
531 Builder.setInstrAndDebugLoc(MI);
532 Register Dst = MI.getOperand(i: 0).getReg();
533 auto ScalarTy = MRI.getType(Reg: Dst).getElementType();
534 Register DstVec, SrcVec;
535 int DstLane, SrcLane;
536 std::tie(args&: DstVec, args&: DstLane, args&: SrcVec, args&: SrcLane) = MatchInfo;
537 auto SrcCst = Builder.buildConstant(Res: LLT::integer(SizeInBits: 64), Val: SrcLane);
538 auto Extract = Builder.buildExtractVectorElement(Res: ScalarTy, Val: SrcVec, Idx: SrcCst);
539 auto DstCst = Builder.buildConstant(Res: LLT::integer(SizeInBits: 64), Val: DstLane);
540 Builder.buildInsertVectorElement(Res: Dst, Val: DstVec, Elt: Extract, Idx: DstCst);
541 MI.eraseFromParent();
542}
543
544/// isVShiftRImm - Check if this is a valid vector for the immediate
545/// operand of a vector shift right operation. The value must be in the range:
546/// 1 <= Value <= ElementBits for a right shift.
547bool isVShiftRImm(Register Reg, MachineRegisterInfo &MRI, LLT Ty,
548 int64_t &Cnt) {
549 assert(Ty.isVector() && "vector shift count is not a vector type");
550 MachineInstr *MI = MRI.getVRegDef(Reg);
551 auto Cst = getAArch64VectorSplatScalar(MI: *MI, MRI);
552 if (!Cst)
553 return false;
554 Cnt = *Cst;
555 int64_t ElementBits = Ty.getScalarSizeInBits();
556 return Cnt >= 1 && Cnt <= ElementBits;
557}
558
559/// Match a vector G_ASHR or G_LSHR with a valid immediate shift.
560bool matchVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
561 int64_t &Imm) {
562 assert(MI.getOpcode() == TargetOpcode::G_ASHR ||
563 MI.getOpcode() == TargetOpcode::G_LSHR);
564 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
565 if (!Ty.isVector())
566 return false;
567 return isVShiftRImm(Reg: MI.getOperand(i: 2).getReg(), MRI, Ty, Cnt&: Imm);
568}
569
570void applyVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
571 int64_t &Imm) {
572 unsigned Opc = MI.getOpcode();
573 assert(Opc == TargetOpcode::G_ASHR || Opc == TargetOpcode::G_LSHR);
574 unsigned NewOpc =
575 Opc == TargetOpcode::G_ASHR ? AArch64::G_VASHR : AArch64::G_VLSHR;
576 MachineIRBuilder MIB(MI);
577 MIB.buildInstr(Opc: NewOpc, DstOps: {MI.getOperand(i: 0)}, SrcOps: {MI.getOperand(i: 1)}).addImm(Val: Imm);
578 MI.eraseFromParent();
579}
580
581/// Determine whether an integer G_ICMP against 1 or -1 can compare
582/// against 0 instead.
583///
584/// AArch64 can fold a compare-with-zero more cheaply than some non-arithmetic
585/// immediates (SUBS/ADDS, or TST when the LHS is an AND). When the predicate
586/// can be adjusted without changing semantics, the RHS may become 0.
587///
588/// Supported transforms (signed predicates only):
589/// (and X, Y) slt 1 => (and X, Y) sle 0
590/// (and X, Y) sge 1 => (and X, Y) sgt 0
591/// X sle -1 => X slt 0
592/// X sgt -1 => X sge 0
593///
594/// The compare-against-1 cases require the LHS to be G_AND because the
595/// compare-with-zero path enables ANDS (TST) selection, and ANDS flags are
596/// only reliable for those signed comparisons. This mirrors SelectionDAG
597/// emitComparison().
598///
599/// For compare-against--1 on a non-AND LHS, \p LHS must have a single
600/// non-debug use so other users are not left with a different immediate.
601///
602/// \param LHS The compare LHS register.
603/// \param C The constant RHS (only 1 or all-ones are considered).
604/// \param P In/out predicate; updated when a transform applies.
605/// \param MRI Used to inspect the LHS definition and use count.
606/// \returns true if \p P was updated and comparing against 0 is equivalent.
607static bool shouldBeAdjustedToZero(Register LHS, const APInt &C,
608 CmpInst::Predicate &P,
609 const MachineRegisterInfo &MRI) {
610 const bool IsAndLHS = getOpcodeDef<GAnd>(Reg: LHS, MRI) != nullptr;
611
612 if (C.isOne() && (P == CmpInst::ICMP_SLT || P == CmpInst::ICMP_SGE) &&
613 IsAndLHS) {
614 P = (P == CmpInst::ICMP_SLT) ? CmpInst::ICMP_SLE : CmpInst::ICMP_SGT;
615 return true;
616 }
617
618 if (!IsAndLHS && !MRI.hasOneNonDBGUse(RegNo: LHS))
619 return false;
620
621 if (C.isAllOnes() && (P == CmpInst::ICMP_SLE || P == CmpInst::ICMP_SGT)) {
622 P = (P == CmpInst::ICMP_SLE) ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGE;
623 return true;
624 }
625 return false;
626}
627
628/// Determine if it is possible to modify the \p RHS and predicate \p P of a
629/// G_ICMP instruction such that the right-hand side is an arithmetic immediate.
630///
631/// \returns A pair containing the updated immediate and predicate which may
632/// be used to optimize the instruction.
633///
634/// \note This assumes that the comparison has been legalized.
635std::optional<std::pair<uint64_t, CmpInst::Predicate>>
636tryAdjustICmpImmAndPred(Register LHS, Register RHS, CmpInst::Predicate P,
637 const MachineRegisterInfo &MRI) {
638 const auto &Ty = MRI.getType(Reg: RHS);
639 if (Ty.isVector())
640 return std::nullopt;
641 assert((Ty.getSizeInBits() == 32 || Ty.getSizeInBits() == 64) &&
642 "Expected 32 or 64 bit compare only?");
643
644 // If the RHS is not a constant, or the RHS is already a valid arithmetic
645 // immediate, then there is nothing to change.
646 auto ValAndVReg = getIConstantVRegValWithLookThrough(VReg: RHS, MRI);
647 if (!ValAndVReg)
648 return std::nullopt;
649 APInt C = ValAndVReg->Value;
650 if (shouldBeAdjustedToZero(LHS, C, P, MRI))
651 return {{0, P}};
652
653 if (AArch64_AM::isLegalCmpImmed(C))
654 return std::nullopt;
655
656 uint64_t OriginalC = C.getZExtValue();
657
658 // We have a non-arithmetic immediate. Check if adjusting the immediate and
659 // adjusting the predicate will result in a legal arithmetic immediate.
660 switch (P) {
661 default:
662 return std::nullopt;
663 case CmpInst::ICMP_SLT:
664 case CmpInst::ICMP_SGE:
665 // Check for
666 //
667 // x slt c => x sle c - 1
668 // x sge c => x sgt c - 1
669 //
670 // When c is not the smallest possible negative number.
671 if (C.isMinSignedValue())
672 return std::nullopt;
673 P = (P == CmpInst::ICMP_SLT) ? CmpInst::ICMP_SLE : CmpInst::ICMP_SGT;
674 C = C - 1;
675 break;
676 case CmpInst::ICMP_ULT:
677 case CmpInst::ICMP_UGE:
678 // Check for
679 //
680 // x ult c => x ule c - 1
681 // x uge c => x ugt c - 1
682 //
683 // When c is not zero.
684 assert(!C.isZero() && "C should not be zero here!");
685 P = (P == CmpInst::ICMP_ULT) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
686 C = C - 1;
687 break;
688 case CmpInst::ICMP_SLE:
689 case CmpInst::ICMP_SGT:
690 // Check for
691 //
692 // x sle c => x slt c + 1
693 // x sgt c => s sge c + 1
694 //
695 // When c is not the largest possible signed integer.
696 if (C.isMaxSignedValue())
697 return std::nullopt;
698 P = (P == CmpInst::ICMP_SLE) ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGE;
699 C = C + 1;
700 break;
701 case CmpInst::ICMP_ULE:
702 case CmpInst::ICMP_UGT:
703 // Check for
704 //
705 // x ule c => x ult c + 1
706 // x ugt c => s uge c + 1
707 //
708 // When c is not the largest possible unsigned integer.
709 if (C.isAllOnes())
710 return std::nullopt;
711 P = (P == CmpInst::ICMP_ULE) ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE;
712 C = C + 1;
713 break;
714 }
715
716 // Check if the new constant is valid, and return the updated constant and
717 // predicate if it is.
718 uint64_t NewC = C.getZExtValue();
719 if (AArch64_AM::isLegalCmpImmed(C))
720 return {{NewC, P}};
721
722 auto NumberOfInstrToLoadImm = [=](uint64_t Imm) {
723 SmallVector<AArch64_IMM::ImmInsnModel> Insn;
724 AArch64_IMM::expandMOVImm(Imm, BitSize: 32, Insn);
725 return Insn.size();
726 };
727
728 if (NumberOfInstrToLoadImm(OriginalC) > NumberOfInstrToLoadImm(NewC))
729 return {{NewC, P}};
730
731 return std::nullopt;
732}
733
734/// Determine whether or not it is possible to update the RHS and predicate of
735/// a G_ICMP instruction such that the RHS will be selected as an arithmetic
736/// immediate.
737///
738/// \p MI - The G_ICMP instruction
739/// \p MatchInfo - The new RHS immediate and predicate on success
740///
741/// See tryAdjustICmpImmAndPred for valid transformations.
742bool matchAdjustICmpImmAndPred(
743 MachineInstr &MI, const MachineRegisterInfo &MRI,
744 std::pair<uint64_t, CmpInst::Predicate> &MatchInfo) {
745 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
746 Register LHS = MI.getOperand(i: 2).getReg();
747 Register RHS = MI.getOperand(i: 3).getReg();
748 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
749 if (auto MaybeNewImmAndPred = tryAdjustICmpImmAndPred(LHS, RHS, P: Pred, MRI)) {
750 MatchInfo = *MaybeNewImmAndPred;
751 return true;
752 }
753 return false;
754}
755
756void applyAdjustICmpImmAndPred(
757 MachineInstr &MI, std::pair<uint64_t, CmpInst::Predicate> &MatchInfo,
758 MachineIRBuilder &MIB, GISelChangeObserver &Observer) {
759 MIB.setInstrAndDebugLoc(MI);
760 MachineOperand &RHS = MI.getOperand(i: 3);
761 MachineRegisterInfo &MRI = *MIB.getMRI();
762 auto Cst = MIB.buildConstant(Res: MRI.cloneVirtualRegister(VReg: RHS.getReg()),
763 Val: MatchInfo.first);
764 Observer.changingInstr(MI);
765 RHS.setReg(Cst->getOperand(i: 0).getReg());
766 MI.getOperand(i: 1).setPredicate(MatchInfo.second);
767 Observer.changedInstr(MI);
768}
769
770bool matchDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
771 std::pair<unsigned, int> &MatchInfo) {
772 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
773 Register Src1Reg = MI.getOperand(i: 1).getReg();
774 const LLT SrcTy = MRI.getType(Reg: Src1Reg);
775 const LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
776
777 auto LaneIdx = getSplatIndex(MI);
778 if (!LaneIdx)
779 return false;
780
781 // The lane idx should be within the first source vector.
782 if (*LaneIdx >= SrcTy.getNumElements())
783 return false;
784
785 if (DstTy != SrcTy)
786 return false;
787
788 LLT ScalarTy = SrcTy.getElementType();
789 unsigned ScalarSize = ScalarTy.getSizeInBits();
790
791 unsigned Opc = 0;
792 switch (SrcTy.getNumElements()) {
793 case 2:
794 if (ScalarSize == 64)
795 Opc = AArch64::G_DUPLANE64;
796 else if (ScalarSize == 32)
797 Opc = AArch64::G_DUPLANE32;
798 break;
799 case 4:
800 if (ScalarSize == 32)
801 Opc = AArch64::G_DUPLANE32;
802 else if (ScalarSize == 16)
803 Opc = AArch64::G_DUPLANE16;
804 break;
805 case 8:
806 if (ScalarSize == 8)
807 Opc = AArch64::G_DUPLANE8;
808 else if (ScalarSize == 16)
809 Opc = AArch64::G_DUPLANE16;
810 break;
811 case 16:
812 if (ScalarSize == 8)
813 Opc = AArch64::G_DUPLANE8;
814 break;
815 default:
816 break;
817 }
818 if (!Opc)
819 return false;
820
821 MatchInfo.first = Opc;
822 MatchInfo.second = *LaneIdx;
823 return true;
824}
825
826void applyDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
827 MachineIRBuilder &B, std::pair<unsigned, int> &MatchInfo) {
828 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
829 Register Src1Reg = MI.getOperand(i: 1).getReg();
830 const LLT SrcTy = MRI.getType(Reg: Src1Reg);
831
832 B.setInstrAndDebugLoc(MI);
833 auto Lane = B.buildConstant(Res: LLT::integer(SizeInBits: 64), Val: MatchInfo.second);
834
835 Register DupSrc = MI.getOperand(i: 1).getReg();
836 // For types like <2 x s32>, we can use G_DUPLANE32, with a <4 x s32> source.
837 // To do this, we can use a G_CONCAT_VECTORS to do the widening.
838 if (SrcTy.getSizeInBits() == 64) {
839 auto Undef = B.buildUndef(Res: SrcTy);
840 DupSrc = B.buildConcatVectors(Res: SrcTy.multiplyElements(Factor: 2),
841 Ops: {Src1Reg, Undef.getReg(Idx: 0)})
842 .getReg(Idx: 0);
843 }
844 B.buildInstr(Opc: MatchInfo.first, DstOps: {MI.getOperand(i: 0).getReg()}, SrcOps: {DupSrc, Lane});
845 MI.eraseFromParent();
846}
847
848bool matchScalarizeVectorUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI) {
849 auto &Unmerge = cast<GUnmerge>(Val&: MI);
850 Register Src1Reg = Unmerge.getReg(Idx: Unmerge.getNumOperands() - 1);
851 const LLT SrcTy = MRI.getType(Reg: Src1Reg);
852 if (SrcTy.getSizeInBits() != 128 && SrcTy.getSizeInBits() != 64)
853 return false;
854 return SrcTy.isVector() && !SrcTy.isScalable() &&
855 (Unmerge.getNumOperands() == (unsigned)SrcTy.getNumElements() + 1 ||
856 (Unmerge.getNumDefs() == 2 && SrcTy.getSizeInBits() == 128 &&
857 MRI.getType(Reg: Unmerge.getReg(Idx: 0)).getSizeInBits() == 64));
858}
859
860void applyScalarizeVectorUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
861 MachineIRBuilder &B) {
862 auto &Unmerge = cast<GUnmerge>(Val&: MI);
863 Register Src1Reg = Unmerge.getReg(Idx: Unmerge.getNumOperands() - 1);
864 const LLT SrcTy = MRI.getType(Reg: Src1Reg);
865 const LLT DstTy = MRI.getType(Reg: Unmerge.getReg(Idx: 0));
866 assert((SrcTy.isVector() && !SrcTy.isScalable()) &&
867 "Expected a fixed length vector");
868
869 if (DstTy.isVector()) {
870 assert(Unmerge.getNumDefs() == 2);
871 if (!MRI.use_nodbg_empty(RegNo: Unmerge.getReg(Idx: 0)))
872 B.buildExtractSubvector(Res: Unmerge.getReg(Idx: 0), Src: Src1Reg, Index: 0);
873 if (!MRI.use_nodbg_empty(RegNo: Unmerge.getReg(Idx: 1)))
874 B.buildExtractSubvector(Res: Unmerge.getReg(Idx: 1), Src: Src1Reg,
875 Index: SrcTy.getNumElements() / 2);
876 } else {
877 for (int I = 0; I < SrcTy.getNumElements(); ++I)
878 if (!MRI.use_nodbg_empty(RegNo: Unmerge.getReg(Idx: I)))
879 B.buildExtractVectorElementConstant(Res: Unmerge.getReg(Idx: I), Val: Src1Reg, Idx: I);
880 }
881 MI.eraseFromParent();
882}
883
884bool matchBuildVectorToDup(MachineInstr &MI, Register &Src,
885 MachineRegisterInfo &MRI) {
886 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
887
888 // Later, during selection, we'll try to match imported patterns using
889 // immAllOnesV and immAllZerosV. These require G_BUILD_VECTOR. Don't lower
890 // G_BUILD_VECTORs which could match those patterns.
891 if (isBuildVectorAllZeros(MI, MRI) || isBuildVectorAllOnes(MI, MRI))
892 return false;
893
894 // Find buildvector which always uses the same register or undef. Return true
895 // so long as at least 2 registers were found (not all-undef or only 1
896 // non-undef entry).
897 Register Reg = 0;
898 unsigned NumNonUndef = 0;
899 for (const MachineOperand &Op : drop_begin(RangeOrContainer: MI.operands())) {
900 if (getOpcodeDef<GImplicitDef>(Reg: Op.getReg(), MRI))
901 continue;
902
903 if (!Reg)
904 Reg = Op.getReg();
905 else if (Op.getReg() != Reg)
906 return false;
907 NumNonUndef++;
908 }
909
910 Src = Reg;
911 return Reg && NumNonUndef > 1;
912}
913
914void applyBuildVectorToDup(MachineInstr &MI, Register Src,
915 MachineRegisterInfo &MRI, MachineIRBuilder &B) {
916 B.setInstrAndDebugLoc(MI);
917 B.buildInstr(Opc: AArch64::G_DUP, DstOps: {MI.getOperand(i: 0).getReg()}, SrcOps: {Src});
918 MI.eraseFromParent();
919}
920
921/// \returns how many instructions would be saved by folding a G_ICMP's shift
922/// and/or extension operations.
923static unsigned getCmpOperandFoldingProfit(Register CmpOp,
924 MachineRegisterInfo &MRI) {
925 // FIXME: This is duplicated with the selector. (See: selectShiftedRegister)
926 auto IsSupportedExtend = [&](const MachineInstr &MI) {
927 if (MI.getOpcode() == TargetOpcode::G_SEXT_INREG)
928 return true;
929 if (MI.getOpcode() == TargetOpcode::G_AND) {
930 auto ValAndVReg =
931 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
932 if (ValAndVReg) {
933 uint64_t Mask = ValAndVReg->Value.getZExtValue();
934 return (Mask == 0xFF || Mask == 0xFFFF || Mask == 0xFFFFFFFF);
935 }
936 }
937 return false;
938 };
939
940 // No instructions to save if there's more than one use or no uses.
941 if (!MRI.hasOneNonDBGUse(RegNo: CmpOp))
942 return 0;
943
944 MachineInstr *Def = getDefIgnoringCopies(Reg: CmpOp, MRI);
945 if (IsSupportedExtend(*Def))
946 return 1;
947
948 unsigned Opc = Def->getOpcode();
949 if (Opc == TargetOpcode::G_SHL || Opc == TargetOpcode::G_LSHR ||
950 Opc == TargetOpcode::G_ASHR) {
951 auto MaybeShiftAmt =
952 getIConstantVRegValWithLookThrough(VReg: Def->getOperand(i: 2).getReg(), MRI);
953 if (MaybeShiftAmt) {
954 uint64_t ShiftAmt = MaybeShiftAmt->Value.getZExtValue();
955 MachineInstr *ShiftLHS =
956 getDefIgnoringCopies(Reg: Def->getOperand(i: 1).getReg(), MRI);
957 if (IsSupportedExtend(*ShiftLHS))
958 return (ShiftAmt <= 4) ? 2 : 1;
959 LLT Ty = MRI.getType(Reg: Def->getOperand(i: 0).getReg());
960 if (Ty.isVector())
961 return 0;
962 unsigned ShiftSize = Ty.getSizeInBits();
963 if ((ShiftSize == 32 && ShiftAmt <= 31) ||
964 (ShiftSize == 64 && ShiftAmt <= 63))
965 return 1;
966 }
967 }
968
969 return 0;
970}
971
972/// \returns true if it would be profitable to swap the LHS and RHS of a G_ICMP
973/// instruction \p MI.
974bool trySwapICmpOperands(MachineInstr &MI, MachineRegisterInfo &MRI) {
975 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
976 // Swap the operands if it would introduce a profitable folding opportunity.
977 // (e.g. a shift + extend).
978 //
979 // For example:
980 // lsl w13, w11, #1
981 // cmp w13, w12
982 // can be turned into:
983 // cmp w12, w11, lsl #1
984
985 // Don't swap if there's a constant on the RHS and it is a legal compare
986 // immediate, because we know we can fold that.
987 Register RHS = MI.getOperand(i: 3).getReg();
988 auto RHSCst = getIConstantVRegValWithLookThrough(VReg: RHS, MRI);
989 if (RHSCst && AArch64_AM::isLegalCmpImmed(C: RHSCst->Value))
990 return false;
991
992 Register LHS = MI.getOperand(i: 2).getReg();
993 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
994 auto GetRegForProfit = [&](Register Reg) {
995 MachineInstr *Def = getDefIgnoringCopies(Reg, MRI);
996 return isCMN(MaybeSub: Def, Pred, MRI) ? Def->getOperand(i: 2).getReg() : Reg;
997 };
998
999 // Don't have a constant on the RHS. If we swap the LHS and RHS of the
1000 // compare, would we be able to fold more instructions?
1001 Register TheLHS = GetRegForProfit(LHS);
1002 Register TheRHS = GetRegForProfit(RHS);
1003
1004 // If the LHS is more likely to give us a folding opportunity, then swap the
1005 // LHS and RHS.
1006 return (getCmpOperandFoldingProfit(CmpOp: TheLHS, MRI) >
1007 getCmpOperandFoldingProfit(CmpOp: TheRHS, MRI));
1008}
1009
1010void applySwapICmpOperands(MachineInstr &MI, GISelChangeObserver &Observer) {
1011 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
1012 Register LHS = MI.getOperand(i: 2).getReg();
1013 Register RHS = MI.getOperand(i: 3).getReg();
1014 Observer.changingInstr(MI);
1015 MI.getOperand(i: 1).setPredicate(CmpInst::getSwappedPredicate(pred: Pred));
1016 MI.getOperand(i: 2).setReg(RHS);
1017 MI.getOperand(i: 3).setReg(LHS);
1018 Observer.changedInstr(MI);
1019}
1020
1021/// \returns a function which builds a vector floating point compare instruction
1022/// for a condition code \p CC.
1023/// \param [in] NoNans - True if the instruction has nnan flag.
1024std::function<Register(MachineIRBuilder &)>
1025getVectorFCMP(AArch64CC::CondCode CC, Register LHS, Register RHS, bool NoNans,
1026 MachineRegisterInfo &MRI) {
1027 LLT OldTy = MRI.getType(Reg: LHS);
1028 LLT DstTy = LLT::fixed_vector(NumElements: OldTy.getNumElements(),
1029 ScalarTy: LLT::integer(SizeInBits: OldTy.getScalarSizeInBits()));
1030 assert(DstTy.isVector() && "Expected vector types only?");
1031 switch (CC) {
1032 default:
1033 llvm_unreachable("Unexpected condition code!");
1034 case AArch64CC::NE:
1035 return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
1036 auto FCmp = MIB.buildInstr(Opc: AArch64::G_FCMEQ, DstOps: {DstTy}, SrcOps: {LHS, RHS});
1037 return MIB.buildNot(Dst: DstTy, Src0: FCmp).getReg(Idx: 0);
1038 };
1039 case AArch64CC::EQ:
1040 return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
1041 return MIB.buildInstr(Opc: AArch64::G_FCMEQ, DstOps: {DstTy}, SrcOps: {LHS, RHS}).getReg(Idx: 0);
1042 };
1043 case AArch64CC::GE:
1044 return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
1045 return MIB.buildInstr(Opc: AArch64::G_FCMGE, DstOps: {DstTy}, SrcOps: {LHS, RHS}).getReg(Idx: 0);
1046 };
1047 case AArch64CC::GT:
1048 return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
1049 return MIB.buildInstr(Opc: AArch64::G_FCMGT, DstOps: {DstTy}, SrcOps: {LHS, RHS}).getReg(Idx: 0);
1050 };
1051 case AArch64CC::LS:
1052 return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
1053 return MIB.buildInstr(Opc: AArch64::G_FCMGE, DstOps: {DstTy}, SrcOps: {RHS, LHS}).getReg(Idx: 0);
1054 };
1055 case AArch64CC::MI:
1056 return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
1057 return MIB.buildInstr(Opc: AArch64::G_FCMGT, DstOps: {DstTy}, SrcOps: {RHS, LHS}).getReg(Idx: 0);
1058 };
1059 }
1060}
1061
1062/// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo.
1063bool matchLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
1064 MachineIRBuilder &MIB) {
1065 assert(MI.getOpcode() == TargetOpcode::G_FCMP);
1066 const auto &ST = MI.getMF()->getSubtarget<AArch64Subtarget>();
1067
1068 Register Dst = MI.getOperand(i: 0).getReg();
1069 LLT DstTy = MRI.getType(Reg: Dst);
1070 if (!DstTy.isVector() || !ST.hasNEON())
1071 return false;
1072 Register LHS = MI.getOperand(i: 2).getReg();
1073 unsigned EltSize = MRI.getType(Reg: LHS).getScalarSizeInBits();
1074 if (EltSize == 16 && !ST.hasFullFP16())
1075 return false;
1076 if (EltSize != 16 && EltSize != 32 && EltSize != 64)
1077 return false;
1078
1079 return true;
1080}
1081
1082/// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo.
1083void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
1084 MachineIRBuilder &MIB) {
1085 assert(MI.getOpcode() == TargetOpcode::G_FCMP);
1086
1087 const auto &CmpMI = cast<GFCmp>(Val&: MI);
1088
1089 Register Dst = CmpMI.getReg(Idx: 0);
1090 CmpInst::Predicate Pred = CmpMI.getCond();
1091 Register LHS = CmpMI.getLHSReg();
1092 Register RHS = CmpMI.getRHSReg();
1093
1094 LLT DstTy = MRI.getType(Reg: Dst);
1095
1096 bool Invert = false;
1097 AArch64CC::CondCode CC, CC2 = AArch64CC::AL;
1098 if ((Pred == CmpInst::Predicate::FCMP_ORD ||
1099 Pred == CmpInst::Predicate::FCMP_UNO) &&
1100 isBuildVectorAllZeros(MI: *MRI.getVRegDef(Reg: RHS), MRI)) {
1101 // The special case "fcmp ord %a, 0" is the canonical check that LHS isn't
1102 // NaN, so equivalent to a == a and doesn't need the two comparisons an
1103 // "ord" normally would.
1104 // Similarly, "fcmp uno %a, 0" is the canonical check that LHS is NaN and is
1105 // thus equivalent to a != a.
1106 RHS = LHS;
1107 CC = Pred == CmpInst::Predicate::FCMP_ORD ? AArch64CC::EQ : AArch64CC::NE;
1108 } else
1109 changeVectorFCMPPredToAArch64CC(P: Pred, CondCode&: CC, CondCode2&: CC2, Invert);
1110
1111 // Instead of having an apply function, just build here to simplify things.
1112 MIB.setInstrAndDebugLoc(MI);
1113
1114 // TODO: Also consider GISelValueTracking result if eligible.
1115 const bool NoNans = MI.getFlag(Flag: MachineInstr::FmNoNans);
1116
1117 auto Cmp = getVectorFCMP(CC, LHS, RHS, NoNans, MRI);
1118 Register CmpRes;
1119 if (CC2 == AArch64CC::AL)
1120 CmpRes = Cmp(MIB);
1121 else {
1122 auto Cmp2 = getVectorFCMP(CC: CC2, LHS, RHS, NoNans, MRI);
1123 auto Cmp2Dst = Cmp2(MIB);
1124 auto Cmp1Dst = Cmp(MIB);
1125 CmpRes = MIB.buildOr(Dst: DstTy, Src0: Cmp1Dst, Src1: Cmp2Dst).getReg(Idx: 0);
1126 }
1127 if (Invert)
1128 CmpRes = MIB.buildNot(Dst: DstTy, Src0: CmpRes).getReg(Idx: 0);
1129 MRI.replaceRegWith(FromReg: Dst, ToReg: CmpRes);
1130 MI.eraseFromParent();
1131}
1132
1133// Matches G_BUILD_VECTOR where at least one source operand is not a constant
1134bool matchLowerBuildToInsertVecElt(MachineInstr &MI, MachineRegisterInfo &MRI) {
1135 auto *GBuildVec = cast<GBuildVector>(Val: &MI);
1136
1137 // Check if the values are all constants
1138 for (unsigned I = 0; I < GBuildVec->getNumSources(); ++I) {
1139 auto ConstVal =
1140 getAnyConstantVRegValWithLookThrough(VReg: GBuildVec->getSourceReg(I), MRI);
1141
1142 if (!ConstVal.has_value())
1143 return true;
1144 }
1145
1146 return false;
1147}
1148
1149void applyLowerBuildToInsertVecElt(MachineInstr &MI, MachineRegisterInfo &MRI,
1150 MachineIRBuilder &B) {
1151 auto *GBuildVec = cast<GBuildVector>(Val: &MI);
1152 LLT DstTy = MRI.getType(Reg: GBuildVec->getReg(Idx: 0));
1153 Register DstReg = B.buildUndef(Res: DstTy).getReg(Idx: 0);
1154
1155 for (unsigned I = 0; I < GBuildVec->getNumSources(); ++I) {
1156 Register SrcReg = GBuildVec->getSourceReg(I);
1157 if (mi_match(R: SrcReg, MRI, P: m_GImplicitDef()))
1158 continue;
1159 auto IdxReg = B.buildConstant(Res: LLT::integer(SizeInBits: 64), Val: I);
1160 DstReg =
1161 B.buildInsertVectorElement(Res: DstTy, Val: DstReg, Elt: SrcReg, Idx: IdxReg).getReg(Idx: 0);
1162 }
1163 B.buildCopy(Res: GBuildVec->getReg(Idx: 0), Op: DstReg);
1164 GBuildVec->eraseFromParent();
1165}
1166
1167bool matchFormTruncstore(MachineInstr &MI, MachineRegisterInfo &MRI,
1168 Register &SrcReg) {
1169 assert(MI.getOpcode() == TargetOpcode::G_STORE);
1170 Register DstReg = MI.getOperand(i: 0).getReg();
1171 if (MRI.getType(Reg: DstReg).isVector())
1172 return false;
1173 // Match a store of a truncate.
1174 if (!mi_match(R: DstReg, MRI, P: m_GTrunc(Src: m_Reg(R&: SrcReg))))
1175 return false;
1176 // Only form truncstores for value types of max 64b.
1177 return MRI.getType(Reg: SrcReg).getSizeInBits() <= 64;
1178}
1179
1180void applyFormTruncstore(MachineInstr &MI, MachineRegisterInfo &MRI,
1181 MachineIRBuilder &B, GISelChangeObserver &Observer,
1182 Register &SrcReg) {
1183 assert(MI.getOpcode() == TargetOpcode::G_STORE);
1184 Observer.changingInstr(MI);
1185 MI.getOperand(i: 0).setReg(SrcReg);
1186 Observer.changedInstr(MI);
1187}
1188
1189// Lower vector G_SEXT_INREG back to shifts for selection. We allowed them to
1190// form in the first place for combine opportunities, so any remaining ones
1191// at this stage need be lowered back.
1192bool matchVectorSextInReg(MachineInstr &MI, MachineRegisterInfo &MRI) {
1193 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1194 Register DstReg = MI.getOperand(i: 0).getReg();
1195 LLT DstTy = MRI.getType(Reg: DstReg);
1196 return DstTy.isVector();
1197}
1198
1199void applyVectorSextInReg(MachineInstr &MI, MachineRegisterInfo &MRI,
1200 MachineIRBuilder &B, GISelChangeObserver &Observer) {
1201 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1202 B.setInstrAndDebugLoc(MI);
1203 LegalizerHelper Helper(*MI.getMF(), Observer, B);
1204 Helper.lower(MI, TypeIdx: 0, /* Unused hint type */ Ty: LLT());
1205}
1206
1207/// Combine <N x t>, unused = unmerge(G_EXT <2*N x t> v, undef, N)
1208/// => unused, <N x t> = unmerge v
1209bool matchUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
1210 Register &MatchInfo) {
1211 auto &Unmerge = cast<GUnmerge>(Val&: MI);
1212 if (Unmerge.getNumDefs() != 2)
1213 return false;
1214 if (!MRI.use_nodbg_empty(RegNo: Unmerge.getReg(Idx: 1)))
1215 return false;
1216
1217 LLT DstTy = MRI.getType(Reg: Unmerge.getReg(Idx: 0));
1218 if (!DstTy.isVector())
1219 return false;
1220
1221 MachineInstr *Ext = getOpcodeDef(Opcode: AArch64::G_EXT, Reg: Unmerge.getSourceReg(), MRI);
1222 if (!Ext)
1223 return false;
1224
1225 Register ExtSrc1 = Ext->getOperand(i: 1).getReg();
1226 Register ExtSrc2 = Ext->getOperand(i: 2).getReg();
1227 auto LowestVal =
1228 getIConstantVRegValWithLookThrough(VReg: Ext->getOperand(i: 3).getReg(), MRI);
1229 if (!LowestVal || LowestVal->Value.getZExtValue() != DstTy.getSizeInBytes())
1230 return false;
1231
1232 if (!getOpcodeDef<GImplicitDef>(Reg: ExtSrc2, MRI))
1233 return false;
1234
1235 MatchInfo = ExtSrc1;
1236 return true;
1237}
1238
1239void applyUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
1240 MachineIRBuilder &B,
1241 GISelChangeObserver &Observer, Register &SrcReg) {
1242 Observer.changingInstr(MI);
1243 // Swap dst registers.
1244 Register Dst1 = MI.getOperand(i: 0).getReg();
1245 MI.getOperand(i: 0).setReg(MI.getOperand(i: 1).getReg());
1246 MI.getOperand(i: 1).setReg(Dst1);
1247 MI.getOperand(i: 2).setReg(SrcReg);
1248 Observer.changedInstr(MI);
1249}
1250
1251// Match mul({z/s}ext , {z/s}ext) => {u/s}mull OR
1252// Match v2s64 mul instructions, which will then be scalarised later on
1253// Doing these two matches in one function to ensure that the order of matching
1254// will always be the same.
1255// Try lowering MUL to MULL before trying to scalarize if needed.
1256bool matchMulv2s64(MachineInstr &MI, MachineRegisterInfo &MRI) {
1257 // Get the instructions that defined the source operand
1258 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
1259 return DstTy == LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 64);
1260}
1261
1262void applyMulv2s64(MachineInstr &MI, MachineRegisterInfo &MRI,
1263 MachineIRBuilder &B, GISelChangeObserver &Observer) {
1264 assert(MI.getOpcode() == TargetOpcode::G_MUL &&
1265 "Expected a G_MUL instruction");
1266
1267 // Get the instructions that defined the source operand
1268 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
1269 assert(DstTy == LLT::fixed_vector(2, 64) && "Expected v2s64 Mul");
1270 LegalizerHelper Helper(*MI.getMF(), Observer, B);
1271 Helper.fewerElementsVector(
1272 MI, TypeIdx: 0,
1273 NarrowTy: DstTy.changeElementCount(EC: DstTy.getElementCount().divideCoefficientBy(RHS: 2)));
1274}
1275
1276class AArch64PostLegalizerLoweringImpl : public Combiner {
1277protected:
1278 const CombinerHelper Helper;
1279 const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig;
1280 const AArch64Subtarget &STI;
1281
1282public:
1283 AArch64PostLegalizerLoweringImpl(
1284 MachineFunction &MF, CombinerInfo &CInfo, GISelCSEInfo *CSEInfo,
1285 const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig,
1286 const AArch64Subtarget &STI);
1287
1288 static const char *getName() { return "AArch6400PreLegalizerCombiner"; }
1289
1290 bool tryCombineAll(MachineInstr &I) const override;
1291
1292private:
1293#define GET_GICOMBINER_CLASS_MEMBERS
1294#include "AArch64GenPostLegalizeGILowering.inc"
1295#undef GET_GICOMBINER_CLASS_MEMBERS
1296};
1297
1298#define GET_GICOMBINER_IMPL
1299#include "AArch64GenPostLegalizeGILowering.inc"
1300#undef GET_GICOMBINER_IMPL
1301
1302AArch64PostLegalizerLoweringImpl::AArch64PostLegalizerLoweringImpl(
1303 MachineFunction &MF, CombinerInfo &CInfo, GISelCSEInfo *CSEInfo,
1304 const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig,
1305 const AArch64Subtarget &STI)
1306 : Combiner(MF, CInfo, /*VT*/ nullptr, CSEInfo),
1307 Helper(Observer, B, /*IsPreLegalize*/ true), RuleConfig(RuleConfig),
1308 STI(STI),
1309#define GET_GICOMBINER_CONSTRUCTOR_INITS
1310#include "AArch64GenPostLegalizeGILowering.inc"
1311#undef GET_GICOMBINER_CONSTRUCTOR_INITS
1312{
1313}
1314
1315bool runPostLegalizerLowering(
1316 MachineFunction &MF,
1317 const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig) {
1318 if (MF.getProperties().hasFailedISel())
1319 return false;
1320 const Function &F = MF.getFunction();
1321
1322 const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>();
1323 CombinerInfo CInfo(/*AllowIllegalOps=*/true, /*ShouldLegalizeIllegal=*/false,
1324 /*LegalizerInfo=*/nullptr, /*OptEnabled=*/true,
1325 F.hasOptSize(), F.hasMinSize());
1326 // Disable fixed-point iteration to reduce compile-time
1327 CInfo.MaxIterations = 1;
1328 CInfo.ObserverLvl = CombinerInfo::ObserverLevel::SinglePass;
1329 // PostLegalizerCombiner performs DCE, so a full DCE pass is unnecessary.
1330 CInfo.EnableFullDCE = false;
1331 AArch64PostLegalizerLoweringImpl Impl(MF, CInfo, /*CSEInfo=*/nullptr,
1332 RuleConfig, ST);
1333 return Impl.combineMachineInstrs();
1334}
1335
1336class AArch64PostLegalizerLoweringLegacy : public MachineFunctionPass {
1337public:
1338 static char ID;
1339
1340 AArch64PostLegalizerLoweringLegacy();
1341
1342 StringRef getPassName() const override {
1343 return "AArch64PostLegalizerLowering";
1344 }
1345
1346 bool runOnMachineFunction(MachineFunction &MF) override;
1347 void getAnalysisUsage(AnalysisUsage &AU) const override;
1348
1349private:
1350 AArch64PostLegalizerLoweringImplRuleConfig RuleConfig;
1351};
1352} // end anonymous namespace
1353
1354void AArch64PostLegalizerLoweringLegacy::getAnalysisUsage(
1355 AnalysisUsage &AU) const {
1356 AU.setPreservesCFG();
1357 getSelectionDAGFallbackAnalysisUsage(AU);
1358 MachineFunctionPass::getAnalysisUsage(AU);
1359}
1360
1361AArch64PostLegalizerLoweringLegacy::AArch64PostLegalizerLoweringLegacy()
1362 : MachineFunctionPass(ID) {
1363 if (!RuleConfig.parseCommandLineOption())
1364 report_fatal_error(reason: "Invalid rule identifier");
1365}
1366
1367bool AArch64PostLegalizerLoweringLegacy::runOnMachineFunction(
1368 MachineFunction &MF) {
1369 assert(MF.getProperties().hasLegalized() && "Expected a legalized function?");
1370 return runPostLegalizerLowering(MF, RuleConfig);
1371}
1372
1373char AArch64PostLegalizerLoweringLegacy::ID = 0;
1374INITIALIZE_PASS_BEGIN(AArch64PostLegalizerLoweringLegacy, DEBUG_TYPE,
1375 "Lower AArch64 MachineInstrs after legalization", false,
1376 false)
1377INITIALIZE_PASS_END(AArch64PostLegalizerLoweringLegacy, DEBUG_TYPE,
1378 "Lower AArch64 MachineInstrs after legalization", false,
1379 false)
1380
1381AArch64PostLegalizerLoweringPass::AArch64PostLegalizerLoweringPass()
1382 : RuleConfig(
1383 std::make_unique<AArch64PostLegalizerLoweringImplRuleConfig>()) {
1384 if (!RuleConfig->parseCommandLineOption())
1385 reportFatalUsageError(reason: "invalid rule identifier");
1386}
1387
1388AArch64PostLegalizerLoweringPass::AArch64PostLegalizerLoweringPass(
1389 AArch64PostLegalizerLoweringPass &&) = default;
1390
1391AArch64PostLegalizerLoweringPass::~AArch64PostLegalizerLoweringPass() = default;
1392
1393PreservedAnalyses
1394AArch64PostLegalizerLoweringPass::run(MachineFunction &MF,
1395 MachineFunctionAnalysisManager &MFAM) {
1396 MFPropsModifier _(*this, MF);
1397 const bool Changed = runPostLegalizerLowering(MF, RuleConfig: *RuleConfig);
1398
1399 if (!Changed)
1400 return PreservedAnalyses::all();
1401
1402 PreservedAnalyses PA = getMachineFunctionPassPreservedAnalyses();
1403 PA.preserveSet<CFGAnalyses>();
1404 return PA;
1405}
1406
1407namespace llvm {
1408FunctionPass *createAArch64PostLegalizerLowering() {
1409 return new AArch64PostLegalizerLoweringLegacy();
1410}
1411} // end namespace llvm
1412