1 | //===- CombinerHelperVectorOps.cpp-----------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements CombinerHelper for G_EXTRACT_VECTOR_ELT, |
10 | // G_INSERT_VECTOR_ELT, and G_VSCALE |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" |
14 | #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" |
15 | #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" |
16 | #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" |
17 | #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" |
18 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
19 | #include "llvm/CodeGen/GlobalISel/Utils.h" |
20 | #include "llvm/CodeGen/LowLevelTypeUtils.h" |
21 | #include "llvm/CodeGen/MachineOperand.h" |
22 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
23 | #include "llvm/CodeGen/TargetLowering.h" |
24 | #include "llvm/CodeGen/TargetOpcodes.h" |
25 | #include "llvm/Support/Casting.h" |
26 | #include <optional> |
27 | |
28 | #define DEBUG_TYPE "gi-combiner" |
29 | |
30 | using namespace llvm; |
31 | using namespace MIPatternMatch; |
32 | |
33 | bool CombinerHelper::(MachineInstr &MI, |
34 | BuildFnTy &MatchInfo) { |
35 | GExtractVectorElement * = cast<GExtractVectorElement>(Val: &MI); |
36 | |
37 | Register Dst = Extract->getReg(Idx: 0); |
38 | Register Vector = Extract->getVectorReg(); |
39 | Register Index = Extract->getIndexReg(); |
40 | LLT DstTy = MRI.getType(Reg: Dst); |
41 | LLT VectorTy = MRI.getType(Reg: Vector); |
42 | |
43 | // The vector register can be def'd by various ops that have vector as its |
44 | // type. They can all be used for constant folding, scalarizing, |
45 | // canonicalization, or combining based on symmetry. |
46 | // |
47 | // vector like ops |
48 | // * build vector |
49 | // * build vector trunc |
50 | // * shuffle vector |
51 | // * splat vector |
52 | // * concat vectors |
53 | // * insert/extract vector element |
54 | // * insert/extract subvector |
55 | // * vector loads |
56 | // * scalable vector loads |
57 | // |
58 | // compute like ops |
59 | // * binary ops |
60 | // * unary ops |
61 | // * exts and truncs |
62 | // * casts |
63 | // * fneg |
64 | // * select |
65 | // * phis |
66 | // * cmps |
67 | // * freeze |
68 | // * bitcast |
69 | // * undef |
70 | |
71 | // We try to get the value of the Index register. |
72 | std::optional<ValueAndVReg> MaybeIndex = |
73 | getIConstantVRegValWithLookThrough(VReg: Index, MRI); |
74 | std::optional<APInt> IndexC = std::nullopt; |
75 | |
76 | if (MaybeIndex) |
77 | IndexC = MaybeIndex->Value; |
78 | |
79 | // Fold extractVectorElement(Vector, TOOLARGE) -> undef |
80 | if (IndexC && VectorTy.isFixedVector() && |
81 | IndexC->uge(RHS: VectorTy.getNumElements()) && |
82 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) { |
83 | // For fixed-length vectors, it's invalid to extract out-of-range elements. |
84 | MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Res: Dst); }; |
85 | return true; |
86 | } |
87 | |
88 | return false; |
89 | } |
90 | |
91 | bool CombinerHelper::matchExtractVectorElementWithDifferentIndices( |
92 | const MachineOperand &MO, BuildFnTy &MatchInfo) { |
93 | MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI); |
94 | GExtractVectorElement * = cast<GExtractVectorElement>(Val: Root); |
95 | |
96 | // |
97 | // %idx1:_(s64) = G_CONSTANT i64 1 |
98 | // %idx2:_(s64) = G_CONSTANT i64 2 |
99 | // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>), |
100 | // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %insert(<2 |
101 | // x s32>), %idx1(s64) |
102 | // |
103 | // --> |
104 | // |
105 | // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>), |
106 | // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x |
107 | // s32>), %idx1(s64) |
108 | // |
109 | // |
110 | |
111 | Register Index = Extract->getIndexReg(); |
112 | |
113 | // We try to get the value of the Index register. |
114 | std::optional<ValueAndVReg> MaybeIndex = |
115 | getIConstantVRegValWithLookThrough(VReg: Index, MRI); |
116 | std::optional<APInt> IndexC = std::nullopt; |
117 | |
118 | if (!MaybeIndex) |
119 | return false; |
120 | else |
121 | IndexC = MaybeIndex->Value; |
122 | |
123 | Register Vector = Extract->getVectorReg(); |
124 | |
125 | GInsertVectorElement *Insert = |
126 | getOpcodeDef<GInsertVectorElement>(Reg: Vector, MRI); |
127 | if (!Insert) |
128 | return false; |
129 | |
130 | Register Dst = Extract->getReg(Idx: 0); |
131 | |
132 | std::optional<ValueAndVReg> MaybeInsertIndex = |
133 | getIConstantVRegValWithLookThrough(VReg: Insert->getIndexReg(), MRI); |
134 | |
135 | if (MaybeInsertIndex && MaybeInsertIndex->Value != *IndexC) { |
136 | // There is no one-use check. We have to keep the insert. When both Index |
137 | // registers are constants and not equal, we can look into the Vector |
138 | // register of the insert. |
139 | MatchInfo = [=](MachineIRBuilder &B) { |
140 | B.buildExtractVectorElement(Res: Dst, Val: Insert->getVectorReg(), Idx: Index); |
141 | }; |
142 | return true; |
143 | } |
144 | |
145 | return false; |
146 | } |
147 | |
148 | bool CombinerHelper::matchExtractVectorElementWithBuildVector( |
149 | const MachineOperand &MO, BuildFnTy &MatchInfo) { |
150 | MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI); |
151 | GExtractVectorElement * = cast<GExtractVectorElement>(Val: Root); |
152 | |
153 | // |
154 | // %zero:_(s64) = G_CONSTANT i64 0 |
155 | // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32) |
156 | // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64) |
157 | // |
158 | // --> |
159 | // |
160 | // %extract:_(32) = COPY %arg1(s32) |
161 | // |
162 | // |
163 | // |
164 | // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32) |
165 | // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64) |
166 | // |
167 | // --> |
168 | // |
169 | // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32) |
170 | // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64) |
171 | // |
172 | |
173 | Register Vector = Extract->getVectorReg(); |
174 | |
175 | // We expect a buildVector on the Vector register. |
176 | GBuildVector *Build = getOpcodeDef<GBuildVector>(Reg: Vector, MRI); |
177 | if (!Build) |
178 | return false; |
179 | |
180 | LLT VectorTy = MRI.getType(Reg: Vector); |
181 | |
182 | // There is a one-use check. There are more combines on build vectors. |
183 | EVT Ty(getMVTForLLT(Ty: VectorTy)); |
184 | if (!MRI.hasOneNonDBGUse(RegNo: Build->getReg(Idx: 0)) || |
185 | !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty)) |
186 | return false; |
187 | |
188 | Register Index = Extract->getIndexReg(); |
189 | |
190 | // If the Index is constant, then we can extract the element from the given |
191 | // offset. |
192 | std::optional<ValueAndVReg> MaybeIndex = |
193 | getIConstantVRegValWithLookThrough(VReg: Index, MRI); |
194 | if (!MaybeIndex) |
195 | return false; |
196 | |
197 | // We now know that there is a buildVector def'd on the Vector register and |
198 | // the index is const. The combine will succeed. |
199 | |
200 | Register Dst = Extract->getReg(Idx: 0); |
201 | |
202 | MatchInfo = [=](MachineIRBuilder &B) { |
203 | B.buildCopy(Res: Dst, Op: Build->getSourceReg(I: MaybeIndex->Value.getZExtValue())); |
204 | }; |
205 | |
206 | return true; |
207 | } |
208 | |
209 | bool CombinerHelper::matchExtractVectorElementWithBuildVectorTrunc( |
210 | const MachineOperand &MO, BuildFnTy &MatchInfo) { |
211 | MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI); |
212 | GExtractVectorElement * = cast<GExtractVectorElement>(Val: Root); |
213 | |
214 | // |
215 | // %zero:_(s64) = G_CONSTANT i64 0 |
216 | // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64) |
217 | // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64) |
218 | // |
219 | // --> |
220 | // |
221 | // %extract:_(32) = G_TRUNC %arg1(s64) |
222 | // |
223 | // |
224 | // |
225 | // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64) |
226 | // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64) |
227 | // |
228 | // --> |
229 | // |
230 | // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64) |
231 | // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64) |
232 | // |
233 | |
234 | Register Vector = Extract->getVectorReg(); |
235 | |
236 | // We expect a buildVectorTrunc on the Vector register. |
237 | GBuildVectorTrunc *Build = getOpcodeDef<GBuildVectorTrunc>(Reg: Vector, MRI); |
238 | if (!Build) |
239 | return false; |
240 | |
241 | LLT VectorTy = MRI.getType(Reg: Vector); |
242 | |
243 | // There is a one-use check. There are more combines on build vectors. |
244 | EVT Ty(getMVTForLLT(Ty: VectorTy)); |
245 | if (!MRI.hasOneNonDBGUse(RegNo: Build->getReg(Idx: 0)) || |
246 | !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty)) |
247 | return false; |
248 | |
249 | Register Index = Extract->getIndexReg(); |
250 | |
251 | // If the Index is constant, then we can extract the element from the given |
252 | // offset. |
253 | std::optional<ValueAndVReg> MaybeIndex = |
254 | getIConstantVRegValWithLookThrough(VReg: Index, MRI); |
255 | if (!MaybeIndex) |
256 | return false; |
257 | |
258 | // We now know that there is a buildVectorTrunc def'd on the Vector register |
259 | // and the index is const. The combine will succeed. |
260 | |
261 | Register Dst = Extract->getReg(Idx: 0); |
262 | LLT DstTy = MRI.getType(Reg: Dst); |
263 | LLT SrcTy = MRI.getType(Reg: Build->getSourceReg(I: 0)); |
264 | |
265 | // For buildVectorTrunc, the inputs are truncated. |
266 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) |
267 | return false; |
268 | |
269 | MatchInfo = [=](MachineIRBuilder &B) { |
270 | B.buildTrunc(Res: Dst, Op: Build->getSourceReg(I: MaybeIndex->Value.getZExtValue())); |
271 | }; |
272 | |
273 | return true; |
274 | } |
275 | |
276 | bool CombinerHelper::matchExtractVectorElementWithShuffleVector( |
277 | const MachineOperand &MO, BuildFnTy &MatchInfo) { |
278 | GExtractVectorElement * = |
279 | cast<GExtractVectorElement>(Val: getDefIgnoringCopies(Reg: MO.getReg(), MRI)); |
280 | |
281 | // |
282 | // %zero:_(s64) = G_CONSTANT i64 0 |
283 | // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>), |
284 | // shufflemask(0, 0, 0, 0) |
285 | // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %zero(s64) |
286 | // |
287 | // --> |
288 | // |
289 | // %zero1:_(s64) = G_CONSTANT i64 0 |
290 | // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %arg1(<4 x s32>), %zero1(s64) |
291 | // |
292 | // |
293 | // |
294 | // |
295 | // %three:_(s64) = G_CONSTANT i64 3 |
296 | // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>), |
297 | // shufflemask(0, 0, 0, -1) |
298 | // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %three(s64) |
299 | // |
300 | // --> |
301 | // |
302 | // %extract:_(s32) = G_IMPLICIT_DEF |
303 | // |
304 | // |
305 | // |
306 | // |
307 | // |
308 | // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>), |
309 | // shufflemask(0, 0, 0, -1) |
310 | // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %opaque(s64) |
311 | // |
312 | // --> |
313 | // |
314 | // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>), |
315 | // shufflemask(0, 0, 0, -1) |
316 | // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %opaque(s64) |
317 | // |
318 | |
319 | // We try to get the value of the Index register. |
320 | std::optional<ValueAndVReg> MaybeIndex = |
321 | getIConstantVRegValWithLookThrough(VReg: Extract->getIndexReg(), MRI); |
322 | if (!MaybeIndex) |
323 | return false; |
324 | |
325 | GShuffleVector *Shuffle = |
326 | cast<GShuffleVector>(Val: getDefIgnoringCopies(Reg: Extract->getVectorReg(), MRI)); |
327 | |
328 | ArrayRef<int> Mask = Shuffle->getMask(); |
329 | |
330 | unsigned Offset = MaybeIndex->Value.getZExtValue(); |
331 | int SrcIdx = Mask[Offset]; |
332 | |
333 | LLT Src1Type = MRI.getType(Reg: Shuffle->getSrc1Reg()); |
334 | // At the IR level a <1 x ty> shuffle vector is valid, but we want to extract |
335 | // from a vector. |
336 | assert(Src1Type.isVector() && "expected to extract from a vector" ); |
337 | unsigned LHSWidth = Src1Type.isVector() ? Src1Type.getNumElements() : 1; |
338 | |
339 | // Note that there is no one use check. |
340 | Register Dst = Extract->getReg(Idx: 0); |
341 | LLT DstTy = MRI.getType(Reg: Dst); |
342 | |
343 | if (SrcIdx < 0 && |
344 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) { |
345 | MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Res: Dst); }; |
346 | return true; |
347 | } |
348 | |
349 | // If the legality check failed, then we still have to abort. |
350 | if (SrcIdx < 0) |
351 | return false; |
352 | |
353 | Register NewVector; |
354 | |
355 | // We check in which vector and at what offset to look through. |
356 | if (SrcIdx < (int)LHSWidth) { |
357 | NewVector = Shuffle->getSrc1Reg(); |
358 | // SrcIdx unchanged |
359 | } else { // SrcIdx >= LHSWidth |
360 | NewVector = Shuffle->getSrc2Reg(); |
361 | SrcIdx -= LHSWidth; |
362 | } |
363 | |
364 | LLT IdxTy = MRI.getType(Reg: Extract->getIndexReg()); |
365 | LLT NewVectorTy = MRI.getType(Reg: NewVector); |
366 | |
367 | // We check the legality of the look through. |
368 | if (!isLegalOrBeforeLegalizer( |
369 | Query: {TargetOpcode::G_EXTRACT_VECTOR_ELT, {DstTy, NewVectorTy, IdxTy}}) || |
370 | !isConstantLegalOrBeforeLegalizer(Ty: {IdxTy})) |
371 | return false; |
372 | |
373 | // We look through the shuffle vector. |
374 | MatchInfo = [=](MachineIRBuilder &B) { |
375 | auto Idx = B.buildConstant(Res: IdxTy, Val: SrcIdx); |
376 | B.buildExtractVectorElement(Res: Dst, Val: NewVector, Idx); |
377 | }; |
378 | |
379 | return true; |
380 | } |
381 | |
382 | bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI, |
383 | BuildFnTy &MatchInfo) { |
384 | GInsertVectorElement *Insert = cast<GInsertVectorElement>(Val: &MI); |
385 | |
386 | Register Dst = Insert->getReg(Idx: 0); |
387 | LLT DstTy = MRI.getType(Reg: Dst); |
388 | Register Index = Insert->getIndexReg(); |
389 | |
390 | if (!DstTy.isFixedVector()) |
391 | return false; |
392 | |
393 | std::optional<ValueAndVReg> MaybeIndex = |
394 | getIConstantVRegValWithLookThrough(VReg: Index, MRI); |
395 | |
396 | if (MaybeIndex && MaybeIndex->Value.uge(RHS: DstTy.getNumElements()) && |
397 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) { |
398 | MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Res: Dst); }; |
399 | return true; |
400 | } |
401 | |
402 | return false; |
403 | } |
404 | |
405 | bool CombinerHelper::matchAddOfVScale(const MachineOperand &MO, |
406 | BuildFnTy &MatchInfo) { |
407 | GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: MO.getReg())); |
408 | GVScale *LHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Add->getLHSReg())); |
409 | GVScale *RHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Add->getRHSReg())); |
410 | |
411 | Register Dst = Add->getReg(Idx: 0); |
412 | |
413 | if (!MRI.hasOneNonDBGUse(RegNo: LHSVScale->getReg(Idx: 0)) || |
414 | !MRI.hasOneNonDBGUse(RegNo: RHSVScale->getReg(Idx: 0))) |
415 | return false; |
416 | |
417 | MatchInfo = [=](MachineIRBuilder &B) { |
418 | B.buildVScale(Res: Dst, MinElts: LHSVScale->getSrc() + RHSVScale->getSrc()); |
419 | }; |
420 | |
421 | return true; |
422 | } |
423 | |
424 | bool CombinerHelper::matchMulOfVScale(const MachineOperand &MO, |
425 | BuildFnTy &MatchInfo) { |
426 | GMul *Mul = cast<GMul>(Val: MRI.getVRegDef(Reg: MO.getReg())); |
427 | GVScale *LHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Mul->getLHSReg())); |
428 | |
429 | std::optional<APInt> MaybeRHS = getIConstantVRegVal(VReg: Mul->getRHSReg(), MRI); |
430 | if (!MaybeRHS) |
431 | return false; |
432 | |
433 | Register Dst = MO.getReg(); |
434 | |
435 | if (!MRI.hasOneNonDBGUse(RegNo: LHSVScale->getReg(Idx: 0))) |
436 | return false; |
437 | |
438 | MatchInfo = [=](MachineIRBuilder &B) { |
439 | B.buildVScale(Res: Dst, MinElts: LHSVScale->getSrc() * *MaybeRHS); |
440 | }; |
441 | |
442 | return true; |
443 | } |
444 | |
445 | bool CombinerHelper::matchSubOfVScale(const MachineOperand &MO, |
446 | BuildFnTy &MatchInfo) { |
447 | GSub *Sub = cast<GSub>(Val: MRI.getVRegDef(Reg: MO.getReg())); |
448 | GVScale *RHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Sub->getRHSReg())); |
449 | |
450 | Register Dst = MO.getReg(); |
451 | LLT DstTy = MRI.getType(Reg: Dst); |
452 | |
453 | if (!MRI.hasOneNonDBGUse(RegNo: RHSVScale->getReg(Idx: 0)) || |
454 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, DstTy})) |
455 | return false; |
456 | |
457 | MatchInfo = [=](MachineIRBuilder &B) { |
458 | auto VScale = B.buildVScale(Res: DstTy, MinElts: -RHSVScale->getSrc()); |
459 | B.buildAdd(Dst, Src0: Sub->getLHSReg(), Src1: VScale, Flags: Sub->getFlags()); |
460 | }; |
461 | |
462 | return true; |
463 | } |
464 | |
465 | bool CombinerHelper::matchShlOfVScale(const MachineOperand &MO, |
466 | BuildFnTy &MatchInfo) { |
467 | GShl *Shl = cast<GShl>(Val: MRI.getVRegDef(Reg: MO.getReg())); |
468 | GVScale *LHSVScale = cast<GVScale>(Val: MRI.getVRegDef(Reg: Shl->getSrcReg())); |
469 | |
470 | std::optional<APInt> MaybeRHS = getIConstantVRegVal(VReg: Shl->getShiftReg(), MRI); |
471 | if (!MaybeRHS) |
472 | return false; |
473 | |
474 | Register Dst = MO.getReg(); |
475 | LLT DstTy = MRI.getType(Reg: Dst); |
476 | |
477 | if (!MRI.hasOneNonDBGUse(RegNo: LHSVScale->getReg(Idx: 0)) || |
478 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_VSCALE, DstTy})) |
479 | return false; |
480 | |
481 | MatchInfo = [=](MachineIRBuilder &B) { |
482 | B.buildVScale(Res: Dst, MinElts: LHSVScale->getSrc().shl(ShiftAmt: *MaybeRHS)); |
483 | }; |
484 | |
485 | return true; |
486 | } |
487 | |