1//===- CombinerHelperVectorOps.cpp-----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements CombinerHelper for G_EXTRACT_VECTOR_ELT,
10// G_INSERT_VECTOR_ELT, and G_VSCALE
11//
12//===----------------------------------------------------------------------===//
13#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
14#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
15#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
16#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
17#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
18#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19#include "llvm/CodeGen/GlobalISel/Utils.h"
20#include "llvm/CodeGen/LowLevelTypeUtils.h"
21#include "llvm/CodeGen/MachineOperand.h"
22#include "llvm/CodeGen/MachineRegisterInfo.h"
23#include "llvm/CodeGen/TargetLowering.h"
24#include "llvm/CodeGen/TargetOpcodes.h"
25#include "llvm/Support/Casting.h"
26#include <optional>
27
28#define DEBUG_TYPE "gi-combiner"
29
30using namespace llvm;
31using namespace MIPatternMatch;
32
33bool CombinerHelper::matchExtractVectorElement(MachineInstr &MI,
34 BuildFnTy &MatchInfo) const {
35 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Val: &MI);
36
37 Register Dst = Extract->getReg(Idx: 0);
38 Register Vector = Extract->getVectorReg();
39 Register Index = Extract->getIndexReg();
40 LLT DstTy = MRI.getType(Reg: Dst);
41 LLT VectorTy = MRI.getType(Reg: Vector);
42
43 // The vector register can be def'd by various ops that have vector as its
44 // type. They can all be used for constant folding, scalarizing,
45 // canonicalization, or combining based on symmetry.
46 //
47 // vector like ops
48 // * build vector
49 // * build vector trunc
50 // * shuffle vector
51 // * splat vector
52 // * concat vectors
53 // * insert/extract vector element
54 // * insert/extract subvector
55 // * vector loads
56 // * scalable vector loads
57 //
58 // compute like ops
59 // * binary ops
60 // * unary ops
61 // * exts and truncs
62 // * casts
63 // * fneg
64 // * select
65 // * phis
66 // * cmps
67 // * freeze
68 // * bitcast
69 // * undef
70
71 // We try to get the value of the Index register.
72 std::optional<ValueAndVReg> MaybeIndex =
73 getIConstantVRegValWithLookThrough(VReg: Index, MRI);
74 std::optional<APInt> IndexC = std::nullopt;
75
76 if (MaybeIndex)
77 IndexC = MaybeIndex->Value;
78
79 // Fold extractVectorElement(Vector, TOOLARGE) -> undef
80 if (IndexC && VectorTy.isFixedVector() &&
81 IndexC->uge(RHS: VectorTy.getNumElements()) &&
82 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
83 // For fixed-length vectors, it's invalid to extract out-of-range elements.
84 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Res: Dst); };
85 return true;
86 }
87
88 return false;
89}
90
91bool CombinerHelper::matchExtractVectorElementWithDifferentIndices(
92 const MachineOperand &MO, BuildFnTy &MatchInfo) const {
93 MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI);
94 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Val: Root);
95
96 //
97 // %idx1:_(s64) = G_CONSTANT i64 1
98 // %idx2:_(s64) = G_CONSTANT i64 2
99 // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>),
100 // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %insert(<2
101 // x s32>), %idx1(s64)
102 //
103 // -->
104 //
105 // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>),
106 // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x
107 // s32>), %idx1(s64)
108 //
109 //
110
111 Register Index = Extract->getIndexReg();
112
113 // We try to get the value of the Index register.
114 std::optional<ValueAndVReg> MaybeIndex =
115 getIConstantVRegValWithLookThrough(VReg: Index, MRI);
116 std::optional<APInt> IndexC = std::nullopt;
117
118 if (!MaybeIndex)
119 return false;
120 else
121 IndexC = MaybeIndex->Value;
122
123 Register Vector = Extract->getVectorReg();
124
125 GInsertVectorElement *Insert =
126 getOpcodeDef<GInsertVectorElement>(Reg: Vector, MRI);
127 if (!Insert)
128 return false;
129
130 Register Dst = Extract->getReg(Idx: 0);
131
132 std::optional<ValueAndVReg> MaybeInsertIndex =
133 getIConstantVRegValWithLookThrough(VReg: Insert->getIndexReg(), MRI);
134
135 if (MaybeInsertIndex && MaybeInsertIndex->Value != *IndexC) {
136 // There is no one-use check. We have to keep the insert. When both Index
137 // registers are constants and not equal, we can look into the Vector
138 // register of the insert.
139 MatchInfo = [=](MachineIRBuilder &B) {
140 B.buildExtractVectorElement(Res: Dst, Val: Insert->getVectorReg(), Idx: Index);
141 };
142 return true;
143 }
144
145 return false;
146}
147
148bool CombinerHelper::matchExtractVectorElementWithBuildVector(
149 const MachineInstr &MI, const MachineInstr &MI2,
150 BuildFnTy &MatchInfo) const {
151 const GExtractVectorElement *Extract = cast<GExtractVectorElement>(Val: &MI);
152 const GBuildVector *Build = cast<GBuildVector>(Val: &MI2);
153
154 //
155 // %zero:_(s64) = G_CONSTANT i64 0
156 // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32)
157 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64)
158 //
159 // -->
160 //
161 // %extract:_(32) = COPY %arg1(s32)
162 //
163 //
164
165 Register Vector = Extract->getVectorReg();
166 LLT VectorTy = MRI.getType(Reg: Vector);
167
168 // There is a one-use check. There are more combines on build vectors.
169 EVT Ty(getMVTForLLT(Ty: VectorTy));
170 if (!MRI.hasOneNonDBGUse(RegNo: Build->getReg(Idx: 0)) ||
171 !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty))
172 return false;
173
174 APInt Index = getIConstantFromReg(VReg: Extract->getIndexReg(), MRI);
175
176 // We now know that there is a buildVector def'd on the Vector register and
177 // the index is const. The combine will succeed.
178
179 Register Dst = Extract->getReg(Idx: 0);
180
181 MatchInfo = [=](MachineIRBuilder &B) {
182 B.buildCopy(Res: Dst, Op: Build->getSourceReg(I: Index.getZExtValue()));
183 };
184
185 return true;
186}
187
188bool CombinerHelper::matchExtractVectorElementWithBuildVectorTrunc(
189 const MachineOperand &MO, BuildFnTy &MatchInfo) const {
190 MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI);
191 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Val: Root);
192
193 //
194 // %zero:_(s64) = G_CONSTANT i64 0
195 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
196 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64)
197 //
198 // -->
199 //
200 // %extract:_(32) = G_TRUNC %arg1(s64)
201 //
202 //
203 //
204 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
205 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
206 //
207 // -->
208 //
209 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
210 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
211 //
212
213 Register Vector = Extract->getVectorReg();
214
215 // We expect a buildVectorTrunc on the Vector register.
216 GBuildVectorTrunc *Build = getOpcodeDef<GBuildVectorTrunc>(Reg: Vector, MRI);
217 if (!Build)
218 return false;
219
220 LLT VectorTy = MRI.getType(Reg: Vector);
221
222 // There is a one-use check. There are more combines on build vectors.
223 EVT Ty(getMVTForLLT(Ty: VectorTy));
224 if (!MRI.hasOneNonDBGUse(RegNo: Build->getReg(Idx: 0)) ||
225 !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty))
226 return false;
227
228 Register Index = Extract->getIndexReg();
229
230 // If the Index is constant, then we can extract the element from the given
231 // offset.
232 std::optional<ValueAndVReg> MaybeIndex =
233 getIConstantVRegValWithLookThrough(VReg: Index, MRI);
234 if (!MaybeIndex)
235 return false;
236
237 // We now know that there is a buildVectorTrunc def'd on the Vector register
238 // and the index is const. The combine will succeed.
239
240 Register Dst = Extract->getReg(Idx: 0);
241 LLT DstTy = MRI.getType(Reg: Dst);
242 LLT SrcTy = MRI.getType(Reg: Build->getSourceReg(I: 0));
243
244 // For buildVectorTrunc, the inputs are truncated.
245 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {DstTy, SrcTy}}))
246 return false;
247
248 MatchInfo = [=](MachineIRBuilder &B) {
249 B.buildTrunc(Res: Dst, Op: Build->getSourceReg(I: MaybeIndex->Value.getZExtValue()));
250 };
251
252 return true;
253}
254
255bool CombinerHelper::matchExtractVectorElementWithShuffleVector(
256 const MachineInstr &MI, const MachineInstr &MI2,
257 BuildFnTy &MatchInfo) const {
258 const GExtractVectorElement *Extract = cast<GExtractVectorElement>(Val: &MI);
259 const GShuffleVector *Shuffle = cast<GShuffleVector>(Val: &MI2);
260
261 //
262 // %zero:_(s64) = G_CONSTANT i64 0
263 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
264 // shufflemask(0, 0, 0, 0)
265 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %zero(s64)
266 //
267 // -->
268 //
269 // %zero1:_(s64) = G_CONSTANT i64 0
270 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %arg1(<4 x s32>), %zero1(s64)
271 //
272 //
273 //
274 //
275 // %three:_(s64) = G_CONSTANT i64 3
276 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
277 // shufflemask(0, 0, 0, -1)
278 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %three(s64)
279 //
280 // -->
281 //
282 // %extract:_(s32) = G_IMPLICIT_DEF
283 //
284 //
285
286 APInt Index = getIConstantFromReg(VReg: Extract->getIndexReg(), MRI);
287
288 ArrayRef<int> Mask = Shuffle->getMask();
289
290 unsigned Offset = Index.getZExtValue();
291 int SrcIdx = Mask[Offset];
292
293 LLT Src1Type = MRI.getType(Reg: Shuffle->getSrc1Reg());
294 // At the IR level a <1 x ty> shuffle vector is valid, but we want to extract
295 // from a vector.
296 assert(Src1Type.isVector() && "expected to extract from a vector");
297 unsigned LHSWidth = Src1Type.isVector() ? Src1Type.getNumElements() : 1;
298
299 // Note that there is no one use check.
300 Register Dst = Extract->getReg(Idx: 0);
301 LLT DstTy = MRI.getType(Reg: Dst);
302
303 if (SrcIdx < 0 &&
304 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
305 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Res: Dst); };
306 return true;
307 }
308
309 // If the legality check failed, then we still have to abort.
310 if (SrcIdx < 0)
311 return false;
312
313 Register NewVector;
314
315 // We check in which vector and at what offset to look through.
316 if (SrcIdx < (int)LHSWidth) {
317 NewVector = Shuffle->getSrc1Reg();
318 // SrcIdx unchanged
319 } else { // SrcIdx >= LHSWidth
320 NewVector = Shuffle->getSrc2Reg();
321 SrcIdx -= LHSWidth;
322 }
323
324 LLT IdxTy = MRI.getType(Reg: Extract->getIndexReg());
325 LLT NewVectorTy = MRI.getType(Reg: NewVector);
326
327 // We check the legality of the look through.
328 if (!isLegalOrBeforeLegalizer(
329 Query: {TargetOpcode::G_EXTRACT_VECTOR_ELT, {DstTy, NewVectorTy, IdxTy}}) ||
330 !isConstantLegalOrBeforeLegalizer(Ty: {IdxTy}))
331 return false;
332
333 // We look through the shuffle vector.
334 MatchInfo = [=](MachineIRBuilder &B) {
335 auto Idx = B.buildConstant(Res: IdxTy, Val: SrcIdx);
336 B.buildExtractVectorElement(Res: Dst, Val: NewVector, Idx);
337 };
338
339 return true;
340}
341
342bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI,
343 BuildFnTy &MatchInfo) const {
344 GInsertVectorElement *Insert = cast<GInsertVectorElement>(Val: &MI);
345
346 Register Dst = Insert->getReg(Idx: 0);
347 LLT DstTy = MRI.getType(Reg: Dst);
348 Register Index = Insert->getIndexReg();
349
350 if (!DstTy.isFixedVector())
351 return false;
352
353 std::optional<ValueAndVReg> MaybeIndex =
354 getIConstantVRegValWithLookThrough(VReg: Index, MRI);
355
356 if (MaybeIndex && MaybeIndex->Value.uge(RHS: DstTy.getNumElements()) &&
357 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
358 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Res: Dst); };
359 return true;
360 }
361
362 return false;
363}
364
365bool CombinerHelper::matchAddOfVScale(const MachineOperand &MO,
366 BuildFnTy &MatchInfo) const {
367 GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: MO.getReg()));
368 GVScale *LHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Add->getLHSReg()));
369 GVScale *RHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Add->getRHSReg()));
370
371 Register Dst = Add->getReg(Idx: 0);
372
373 if (!MRI.hasOneNonDBGUse(RegNo: LHSVScale->getReg(Idx: 0)) ||
374 !MRI.hasOneNonDBGUse(RegNo: RHSVScale->getReg(Idx: 0)))
375 return false;
376
377 MatchInfo = [=](MachineIRBuilder &B) {
378 B.buildVScale(Res: Dst, MinElts: LHSVScale->getSrc() + RHSVScale->getSrc());
379 };
380
381 return true;
382}
383
384bool CombinerHelper::matchMulOfVScale(const MachineOperand &MO,
385 BuildFnTy &MatchInfo) const {
386 GMul *Mul = cast<GMul>(Val: MRI.getVRegDef(Reg: MO.getReg()));
387 GVScale *LHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Mul->getLHSReg()));
388
389 std::optional<APInt> MaybeRHS = getIConstantVRegVal(VReg: Mul->getRHSReg(), MRI);
390 if (!MaybeRHS)
391 return false;
392
393 Register Dst = MO.getReg();
394
395 if (!MRI.hasOneNonDBGUse(RegNo: LHSVScale->getReg(Idx: 0)))
396 return false;
397
398 MatchInfo = [=](MachineIRBuilder &B) {
399 B.buildVScale(Res: Dst, MinElts: LHSVScale->getSrc() * *MaybeRHS);
400 };
401
402 return true;
403}
404
405bool CombinerHelper::matchSubOfVScale(const MachineOperand &MO,
406 BuildFnTy &MatchInfo) const {
407 GSub *Sub = cast<GSub>(Val: MRI.getVRegDef(Reg: MO.getReg()));
408 GVScale *RHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Sub->getRHSReg()));
409
410 Register Dst = MO.getReg();
411 LLT DstTy = MRI.getType(Reg: Dst);
412
413 if (!MRI.hasOneNonDBGUse(RegNo: RHSVScale->getReg(Idx: 0)) ||
414 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, DstTy}))
415 return false;
416
417 MatchInfo = [=](MachineIRBuilder &B) {
418 auto VScale = B.buildVScale(Res: DstTy, MinElts: -RHSVScale->getSrc());
419 B.buildAdd(Dst, Src0: Sub->getLHSReg(), Src1: VScale, Flags: Sub->getFlags());
420 };
421
422 return true;
423}
424
425bool CombinerHelper::matchShlOfVScale(const MachineOperand &MO,
426 BuildFnTy &MatchInfo) const {
427 GShl *Shl = cast<GShl>(Val: MRI.getVRegDef(Reg: MO.getReg()));
428 GVScale *LHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Shl->getSrcReg()));
429
430 std::optional<APInt> MaybeRHS = getIConstantVRegVal(VReg: Shl->getShiftReg(), MRI);
431 if (!MaybeRHS)
432 return false;
433
434 Register Dst = MO.getReg();
435 LLT DstTy = MRI.getType(Reg: Dst);
436
437 if (!MRI.hasOneNonDBGUse(RegNo: LHSVScale->getReg(Idx: 0)) ||
438 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_VSCALE, DstTy}))
439 return false;
440
441 MatchInfo = [=](MachineIRBuilder &B) {
442 B.buildVScale(Res: Dst, MinElts: LHSVScale->getSrc().shl(ShiftAmt: *MaybeRHS));
443 };
444
445 return true;
446}
447