1//===- CombinerHelperVectorOps.cpp-----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements CombinerHelper for G_EXTRACT_VECTOR_ELT,
10// G_INSERT_VECTOR_ELT, and G_VSCALE
11//
12//===----------------------------------------------------------------------===//
13#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
14#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
15#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
16#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
17#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
18#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19#include "llvm/CodeGen/GlobalISel/Utils.h"
20#include "llvm/CodeGen/LowLevelTypeUtils.h"
21#include "llvm/CodeGen/MachineOperand.h"
22#include "llvm/CodeGen/MachineRegisterInfo.h"
23#include "llvm/CodeGen/TargetLowering.h"
24#include "llvm/CodeGen/TargetOpcodes.h"
25#include "llvm/Support/Casting.h"
26#include <optional>
27
28#define DEBUG_TYPE "gi-combiner"
29
30using namespace llvm;
31using namespace MIPatternMatch;
32
33bool CombinerHelper::matchExtractVectorElement(MachineInstr &MI,
34 BuildFnTy &MatchInfo) {
35 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Val: &MI);
36
37 Register Dst = Extract->getReg(Idx: 0);
38 Register Vector = Extract->getVectorReg();
39 Register Index = Extract->getIndexReg();
40 LLT DstTy = MRI.getType(Reg: Dst);
41 LLT VectorTy = MRI.getType(Reg: Vector);
42
43 // The vector register can be def'd by various ops that have vector as its
44 // type. They can all be used for constant folding, scalarizing,
45 // canonicalization, or combining based on symmetry.
46 //
47 // vector like ops
48 // * build vector
49 // * build vector trunc
50 // * shuffle vector
51 // * splat vector
52 // * concat vectors
53 // * insert/extract vector element
54 // * insert/extract subvector
55 // * vector loads
56 // * scalable vector loads
57 //
58 // compute like ops
59 // * binary ops
60 // * unary ops
61 // * exts and truncs
62 // * casts
63 // * fneg
64 // * select
65 // * phis
66 // * cmps
67 // * freeze
68 // * bitcast
69 // * undef
70
71 // We try to get the value of the Index register.
72 std::optional<ValueAndVReg> MaybeIndex =
73 getIConstantVRegValWithLookThrough(VReg: Index, MRI);
74 std::optional<APInt> IndexC = std::nullopt;
75
76 if (MaybeIndex)
77 IndexC = MaybeIndex->Value;
78
79 // Fold extractVectorElement(Vector, TOOLARGE) -> undef
80 if (IndexC && VectorTy.isFixedVector() &&
81 IndexC->uge(RHS: VectorTy.getNumElements()) &&
82 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
83 // For fixed-length vectors, it's invalid to extract out-of-range elements.
84 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Res: Dst); };
85 return true;
86 }
87
88 return false;
89}
90
91bool CombinerHelper::matchExtractVectorElementWithDifferentIndices(
92 const MachineOperand &MO, BuildFnTy &MatchInfo) {
93 MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI);
94 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Val: Root);
95
96 //
97 // %idx1:_(s64) = G_CONSTANT i64 1
98 // %idx2:_(s64) = G_CONSTANT i64 2
99 // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>),
100 // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %insert(<2
101 // x s32>), %idx1(s64)
102 //
103 // -->
104 //
105 // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>),
106 // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x
107 // s32>), %idx1(s64)
108 //
109 //
110
111 Register Index = Extract->getIndexReg();
112
113 // We try to get the value of the Index register.
114 std::optional<ValueAndVReg> MaybeIndex =
115 getIConstantVRegValWithLookThrough(VReg: Index, MRI);
116 std::optional<APInt> IndexC = std::nullopt;
117
118 if (!MaybeIndex)
119 return false;
120 else
121 IndexC = MaybeIndex->Value;
122
123 Register Vector = Extract->getVectorReg();
124
125 GInsertVectorElement *Insert =
126 getOpcodeDef<GInsertVectorElement>(Reg: Vector, MRI);
127 if (!Insert)
128 return false;
129
130 Register Dst = Extract->getReg(Idx: 0);
131
132 std::optional<ValueAndVReg> MaybeInsertIndex =
133 getIConstantVRegValWithLookThrough(VReg: Insert->getIndexReg(), MRI);
134
135 if (MaybeInsertIndex && MaybeInsertIndex->Value != *IndexC) {
136 // There is no one-use check. We have to keep the insert. When both Index
137 // registers are constants and not equal, we can look into the Vector
138 // register of the insert.
139 MatchInfo = [=](MachineIRBuilder &B) {
140 B.buildExtractVectorElement(Res: Dst, Val: Insert->getVectorReg(), Idx: Index);
141 };
142 return true;
143 }
144
145 return false;
146}
147
148bool CombinerHelper::matchExtractVectorElementWithBuildVector(
149 const MachineOperand &MO, BuildFnTy &MatchInfo) {
150 MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI);
151 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Val: Root);
152
153 //
154 // %zero:_(s64) = G_CONSTANT i64 0
155 // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32)
156 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64)
157 //
158 // -->
159 //
160 // %extract:_(32) = COPY %arg1(s32)
161 //
162 //
163 //
164 // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32)
165 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
166 //
167 // -->
168 //
169 // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32)
170 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
171 //
172
173 Register Vector = Extract->getVectorReg();
174
175 // We expect a buildVector on the Vector register.
176 GBuildVector *Build = getOpcodeDef<GBuildVector>(Reg: Vector, MRI);
177 if (!Build)
178 return false;
179
180 LLT VectorTy = MRI.getType(Reg: Vector);
181
182 // There is a one-use check. There are more combines on build vectors.
183 EVT Ty(getMVTForLLT(Ty: VectorTy));
184 if (!MRI.hasOneNonDBGUse(RegNo: Build->getReg(Idx: 0)) ||
185 !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty))
186 return false;
187
188 Register Index = Extract->getIndexReg();
189
190 // If the Index is constant, then we can extract the element from the given
191 // offset.
192 std::optional<ValueAndVReg> MaybeIndex =
193 getIConstantVRegValWithLookThrough(VReg: Index, MRI);
194 if (!MaybeIndex)
195 return false;
196
197 // We now know that there is a buildVector def'd on the Vector register and
198 // the index is const. The combine will succeed.
199
200 Register Dst = Extract->getReg(Idx: 0);
201
202 MatchInfo = [=](MachineIRBuilder &B) {
203 B.buildCopy(Res: Dst, Op: Build->getSourceReg(I: MaybeIndex->Value.getZExtValue()));
204 };
205
206 return true;
207}
208
209bool CombinerHelper::matchExtractVectorElementWithBuildVectorTrunc(
210 const MachineOperand &MO, BuildFnTy &MatchInfo) {
211 MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI);
212 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Val: Root);
213
214 //
215 // %zero:_(s64) = G_CONSTANT i64 0
216 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
217 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64)
218 //
219 // -->
220 //
221 // %extract:_(32) = G_TRUNC %arg1(s64)
222 //
223 //
224 //
225 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
226 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
227 //
228 // -->
229 //
230 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
231 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
232 //
233
234 Register Vector = Extract->getVectorReg();
235
236 // We expect a buildVectorTrunc on the Vector register.
237 GBuildVectorTrunc *Build = getOpcodeDef<GBuildVectorTrunc>(Reg: Vector, MRI);
238 if (!Build)
239 return false;
240
241 LLT VectorTy = MRI.getType(Reg: Vector);
242
243 // There is a one-use check. There are more combines on build vectors.
244 EVT Ty(getMVTForLLT(Ty: VectorTy));
245 if (!MRI.hasOneNonDBGUse(RegNo: Build->getReg(Idx: 0)) ||
246 !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty))
247 return false;
248
249 Register Index = Extract->getIndexReg();
250
251 // If the Index is constant, then we can extract the element from the given
252 // offset.
253 std::optional<ValueAndVReg> MaybeIndex =
254 getIConstantVRegValWithLookThrough(VReg: Index, MRI);
255 if (!MaybeIndex)
256 return false;
257
258 // We now know that there is a buildVectorTrunc def'd on the Vector register
259 // and the index is const. The combine will succeed.
260
261 Register Dst = Extract->getReg(Idx: 0);
262 LLT DstTy = MRI.getType(Reg: Dst);
263 LLT SrcTy = MRI.getType(Reg: Build->getSourceReg(I: 0));
264
265 // For buildVectorTrunc, the inputs are truncated.
266 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {DstTy, SrcTy}}))
267 return false;
268
269 MatchInfo = [=](MachineIRBuilder &B) {
270 B.buildTrunc(Res: Dst, Op: Build->getSourceReg(I: MaybeIndex->Value.getZExtValue()));
271 };
272
273 return true;
274}
275
276bool CombinerHelper::matchExtractVectorElementWithShuffleVector(
277 const MachineOperand &MO, BuildFnTy &MatchInfo) {
278 GExtractVectorElement *Extract =
279 cast<GExtractVectorElement>(Val: getDefIgnoringCopies(Reg: MO.getReg(), MRI));
280
281 //
282 // %zero:_(s64) = G_CONSTANT i64 0
283 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
284 // shufflemask(0, 0, 0, 0)
285 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %zero(s64)
286 //
287 // -->
288 //
289 // %zero1:_(s64) = G_CONSTANT i64 0
290 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %arg1(<4 x s32>), %zero1(s64)
291 //
292 //
293 //
294 //
295 // %three:_(s64) = G_CONSTANT i64 3
296 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
297 // shufflemask(0, 0, 0, -1)
298 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %three(s64)
299 //
300 // -->
301 //
302 // %extract:_(s32) = G_IMPLICIT_DEF
303 //
304 //
305 //
306 //
307 //
308 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
309 // shufflemask(0, 0, 0, -1)
310 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %opaque(s64)
311 //
312 // -->
313 //
314 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
315 // shufflemask(0, 0, 0, -1)
316 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %opaque(s64)
317 //
318
319 // We try to get the value of the Index register.
320 std::optional<ValueAndVReg> MaybeIndex =
321 getIConstantVRegValWithLookThrough(VReg: Extract->getIndexReg(), MRI);
322 if (!MaybeIndex)
323 return false;
324
325 GShuffleVector *Shuffle =
326 cast<GShuffleVector>(Val: getDefIgnoringCopies(Reg: Extract->getVectorReg(), MRI));
327
328 ArrayRef<int> Mask = Shuffle->getMask();
329
330 unsigned Offset = MaybeIndex->Value.getZExtValue();
331 int SrcIdx = Mask[Offset];
332
333 LLT Src1Type = MRI.getType(Reg: Shuffle->getSrc1Reg());
334 // At the IR level a <1 x ty> shuffle vector is valid, but we want to extract
335 // from a vector.
336 assert(Src1Type.isVector() && "expected to extract from a vector");
337 unsigned LHSWidth = Src1Type.isVector() ? Src1Type.getNumElements() : 1;
338
339 // Note that there is no one use check.
340 Register Dst = Extract->getReg(Idx: 0);
341 LLT DstTy = MRI.getType(Reg: Dst);
342
343 if (SrcIdx < 0 &&
344 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
345 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Res: Dst); };
346 return true;
347 }
348
349 // If the legality check failed, then we still have to abort.
350 if (SrcIdx < 0)
351 return false;
352
353 Register NewVector;
354
355 // We check in which vector and at what offset to look through.
356 if (SrcIdx < (int)LHSWidth) {
357 NewVector = Shuffle->getSrc1Reg();
358 // SrcIdx unchanged
359 } else { // SrcIdx >= LHSWidth
360 NewVector = Shuffle->getSrc2Reg();
361 SrcIdx -= LHSWidth;
362 }
363
364 LLT IdxTy = MRI.getType(Reg: Extract->getIndexReg());
365 LLT NewVectorTy = MRI.getType(Reg: NewVector);
366
367 // We check the legality of the look through.
368 if (!isLegalOrBeforeLegalizer(
369 Query: {TargetOpcode::G_EXTRACT_VECTOR_ELT, {DstTy, NewVectorTy, IdxTy}}) ||
370 !isConstantLegalOrBeforeLegalizer(Ty: {IdxTy}))
371 return false;
372
373 // We look through the shuffle vector.
374 MatchInfo = [=](MachineIRBuilder &B) {
375 auto Idx = B.buildConstant(Res: IdxTy, Val: SrcIdx);
376 B.buildExtractVectorElement(Res: Dst, Val: NewVector, Idx);
377 };
378
379 return true;
380}
381
382bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI,
383 BuildFnTy &MatchInfo) {
384 GInsertVectorElement *Insert = cast<GInsertVectorElement>(Val: &MI);
385
386 Register Dst = Insert->getReg(Idx: 0);
387 LLT DstTy = MRI.getType(Reg: Dst);
388 Register Index = Insert->getIndexReg();
389
390 if (!DstTy.isFixedVector())
391 return false;
392
393 std::optional<ValueAndVReg> MaybeIndex =
394 getIConstantVRegValWithLookThrough(VReg: Index, MRI);
395
396 if (MaybeIndex && MaybeIndex->Value.uge(RHS: DstTy.getNumElements()) &&
397 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
398 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Res: Dst); };
399 return true;
400 }
401
402 return false;
403}
404
405bool CombinerHelper::matchAddOfVScale(const MachineOperand &MO,
406 BuildFnTy &MatchInfo) {
407 GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: MO.getReg()));
408 GVScale *LHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Add->getLHSReg()));
409 GVScale *RHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Add->getRHSReg()));
410
411 Register Dst = Add->getReg(Idx: 0);
412
413 if (!MRI.hasOneNonDBGUse(RegNo: LHSVScale->getReg(Idx: 0)) ||
414 !MRI.hasOneNonDBGUse(RegNo: RHSVScale->getReg(Idx: 0)))
415 return false;
416
417 MatchInfo = [=](MachineIRBuilder &B) {
418 B.buildVScale(Res: Dst, MinElts: LHSVScale->getSrc() + RHSVScale->getSrc());
419 };
420
421 return true;
422}
423
424bool CombinerHelper::matchMulOfVScale(const MachineOperand &MO,
425 BuildFnTy &MatchInfo) {
426 GMul *Mul = cast<GMul>(Val: MRI.getVRegDef(Reg: MO.getReg()));
427 GVScale *LHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Mul->getLHSReg()));
428
429 std::optional<APInt> MaybeRHS = getIConstantVRegVal(VReg: Mul->getRHSReg(), MRI);
430 if (!MaybeRHS)
431 return false;
432
433 Register Dst = MO.getReg();
434
435 if (!MRI.hasOneNonDBGUse(RegNo: LHSVScale->getReg(Idx: 0)))
436 return false;
437
438 MatchInfo = [=](MachineIRBuilder &B) {
439 B.buildVScale(Res: Dst, MinElts: LHSVScale->getSrc() * *MaybeRHS);
440 };
441
442 return true;
443}
444
445bool CombinerHelper::matchSubOfVScale(const MachineOperand &MO,
446 BuildFnTy &MatchInfo) {
447 GSub *Sub = cast<GSub>(Val: MRI.getVRegDef(Reg: MO.getReg()));
448 GVScale *RHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Sub->getRHSReg()));
449
450 Register Dst = MO.getReg();
451 LLT DstTy = MRI.getType(Reg: Dst);
452
453 if (!MRI.hasOneNonDBGUse(RegNo: RHSVScale->getReg(Idx: 0)) ||
454 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, DstTy}))
455 return false;
456
457 MatchInfo = [=](MachineIRBuilder &B) {
458 auto VScale = B.buildVScale(Res: DstTy, MinElts: -RHSVScale->getSrc());
459 B.buildAdd(Dst, Src0: Sub->getLHSReg(), Src1: VScale, Flags: Sub->getFlags());
460 };
461
462 return true;
463}
464
465bool CombinerHelper::matchShlOfVScale(const MachineOperand &MO,
466 BuildFnTy &MatchInfo) {
467 GShl *Shl = cast<GShl>(Val: MRI.getVRegDef(Reg: MO.getReg()));
468 GVScale *LHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Shl->getSrcReg()));
469
470 std::optional<APInt> MaybeRHS = getIConstantVRegVal(VReg: Shl->getShiftReg(), MRI);
471 if (!MaybeRHS)
472 return false;
473
474 Register Dst = MO.getReg();
475 LLT DstTy = MRI.getType(Reg: Dst);
476
477 if (!MRI.hasOneNonDBGUse(RegNo: LHSVScale->getReg(Idx: 0)) ||
478 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_VSCALE, DstTy}))
479 return false;
480
481 MatchInfo = [=](MachineIRBuilder &B) {
482 B.buildVScale(Res: Dst, MinElts: LHSVScale->getSrc().shl(ShiftAmt: *MaybeRHS));
483 };
484
485 return true;
486}
487