1//===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
9#include "llvm/ADT/APFloat.h"
10#include "llvm/ADT/STLExtras.h"
11#include "llvm/ADT/SetVector.h"
12#include "llvm/ADT/SmallBitVector.h"
13#include "llvm/Analysis/CmpInstAnalysis.h"
14#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
15#include "llvm/CodeGen/GlobalISel/GISelValueTracking.h"
16#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
17#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
19#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
20#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
21#include "llvm/CodeGen/GlobalISel/Utils.h"
22#include "llvm/CodeGen/LowLevelTypeUtils.h"
23#include "llvm/CodeGen/MachineBasicBlock.h"
24#include "llvm/CodeGen/MachineDominators.h"
25#include "llvm/CodeGen/MachineInstr.h"
26#include "llvm/CodeGen/MachineMemOperand.h"
27#include "llvm/CodeGen/MachineRegisterInfo.h"
28#include "llvm/CodeGen/Register.h"
29#include "llvm/CodeGen/RegisterBankInfo.h"
30#include "llvm/CodeGen/TargetInstrInfo.h"
31#include "llvm/CodeGen/TargetLowering.h"
32#include "llvm/CodeGen/TargetOpcodes.h"
33#include "llvm/IR/ConstantRange.h"
34#include "llvm/IR/DataLayout.h"
35#include "llvm/IR/InstrTypes.h"
36#include "llvm/IR/PatternMatch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/DivisionByConstantInfo.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/MathExtras.h"
41#include "llvm/Target/TargetMachine.h"
42#include <cmath>
43#include <optional>
44#include <tuple>
45
46#define DEBUG_TYPE "gi-combiner"
47
48using namespace llvm;
49using namespace MIPatternMatch;
50
51// Option to allow testing of the combiner while no targets know about indexed
52// addressing.
53static cl::opt<bool>
54 ForceLegalIndexing("force-legal-indexing", cl::Hidden, cl::init(Val: false),
55 cl::desc("Force all indexed operations to be "
56 "legal for the GlobalISel combiner"));
57
58CombinerHelper::CombinerHelper(GISelChangeObserver &Observer,
59 MachineIRBuilder &B, bool IsPreLegalize,
60 GISelValueTracking *VT,
61 MachineDominatorTree *MDT,
62 const LegalizerInfo *LI)
63 : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), VT(VT),
64 MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI),
65 RBI(Builder.getMF().getSubtarget().getRegBankInfo()),
66 TRI(Builder.getMF().getSubtarget().getRegisterInfo()) {
67 (void)this->VT;
68}
69
70const TargetLowering &CombinerHelper::getTargetLowering() const {
71 return *Builder.getMF().getSubtarget().getTargetLowering();
72}
73
74const MachineFunction &CombinerHelper::getMachineFunction() const {
75 return Builder.getMF();
76}
77
78const DataLayout &CombinerHelper::getDataLayout() const {
79 return getMachineFunction().getDataLayout();
80}
81
82LLVMContext &CombinerHelper::getContext() const { return Builder.getContext(); }
83
84/// \returns The little endian in-memory byte position of byte \p I in a
85/// \p ByteWidth bytes wide type.
86///
87/// E.g. Given a 4-byte type x, x[0] -> byte 0
88static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) {
89 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
90 return I;
91}
92
93/// Determines the LogBase2 value for a non-null input value using the
94/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
95static Register buildLogBase2(Register V, MachineIRBuilder &MIB) {
96 auto &MRI = *MIB.getMRI();
97 LLT Ty = MRI.getType(Reg: V);
98 auto Ctlz = MIB.buildCTLZ(Dst: Ty, Src0: V);
99 auto Base = MIB.buildConstant(Res: Ty, Val: Ty.getScalarSizeInBits() - 1);
100 return MIB.buildSub(Dst: Ty, Src0: Base, Src1: Ctlz).getReg(Idx: 0);
101}
102
103/// \returns The big endian in-memory byte position of byte \p I in a
104/// \p ByteWidth bytes wide type.
105///
106/// E.g. Given a 4-byte type x, x[0] -> byte 3
107static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) {
108 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
109 return ByteWidth - I - 1;
110}
111
112/// Given a map from byte offsets in memory to indices in a load/store,
113/// determine if that map corresponds to a little or big endian byte pattern.
114///
115/// \param MemOffset2Idx maps memory offsets to address offsets.
116/// \param LowestIdx is the lowest index in \p MemOffset2Idx.
117///
118/// \returns true if the map corresponds to a big endian byte pattern, false if
119/// it corresponds to a little endian byte pattern, and std::nullopt otherwise.
120///
121/// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns
122/// are as follows:
123///
124/// AddrOffset Little endian Big endian
125/// 0 0 3
126/// 1 1 2
127/// 2 2 1
128/// 3 3 0
129static std::optional<bool>
130isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
131 int64_t LowestIdx) {
132 // Need at least two byte positions to decide on endianness.
133 unsigned Width = MemOffset2Idx.size();
134 if (Width < 2)
135 return std::nullopt;
136 bool BigEndian = true, LittleEndian = true;
137 for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) {
138 auto MemOffsetAndIdx = MemOffset2Idx.find(Val: MemOffset);
139 if (MemOffsetAndIdx == MemOffset2Idx.end())
140 return std::nullopt;
141 const int64_t Idx = MemOffsetAndIdx->second - LowestIdx;
142 assert(Idx >= 0 && "Expected non-negative byte offset?");
143 LittleEndian &= Idx == littleEndianByteAt(ByteWidth: Width, I: MemOffset);
144 BigEndian &= Idx == bigEndianByteAt(ByteWidth: Width, I: MemOffset);
145 if (!BigEndian && !LittleEndian)
146 return std::nullopt;
147 }
148
149 assert((BigEndian != LittleEndian) &&
150 "Pattern cannot be both big and little endian!");
151 return BigEndian;
152}
153
154bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; }
155
156bool CombinerHelper::isLegal(const LegalityQuery &Query) const {
157 assert(LI && "Must have LegalizerInfo to query isLegal!");
158 return LI->getAction(Query).Action == LegalizeActions::Legal;
159}
160
161bool CombinerHelper::isLegalOrBeforeLegalizer(
162 const LegalityQuery &Query) const {
163 return isPreLegalize() || isLegal(Query);
164}
165
166bool CombinerHelper::isLegalOrHasWidenScalar(const LegalityQuery &Query) const {
167 return isLegal(Query) ||
168 LI->getAction(Query).Action == LegalizeActions::WidenScalar;
169}
170
171bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
172 if (!Ty.isVector())
173 return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CONSTANT, {Ty}});
174 // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs.
175 if (isPreLegalize())
176 return true;
177 LLT EltTy = Ty.getElementType();
178 return isLegal(Query: {TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) &&
179 isLegal(Query: {TargetOpcode::G_CONSTANT, {EltTy}});
180}
181
182void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg,
183 Register ToReg) const {
184 Observer.changingAllUsesOfReg(MRI, Reg: FromReg);
185
186 if (MRI.constrainRegAttrs(Reg: ToReg, ConstrainingReg: FromReg))
187 MRI.replaceRegWith(FromReg, ToReg);
188 else
189 Builder.buildCopy(Res: FromReg, Op: ToReg);
190
191 Observer.finishedChangingAllUsesOfReg();
192}
193
194void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI,
195 MachineOperand &FromRegOp,
196 Register ToReg) const {
197 assert(FromRegOp.getParent() && "Expected an operand in an MI");
198 Observer.changingInstr(MI&: *FromRegOp.getParent());
199
200 FromRegOp.setReg(ToReg);
201
202 Observer.changedInstr(MI&: *FromRegOp.getParent());
203}
204
205void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI,
206 unsigned ToOpcode) const {
207 Observer.changingInstr(MI&: FromMI);
208
209 FromMI.setDesc(Builder.getTII().get(Opcode: ToOpcode));
210
211 Observer.changedInstr(MI&: FromMI);
212}
213
214const RegisterBank *CombinerHelper::getRegBank(Register Reg) const {
215 return RBI->getRegBank(Reg, MRI, TRI: *TRI);
216}
217
218void CombinerHelper::setRegBank(Register Reg,
219 const RegisterBank *RegBank) const {
220 if (RegBank)
221 MRI.setRegBank(Reg, RegBank: *RegBank);
222}
223
224bool CombinerHelper::tryCombineCopy(MachineInstr &MI) const {
225 if (matchCombineCopy(MI)) {
226 applyCombineCopy(MI);
227 return true;
228 }
229 return false;
230}
231bool CombinerHelper::matchCombineCopy(MachineInstr &MI) const {
232 if (MI.getOpcode() != TargetOpcode::COPY)
233 return false;
234 Register DstReg = MI.getOperand(i: 0).getReg();
235 Register SrcReg = MI.getOperand(i: 1).getReg();
236 return canReplaceReg(DstReg, SrcReg, MRI);
237}
238void CombinerHelper::applyCombineCopy(MachineInstr &MI) const {
239 Register DstReg = MI.getOperand(i: 0).getReg();
240 Register SrcReg = MI.getOperand(i: 1).getReg();
241 replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg);
242 MI.eraseFromParent();
243}
244
245bool CombinerHelper::matchFreezeOfSingleMaybePoisonOperand(
246 MachineInstr &MI, BuildFnTy &MatchInfo) const {
247 // Ported from InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating.
248 Register DstOp = MI.getOperand(i: 0).getReg();
249 Register OrigOp = MI.getOperand(i: 1).getReg();
250
251 if (!MRI.hasOneNonDBGUse(RegNo: OrigOp))
252 return false;
253
254 MachineInstr *OrigDef = MRI.getUniqueVRegDef(Reg: OrigOp);
255 // Even if only a single operand of the PHI is not guaranteed non-poison,
256 // moving freeze() backwards across a PHI can cause optimization issues for
257 // other users of that operand.
258 //
259 // Moving freeze() from one of the output registers of a G_UNMERGE_VALUES to
260 // the source register is unprofitable because it makes the freeze() more
261 // strict than is necessary (it would affect the whole register instead of
262 // just the subreg being frozen).
263 if (OrigDef->isPHI() || isa<GUnmerge>(Val: OrigDef))
264 return false;
265
266 if (canCreateUndefOrPoison(Reg: OrigOp, MRI,
267 /*ConsiderFlagsAndMetadata=*/false))
268 return false;
269
270 std::optional<MachineOperand> MaybePoisonOperand;
271 for (MachineOperand &Operand : OrigDef->uses()) {
272 if (!Operand.isReg())
273 return false;
274
275 if (isGuaranteedNotToBeUndefOrPoison(Reg: Operand.getReg(), MRI))
276 continue;
277
278 if (!MaybePoisonOperand)
279 MaybePoisonOperand = Operand;
280 else {
281 // We have more than one maybe-poison operand. Moving the freeze is
282 // unsafe.
283 return false;
284 }
285 }
286
287 // Eliminate freeze if all operands are guaranteed non-poison.
288 if (!MaybePoisonOperand) {
289 MatchInfo = [=](MachineIRBuilder &B) {
290 Observer.changingInstr(MI&: *OrigDef);
291 cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags();
292 Observer.changedInstr(MI&: *OrigDef);
293 B.buildCopy(Res: DstOp, Op: OrigOp);
294 };
295 return true;
296 }
297
298 Register MaybePoisonOperandReg = MaybePoisonOperand->getReg();
299 LLT MaybePoisonOperandRegTy = MRI.getType(Reg: MaybePoisonOperandReg);
300
301 MatchInfo = [=](MachineIRBuilder &B) mutable {
302 Observer.changingInstr(MI&: *OrigDef);
303 cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags();
304 Observer.changedInstr(MI&: *OrigDef);
305 B.setInsertPt(MBB&: *OrigDef->getParent(), II: OrigDef->getIterator());
306 auto Freeze = B.buildFreeze(Dst: MaybePoisonOperandRegTy, Src: MaybePoisonOperandReg);
307 replaceRegOpWith(
308 MRI, FromRegOp&: *OrigDef->findRegisterUseOperand(Reg: MaybePoisonOperandReg, TRI),
309 ToReg: Freeze.getReg(Idx: 0));
310 replaceRegWith(MRI, FromReg: DstOp, ToReg: OrigOp);
311 };
312 return true;
313}
314
315bool CombinerHelper::matchCombineConcatVectors(
316 MachineInstr &MI, SmallVector<Register> &Ops) const {
317 assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS &&
318 "Invalid instruction");
319 bool IsUndef = true;
320 MachineInstr *Undef = nullptr;
321
322 // Walk over all the operands of concat vectors and check if they are
323 // build_vector themselves or undef.
324 // Then collect their operands in Ops.
325 for (const MachineOperand &MO : MI.uses()) {
326 Register Reg = MO.getReg();
327 MachineInstr *Def = MRI.getVRegDef(Reg);
328 assert(Def && "Operand not defined");
329 if (!MRI.hasOneNonDBGUse(RegNo: Reg))
330 return false;
331 switch (Def->getOpcode()) {
332 case TargetOpcode::G_BUILD_VECTOR:
333 IsUndef = false;
334 // Remember the operands of the build_vector to fold
335 // them into the yet-to-build flattened concat vectors.
336 for (const MachineOperand &BuildVecMO : Def->uses())
337 Ops.push_back(Elt: BuildVecMO.getReg());
338 break;
339 case TargetOpcode::G_IMPLICIT_DEF: {
340 LLT OpType = MRI.getType(Reg);
341 // Keep one undef value for all the undef operands.
342 if (!Undef) {
343 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
344 Undef = Builder.buildUndef(Res: OpType.getScalarType());
345 }
346 assert(MRI.getType(Undef->getOperand(0).getReg()) ==
347 OpType.getScalarType() &&
348 "All undefs should have the same type");
349 // Break the undef vector in as many scalar elements as needed
350 // for the flattening.
351 for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements();
352 EltIdx != EltEnd; ++EltIdx)
353 Ops.push_back(Elt: Undef->getOperand(i: 0).getReg());
354 break;
355 }
356 default:
357 return false;
358 }
359 }
360
361 // Check if the combine is illegal
362 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
363 if (!isLegalOrBeforeLegalizer(
364 Query: {TargetOpcode::G_BUILD_VECTOR, {DstTy, MRI.getType(Reg: Ops[0])}})) {
365 return false;
366 }
367
368 if (IsUndef)
369 Ops.clear();
370
371 return true;
372}
373void CombinerHelper::applyCombineConcatVectors(
374 MachineInstr &MI, SmallVector<Register> &Ops) const {
375 // We determined that the concat_vectors can be flatten.
376 // Generate the flattened build_vector.
377 Register DstReg = MI.getOperand(i: 0).getReg();
378 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
379 Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg);
380
381 // Note: IsUndef is sort of redundant. We could have determine it by
382 // checking that at all Ops are undef. Alternatively, we could have
383 // generate a build_vector of undefs and rely on another combine to
384 // clean that up. For now, given we already gather this information
385 // in matchCombineConcatVectors, just save compile time and issue the
386 // right thing.
387 if (Ops.empty())
388 Builder.buildUndef(Res: NewDstReg);
389 else
390 Builder.buildBuildVector(Res: NewDstReg, Ops);
391 replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg);
392 MI.eraseFromParent();
393}
394
395void CombinerHelper::applyCombineShuffleToBuildVector(MachineInstr &MI) const {
396 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
397
398 Register SrcVec1 = Shuffle.getSrc1Reg();
399 Register SrcVec2 = Shuffle.getSrc2Reg();
400 LLT EltTy = MRI.getType(Reg: SrcVec1).getElementType();
401 int Width = MRI.getType(Reg: SrcVec1).getNumElements();
402
403 auto Unmerge1 = Builder.buildUnmerge(Res: EltTy, Op: SrcVec1);
404 auto Unmerge2 = Builder.buildUnmerge(Res: EltTy, Op: SrcVec2);
405
406 SmallVector<Register> Extracts;
407 // Select only applicable elements from unmerged values.
408 for (int Val : Shuffle.getMask()) {
409 if (Val == -1)
410 Extracts.push_back(Elt: Builder.buildUndef(Res: EltTy).getReg(Idx: 0));
411 else if (Val < Width)
412 Extracts.push_back(Elt: Unmerge1.getReg(Idx: Val));
413 else
414 Extracts.push_back(Elt: Unmerge2.getReg(Idx: Val - Width));
415 }
416 assert(Extracts.size() > 0 && "Expected at least one element in the shuffle");
417 if (Extracts.size() == 1)
418 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: Extracts[0]);
419 else
420 Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: Extracts);
421 MI.eraseFromParent();
422}
423
424bool CombinerHelper::matchCombineShuffleConcat(
425 MachineInstr &MI, SmallVector<Register> &Ops) const {
426 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
427 auto ConcatMI1 =
428 dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg()));
429 auto ConcatMI2 =
430 dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg()));
431 if (!ConcatMI1 || !ConcatMI2)
432 return false;
433
434 // Check that the sources of the Concat instructions have the same type
435 if (MRI.getType(Reg: ConcatMI1->getSourceReg(I: 0)) !=
436 MRI.getType(Reg: ConcatMI2->getSourceReg(I: 0)))
437 return false;
438
439 LLT ConcatSrcTy = MRI.getType(Reg: ConcatMI1->getReg(Idx: 1));
440 LLT ShuffleSrcTy1 = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
441 unsigned ConcatSrcNumElt = ConcatSrcTy.getNumElements();
442 for (unsigned i = 0; i < Mask.size(); i += ConcatSrcNumElt) {
443 // Check if the index takes a whole source register from G_CONCAT_VECTORS
444 // Assumes that all Sources of G_CONCAT_VECTORS are the same type
445 if (Mask[i] == -1) {
446 for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
447 if (i + j >= Mask.size())
448 return false;
449 if (Mask[i + j] != -1)
450 return false;
451 }
452 if (!isLegalOrBeforeLegalizer(
453 Query: {TargetOpcode::G_IMPLICIT_DEF, {ConcatSrcTy}}))
454 return false;
455 Ops.push_back(Elt: 0);
456 } else if (Mask[i] % ConcatSrcNumElt == 0) {
457 for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
458 if (i + j >= Mask.size())
459 return false;
460 if (Mask[i + j] != Mask[i] + static_cast<int>(j))
461 return false;
462 }
463 // Retrieve the source register from its respective G_CONCAT_VECTORS
464 // instruction
465 if (Mask[i] < ShuffleSrcTy1.getNumElements()) {
466 Ops.push_back(Elt: ConcatMI1->getSourceReg(I: Mask[i] / ConcatSrcNumElt));
467 } else {
468 Ops.push_back(Elt: ConcatMI2->getSourceReg(I: Mask[i] / ConcatSrcNumElt -
469 ConcatMI1->getNumSources()));
470 }
471 } else {
472 return false;
473 }
474 }
475
476 if (!isLegalOrBeforeLegalizer(
477 Query: {TargetOpcode::G_CONCAT_VECTORS,
478 {MRI.getType(Reg: MI.getOperand(i: 0).getReg()), ConcatSrcTy}}))
479 return false;
480
481 return !Ops.empty();
482}
483
484void CombinerHelper::applyCombineShuffleConcat(
485 MachineInstr &MI, SmallVector<Register> &Ops) const {
486 LLT SrcTy;
487 for (Register &Reg : Ops) {
488 if (Reg != 0)
489 SrcTy = MRI.getType(Reg);
490 }
491 assert(SrcTy.isValid() && "Unexpected full undef vector in concat combine");
492
493 Register UndefReg = 0;
494
495 for (Register &Reg : Ops) {
496 if (Reg == 0) {
497 if (UndefReg == 0)
498 UndefReg = Builder.buildUndef(Res: SrcTy).getReg(Idx: 0);
499 Reg = UndefReg;
500 }
501 }
502
503 if (Ops.size() > 1)
504 Builder.buildConcatVectors(Res: MI.getOperand(i: 0).getReg(), Ops);
505 else
506 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: Ops[0]);
507 MI.eraseFromParent();
508}
509
510bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) const {
511 SmallVector<Register, 4> Ops;
512 if (matchCombineShuffleVector(MI, Ops)) {
513 applyCombineShuffleVector(MI, Ops);
514 return true;
515 }
516 return false;
517}
518
519bool CombinerHelper::matchCombineShuffleVector(
520 MachineInstr &MI, SmallVectorImpl<Register> &Ops) const {
521 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
522 "Invalid instruction kind");
523 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
524 Register Src1 = MI.getOperand(i: 1).getReg();
525 LLT SrcType = MRI.getType(Reg: Src1);
526
527 unsigned DstNumElts = DstType.getNumElements();
528 unsigned SrcNumElts = SrcType.getNumElements();
529
530 // If the resulting vector is smaller than the size of the source
531 // vectors being concatenated, we won't be able to replace the
532 // shuffle vector into a concat_vectors.
533 //
534 // Note: We may still be able to produce a concat_vectors fed by
535 // extract_vector_elt and so on. It is less clear that would
536 // be better though, so don't bother for now.
537 //
538 // If the destination is a scalar, the size of the sources doesn't
539 // matter. we will lower the shuffle to a plain copy. This will
540 // work only if the source and destination have the same size. But
541 // that's covered by the next condition.
542 //
543 // TODO: If the size between the source and destination don't match
544 // we could still emit an extract vector element in that case.
545 if (DstNumElts < 2 * SrcNumElts)
546 return false;
547
548 // Check that the shuffle mask can be broken evenly between the
549 // different sources.
550 if (DstNumElts % SrcNumElts != 0)
551 return false;
552
553 // Mask length is a multiple of the source vector length.
554 // Check if the shuffle is some kind of concatenation of the input
555 // vectors.
556 unsigned NumConcat = DstNumElts / SrcNumElts;
557 SmallVector<int, 8> ConcatSrcs(NumConcat, -1);
558 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
559 for (unsigned i = 0; i != DstNumElts; ++i) {
560 int Idx = Mask[i];
561 // Undef value.
562 if (Idx < 0)
563 continue;
564 // Ensure the indices in each SrcType sized piece are sequential and that
565 // the same source is used for the whole piece.
566 if ((Idx % SrcNumElts != (i % SrcNumElts)) ||
567 (ConcatSrcs[i / SrcNumElts] >= 0 &&
568 ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts)))
569 return false;
570 // Remember which source this index came from.
571 ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts;
572 }
573
574 // The shuffle is concatenating multiple vectors together.
575 // Collect the different operands for that.
576 Register UndefReg;
577 Register Src2 = MI.getOperand(i: 2).getReg();
578 for (auto Src : ConcatSrcs) {
579 if (Src < 0) {
580 if (!UndefReg) {
581 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
582 UndefReg = Builder.buildUndef(Res: SrcType).getReg(Idx: 0);
583 }
584 Ops.push_back(Elt: UndefReg);
585 } else if (Src == 0)
586 Ops.push_back(Elt: Src1);
587 else
588 Ops.push_back(Elt: Src2);
589 }
590 return true;
591}
592
593void CombinerHelper::applyCombineShuffleVector(MachineInstr &MI,
594 ArrayRef<Register> Ops) const {
595 Register DstReg = MI.getOperand(i: 0).getReg();
596 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
597 Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg);
598
599 if (Ops.size() == 1)
600 Builder.buildCopy(Res: NewDstReg, Op: Ops[0]);
601 else
602 Builder.buildMergeLikeInstr(Res: NewDstReg, Ops);
603
604 replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg);
605 MI.eraseFromParent();
606}
607
608namespace {
609
610/// Select a preference between two uses. CurrentUse is the current preference
611/// while *ForCandidate is attributes of the candidate under consideration.
612PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI,
613 PreferredTuple &CurrentUse,
614 const LLT TyForCandidate,
615 unsigned OpcodeForCandidate,
616 MachineInstr *MIForCandidate) {
617 if (!CurrentUse.Ty.isValid()) {
618 if (CurrentUse.ExtendOpcode == OpcodeForCandidate ||
619 CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT)
620 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
621 return CurrentUse;
622 }
623
624 // We permit the extend to hoist through basic blocks but this is only
625 // sensible if the target has extending loads. If you end up lowering back
626 // into a load and extend during the legalizer then the end result is
627 // hoisting the extend up to the load.
628
629 // Prefer defined extensions to undefined extensions as these are more
630 // likely to reduce the number of instructions.
631 if (OpcodeForCandidate == TargetOpcode::G_ANYEXT &&
632 CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT)
633 return CurrentUse;
634 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT &&
635 OpcodeForCandidate != TargetOpcode::G_ANYEXT)
636 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
637
638 // Prefer sign extensions to zero extensions as sign-extensions tend to be
639 // more expensive. Don't do this if the load is already a zero-extend load
640 // though, otherwise we'll rewrite a zero-extend load into a sign-extend
641 // later.
642 if (!isa<GZExtLoad>(Val: LoadMI) && CurrentUse.Ty == TyForCandidate) {
643 if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT &&
644 OpcodeForCandidate == TargetOpcode::G_ZEXT)
645 return CurrentUse;
646 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT &&
647 OpcodeForCandidate == TargetOpcode::G_SEXT)
648 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
649 }
650
651 // This is potentially target specific. We've chosen the largest type
652 // because G_TRUNC is usually free. One potential catch with this is that
653 // some targets have a reduced number of larger registers than smaller
654 // registers and this choice potentially increases the live-range for the
655 // larger value.
656 if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) {
657 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
658 }
659 return CurrentUse;
660}
661
662/// Find a suitable place to insert some instructions and insert them. This
663/// function accounts for special cases like inserting before a PHI node.
664/// The current strategy for inserting before PHI's is to duplicate the
665/// instructions for each predecessor. However, while that's ok for G_TRUNC
666/// on most targets since it generally requires no code, other targets/cases may
667/// want to try harder to find a dominating block.
668static void InsertInsnsWithoutSideEffectsBeforeUse(
669 MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO,
670 std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator,
671 MachineOperand &UseMO)>
672 Inserter) {
673 MachineInstr &UseMI = *UseMO.getParent();
674
675 MachineBasicBlock *InsertBB = UseMI.getParent();
676
677 // If the use is a PHI then we want the predecessor block instead.
678 if (UseMI.isPHI()) {
679 MachineOperand *PredBB = std::next(x: &UseMO);
680 InsertBB = PredBB->getMBB();
681 }
682
683 // If the block is the same block as the def then we want to insert just after
684 // the def instead of at the start of the block.
685 if (InsertBB == DefMI.getParent()) {
686 MachineBasicBlock::iterator InsertPt = &DefMI;
687 Inserter(InsertBB, std::next(x: InsertPt), UseMO);
688 return;
689 }
690
691 // Otherwise we want the start of the BB
692 Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO);
693}
694} // end anonymous namespace
695
696bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) const {
697 PreferredTuple Preferred;
698 if (matchCombineExtendingLoads(MI, MatchInfo&: Preferred)) {
699 applyCombineExtendingLoads(MI, MatchInfo&: Preferred);
700 return true;
701 }
702 return false;
703}
704
705static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) {
706 unsigned CandidateLoadOpc;
707 switch (ExtOpc) {
708 case TargetOpcode::G_ANYEXT:
709 CandidateLoadOpc = TargetOpcode::G_LOAD;
710 break;
711 case TargetOpcode::G_SEXT:
712 CandidateLoadOpc = TargetOpcode::G_SEXTLOAD;
713 break;
714 case TargetOpcode::G_ZEXT:
715 CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD;
716 break;
717 default:
718 llvm_unreachable("Unexpected extend opc");
719 }
720 return CandidateLoadOpc;
721}
722
723bool CombinerHelper::matchCombineExtendingLoads(
724 MachineInstr &MI, PreferredTuple &Preferred) const {
725 // We match the loads and follow the uses to the extend instead of matching
726 // the extends and following the def to the load. This is because the load
727 // must remain in the same position for correctness (unless we also add code
728 // to find a safe place to sink it) whereas the extend is freely movable.
729 // It also prevents us from duplicating the load for the volatile case or just
730 // for performance.
731 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: &MI);
732 if (!LoadMI)
733 return false;
734
735 Register LoadReg = LoadMI->getDstReg();
736
737 LLT LoadValueTy = MRI.getType(Reg: LoadReg);
738 if (!LoadValueTy.isScalar())
739 return false;
740
741 // Most architectures are going to legalize <s8 loads into at least a 1 byte
742 // load, and the MMOs can only describe memory accesses in multiples of bytes.
743 // If we try to perform extload combining on those, we can end up with
744 // %a(s8) = extload %ptr (load 1 byte from %ptr)
745 // ... which is an illegal extload instruction.
746 if (LoadValueTy.getSizeInBits() < 8)
747 return false;
748
749 // For non power-of-2 types, they will very likely be legalized into multiple
750 // loads. Don't bother trying to match them into extending loads.
751 if (!llvm::has_single_bit<uint32_t>(Value: LoadValueTy.getSizeInBits()))
752 return false;
753
754 // Find the preferred type aside from the any-extends (unless it's the only
755 // one) and non-extending ops. We'll emit an extending load to that type and
756 // and emit a variant of (extend (trunc X)) for the others according to the
757 // relative type sizes. At the same time, pick an extend to use based on the
758 // extend involved in the chosen type.
759 unsigned PreferredOpcode =
760 isa<GLoad>(Val: &MI)
761 ? TargetOpcode::G_ANYEXT
762 : isa<GSExtLoad>(Val: &MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
763 Preferred = {.Ty: LLT(), .ExtendOpcode: PreferredOpcode, .MI: nullptr};
764 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: LoadReg)) {
765 if (UseMI.getOpcode() == TargetOpcode::G_SEXT ||
766 UseMI.getOpcode() == TargetOpcode::G_ZEXT ||
767 (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) {
768 const auto &MMO = LoadMI->getMMO();
769 // Don't do anything for atomics.
770 if (MMO.isAtomic())
771 continue;
772 // Check for legality.
773 if (!isPreLegalize()) {
774 LegalityQuery::MemDesc MMDesc(MMO);
775 unsigned CandidateLoadOpc = getExtLoadOpcForExtend(ExtOpc: UseMI.getOpcode());
776 LLT UseTy = MRI.getType(Reg: UseMI.getOperand(i: 0).getReg());
777 LLT SrcTy = MRI.getType(Reg: LoadMI->getPointerReg());
778 if (LI->getAction(Query: {CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}})
779 .Action != LegalizeActions::Legal)
780 continue;
781 }
782 Preferred = ChoosePreferredUse(LoadMI&: MI, CurrentUse&: Preferred,
783 TyForCandidate: MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()),
784 OpcodeForCandidate: UseMI.getOpcode(), MIForCandidate: &UseMI);
785 }
786 }
787
788 // There were no extends
789 if (!Preferred.MI)
790 return false;
791 // It should be impossible to chose an extend without selecting a different
792 // type since by definition the result of an extend is larger.
793 assert(Preferred.Ty != LoadValueTy && "Extending to same type?");
794
795 LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI);
796 return true;
797}
798
799void CombinerHelper::applyCombineExtendingLoads(
800 MachineInstr &MI, PreferredTuple &Preferred) const {
801 // Rewrite the load to the chosen extending load.
802 Register ChosenDstReg = Preferred.MI->getOperand(i: 0).getReg();
803
804 // Inserter to insert a truncate back to the original type at a given point
805 // with some basic CSE to limit truncate duplication to one per BB.
806 DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns;
807 auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB,
808 MachineBasicBlock::iterator InsertBefore,
809 MachineOperand &UseMO) {
810 MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(Val: InsertIntoBB);
811 if (PreviouslyEmitted) {
812 Observer.changingInstr(MI&: *UseMO.getParent());
813 UseMO.setReg(PreviouslyEmitted->getOperand(i: 0).getReg());
814 Observer.changedInstr(MI&: *UseMO.getParent());
815 return;
816 }
817
818 Builder.setInsertPt(MBB&: *InsertIntoBB, II: InsertBefore);
819 Register NewDstReg = MRI.cloneVirtualRegister(VReg: MI.getOperand(i: 0).getReg());
820 MachineInstr *NewMI = Builder.buildTrunc(Res: NewDstReg, Op: ChosenDstReg);
821 EmittedInsns[InsertIntoBB] = NewMI;
822 replaceRegOpWith(MRI, FromRegOp&: UseMO, ToReg: NewDstReg);
823 };
824
825 Observer.changingInstr(MI);
826 unsigned LoadOpc = getExtLoadOpcForExtend(ExtOpc: Preferred.ExtendOpcode);
827 MI.setDesc(Builder.getTII().get(Opcode: LoadOpc));
828
829 // Rewrite all the uses to fix up the types.
830 auto &LoadValue = MI.getOperand(i: 0);
831 SmallVector<MachineOperand *, 4> Uses(
832 llvm::make_pointer_range(Range: MRI.use_operands(Reg: LoadValue.getReg())));
833
834 for (auto *UseMO : Uses) {
835 MachineInstr *UseMI = UseMO->getParent();
836
837 // If the extend is compatible with the preferred extend then we should fix
838 // up the type and extend so that it uses the preferred use.
839 if (UseMI->getOpcode() == Preferred.ExtendOpcode ||
840 UseMI->getOpcode() == TargetOpcode::G_ANYEXT) {
841 Register UseDstReg = UseMI->getOperand(i: 0).getReg();
842 MachineOperand &UseSrcMO = UseMI->getOperand(i: 1);
843 const LLT UseDstTy = MRI.getType(Reg: UseDstReg);
844 if (UseDstReg != ChosenDstReg) {
845 if (Preferred.Ty == UseDstTy) {
846 // If the use has the same type as the preferred use, then merge
847 // the vregs and erase the extend. For example:
848 // %1:_(s8) = G_LOAD ...
849 // %2:_(s32) = G_SEXT %1(s8)
850 // %3:_(s32) = G_ANYEXT %1(s8)
851 // ... = ... %3(s32)
852 // rewrites to:
853 // %2:_(s32) = G_SEXTLOAD ...
854 // ... = ... %2(s32)
855 replaceRegWith(MRI, FromReg: UseDstReg, ToReg: ChosenDstReg);
856 Observer.erasingInstr(MI&: *UseMO->getParent());
857 UseMO->getParent()->eraseFromParent();
858 } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) {
859 // If the preferred size is smaller, then keep the extend but extend
860 // from the result of the extending load. For example:
861 // %1:_(s8) = G_LOAD ...
862 // %2:_(s32) = G_SEXT %1(s8)
863 // %3:_(s64) = G_ANYEXT %1(s8)
864 // ... = ... %3(s64)
865 /// rewrites to:
866 // %2:_(s32) = G_SEXTLOAD ...
867 // %3:_(s64) = G_ANYEXT %2:_(s32)
868 // ... = ... %3(s64)
869 replaceRegOpWith(MRI, FromRegOp&: UseSrcMO, ToReg: ChosenDstReg);
870 } else {
871 // If the preferred size is large, then insert a truncate. For
872 // example:
873 // %1:_(s8) = G_LOAD ...
874 // %2:_(s64) = G_SEXT %1(s8)
875 // %3:_(s32) = G_ZEXT %1(s8)
876 // ... = ... %3(s32)
877 /// rewrites to:
878 // %2:_(s64) = G_SEXTLOAD ...
879 // %4:_(s8) = G_TRUNC %2:_(s32)
880 // %3:_(s64) = G_ZEXT %2:_(s8)
881 // ... = ... %3(s64)
882 InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO,
883 Inserter: InsertTruncAt);
884 }
885 continue;
886 }
887 // The use is (one of) the uses of the preferred use we chose earlier.
888 // We're going to update the load to def this value later so just erase
889 // the old extend.
890 Observer.erasingInstr(MI&: *UseMO->getParent());
891 UseMO->getParent()->eraseFromParent();
892 continue;
893 }
894
895 // The use isn't an extend. Truncate back to the type we originally loaded.
896 // This is free on many targets.
897 InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, Inserter: InsertTruncAt);
898 }
899
900 MI.getOperand(i: 0).setReg(ChosenDstReg);
901 Observer.changedInstr(MI);
902}
903
904bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI,
905 BuildFnTy &MatchInfo) const {
906 assert(MI.getOpcode() == TargetOpcode::G_AND);
907
908 // If we have the following code:
909 // %mask = G_CONSTANT 255
910 // %ld = G_LOAD %ptr, (load s16)
911 // %and = G_AND %ld, %mask
912 //
913 // Try to fold it into
914 // %ld = G_ZEXTLOAD %ptr, (load s8)
915
916 Register Dst = MI.getOperand(i: 0).getReg();
917 if (MRI.getType(Reg: Dst).isVector())
918 return false;
919
920 auto MaybeMask =
921 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
922 if (!MaybeMask)
923 return false;
924
925 APInt MaskVal = MaybeMask->Value;
926
927 if (!MaskVal.isMask())
928 return false;
929
930 Register SrcReg = MI.getOperand(i: 1).getReg();
931 // Don't use getOpcodeDef() here since intermediate instructions may have
932 // multiple users.
933 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: MRI.getVRegDef(Reg: SrcReg));
934 if (!LoadMI || !MRI.hasOneNonDBGUse(RegNo: LoadMI->getDstReg()))
935 return false;
936
937 Register LoadReg = LoadMI->getDstReg();
938 LLT RegTy = MRI.getType(Reg: LoadReg);
939 Register PtrReg = LoadMI->getPointerReg();
940 unsigned RegSize = RegTy.getSizeInBits();
941 LocationSize LoadSizeBits = LoadMI->getMemSizeInBits();
942 unsigned MaskSizeBits = MaskVal.countr_one();
943
944 // The mask may not be larger than the in-memory type, as it might cover sign
945 // extended bits
946 if (MaskSizeBits > LoadSizeBits.getValue())
947 return false;
948
949 // If the mask covers the whole destination register, there's nothing to
950 // extend
951 if (MaskSizeBits >= RegSize)
952 return false;
953
954 // Most targets cannot deal with loads of size < 8 and need to re-legalize to
955 // at least byte loads. Avoid creating such loads here
956 if (MaskSizeBits < 8 || !isPowerOf2_32(Value: MaskSizeBits))
957 return false;
958
959 const MachineMemOperand &MMO = LoadMI->getMMO();
960 LegalityQuery::MemDesc MemDesc(MMO);
961
962 // Don't modify the memory access size if this is atomic/volatile, but we can
963 // still adjust the opcode to indicate the high bit behavior.
964 if (LoadMI->isSimple())
965 MemDesc.MemoryTy = LLT::scalar(SizeInBits: MaskSizeBits);
966 else if (LoadSizeBits.getValue() > MaskSizeBits ||
967 LoadSizeBits.getValue() == RegSize)
968 return false;
969
970 // TODO: Could check if it's legal with the reduced or original memory size.
971 if (!isLegalOrBeforeLegalizer(
972 Query: {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(Reg: PtrReg)}, {MemDesc}}))
973 return false;
974
975 MatchInfo = [=](MachineIRBuilder &B) {
976 B.setInstrAndDebugLoc(*LoadMI);
977 auto &MF = B.getMF();
978 auto PtrInfo = MMO.getPointerInfo();
979 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: MemDesc.MemoryTy);
980 B.buildLoadInstr(Opcode: TargetOpcode::G_ZEXTLOAD, Res: Dst, Addr: PtrReg, MMO&: *NewMMO);
981 LoadMI->eraseFromParent();
982 };
983 return true;
984}
985
986bool CombinerHelper::isPredecessor(const MachineInstr &DefMI,
987 const MachineInstr &UseMI) const {
988 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
989 "shouldn't consider debug uses");
990 assert(DefMI.getParent() == UseMI.getParent());
991 if (&DefMI == &UseMI)
992 return true;
993 const MachineBasicBlock &MBB = *DefMI.getParent();
994 auto DefOrUse = find_if(Range: MBB, P: [&DefMI, &UseMI](const MachineInstr &MI) {
995 return &MI == &DefMI || &MI == &UseMI;
996 });
997 if (DefOrUse == MBB.end())
998 llvm_unreachable("Block must contain both DefMI and UseMI!");
999 return &*DefOrUse == &DefMI;
1000}
1001
1002bool CombinerHelper::dominates(const MachineInstr &DefMI,
1003 const MachineInstr &UseMI) const {
1004 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
1005 "shouldn't consider debug uses");
1006 if (MDT)
1007 return MDT->dominates(A: &DefMI, B: &UseMI);
1008 else if (DefMI.getParent() != UseMI.getParent())
1009 return false;
1010
1011 return isPredecessor(DefMI, UseMI);
1012}
1013
1014bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) const {
1015 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1016 Register SrcReg = MI.getOperand(i: 1).getReg();
1017 Register LoadUser = SrcReg;
1018
1019 if (MRI.getType(Reg: SrcReg).isVector())
1020 return false;
1021
1022 Register TruncSrc;
1023 if (mi_match(R: SrcReg, MRI, P: m_GTrunc(Src: m_Reg(R&: TruncSrc))))
1024 LoadUser = TruncSrc;
1025
1026 uint64_t SizeInBits = MI.getOperand(i: 2).getImm();
1027 // If the source is a G_SEXTLOAD from the same bit width, then we don't
1028 // need any extend at all, just a truncate.
1029 if (auto *LoadMI = getOpcodeDef<GSExtLoad>(Reg: LoadUser, MRI)) {
1030 // If truncating more than the original extended value, abort.
1031 auto LoadSizeBits = LoadMI->getMemSizeInBits();
1032 if (TruncSrc &&
1033 MRI.getType(Reg: TruncSrc).getSizeInBits() < LoadSizeBits.getValue())
1034 return false;
1035 if (LoadSizeBits == SizeInBits)
1036 return true;
1037 }
1038 return false;
1039}
1040
1041void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) const {
1042 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1043 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: MI.getOperand(i: 1).getReg());
1044 MI.eraseFromParent();
1045}
1046
1047bool CombinerHelper::matchSextInRegOfLoad(
1048 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const {
1049 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1050
1051 Register DstReg = MI.getOperand(i: 0).getReg();
1052 LLT RegTy = MRI.getType(Reg: DstReg);
1053
1054 // Only supports scalars for now.
1055 if (RegTy.isVector())
1056 return false;
1057
1058 Register SrcReg = MI.getOperand(i: 1).getReg();
1059 auto *LoadDef = getOpcodeDef<GLoad>(Reg: SrcReg, MRI);
1060 if (!LoadDef || !MRI.hasOneNonDBGUse(RegNo: SrcReg))
1061 return false;
1062
1063 uint64_t MemBits = LoadDef->getMemSizeInBits().getValue();
1064
1065 // If the sign extend extends from a narrower width than the load's width,
1066 // then we can narrow the load width when we combine to a G_SEXTLOAD.
1067 // Avoid widening the load at all.
1068 unsigned NewSizeBits = std::min(a: (uint64_t)MI.getOperand(i: 2).getImm(), b: MemBits);
1069
1070 // Don't generate G_SEXTLOADs with a < 1 byte width.
1071 if (NewSizeBits < 8)
1072 return false;
1073 // Don't bother creating a non-power-2 sextload, it will likely be broken up
1074 // anyway for most targets.
1075 if (!isPowerOf2_32(Value: NewSizeBits))
1076 return false;
1077
1078 const MachineMemOperand &MMO = LoadDef->getMMO();
1079 LegalityQuery::MemDesc MMDesc(MMO);
1080
1081 // Don't modify the memory access size if this is atomic/volatile, but we can
1082 // still adjust the opcode to indicate the high bit behavior.
1083 if (LoadDef->isSimple())
1084 MMDesc.MemoryTy = LLT::scalar(SizeInBits: NewSizeBits);
1085 else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits())
1086 return false;
1087
1088 // TODO: Could check if it's legal with the reduced or original memory size.
1089 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXTLOAD,
1090 {MRI.getType(Reg: LoadDef->getDstReg()),
1091 MRI.getType(Reg: LoadDef->getPointerReg())},
1092 {MMDesc}}))
1093 return false;
1094
1095 MatchInfo = std::make_tuple(args: LoadDef->getDstReg(), args&: NewSizeBits);
1096 return true;
1097}
1098
1099void CombinerHelper::applySextInRegOfLoad(
1100 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const {
1101 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1102 Register LoadReg;
1103 unsigned ScalarSizeBits;
1104 std::tie(args&: LoadReg, args&: ScalarSizeBits) = MatchInfo;
1105 GLoad *LoadDef = cast<GLoad>(Val: MRI.getVRegDef(Reg: LoadReg));
1106
1107 // If we have the following:
1108 // %ld = G_LOAD %ptr, (load 2)
1109 // %ext = G_SEXT_INREG %ld, 8
1110 // ==>
1111 // %ld = G_SEXTLOAD %ptr (load 1)
1112
1113 auto &MMO = LoadDef->getMMO();
1114 Builder.setInstrAndDebugLoc(*LoadDef);
1115 auto &MF = Builder.getMF();
1116 auto PtrInfo = MMO.getPointerInfo();
1117 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: ScalarSizeBits / 8);
1118 Builder.buildLoadInstr(Opcode: TargetOpcode::G_SEXTLOAD, Res: MI.getOperand(i: 0).getReg(),
1119 Addr: LoadDef->getPointerReg(), MMO&: *NewMMO);
1120 MI.eraseFromParent();
1121
1122 // Not all loads can be deleted, so make sure the old one is removed.
1123 LoadDef->eraseFromParent();
1124}
1125
1126/// Return true if 'MI' is a load or a store that may be fold it's address
1127/// operand into the load / store addressing mode.
1128static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI,
1129 MachineRegisterInfo &MRI) {
1130 TargetLowering::AddrMode AM;
1131 auto *MF = MI->getMF();
1132 auto *Addr = getOpcodeDef<GPtrAdd>(Reg: MI->getPointerReg(), MRI);
1133 if (!Addr)
1134 return false;
1135
1136 AM.HasBaseReg = true;
1137 if (auto CstOff = getIConstantVRegVal(VReg: Addr->getOffsetReg(), MRI))
1138 AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm]
1139 else
1140 AM.Scale = 1; // [reg +/- reg]
1141
1142 return TLI.isLegalAddressingMode(
1143 DL: MF->getDataLayout(), AM,
1144 Ty: getTypeForLLT(Ty: MI->getMMO().getMemoryType(),
1145 C&: MF->getFunction().getContext()),
1146 AddrSpace: MI->getMMO().getAddrSpace());
1147}
1148
1149static unsigned getIndexedOpc(unsigned LdStOpc) {
1150 switch (LdStOpc) {
1151 case TargetOpcode::G_LOAD:
1152 return TargetOpcode::G_INDEXED_LOAD;
1153 case TargetOpcode::G_STORE:
1154 return TargetOpcode::G_INDEXED_STORE;
1155 case TargetOpcode::G_ZEXTLOAD:
1156 return TargetOpcode::G_INDEXED_ZEXTLOAD;
1157 case TargetOpcode::G_SEXTLOAD:
1158 return TargetOpcode::G_INDEXED_SEXTLOAD;
1159 default:
1160 llvm_unreachable("Unexpected opcode");
1161 }
1162}
1163
1164bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
1165 // Check for legality.
1166 LLT PtrTy = MRI.getType(Reg: LdSt.getPointerReg());
1167 LLT Ty = MRI.getType(Reg: LdSt.getReg(Idx: 0));
1168 LLT MemTy = LdSt.getMMO().getMemoryType();
1169 SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
1170 {{MemTy, MemTy.getSizeInBits().getKnownMinValue(),
1171 AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic}});
1172 unsigned IndexedOpc = getIndexedOpc(LdStOpc: LdSt.getOpcode());
1173 SmallVector<LLT> OpTys;
1174 if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)
1175 OpTys = {PtrTy, Ty, Ty};
1176 else
1177 OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD
1178
1179 LegalityQuery Q(IndexedOpc, OpTys, MemDescrs);
1180 return isLegal(Query: Q);
1181}
1182
1183static cl::opt<unsigned> PostIndexUseThreshold(
1184 "post-index-use-threshold", cl::Hidden, cl::init(Val: 32),
1185 cl::desc("Number of uses of a base pointer to check before it is no longer "
1186 "considered for post-indexing."));
1187
1188bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr,
1189 Register &Base, Register &Offset,
1190 bool &RematOffset) const {
1191 // We're looking for the following pattern, for either load or store:
1192 // %baseptr:_(p0) = ...
1193 // G_STORE %val(s64), %baseptr(p0)
1194 // %offset:_(s64) = G_CONSTANT i64 -256
1195 // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64)
1196 const auto &TLI = getTargetLowering();
1197
1198 Register Ptr = LdSt.getPointerReg();
1199 // If the store is the only use, don't bother.
1200 if (MRI.hasOneNonDBGUse(RegNo: Ptr))
1201 return false;
1202
1203 if (!isIndexedLoadStoreLegal(LdSt))
1204 return false;
1205
1206 if (getOpcodeDef(Opcode: TargetOpcode::G_FRAME_INDEX, Reg: Ptr, MRI))
1207 return false;
1208
1209 MachineInstr *StoredValDef = getDefIgnoringCopies(Reg: LdSt.getReg(Idx: 0), MRI);
1210 auto *PtrDef = MRI.getVRegDef(Reg: Ptr);
1211
1212 unsigned NumUsesChecked = 0;
1213 for (auto &Use : MRI.use_nodbg_instructions(Reg: Ptr)) {
1214 if (++NumUsesChecked > PostIndexUseThreshold)
1215 return false; // Try to avoid exploding compile time.
1216
1217 auto *PtrAdd = dyn_cast<GPtrAdd>(Val: &Use);
1218 // The use itself might be dead. This can happen during combines if DCE
1219 // hasn't had a chance to run yet. Don't allow it to form an indexed op.
1220 if (!PtrAdd || MRI.use_nodbg_empty(RegNo: PtrAdd->getReg(Idx: 0)))
1221 continue;
1222
1223 // Check the user of this isn't the store, otherwise we'd be generate a
1224 // indexed store defining its own use.
1225 if (StoredValDef == &Use)
1226 continue;
1227
1228 Offset = PtrAdd->getOffsetReg();
1229 if (!ForceLegalIndexing &&
1230 !TLI.isIndexingLegal(MI&: LdSt, Base: PtrAdd->getBaseReg(), Offset,
1231 /*IsPre*/ false, MRI))
1232 continue;
1233
1234 // Make sure the offset calculation is before the potentially indexed op.
1235 MachineInstr *OffsetDef = MRI.getVRegDef(Reg: Offset);
1236 RematOffset = false;
1237 if (!dominates(DefMI: *OffsetDef, UseMI: LdSt)) {
1238 // If the offset however is just a G_CONSTANT, we can always just
1239 // rematerialize it where we need it.
1240 if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT)
1241 continue;
1242 RematOffset = true;
1243 }
1244
1245 for (auto &BasePtrUse : MRI.use_nodbg_instructions(Reg: PtrAdd->getBaseReg())) {
1246 if (&BasePtrUse == PtrDef)
1247 continue;
1248
1249 // If the user is a later load/store that can be post-indexed, then don't
1250 // combine this one.
1251 auto *BasePtrLdSt = dyn_cast<GLoadStore>(Val: &BasePtrUse);
1252 if (BasePtrLdSt && BasePtrLdSt != &LdSt &&
1253 dominates(DefMI: LdSt, UseMI: *BasePtrLdSt) &&
1254 isIndexedLoadStoreLegal(LdSt&: *BasePtrLdSt))
1255 return false;
1256
1257 // Now we're looking for the key G_PTR_ADD instruction, which contains
1258 // the offset add that we want to fold.
1259 if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(Val: &BasePtrUse)) {
1260 Register PtrAddDefReg = BasePtrUseDef->getReg(Idx: 0);
1261 for (auto &BaseUseUse : MRI.use_nodbg_instructions(Reg: PtrAddDefReg)) {
1262 // If the use is in a different block, then we may produce worse code
1263 // due to the extra register pressure.
1264 if (BaseUseUse.getParent() != LdSt.getParent())
1265 return false;
1266
1267 if (auto *UseUseLdSt = dyn_cast<GLoadStore>(Val: &BaseUseUse))
1268 if (canFoldInAddressingMode(MI: UseUseLdSt, TLI, MRI))
1269 return false;
1270 }
1271 if (!dominates(DefMI: LdSt, UseMI: BasePtrUse))
1272 return false; // All use must be dominated by the load/store.
1273 }
1274 }
1275
1276 Addr = PtrAdd->getReg(Idx: 0);
1277 Base = PtrAdd->getBaseReg();
1278 return true;
1279 }
1280
1281 return false;
1282}
1283
1284bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr,
1285 Register &Base,
1286 Register &Offset) const {
1287 auto &MF = *LdSt.getParent()->getParent();
1288 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1289
1290 Addr = LdSt.getPointerReg();
1291 if (!mi_match(R: Addr, MRI, P: m_GPtrAdd(L: m_Reg(R&: Base), R: m_Reg(R&: Offset))) ||
1292 MRI.hasOneNonDBGUse(RegNo: Addr))
1293 return false;
1294
1295 if (!ForceLegalIndexing &&
1296 !TLI.isIndexingLegal(MI&: LdSt, Base, Offset, /*IsPre*/ true, MRI))
1297 return false;
1298
1299 if (!isIndexedLoadStoreLegal(LdSt))
1300 return false;
1301
1302 MachineInstr *BaseDef = getDefIgnoringCopies(Reg: Base, MRI);
1303 if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX)
1304 return false;
1305
1306 if (auto *St = dyn_cast<GStore>(Val: &LdSt)) {
1307 // Would require a copy.
1308 if (Base == St->getValueReg())
1309 return false;
1310
1311 // We're expecting one use of Addr in MI, but it could also be the
1312 // value stored, which isn't actually dominated by the instruction.
1313 if (St->getValueReg() == Addr)
1314 return false;
1315 }
1316
1317 // Avoid increasing cross-block register pressure.
1318 for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr))
1319 if (AddrUse.getParent() != LdSt.getParent())
1320 return false;
1321
1322 // FIXME: check whether all uses of the base pointer are constant PtrAdds.
1323 // That might allow us to end base's liveness here by adjusting the constant.
1324 bool RealUse = false;
1325 for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) {
1326 if (!dominates(DefMI: LdSt, UseMI: AddrUse))
1327 return false; // All use must be dominated by the load/store.
1328
1329 // If Ptr may be folded in addressing mode of other use, then it's
1330 // not profitable to do this transformation.
1331 if (auto *UseLdSt = dyn_cast<GLoadStore>(Val: &AddrUse)) {
1332 if (!canFoldInAddressingMode(MI: UseLdSt, TLI, MRI))
1333 RealUse = true;
1334 } else {
1335 RealUse = true;
1336 }
1337 }
1338 return RealUse;
1339}
1340
1341bool CombinerHelper::matchCombineExtractedVectorLoad(
1342 MachineInstr &MI, BuildFnTy &MatchInfo) const {
1343 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
1344
1345 // Check if there is a load that defines the vector being extracted from.
1346 auto *LoadMI = getOpcodeDef<GLoad>(Reg: MI.getOperand(i: 1).getReg(), MRI);
1347 if (!LoadMI)
1348 return false;
1349
1350 Register Vector = MI.getOperand(i: 1).getReg();
1351 LLT VecEltTy = MRI.getType(Reg: Vector).getElementType();
1352
1353 assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy);
1354
1355 // Checking whether we should reduce the load width.
1356 if (!MRI.hasOneNonDBGUse(RegNo: Vector))
1357 return false;
1358
1359 // Check if the defining load is simple.
1360 if (!LoadMI->isSimple())
1361 return false;
1362
1363 // If the vector element type is not a multiple of a byte then we are unable
1364 // to correctly compute an address to load only the extracted element as a
1365 // scalar.
1366 if (!VecEltTy.isByteSized())
1367 return false;
1368
1369 // Check for load fold barriers between the extraction and the load.
1370 if (MI.getParent() != LoadMI->getParent())
1371 return false;
1372 const unsigned MaxIter = 20;
1373 unsigned Iter = 0;
1374 for (auto II = LoadMI->getIterator(), IE = MI.getIterator(); II != IE; ++II) {
1375 if (II->isLoadFoldBarrier())
1376 return false;
1377 if (Iter++ == MaxIter)
1378 return false;
1379 }
1380
1381 // Check if the new load that we are going to create is legal
1382 // if we are in the post-legalization phase.
1383 MachineMemOperand MMO = LoadMI->getMMO();
1384 Align Alignment = MMO.getAlign();
1385 MachinePointerInfo PtrInfo;
1386 uint64_t Offset;
1387
1388 // Finding the appropriate PtrInfo if offset is a known constant.
1389 // This is required to create the memory operand for the narrowed load.
1390 // This machine memory operand object helps us infer about legality
1391 // before we proceed to combine the instruction.
1392 if (auto CVal = getIConstantVRegVal(VReg: Vector, MRI)) {
1393 int Elt = CVal->getZExtValue();
1394 // FIXME: should be (ABI size)*Elt.
1395 Offset = VecEltTy.getSizeInBits() * Elt / 8;
1396 PtrInfo = MMO.getPointerInfo().getWithOffset(O: Offset);
1397 } else {
1398 // Discard the pointer info except the address space because the memory
1399 // operand can't represent this new access since the offset is variable.
1400 Offset = VecEltTy.getSizeInBits() / 8;
1401 PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace());
1402 }
1403
1404 Alignment = commonAlignment(A: Alignment, Offset);
1405
1406 Register VecPtr = LoadMI->getPointerReg();
1407 LLT PtrTy = MRI.getType(Reg: VecPtr);
1408
1409 MachineFunction &MF = *MI.getMF();
1410 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: VecEltTy);
1411
1412 LegalityQuery::MemDesc MMDesc(*NewMMO);
1413
1414 if (!isLegalOrBeforeLegalizer(
1415 Query: {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}}))
1416 return false;
1417
1418 // Load must be allowed and fast on the target.
1419 LLVMContext &C = MF.getFunction().getContext();
1420 auto &DL = MF.getDataLayout();
1421 unsigned Fast = 0;
1422 if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty: VecEltTy, MMO: *NewMMO,
1423 Fast: &Fast) ||
1424 !Fast)
1425 return false;
1426
1427 Register Result = MI.getOperand(i: 0).getReg();
1428 Register Index = MI.getOperand(i: 2).getReg();
1429
1430 MatchInfo = [=](MachineIRBuilder &B) {
1431 GISelObserverWrapper DummyObserver;
1432 LegalizerHelper Helper(B.getMF(), DummyObserver, B);
1433 //// Get pointer to the vector element.
1434 Register finalPtr = Helper.getVectorElementPointer(
1435 VecPtr: LoadMI->getPointerReg(), VecTy: MRI.getType(Reg: LoadMI->getOperand(i: 0).getReg()),
1436 Index);
1437 // New G_LOAD instruction.
1438 B.buildLoad(Res: Result, Addr: finalPtr, PtrInfo, Alignment);
1439 // Remove original GLOAD instruction.
1440 LoadMI->eraseFromParent();
1441 };
1442
1443 return true;
1444}
1445
1446bool CombinerHelper::matchCombineIndexedLoadStore(
1447 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const {
1448 auto &LdSt = cast<GLoadStore>(Val&: MI);
1449
1450 if (LdSt.isAtomic())
1451 return false;
1452
1453 MatchInfo.IsPre = findPreIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base,
1454 Offset&: MatchInfo.Offset);
1455 if (!MatchInfo.IsPre &&
1456 !findPostIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base,
1457 Offset&: MatchInfo.Offset, RematOffset&: MatchInfo.RematOffset))
1458 return false;
1459
1460 return true;
1461}
1462
1463void CombinerHelper::applyCombineIndexedLoadStore(
1464 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const {
1465 MachineInstr &AddrDef = *MRI.getUniqueVRegDef(Reg: MatchInfo.Addr);
1466 unsigned Opcode = MI.getOpcode();
1467 bool IsStore = Opcode == TargetOpcode::G_STORE;
1468 unsigned NewOpcode = getIndexedOpc(LdStOpc: Opcode);
1469
1470 // If the offset constant didn't happen to dominate the load/store, we can
1471 // just clone it as needed.
1472 if (MatchInfo.RematOffset) {
1473 auto *OldCst = MRI.getVRegDef(Reg: MatchInfo.Offset);
1474 auto NewCst = Builder.buildConstant(Res: MRI.getType(Reg: MatchInfo.Offset),
1475 Val: *OldCst->getOperand(i: 1).getCImm());
1476 MatchInfo.Offset = NewCst.getReg(Idx: 0);
1477 }
1478
1479 auto MIB = Builder.buildInstr(Opcode: NewOpcode);
1480 if (IsStore) {
1481 MIB.addDef(RegNo: MatchInfo.Addr);
1482 MIB.addUse(RegNo: MI.getOperand(i: 0).getReg());
1483 } else {
1484 MIB.addDef(RegNo: MI.getOperand(i: 0).getReg());
1485 MIB.addDef(RegNo: MatchInfo.Addr);
1486 }
1487
1488 MIB.addUse(RegNo: MatchInfo.Base);
1489 MIB.addUse(RegNo: MatchInfo.Offset);
1490 MIB.addImm(Val: MatchInfo.IsPre);
1491 MIB->cloneMemRefs(MF&: *MI.getMF(), MI);
1492 MI.eraseFromParent();
1493 AddrDef.eraseFromParent();
1494
1495 LLVM_DEBUG(dbgs() << " Combinined to indexed operation");
1496}
1497
1498bool CombinerHelper::matchCombineDivRem(MachineInstr &MI,
1499 MachineInstr *&OtherMI) const {
1500 unsigned Opcode = MI.getOpcode();
1501 bool IsDiv, IsSigned;
1502
1503 switch (Opcode) {
1504 default:
1505 llvm_unreachable("Unexpected opcode!");
1506 case TargetOpcode::G_SDIV:
1507 case TargetOpcode::G_UDIV: {
1508 IsDiv = true;
1509 IsSigned = Opcode == TargetOpcode::G_SDIV;
1510 break;
1511 }
1512 case TargetOpcode::G_SREM:
1513 case TargetOpcode::G_UREM: {
1514 IsDiv = false;
1515 IsSigned = Opcode == TargetOpcode::G_SREM;
1516 break;
1517 }
1518 }
1519
1520 Register Src1 = MI.getOperand(i: 1).getReg();
1521 unsigned DivOpcode, RemOpcode, DivremOpcode;
1522 if (IsSigned) {
1523 DivOpcode = TargetOpcode::G_SDIV;
1524 RemOpcode = TargetOpcode::G_SREM;
1525 DivremOpcode = TargetOpcode::G_SDIVREM;
1526 } else {
1527 DivOpcode = TargetOpcode::G_UDIV;
1528 RemOpcode = TargetOpcode::G_UREM;
1529 DivremOpcode = TargetOpcode::G_UDIVREM;
1530 }
1531
1532 if (!isLegalOrBeforeLegalizer(Query: {DivremOpcode, {MRI.getType(Reg: Src1)}}))
1533 return false;
1534
1535 // Combine:
1536 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1537 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1538 // into:
1539 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1540
1541 // Combine:
1542 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1543 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1544 // into:
1545 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1546
1547 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: Src1)) {
1548 if (MI.getParent() == UseMI.getParent() &&
1549 ((IsDiv && UseMI.getOpcode() == RemOpcode) ||
1550 (!IsDiv && UseMI.getOpcode() == DivOpcode)) &&
1551 matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: UseMI.getOperand(i: 2)) &&
1552 matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: UseMI.getOperand(i: 1))) {
1553 OtherMI = &UseMI;
1554 return true;
1555 }
1556 }
1557
1558 return false;
1559}
1560
1561void CombinerHelper::applyCombineDivRem(MachineInstr &MI,
1562 MachineInstr *&OtherMI) const {
1563 unsigned Opcode = MI.getOpcode();
1564 assert(OtherMI && "OtherMI shouldn't be empty.");
1565
1566 Register DestDivReg, DestRemReg;
1567 if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) {
1568 DestDivReg = MI.getOperand(i: 0).getReg();
1569 DestRemReg = OtherMI->getOperand(i: 0).getReg();
1570 } else {
1571 DestDivReg = OtherMI->getOperand(i: 0).getReg();
1572 DestRemReg = MI.getOperand(i: 0).getReg();
1573 }
1574
1575 bool IsSigned =
1576 Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM;
1577
1578 // Check which instruction is first in the block so we don't break def-use
1579 // deps by "moving" the instruction incorrectly. Also keep track of which
1580 // instruction is first so we pick it's operands, avoiding use-before-def
1581 // bugs.
1582 MachineInstr *FirstInst = dominates(DefMI: MI, UseMI: *OtherMI) ? &MI : OtherMI;
1583 Builder.setInstrAndDebugLoc(*FirstInst);
1584
1585 Builder.buildInstr(Opc: IsSigned ? TargetOpcode::G_SDIVREM
1586 : TargetOpcode::G_UDIVREM,
1587 DstOps: {DestDivReg, DestRemReg},
1588 SrcOps: { FirstInst->getOperand(i: 1), FirstInst->getOperand(i: 2) });
1589 MI.eraseFromParent();
1590 OtherMI->eraseFromParent();
1591}
1592
1593bool CombinerHelper::matchOptBrCondByInvertingCond(
1594 MachineInstr &MI, MachineInstr *&BrCond) const {
1595 assert(MI.getOpcode() == TargetOpcode::G_BR);
1596
1597 // Try to match the following:
1598 // bb1:
1599 // G_BRCOND %c1, %bb2
1600 // G_BR %bb3
1601 // bb2:
1602 // ...
1603 // bb3:
1604
1605 // The above pattern does not have a fall through to the successor bb2, always
1606 // resulting in a branch no matter which path is taken. Here we try to find
1607 // and replace that pattern with conditional branch to bb3 and otherwise
1608 // fallthrough to bb2. This is generally better for branch predictors.
1609
1610 MachineBasicBlock *MBB = MI.getParent();
1611 MachineBasicBlock::iterator BrIt(MI);
1612 if (BrIt == MBB->begin())
1613 return false;
1614 assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator");
1615
1616 BrCond = &*std::prev(x: BrIt);
1617 if (BrCond->getOpcode() != TargetOpcode::G_BRCOND)
1618 return false;
1619
1620 // Check that the next block is the conditional branch target. Also make sure
1621 // that it isn't the same as the G_BR's target (otherwise, this will loop.)
1622 MachineBasicBlock *BrCondTarget = BrCond->getOperand(i: 1).getMBB();
1623 return BrCondTarget != MI.getOperand(i: 0).getMBB() &&
1624 MBB->isLayoutSuccessor(MBB: BrCondTarget);
1625}
1626
1627void CombinerHelper::applyOptBrCondByInvertingCond(
1628 MachineInstr &MI, MachineInstr *&BrCond) const {
1629 MachineBasicBlock *BrTarget = MI.getOperand(i: 0).getMBB();
1630 Builder.setInstrAndDebugLoc(*BrCond);
1631 LLT Ty = MRI.getType(Reg: BrCond->getOperand(i: 0).getReg());
1632 // FIXME: Does int/fp matter for this? If so, we might need to restrict
1633 // this to i1 only since we might not know for sure what kind of
1634 // compare generated the condition value.
1635 auto True = Builder.buildConstant(
1636 Res: Ty, Val: getICmpTrueVal(TLI: getTargetLowering(), IsVector: false, IsFP: false));
1637 auto Xor = Builder.buildXor(Dst: Ty, Src0: BrCond->getOperand(i: 0), Src1: True);
1638
1639 auto *FallthroughBB = BrCond->getOperand(i: 1).getMBB();
1640 Observer.changingInstr(MI);
1641 MI.getOperand(i: 0).setMBB(FallthroughBB);
1642 Observer.changedInstr(MI);
1643
1644 // Change the conditional branch to use the inverted condition and
1645 // new target block.
1646 Observer.changingInstr(MI&: *BrCond);
1647 BrCond->getOperand(i: 0).setReg(Xor.getReg(Idx: 0));
1648 BrCond->getOperand(i: 1).setMBB(BrTarget);
1649 Observer.changedInstr(MI&: *BrCond);
1650}
1651
1652bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) const {
1653 MachineIRBuilder HelperBuilder(MI);
1654 GISelObserverWrapper DummyObserver;
1655 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1656 return Helper.lowerMemcpyInline(MI) ==
1657 LegalizerHelper::LegalizeResult::Legalized;
1658}
1659
1660bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI,
1661 unsigned MaxLen) const {
1662 MachineIRBuilder HelperBuilder(MI);
1663 GISelObserverWrapper DummyObserver;
1664 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1665 return Helper.lowerMemCpyFamily(MI, MaxLen) ==
1666 LegalizerHelper::LegalizeResult::Legalized;
1667}
1668
1669static APFloat constantFoldFpUnary(const MachineInstr &MI,
1670 const MachineRegisterInfo &MRI,
1671 const APFloat &Val) {
1672 APFloat Result(Val);
1673 switch (MI.getOpcode()) {
1674 default:
1675 llvm_unreachable("Unexpected opcode!");
1676 case TargetOpcode::G_FNEG: {
1677 Result.changeSign();
1678 return Result;
1679 }
1680 case TargetOpcode::G_FABS: {
1681 Result.clearSign();
1682 return Result;
1683 }
1684 case TargetOpcode::G_FCEIL:
1685 Result.roundToIntegral(RM: APFloat::rmTowardPositive);
1686 return Result;
1687 case TargetOpcode::G_FFLOOR:
1688 Result.roundToIntegral(RM: APFloat::rmTowardNegative);
1689 return Result;
1690 case TargetOpcode::G_INTRINSIC_TRUNC:
1691 Result.roundToIntegral(RM: APFloat::rmTowardZero);
1692 return Result;
1693 case TargetOpcode::G_INTRINSIC_ROUND:
1694 Result.roundToIntegral(RM: APFloat::rmNearestTiesToAway);
1695 return Result;
1696 case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
1697 Result.roundToIntegral(RM: APFloat::rmNearestTiesToEven);
1698 return Result;
1699 case TargetOpcode::G_FRINT:
1700 case TargetOpcode::G_FNEARBYINT:
1701 // Use default rounding mode (round to nearest, ties to even)
1702 Result.roundToIntegral(RM: APFloat::rmNearestTiesToEven);
1703 return Result;
1704 case TargetOpcode::G_FPEXT:
1705 case TargetOpcode::G_FPTRUNC: {
1706 bool Unused;
1707 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
1708 Result.convert(ToSemantics: getFltSemanticForLLT(Ty: DstTy), RM: APFloat::rmNearestTiesToEven,
1709 losesInfo: &Unused);
1710 return Result;
1711 }
1712 case TargetOpcode::G_FSQRT: {
1713 bool Unused;
1714 Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven,
1715 losesInfo: &Unused);
1716 Result = APFloat(sqrt(x: Result.convertToDouble()));
1717 break;
1718 }
1719 case TargetOpcode::G_FLOG2: {
1720 bool Unused;
1721 Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven,
1722 losesInfo: &Unused);
1723 Result = APFloat(log2(x: Result.convertToDouble()));
1724 break;
1725 }
1726 }
1727 // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise,
1728 // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and
1729 // `G_FLOG2` reach here.
1730 bool Unused;
1731 Result.convert(ToSemantics: Val.getSemantics(), RM: APFloat::rmNearestTiesToEven, losesInfo: &Unused);
1732 return Result;
1733}
1734
1735void CombinerHelper::applyCombineConstantFoldFpUnary(
1736 MachineInstr &MI, const ConstantFP *Cst) const {
1737 APFloat Folded = constantFoldFpUnary(MI, MRI, Val: Cst->getValue());
1738 const ConstantFP *NewCst = ConstantFP::get(Context&: Builder.getContext(), V: Folded);
1739 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: *NewCst);
1740 MI.eraseFromParent();
1741}
1742
1743bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI,
1744 PtrAddChain &MatchInfo) const {
1745 // We're trying to match the following pattern:
1746 // %t1 = G_PTR_ADD %base, G_CONSTANT imm1
1747 // %root = G_PTR_ADD %t1, G_CONSTANT imm2
1748 // -->
1749 // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2)
1750
1751 if (MI.getOpcode() != TargetOpcode::G_PTR_ADD)
1752 return false;
1753
1754 Register Add2 = MI.getOperand(i: 1).getReg();
1755 Register Imm1 = MI.getOperand(i: 2).getReg();
1756 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI);
1757 if (!MaybeImmVal)
1758 return false;
1759
1760 MachineInstr *Add2Def = MRI.getVRegDef(Reg: Add2);
1761 if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD)
1762 return false;
1763
1764 Register Base = Add2Def->getOperand(i: 1).getReg();
1765 Register Imm2 = Add2Def->getOperand(i: 2).getReg();
1766 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI);
1767 if (!MaybeImm2Val)
1768 return false;
1769
1770 // Check if the new combined immediate forms an illegal addressing mode.
1771 // Do not combine if it was legal before but would get illegal.
1772 // To do so, we need to find a load/store user of the pointer to get
1773 // the access type.
1774 Type *AccessTy = nullptr;
1775 auto &MF = *MI.getMF();
1776 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: MI.getOperand(i: 0).getReg())) {
1777 if (auto *LdSt = dyn_cast<GLoadStore>(Val: &UseMI)) {
1778 AccessTy = getTypeForLLT(Ty: MRI.getType(Reg: LdSt->getReg(Idx: 0)),
1779 C&: MF.getFunction().getContext());
1780 break;
1781 }
1782 }
1783 TargetLoweringBase::AddrMode AMNew;
1784 APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value;
1785 AMNew.BaseOffs = CombinedImm.getSExtValue();
1786 if (AccessTy) {
1787 AMNew.HasBaseReg = true;
1788 TargetLoweringBase::AddrMode AMOld;
1789 AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue();
1790 AMOld.HasBaseReg = true;
1791 unsigned AS = MRI.getType(Reg: Add2).getAddressSpace();
1792 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1793 if (TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMOld, Ty: AccessTy, AddrSpace: AS) &&
1794 !TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMNew, Ty: AccessTy, AddrSpace: AS))
1795 return false;
1796 }
1797
1798 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
1799 // inbounds, reaching the same result in one G_PTR_ADD is also inbounds.
1800 // The nusw constraints are satisfied because imm1+imm2 cannot exceed the
1801 // largest signed integer that fits into the index type, which is the maximum
1802 // size of allocated objects according to the IR Language Reference.
1803 unsigned PtrAddFlags = MI.getFlags();
1804 unsigned LHSPtrAddFlags = Add2Def->getFlags();
1805 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
1806 bool IsInBounds =
1807 PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::InBounds;
1808 unsigned Flags = 0;
1809 if (IsNoUWrap)
1810 Flags |= MachineInstr::MIFlag::NoUWrap;
1811 if (IsInBounds) {
1812 Flags |= MachineInstr::MIFlag::InBounds;
1813 Flags |= MachineInstr::MIFlag::NoUSWrap;
1814 }
1815
1816 // Pass the combined immediate to the apply function.
1817 MatchInfo.Imm = AMNew.BaseOffs;
1818 MatchInfo.Base = Base;
1819 MatchInfo.Bank = getRegBank(Reg: Imm2);
1820 MatchInfo.Flags = Flags;
1821 return true;
1822}
1823
1824void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI,
1825 PtrAddChain &MatchInfo) const {
1826 assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD");
1827 MachineIRBuilder MIB(MI);
1828 LLT OffsetTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
1829 auto NewOffset = MIB.buildConstant(Res: OffsetTy, Val: MatchInfo.Imm);
1830 setRegBank(Reg: NewOffset.getReg(Idx: 0), RegBank: MatchInfo.Bank);
1831 Observer.changingInstr(MI);
1832 MI.getOperand(i: 1).setReg(MatchInfo.Base);
1833 MI.getOperand(i: 2).setReg(NewOffset.getReg(Idx: 0));
1834 MI.setFlags(MatchInfo.Flags);
1835 Observer.changedInstr(MI);
1836}
1837
1838bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI,
1839 RegisterImmPair &MatchInfo) const {
1840 // We're trying to match the following pattern with any of
1841 // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions:
1842 // %t1 = SHIFT %base, G_CONSTANT imm1
1843 // %root = SHIFT %t1, G_CONSTANT imm2
1844 // -->
1845 // %root = SHIFT %base, G_CONSTANT (imm1 + imm2)
1846
1847 unsigned Opcode = MI.getOpcode();
1848 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1849 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1850 Opcode == TargetOpcode::G_USHLSAT) &&
1851 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1852
1853 Register Shl2 = MI.getOperand(i: 1).getReg();
1854 Register Imm1 = MI.getOperand(i: 2).getReg();
1855 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI);
1856 if (!MaybeImmVal)
1857 return false;
1858
1859 MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Reg: Shl2);
1860 if (Shl2Def->getOpcode() != Opcode)
1861 return false;
1862
1863 Register Base = Shl2Def->getOperand(i: 1).getReg();
1864 Register Imm2 = Shl2Def->getOperand(i: 2).getReg();
1865 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI);
1866 if (!MaybeImm2Val)
1867 return false;
1868
1869 // Pass the combined immediate to the apply function.
1870 MatchInfo.Imm =
1871 (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue();
1872 MatchInfo.Reg = Base;
1873
1874 // There is no simple replacement for a saturating unsigned left shift that
1875 // exceeds the scalar size.
1876 if (Opcode == TargetOpcode::G_USHLSAT &&
1877 MatchInfo.Imm >= MRI.getType(Reg: Shl2).getScalarSizeInBits())
1878 return false;
1879
1880 return true;
1881}
1882
1883void CombinerHelper::applyShiftImmedChain(MachineInstr &MI,
1884 RegisterImmPair &MatchInfo) const {
1885 unsigned Opcode = MI.getOpcode();
1886 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1887 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1888 Opcode == TargetOpcode::G_USHLSAT) &&
1889 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1890
1891 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
1892 unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits();
1893 auto Imm = MatchInfo.Imm;
1894
1895 if (Imm >= ScalarSizeInBits) {
1896 // Any logical shift that exceeds scalar size will produce zero.
1897 if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) {
1898 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: 0);
1899 MI.eraseFromParent();
1900 return;
1901 }
1902 // Arithmetic shift and saturating signed left shift have no effect beyond
1903 // scalar size.
1904 Imm = ScalarSizeInBits - 1;
1905 }
1906
1907 LLT ImmTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
1908 Register NewImm = Builder.buildConstant(Res: ImmTy, Val: Imm).getReg(Idx: 0);
1909 Observer.changingInstr(MI);
1910 MI.getOperand(i: 1).setReg(MatchInfo.Reg);
1911 MI.getOperand(i: 2).setReg(NewImm);
1912 Observer.changedInstr(MI);
1913}
1914
1915bool CombinerHelper::matchShiftOfShiftedLogic(
1916 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const {
1917 // We're trying to match the following pattern with any of
1918 // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination
1919 // with any of G_AND/G_OR/G_XOR logic instructions.
1920 // %t1 = SHIFT %X, G_CONSTANT C0
1921 // %t2 = LOGIC %t1, %Y
1922 // %root = SHIFT %t2, G_CONSTANT C1
1923 // -->
1924 // %t3 = SHIFT %X, G_CONSTANT (C0+C1)
1925 // %t4 = SHIFT %Y, G_CONSTANT C1
1926 // %root = LOGIC %t3, %t4
1927 unsigned ShiftOpcode = MI.getOpcode();
1928 assert((ShiftOpcode == TargetOpcode::G_SHL ||
1929 ShiftOpcode == TargetOpcode::G_ASHR ||
1930 ShiftOpcode == TargetOpcode::G_LSHR ||
1931 ShiftOpcode == TargetOpcode::G_USHLSAT ||
1932 ShiftOpcode == TargetOpcode::G_SSHLSAT) &&
1933 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1934
1935 // Match a one-use bitwise logic op.
1936 Register LogicDest = MI.getOperand(i: 1).getReg();
1937 if (!MRI.hasOneNonDBGUse(RegNo: LogicDest))
1938 return false;
1939
1940 MachineInstr *LogicMI = MRI.getUniqueVRegDef(Reg: LogicDest);
1941 unsigned LogicOpcode = LogicMI->getOpcode();
1942 if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR &&
1943 LogicOpcode != TargetOpcode::G_XOR)
1944 return false;
1945
1946 // Find a matching one-use shift by constant.
1947 const Register C1 = MI.getOperand(i: 2).getReg();
1948 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: C1, MRI);
1949 if (!MaybeImmVal || MaybeImmVal->Value == 0)
1950 return false;
1951
1952 const uint64_t C1Val = MaybeImmVal->Value.getZExtValue();
1953
1954 auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) {
1955 // Shift should match previous one and should be a one-use.
1956 if (MI->getOpcode() != ShiftOpcode ||
1957 !MRI.hasOneNonDBGUse(RegNo: MI->getOperand(i: 0).getReg()))
1958 return false;
1959
1960 // Must be a constant.
1961 auto MaybeImmVal =
1962 getIConstantVRegValWithLookThrough(VReg: MI->getOperand(i: 2).getReg(), MRI);
1963 if (!MaybeImmVal)
1964 return false;
1965
1966 ShiftVal = MaybeImmVal->Value.getSExtValue();
1967 return true;
1968 };
1969
1970 // Logic ops are commutative, so check each operand for a match.
1971 Register LogicMIReg1 = LogicMI->getOperand(i: 1).getReg();
1972 MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(Reg: LogicMIReg1);
1973 Register LogicMIReg2 = LogicMI->getOperand(i: 2).getReg();
1974 MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(Reg: LogicMIReg2);
1975 uint64_t C0Val;
1976
1977 if (matchFirstShift(LogicMIOp1, C0Val)) {
1978 MatchInfo.LogicNonShiftReg = LogicMIReg2;
1979 MatchInfo.Shift2 = LogicMIOp1;
1980 } else if (matchFirstShift(LogicMIOp2, C0Val)) {
1981 MatchInfo.LogicNonShiftReg = LogicMIReg1;
1982 MatchInfo.Shift2 = LogicMIOp2;
1983 } else
1984 return false;
1985
1986 MatchInfo.ValSum = C0Val + C1Val;
1987
1988 // The fold is not valid if the sum of the shift values exceeds bitwidth.
1989 if (MatchInfo.ValSum >= MRI.getType(Reg: LogicDest).getScalarSizeInBits())
1990 return false;
1991
1992 MatchInfo.Logic = LogicMI;
1993 return true;
1994}
1995
1996void CombinerHelper::applyShiftOfShiftedLogic(
1997 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const {
1998 unsigned Opcode = MI.getOpcode();
1999 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
2000 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT ||
2001 Opcode == TargetOpcode::G_SSHLSAT) &&
2002 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
2003
2004 LLT ShlType = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
2005 LLT DestType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2006
2007 Register Const = Builder.buildConstant(Res: ShlType, Val: MatchInfo.ValSum).getReg(Idx: 0);
2008
2009 Register Shift1Base = MatchInfo.Shift2->getOperand(i: 1).getReg();
2010 Register Shift1 =
2011 Builder.buildInstr(Opc: Opcode, DstOps: {DestType}, SrcOps: {Shift1Base, Const}).getReg(Idx: 0);
2012
2013 // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same
2014 // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when
2015 // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we
2016 // remove old shift1. And it will cause crash later. So erase it earlier to
2017 // avoid the crash.
2018 MatchInfo.Shift2->eraseFromParent();
2019
2020 Register Shift2Const = MI.getOperand(i: 2).getReg();
2021 Register Shift2 = Builder
2022 .buildInstr(Opc: Opcode, DstOps: {DestType},
2023 SrcOps: {MatchInfo.LogicNonShiftReg, Shift2Const})
2024 .getReg(Idx: 0);
2025
2026 Register Dest = MI.getOperand(i: 0).getReg();
2027 Builder.buildInstr(Opc: MatchInfo.Logic->getOpcode(), DstOps: {Dest}, SrcOps: {Shift1, Shift2});
2028
2029 // This was one use so it's safe to remove it.
2030 MatchInfo.Logic->eraseFromParent();
2031
2032 MI.eraseFromParent();
2033}
2034
2035bool CombinerHelper::matchCommuteShift(MachineInstr &MI,
2036 BuildFnTy &MatchInfo) const {
2037 assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL");
2038 // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
2039 // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
2040 auto &Shl = cast<GenericMachineInstr>(Val&: MI);
2041 Register DstReg = Shl.getReg(Idx: 0);
2042 Register SrcReg = Shl.getReg(Idx: 1);
2043 Register ShiftReg = Shl.getReg(Idx: 2);
2044 Register X, C1;
2045
2046 if (!getTargetLowering().isDesirableToCommuteWithShift(MI, IsAfterLegal: !isPreLegalize()))
2047 return false;
2048
2049 if (!mi_match(R: SrcReg, MRI,
2050 P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: C1)),
2051 preds: m_GOr(L: m_Reg(R&: X), R: m_Reg(R&: C1))))))
2052 return false;
2053
2054 APInt C1Val, C2Val;
2055 if (!mi_match(R: C1, MRI, P: m_ICstOrSplat(Cst&: C1Val)) ||
2056 !mi_match(R: ShiftReg, MRI, P: m_ICstOrSplat(Cst&: C2Val)))
2057 return false;
2058
2059 auto *SrcDef = MRI.getVRegDef(Reg: SrcReg);
2060 assert((SrcDef->getOpcode() == TargetOpcode::G_ADD ||
2061 SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op");
2062 LLT SrcTy = MRI.getType(Reg: SrcReg);
2063 MatchInfo = [=](MachineIRBuilder &B) {
2064 auto S1 = B.buildShl(Dst: SrcTy, Src0: X, Src1: ShiftReg);
2065 auto S2 = B.buildShl(Dst: SrcTy, Src0: C1, Src1: ShiftReg);
2066 B.buildInstr(Opc: SrcDef->getOpcode(), DstOps: {DstReg}, SrcOps: {S1, S2});
2067 };
2068 return true;
2069}
2070
2071bool CombinerHelper::matchLshrOfTruncOfLshr(MachineInstr &MI,
2072 LshrOfTruncOfLshr &MatchInfo,
2073 MachineInstr &ShiftMI) const {
2074 assert(MI.getOpcode() == TargetOpcode::G_LSHR && "Expected a G_LSHR");
2075
2076 Register N0 = MI.getOperand(i: 1).getReg();
2077 Register N1 = MI.getOperand(i: 2).getReg();
2078 unsigned OpSizeInBits = MRI.getType(Reg: N0).getScalarSizeInBits();
2079
2080 APInt N1C, N001C;
2081 if (!mi_match(R: N1, MRI, P: m_ICstOrSplat(Cst&: N1C)))
2082 return false;
2083 auto N001 = ShiftMI.getOperand(i: 2).getReg();
2084 if (!mi_match(R: N001, MRI, P: m_ICstOrSplat(Cst&: N001C)))
2085 return false;
2086
2087 if (N001C.getBitWidth() > N1C.getBitWidth())
2088 N1C = N1C.zext(width: N001C.getBitWidth());
2089 else
2090 N001C = N001C.zext(width: N1C.getBitWidth());
2091
2092 Register InnerShift = ShiftMI.getOperand(i: 0).getReg();
2093 LLT InnerShiftTy = MRI.getType(Reg: InnerShift);
2094 uint64_t InnerShiftSize = InnerShiftTy.getScalarSizeInBits();
2095 if ((N1C + N001C).ult(RHS: InnerShiftSize)) {
2096 MatchInfo.Src = ShiftMI.getOperand(i: 1).getReg();
2097 MatchInfo.ShiftAmt = N1C + N001C;
2098 MatchInfo.ShiftAmtTy = MRI.getType(Reg: N001);
2099 MatchInfo.InnerShiftTy = InnerShiftTy;
2100
2101 if ((N001C + OpSizeInBits) == InnerShiftSize)
2102 return true;
2103 if (MRI.hasOneUse(RegNo: N0) && MRI.hasOneUse(RegNo: InnerShift)) {
2104 MatchInfo.Mask = true;
2105 MatchInfo.MaskVal = APInt(N1C.getBitWidth(), OpSizeInBits) - N1C;
2106 return true;
2107 }
2108 }
2109 return false;
2110}
2111
2112void CombinerHelper::applyLshrOfTruncOfLshr(
2113 MachineInstr &MI, LshrOfTruncOfLshr &MatchInfo) const {
2114 assert(MI.getOpcode() == TargetOpcode::G_LSHR && "Expected a G_LSHR");
2115
2116 Register Dst = MI.getOperand(i: 0).getReg();
2117 auto ShiftAmt =
2118 Builder.buildConstant(Res: MatchInfo.ShiftAmtTy, Val: MatchInfo.ShiftAmt);
2119 auto Shift =
2120 Builder.buildLShr(Dst: MatchInfo.InnerShiftTy, Src0: MatchInfo.Src, Src1: ShiftAmt);
2121 if (MatchInfo.Mask == true) {
2122 APInt MaskVal =
2123 APInt::getLowBitsSet(numBits: MatchInfo.InnerShiftTy.getScalarSizeInBits(),
2124 loBitsSet: MatchInfo.MaskVal.getZExtValue());
2125 auto Mask = Builder.buildConstant(Res: MatchInfo.InnerShiftTy, Val: MaskVal);
2126 auto And = Builder.buildAnd(Dst: MatchInfo.InnerShiftTy, Src0: Shift, Src1: Mask);
2127 Builder.buildTrunc(Res: Dst, Op: And);
2128 } else
2129 Builder.buildTrunc(Res: Dst, Op: Shift);
2130 MI.eraseFromParent();
2131}
2132
2133bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI,
2134 unsigned &ShiftVal) const {
2135 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2136 auto MaybeImmVal =
2137 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
2138 if (!MaybeImmVal)
2139 return false;
2140
2141 ShiftVal = MaybeImmVal->Value.exactLogBase2();
2142 return (static_cast<int32_t>(ShiftVal) != -1);
2143}
2144
2145void CombinerHelper::applyCombineMulToShl(MachineInstr &MI,
2146 unsigned &ShiftVal) const {
2147 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2148 MachineIRBuilder MIB(MI);
2149 LLT ShiftTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2150 auto ShiftCst = MIB.buildConstant(Res: ShiftTy, Val: ShiftVal);
2151 Observer.changingInstr(MI);
2152 MI.setDesc(MIB.getTII().get(Opcode: TargetOpcode::G_SHL));
2153 MI.getOperand(i: 2).setReg(ShiftCst.getReg(Idx: 0));
2154 if (ShiftVal == ShiftTy.getScalarSizeInBits() - 1)
2155 MI.clearFlag(Flag: MachineInstr::MIFlag::NoSWrap);
2156 Observer.changedInstr(MI);
2157}
2158
2159bool CombinerHelper::matchCombineSubToAdd(MachineInstr &MI,
2160 BuildFnTy &MatchInfo) const {
2161 GSub &Sub = cast<GSub>(Val&: MI);
2162
2163 LLT Ty = MRI.getType(Reg: Sub.getReg(Idx: 0));
2164
2165 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {Ty}}))
2166 return false;
2167
2168 if (!isConstantLegalOrBeforeLegalizer(Ty))
2169 return false;
2170
2171 APInt Imm = getIConstantFromReg(VReg: Sub.getRHSReg(), MRI);
2172
2173 MatchInfo = [=, &MI](MachineIRBuilder &B) {
2174 auto NegCst = B.buildConstant(Res: Ty, Val: -Imm);
2175 Observer.changingInstr(MI);
2176 MI.setDesc(B.getTII().get(Opcode: TargetOpcode::G_ADD));
2177 MI.getOperand(i: 2).setReg(NegCst.getReg(Idx: 0));
2178 MI.clearFlag(Flag: MachineInstr::MIFlag::NoUWrap);
2179 if (Imm.isMinSignedValue())
2180 MI.clearFlags(flags: MachineInstr::MIFlag::NoSWrap);
2181 Observer.changedInstr(MI);
2182 };
2183 return true;
2184}
2185
2186// shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source
2187bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI,
2188 RegisterImmPair &MatchData) const {
2189 assert(MI.getOpcode() == TargetOpcode::G_SHL && VT);
2190 if (!getTargetLowering().isDesirableToPullExtFromShl(MI))
2191 return false;
2192
2193 Register LHS = MI.getOperand(i: 1).getReg();
2194
2195 Register ExtSrc;
2196 if (!mi_match(R: LHS, MRI, P: m_GAnyExt(Src: m_Reg(R&: ExtSrc))) &&
2197 !mi_match(R: LHS, MRI, P: m_GZExt(Src: m_Reg(R&: ExtSrc))) &&
2198 !mi_match(R: LHS, MRI, P: m_GSExt(Src: m_Reg(R&: ExtSrc))))
2199 return false;
2200
2201 Register RHS = MI.getOperand(i: 2).getReg();
2202 MachineInstr *MIShiftAmt = MRI.getVRegDef(Reg: RHS);
2203 auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(MI&: *MIShiftAmt, MRI);
2204 if (!MaybeShiftAmtVal)
2205 return false;
2206
2207 if (LI) {
2208 LLT SrcTy = MRI.getType(Reg: ExtSrc);
2209
2210 // We only really care about the legality with the shifted value. We can
2211 // pick any type the constant shift amount, so ask the target what to
2212 // use. Otherwise we would have to guess and hope it is reported as legal.
2213 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: SrcTy);
2214 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}}))
2215 return false;
2216 }
2217
2218 int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue();
2219 MatchData.Reg = ExtSrc;
2220 MatchData.Imm = ShiftAmt;
2221
2222 unsigned MinLeadingZeros = VT->getKnownZeroes(R: ExtSrc).countl_one();
2223 unsigned SrcTySize = MRI.getType(Reg: ExtSrc).getScalarSizeInBits();
2224 return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize;
2225}
2226
2227void CombinerHelper::applyCombineShlOfExtend(
2228 MachineInstr &MI, const RegisterImmPair &MatchData) const {
2229 Register ExtSrcReg = MatchData.Reg;
2230 int64_t ShiftAmtVal = MatchData.Imm;
2231
2232 LLT ExtSrcTy = MRI.getType(Reg: ExtSrcReg);
2233 auto ShiftAmt = Builder.buildConstant(Res: ExtSrcTy, Val: ShiftAmtVal);
2234 auto NarrowShift =
2235 Builder.buildShl(Dst: ExtSrcTy, Src0: ExtSrcReg, Src1: ShiftAmt, Flags: MI.getFlags());
2236 Builder.buildZExt(Res: MI.getOperand(i: 0), Op: NarrowShift);
2237 MI.eraseFromParent();
2238}
2239
2240bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI,
2241 Register &MatchInfo) const {
2242 GMerge &Merge = cast<GMerge>(Val&: MI);
2243 SmallVector<Register, 16> MergedValues;
2244 for (unsigned I = 0; I < Merge.getNumSources(); ++I)
2245 MergedValues.emplace_back(Args: Merge.getSourceReg(I));
2246
2247 auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: MergedValues[0], MRI);
2248 if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources())
2249 return false;
2250
2251 for (unsigned I = 0; I < MergedValues.size(); ++I)
2252 if (MergedValues[I] != Unmerge->getReg(Idx: I))
2253 return false;
2254
2255 MatchInfo = Unmerge->getSourceReg();
2256 return true;
2257}
2258
2259static Register peekThroughBitcast(Register Reg,
2260 const MachineRegisterInfo &MRI) {
2261 while (mi_match(R: Reg, MRI, P: m_GBitcast(Src: m_Reg(R&: Reg))))
2262 ;
2263
2264 return Reg;
2265}
2266
2267bool CombinerHelper::matchCombineUnmergeMergeToPlainValues(
2268 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const {
2269 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2270 "Expected an unmerge");
2271 auto &Unmerge = cast<GUnmerge>(Val&: MI);
2272 Register SrcReg = peekThroughBitcast(Reg: Unmerge.getSourceReg(), MRI);
2273
2274 auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(Reg: SrcReg, MRI);
2275 if (!SrcInstr)
2276 return false;
2277
2278 // Check the source type of the merge.
2279 LLT SrcMergeTy = MRI.getType(Reg: SrcInstr->getSourceReg(I: 0));
2280 LLT Dst0Ty = MRI.getType(Reg: Unmerge.getReg(Idx: 0));
2281 bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits();
2282 if (SrcMergeTy != Dst0Ty && !SameSize)
2283 return false;
2284 // They are the same now (modulo a bitcast).
2285 // We can collect all the src registers.
2286 for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx)
2287 Operands.push_back(Elt: SrcInstr->getSourceReg(I: Idx));
2288 return true;
2289}
2290
2291void CombinerHelper::applyCombineUnmergeMergeToPlainValues(
2292 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const {
2293 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2294 "Expected an unmerge");
2295 assert((MI.getNumOperands() - 1 == Operands.size()) &&
2296 "Not enough operands to replace all defs");
2297 unsigned NumElems = MI.getNumOperands() - 1;
2298
2299 LLT SrcTy = MRI.getType(Reg: Operands[0]);
2300 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2301 bool CanReuseInputDirectly = DstTy == SrcTy;
2302 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2303 Register DstReg = MI.getOperand(i: Idx).getReg();
2304 Register SrcReg = Operands[Idx];
2305
2306 // This combine may run after RegBankSelect, so we need to be aware of
2307 // register banks.
2308 const auto &DstCB = MRI.getRegClassOrRegBank(Reg: DstReg);
2309 if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(Reg: SrcReg)) {
2310 SrcReg = Builder.buildCopy(Res: MRI.getType(Reg: SrcReg), Op: SrcReg).getReg(Idx: 0);
2311 MRI.setRegClassOrRegBank(Reg: SrcReg, RCOrRB: DstCB);
2312 }
2313
2314 if (CanReuseInputDirectly)
2315 replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg);
2316 else
2317 Builder.buildCast(Dst: DstReg, Src: SrcReg);
2318 }
2319 MI.eraseFromParent();
2320}
2321
2322bool CombinerHelper::matchCombineUnmergeConstant(
2323 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const {
2324 unsigned SrcIdx = MI.getNumOperands() - 1;
2325 Register SrcReg = MI.getOperand(i: SrcIdx).getReg();
2326 MachineInstr *SrcInstr = MRI.getVRegDef(Reg: SrcReg);
2327 if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT &&
2328 SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT)
2329 return false;
2330 // Break down the big constant in smaller ones.
2331 const MachineOperand &CstVal = SrcInstr->getOperand(i: 1);
2332 APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT
2333 ? CstVal.getCImm()->getValue()
2334 : CstVal.getFPImm()->getValueAPF().bitcastToAPInt();
2335
2336 LLT Dst0Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2337 unsigned ShiftAmt = Dst0Ty.getSizeInBits();
2338 // Unmerge a constant.
2339 for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) {
2340 Csts.emplace_back(Args: Val.trunc(width: ShiftAmt));
2341 Val = Val.lshr(shiftAmt: ShiftAmt);
2342 }
2343
2344 return true;
2345}
2346
2347void CombinerHelper::applyCombineUnmergeConstant(
2348 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const {
2349 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2350 "Expected an unmerge");
2351 assert((MI.getNumOperands() - 1 == Csts.size()) &&
2352 "Not enough operands to replace all defs");
2353 unsigned NumElems = MI.getNumOperands() - 1;
2354 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2355 Register DstReg = MI.getOperand(i: Idx).getReg();
2356 Builder.buildConstant(Res: DstReg, Val: Csts[Idx]);
2357 }
2358
2359 MI.eraseFromParent();
2360}
2361
2362bool CombinerHelper::matchCombineUnmergeUndef(
2363 MachineInstr &MI,
2364 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
2365 unsigned SrcIdx = MI.getNumOperands() - 1;
2366 Register SrcReg = MI.getOperand(i: SrcIdx).getReg();
2367 MatchInfo = [&MI](MachineIRBuilder &B) {
2368 unsigned NumElems = MI.getNumOperands() - 1;
2369 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2370 Register DstReg = MI.getOperand(i: Idx).getReg();
2371 B.buildUndef(Res: DstReg);
2372 }
2373 };
2374 return isa<GImplicitDef>(Val: MRI.getVRegDef(Reg: SrcReg));
2375}
2376
2377bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc(
2378 MachineInstr &MI) const {
2379 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2380 "Expected an unmerge");
2381 if (MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector() ||
2382 MRI.getType(Reg: MI.getOperand(i: MI.getNumDefs()).getReg()).isVector())
2383 return false;
2384 // Check that all the lanes are dead except the first one.
2385 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2386 if (!MRI.use_nodbg_empty(RegNo: MI.getOperand(i: Idx).getReg()))
2387 return false;
2388 }
2389 return true;
2390}
2391
2392void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc(
2393 MachineInstr &MI) const {
2394 Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg();
2395 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2396 Builder.buildTrunc(Res: Dst0Reg, Op: SrcReg);
2397 MI.eraseFromParent();
2398}
2399
2400bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) const {
2401 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2402 "Expected an unmerge");
2403 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2404 LLT Dst0Ty = MRI.getType(Reg: Dst0Reg);
2405 // G_ZEXT on vector applies to each lane, so it will
2406 // affect all destinations. Therefore we won't be able
2407 // to simplify the unmerge to just the first definition.
2408 if (Dst0Ty.isVector())
2409 return false;
2410 Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg();
2411 LLT SrcTy = MRI.getType(Reg: SrcReg);
2412 if (SrcTy.isVector())
2413 return false;
2414
2415 Register ZExtSrcReg;
2416 if (!mi_match(R: SrcReg, MRI, P: m_GZExt(Src: m_Reg(R&: ZExtSrcReg))))
2417 return false;
2418
2419 // Finally we can replace the first definition with
2420 // a zext of the source if the definition is big enough to hold
2421 // all of ZExtSrc bits.
2422 LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg);
2423 return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits();
2424}
2425
2426void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) const {
2427 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2428 "Expected an unmerge");
2429
2430 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2431
2432 MachineInstr *ZExtInstr =
2433 MRI.getVRegDef(Reg: MI.getOperand(i: MI.getNumDefs()).getReg());
2434 assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT &&
2435 "Expecting a G_ZEXT");
2436
2437 Register ZExtSrcReg = ZExtInstr->getOperand(i: 1).getReg();
2438 LLT Dst0Ty = MRI.getType(Reg: Dst0Reg);
2439 LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg);
2440
2441 if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) {
2442 Builder.buildZExt(Res: Dst0Reg, Op: ZExtSrcReg);
2443 } else {
2444 assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() &&
2445 "ZExt src doesn't fit in destination");
2446 replaceRegWith(MRI, FromReg: Dst0Reg, ToReg: ZExtSrcReg);
2447 }
2448
2449 Register ZeroReg;
2450 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2451 if (!ZeroReg)
2452 ZeroReg = Builder.buildConstant(Res: Dst0Ty, Val: 0).getReg(Idx: 0);
2453 replaceRegWith(MRI, FromReg: MI.getOperand(i: Idx).getReg(), ToReg: ZeroReg);
2454 }
2455 MI.eraseFromParent();
2456}
2457
2458bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI,
2459 unsigned TargetShiftSize,
2460 unsigned &ShiftVal) const {
2461 assert((MI.getOpcode() == TargetOpcode::G_SHL ||
2462 MI.getOpcode() == TargetOpcode::G_LSHR ||
2463 MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift");
2464
2465 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2466 if (Ty.isVector()) // TODO:
2467 return false;
2468
2469 // Don't narrow further than the requested size.
2470 unsigned Size = Ty.getSizeInBits();
2471 if (Size <= TargetShiftSize)
2472 return false;
2473
2474 auto MaybeImmVal =
2475 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
2476 if (!MaybeImmVal)
2477 return false;
2478
2479 ShiftVal = MaybeImmVal->Value.getSExtValue();
2480 return ShiftVal >= Size / 2 && ShiftVal < Size;
2481}
2482
2483void CombinerHelper::applyCombineShiftToUnmerge(
2484 MachineInstr &MI, const unsigned &ShiftVal) const {
2485 Register DstReg = MI.getOperand(i: 0).getReg();
2486 Register SrcReg = MI.getOperand(i: 1).getReg();
2487 LLT Ty = MRI.getType(Reg: SrcReg);
2488 unsigned Size = Ty.getSizeInBits();
2489 unsigned HalfSize = Size / 2;
2490 assert(ShiftVal >= HalfSize);
2491
2492 LLT HalfTy = LLT::scalar(SizeInBits: HalfSize);
2493
2494 auto Unmerge = Builder.buildUnmerge(Res: HalfTy, Op: SrcReg);
2495 unsigned NarrowShiftAmt = ShiftVal - HalfSize;
2496
2497 if (MI.getOpcode() == TargetOpcode::G_LSHR) {
2498 Register Narrowed = Unmerge.getReg(Idx: 1);
2499
2500 // dst = G_LSHR s64:x, C for C >= 32
2501 // =>
2502 // lo, hi = G_UNMERGE_VALUES x
2503 // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0
2504
2505 if (NarrowShiftAmt != 0) {
2506 Narrowed = Builder.buildLShr(Dst: HalfTy, Src0: Narrowed,
2507 Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0);
2508 }
2509
2510 auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0);
2511 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Narrowed, Zero});
2512 } else if (MI.getOpcode() == TargetOpcode::G_SHL) {
2513 Register Narrowed = Unmerge.getReg(Idx: 0);
2514 // dst = G_SHL s64:x, C for C >= 32
2515 // =>
2516 // lo, hi = G_UNMERGE_VALUES x
2517 // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32)
2518 if (NarrowShiftAmt != 0) {
2519 Narrowed = Builder.buildShl(Dst: HalfTy, Src0: Narrowed,
2520 Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0);
2521 }
2522
2523 auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0);
2524 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Zero, Narrowed});
2525 } else {
2526 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
2527 auto Hi = Builder.buildAShr(
2528 Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1),
2529 Src1: Builder.buildConstant(Res: HalfTy, Val: HalfSize - 1));
2530
2531 if (ShiftVal == HalfSize) {
2532 // (G_ASHR i64:x, 32) ->
2533 // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31)
2534 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Unmerge.getReg(Idx: 1), Hi});
2535 } else if (ShiftVal == Size - 1) {
2536 // Don't need a second shift.
2537 // (G_ASHR i64:x, 63) ->
2538 // %narrowed = (G_ASHR hi_32(x), 31)
2539 // G_MERGE_VALUES %narrowed, %narrowed
2540 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Hi, Hi});
2541 } else {
2542 auto Lo = Builder.buildAShr(
2543 Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1),
2544 Src1: Builder.buildConstant(Res: HalfTy, Val: ShiftVal - HalfSize));
2545
2546 // (G_ASHR i64:x, C) ->, for C >= 32
2547 // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31)
2548 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Lo, Hi});
2549 }
2550 }
2551
2552 MI.eraseFromParent();
2553}
2554
2555bool CombinerHelper::tryCombineShiftToUnmerge(
2556 MachineInstr &MI, unsigned TargetShiftAmount) const {
2557 unsigned ShiftAmt;
2558 if (matchCombineShiftToUnmerge(MI, TargetShiftSize: TargetShiftAmount, ShiftVal&: ShiftAmt)) {
2559 applyCombineShiftToUnmerge(MI, ShiftVal: ShiftAmt);
2560 return true;
2561 }
2562
2563 return false;
2564}
2565
2566bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI,
2567 Register &Reg) const {
2568 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2569 Register DstReg = MI.getOperand(i: 0).getReg();
2570 LLT DstTy = MRI.getType(Reg: DstReg);
2571 Register SrcReg = MI.getOperand(i: 1).getReg();
2572 return mi_match(R: SrcReg, MRI,
2573 P: m_GPtrToInt(Src: m_all_of(preds: m_SpecificType(Ty: DstTy), preds: m_Reg(R&: Reg))));
2574}
2575
2576void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI,
2577 Register &Reg) const {
2578 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2579 Register DstReg = MI.getOperand(i: 0).getReg();
2580 Builder.buildCopy(Res: DstReg, Op: Reg);
2581 MI.eraseFromParent();
2582}
2583
2584void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI,
2585 Register &Reg) const {
2586 assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT");
2587 Register DstReg = MI.getOperand(i: 0).getReg();
2588 Builder.buildZExtOrTrunc(Res: DstReg, Op: Reg);
2589 MI.eraseFromParent();
2590}
2591
2592bool CombinerHelper::matchCombineAddP2IToPtrAdd(
2593 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const {
2594 assert(MI.getOpcode() == TargetOpcode::G_ADD);
2595 Register LHS = MI.getOperand(i: 1).getReg();
2596 Register RHS = MI.getOperand(i: 2).getReg();
2597 LLT IntTy = MRI.getType(Reg: LHS);
2598
2599 // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the
2600 // instruction.
2601 PtrReg.second = false;
2602 for (Register SrcReg : {LHS, RHS}) {
2603 if (mi_match(R: SrcReg, MRI, P: m_GPtrToInt(Src: m_Reg(R&: PtrReg.first)))) {
2604 // Don't handle cases where the integer is implicitly converted to the
2605 // pointer width.
2606 LLT PtrTy = MRI.getType(Reg: PtrReg.first);
2607 if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits())
2608 return true;
2609 }
2610
2611 PtrReg.second = true;
2612 }
2613
2614 return false;
2615}
2616
2617void CombinerHelper::applyCombineAddP2IToPtrAdd(
2618 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const {
2619 Register Dst = MI.getOperand(i: 0).getReg();
2620 Register LHS = MI.getOperand(i: 1).getReg();
2621 Register RHS = MI.getOperand(i: 2).getReg();
2622
2623 const bool DoCommute = PtrReg.second;
2624 if (DoCommute)
2625 std::swap(a&: LHS, b&: RHS);
2626 LHS = PtrReg.first;
2627
2628 LLT PtrTy = MRI.getType(Reg: LHS);
2629
2630 auto PtrAdd = Builder.buildPtrAdd(Res: PtrTy, Op0: LHS, Op1: RHS);
2631 Builder.buildPtrToInt(Dst, Src: PtrAdd);
2632 MI.eraseFromParent();
2633}
2634
2635bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI,
2636 APInt &NewCst) const {
2637 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
2638 Register LHS = PtrAdd.getBaseReg();
2639 Register RHS = PtrAdd.getOffsetReg();
2640 MachineRegisterInfo &MRI = Builder.getMF().getRegInfo();
2641
2642 if (auto RHSCst = getIConstantVRegVal(VReg: RHS, MRI)) {
2643 APInt Cst;
2644 if (mi_match(R: LHS, MRI, P: m_GIntToPtr(Src: m_ICst(Cst)))) {
2645 auto DstTy = MRI.getType(Reg: PtrAdd.getReg(Idx: 0));
2646 // G_INTTOPTR uses zero-extension
2647 NewCst = Cst.zextOrTrunc(width: DstTy.getSizeInBits());
2648 NewCst += RHSCst->sextOrTrunc(width: DstTy.getSizeInBits());
2649 return true;
2650 }
2651 }
2652
2653 return false;
2654}
2655
2656void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI,
2657 APInt &NewCst) const {
2658 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
2659 Register Dst = PtrAdd.getReg(Idx: 0);
2660
2661 Builder.buildConstant(Res: Dst, Val: NewCst);
2662 PtrAdd.eraseFromParent();
2663}
2664
2665bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI,
2666 Register &Reg) const {
2667 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT");
2668 Register DstReg = MI.getOperand(i: 0).getReg();
2669 Register SrcReg = MI.getOperand(i: 1).getReg();
2670 Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI);
2671 if (OriginalSrcReg.isValid())
2672 SrcReg = OriginalSrcReg;
2673 LLT DstTy = MRI.getType(Reg: DstReg);
2674 return mi_match(R: SrcReg, MRI,
2675 P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))) &&
2676 canReplaceReg(DstReg, SrcReg: Reg, MRI);
2677}
2678
2679bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI,
2680 Register &Reg) const {
2681 assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT");
2682 Register DstReg = MI.getOperand(i: 0).getReg();
2683 Register SrcReg = MI.getOperand(i: 1).getReg();
2684 LLT DstTy = MRI.getType(Reg: DstReg);
2685 if (mi_match(R: SrcReg, MRI,
2686 P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))) &&
2687 canReplaceReg(DstReg, SrcReg: Reg, MRI)) {
2688 unsigned DstSize = DstTy.getScalarSizeInBits();
2689 unsigned SrcSize = MRI.getType(Reg: SrcReg).getScalarSizeInBits();
2690 return VT->getKnownBits(R: Reg).countMinLeadingZeros() >= DstSize - SrcSize;
2691 }
2692 return false;
2693}
2694
2695static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) {
2696 const unsigned ShiftSize = ShiftTy.getScalarSizeInBits();
2697 const unsigned TruncSize = TruncTy.getScalarSizeInBits();
2698
2699 // ShiftTy > 32 > TruncTy -> 32
2700 if (ShiftSize > 32 && TruncSize < 32)
2701 return ShiftTy.changeElementSize(NewEltSize: 32);
2702
2703 // TODO: We could also reduce to 16 bits, but that's more target-dependent.
2704 // Some targets like it, some don't, some only like it under certain
2705 // conditions/processor versions, etc.
2706 // A TL hook might be needed for this.
2707
2708 // Don't combine
2709 return ShiftTy;
2710}
2711
2712bool CombinerHelper::matchCombineTruncOfShift(
2713 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const {
2714 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2715 Register DstReg = MI.getOperand(i: 0).getReg();
2716 Register SrcReg = MI.getOperand(i: 1).getReg();
2717
2718 if (!MRI.hasOneNonDBGUse(RegNo: SrcReg))
2719 return false;
2720
2721 LLT SrcTy = MRI.getType(Reg: SrcReg);
2722 LLT DstTy = MRI.getType(Reg: DstReg);
2723
2724 MachineInstr *SrcMI = getDefIgnoringCopies(Reg: SrcReg, MRI);
2725 const auto &TL = getTargetLowering();
2726
2727 LLT NewShiftTy;
2728 switch (SrcMI->getOpcode()) {
2729 default:
2730 return false;
2731 case TargetOpcode::G_SHL: {
2732 NewShiftTy = DstTy;
2733
2734 // Make sure new shift amount is legal.
2735 KnownBits Known = VT->getKnownBits(R: SrcMI->getOperand(i: 2).getReg());
2736 if (Known.getMaxValue().uge(RHS: NewShiftTy.getScalarSizeInBits()))
2737 return false;
2738 break;
2739 }
2740 case TargetOpcode::G_LSHR:
2741 case TargetOpcode::G_ASHR: {
2742 // For right shifts, we conservatively do not do the transform if the TRUNC
2743 // has any STORE users. The reason is that if we change the type of the
2744 // shift, we may break the truncstore combine.
2745 //
2746 // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)).
2747 for (auto &User : MRI.use_instructions(Reg: DstReg))
2748 if (User.getOpcode() == TargetOpcode::G_STORE)
2749 return false;
2750
2751 NewShiftTy = getMidVTForTruncRightShiftCombine(ShiftTy: SrcTy, TruncTy: DstTy);
2752 if (NewShiftTy == SrcTy)
2753 return false;
2754
2755 // Make sure we won't lose information by truncating the high bits.
2756 KnownBits Known = VT->getKnownBits(R: SrcMI->getOperand(i: 2).getReg());
2757 if (Known.getMaxValue().ugt(RHS: NewShiftTy.getScalarSizeInBits() -
2758 DstTy.getScalarSizeInBits()))
2759 return false;
2760 break;
2761 }
2762 }
2763
2764 if (!isLegalOrBeforeLegalizer(
2765 Query: {SrcMI->getOpcode(),
2766 {NewShiftTy, TL.getPreferredShiftAmountTy(ShiftValueTy: NewShiftTy)}}))
2767 return false;
2768
2769 MatchInfo = std::make_pair(x&: SrcMI, y&: NewShiftTy);
2770 return true;
2771}
2772
2773void CombinerHelper::applyCombineTruncOfShift(
2774 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const {
2775 MachineInstr *ShiftMI = MatchInfo.first;
2776 LLT NewShiftTy = MatchInfo.second;
2777
2778 Register Dst = MI.getOperand(i: 0).getReg();
2779 LLT DstTy = MRI.getType(Reg: Dst);
2780
2781 Register ShiftAmt = ShiftMI->getOperand(i: 2).getReg();
2782 Register ShiftSrc = ShiftMI->getOperand(i: 1).getReg();
2783 ShiftSrc = Builder.buildTrunc(Res: NewShiftTy, Op: ShiftSrc).getReg(Idx: 0);
2784
2785 Register NewShift =
2786 Builder
2787 .buildInstr(Opc: ShiftMI->getOpcode(), DstOps: {NewShiftTy}, SrcOps: {ShiftSrc, ShiftAmt})
2788 .getReg(Idx: 0);
2789
2790 if (NewShiftTy == DstTy)
2791 replaceRegWith(MRI, FromReg: Dst, ToReg: NewShift);
2792 else
2793 Builder.buildTrunc(Res: Dst, Op: NewShift);
2794
2795 eraseInst(MI);
2796}
2797
2798bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) const {
2799 return any_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) {
2800 return MO.isReg() &&
2801 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
2802 });
2803}
2804
2805bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) const {
2806 return all_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) {
2807 return !MO.isReg() ||
2808 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
2809 });
2810}
2811
2812bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) const {
2813 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
2814 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
2815 return all_of(Range&: Mask, P: [](int Elt) { return Elt < 0; });
2816}
2817
2818bool CombinerHelper::matchUndefStore(MachineInstr &MI) const {
2819 assert(MI.getOpcode() == TargetOpcode::G_STORE);
2820 return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 0).getReg(),
2821 MRI);
2822}
2823
2824bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) const {
2825 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2826 return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 1).getReg(),
2827 MRI);
2828}
2829
2830bool CombinerHelper::matchInsertExtractVecEltOutOfBounds(
2831 MachineInstr &MI) const {
2832 assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT ||
2833 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) &&
2834 "Expected an insert/extract element op");
2835 LLT VecTy = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
2836 if (VecTy.isScalableVector())
2837 return false;
2838
2839 unsigned IdxIdx =
2840 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3;
2841 auto Idx = getIConstantVRegVal(VReg: MI.getOperand(i: IdxIdx).getReg(), MRI);
2842 if (!Idx)
2843 return false;
2844 return Idx->getZExtValue() >= VecTy.getNumElements();
2845}
2846
2847bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI,
2848 unsigned &OpIdx) const {
2849 GSelect &SelMI = cast<GSelect>(Val&: MI);
2850 auto Cst =
2851 isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: SelMI.getCondReg()), MRI);
2852 if (!Cst)
2853 return false;
2854 OpIdx = Cst->isZero() ? 3 : 2;
2855 return true;
2856}
2857
2858void CombinerHelper::eraseInst(MachineInstr &MI) const { MI.eraseFromParent(); }
2859
2860bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1,
2861 const MachineOperand &MOP2) const {
2862 if (!MOP1.isReg() || !MOP2.isReg())
2863 return false;
2864 auto InstAndDef1 = getDefSrcRegIgnoringCopies(Reg: MOP1.getReg(), MRI);
2865 if (!InstAndDef1)
2866 return false;
2867 auto InstAndDef2 = getDefSrcRegIgnoringCopies(Reg: MOP2.getReg(), MRI);
2868 if (!InstAndDef2)
2869 return false;
2870 MachineInstr *I1 = InstAndDef1->MI;
2871 MachineInstr *I2 = InstAndDef2->MI;
2872
2873 // Handle a case like this:
2874 //
2875 // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>)
2876 //
2877 // Even though %0 and %1 are produced by the same instruction they are not
2878 // the same values.
2879 if (I1 == I2)
2880 return MOP1.getReg() == MOP2.getReg();
2881
2882 // If we have an instruction which loads or stores, we can't guarantee that
2883 // it is identical.
2884 //
2885 // For example, we may have
2886 //
2887 // %x1 = G_LOAD %addr (load N from @somewhere)
2888 // ...
2889 // call @foo
2890 // ...
2891 // %x2 = G_LOAD %addr (load N from @somewhere)
2892 // ...
2893 // %or = G_OR %x1, %x2
2894 //
2895 // It's possible that @foo will modify whatever lives at the address we're
2896 // loading from. To be safe, let's just assume that all loads and stores
2897 // are different (unless we have something which is guaranteed to not
2898 // change.)
2899 if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad())
2900 return false;
2901
2902 // If both instructions are loads or stores, they are equal only if both
2903 // are dereferenceable invariant loads with the same number of bits.
2904 if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) {
2905 GLoadStore *LS1 = dyn_cast<GLoadStore>(Val: I1);
2906 GLoadStore *LS2 = dyn_cast<GLoadStore>(Val: I2);
2907 if (!LS1 || !LS2)
2908 return false;
2909
2910 if (!I2->isDereferenceableInvariantLoad() ||
2911 (LS1->getMemSizeInBits() != LS2->getMemSizeInBits()))
2912 return false;
2913 }
2914
2915 // Check for physical registers on the instructions first to avoid cases
2916 // like this:
2917 //
2918 // %a = COPY $physreg
2919 // ...
2920 // SOMETHING implicit-def $physreg
2921 // ...
2922 // %b = COPY $physreg
2923 //
2924 // These copies are not equivalent.
2925 if (any_of(Range: I1->uses(), P: [](const MachineOperand &MO) {
2926 return MO.isReg() && MO.getReg().isPhysical();
2927 })) {
2928 // Check if we have a case like this:
2929 //
2930 // %a = COPY $physreg
2931 // %b = COPY %a
2932 //
2933 // In this case, I1 and I2 will both be equal to %a = COPY $physreg.
2934 // From that, we know that they must have the same value, since they must
2935 // have come from the same COPY.
2936 return I1->isIdenticalTo(Other: *I2);
2937 }
2938
2939 // We don't have any physical registers, so we don't necessarily need the
2940 // same vreg defs.
2941 //
2942 // On the off-chance that there's some target instruction feeding into the
2943 // instruction, let's use produceSameValue instead of isIdenticalTo.
2944 if (Builder.getTII().produceSameValue(MI0: *I1, MI1: *I2, MRI: &MRI)) {
2945 // Handle instructions with multiple defs that produce same values. Values
2946 // are same for operands with same index.
2947 // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2948 // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2949 // I1 and I2 are different instructions but produce same values,
2950 // %1 and %6 are same, %1 and %7 are not the same value.
2951 return I1->findRegisterDefOperandIdx(Reg: InstAndDef1->Reg, /*TRI=*/nullptr) ==
2952 I2->findRegisterDefOperandIdx(Reg: InstAndDef2->Reg, /*TRI=*/nullptr);
2953 }
2954 return false;
2955}
2956
2957bool CombinerHelper::matchConstantOp(const MachineOperand &MOP,
2958 int64_t C) const {
2959 if (!MOP.isReg())
2960 return false;
2961 auto *MI = MRI.getVRegDef(Reg: MOP.getReg());
2962 auto MaybeCst = isConstantOrConstantSplatVector(MI&: *MI, MRI);
2963 return MaybeCst && MaybeCst->getBitWidth() <= 64 &&
2964 MaybeCst->getSExtValue() == C;
2965}
2966
2967bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP,
2968 double C) const {
2969 if (!MOP.isReg())
2970 return false;
2971 std::optional<FPValueAndVReg> MaybeCst;
2972 if (!mi_match(R: MOP.getReg(), MRI, P: m_GFCstOrSplat(FPValReg&: MaybeCst)))
2973 return false;
2974
2975 return MaybeCst->Value.isExactlyValue(V: C);
2976}
2977
2978void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI,
2979 unsigned OpIdx) const {
2980 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2981 Register OldReg = MI.getOperand(i: 0).getReg();
2982 Register Replacement = MI.getOperand(i: OpIdx).getReg();
2983 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2984 replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement);
2985 MI.eraseFromParent();
2986}
2987
2988void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI,
2989 Register Replacement) const {
2990 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2991 Register OldReg = MI.getOperand(i: 0).getReg();
2992 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2993 replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement);
2994 MI.eraseFromParent();
2995}
2996
2997bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI,
2998 unsigned ConstIdx) const {
2999 Register ConstReg = MI.getOperand(i: ConstIdx).getReg();
3000 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3001
3002 // Get the shift amount
3003 auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI);
3004 if (!VRegAndVal)
3005 return false;
3006
3007 // Return true of shift amount >= Bitwidth
3008 return (VRegAndVal->Value.uge(RHS: DstTy.getSizeInBits()));
3009}
3010
3011void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) const {
3012 assert((MI.getOpcode() == TargetOpcode::G_FSHL ||
3013 MI.getOpcode() == TargetOpcode::G_FSHR) &&
3014 "This is not a funnel shift operation");
3015
3016 Register ConstReg = MI.getOperand(i: 3).getReg();
3017 LLT ConstTy = MRI.getType(Reg: ConstReg);
3018 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3019
3020 auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI);
3021 assert((VRegAndVal) && "Value is not a constant");
3022
3023 // Calculate the new Shift Amount = Old Shift Amount % BitWidth
3024 APInt NewConst = VRegAndVal->Value.urem(
3025 RHS: APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits()));
3026
3027 auto NewConstInstr = Builder.buildConstant(Res: ConstTy, Val: NewConst.getZExtValue());
3028 Builder.buildInstr(
3029 Opc: MI.getOpcode(), DstOps: {MI.getOperand(i: 0)},
3030 SrcOps: {MI.getOperand(i: 1), MI.getOperand(i: 2), NewConstInstr.getReg(Idx: 0)});
3031
3032 MI.eraseFromParent();
3033}
3034
3035bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) const {
3036 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
3037 // Match (cond ? x : x)
3038 return matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: MI.getOperand(i: 3)) &&
3039 canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 2).getReg(),
3040 MRI);
3041}
3042
3043bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) const {
3044 return matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: MI.getOperand(i: 2)) &&
3045 canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 1).getReg(),
3046 MRI);
3047}
3048
3049bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI,
3050 unsigned OpIdx) const {
3051 MachineOperand &MO = MI.getOperand(i: OpIdx);
3052 return MO.isReg() &&
3053 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
3054}
3055
3056bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI,
3057 unsigned OpIdx) const {
3058 MachineOperand &MO = MI.getOperand(i: OpIdx);
3059 return isKnownToBeAPowerOfTwo(Val: MO.getReg(), MRI, ValueTracking: VT);
3060}
3061
3062void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI,
3063 double C) const {
3064 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3065 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: C);
3066 MI.eraseFromParent();
3067}
3068
3069void CombinerHelper::replaceInstWithConstant(MachineInstr &MI,
3070 int64_t C) const {
3071 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3072 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C);
3073 MI.eraseFromParent();
3074}
3075
3076void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) const {
3077 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3078 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C);
3079 MI.eraseFromParent();
3080}
3081
3082void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI,
3083 ConstantFP *CFP) const {
3084 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3085 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: CFP->getValueAPF());
3086 MI.eraseFromParent();
3087}
3088
3089void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) const {
3090 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3091 Builder.buildUndef(Res: MI.getOperand(i: 0));
3092 MI.eraseFromParent();
3093}
3094
3095bool CombinerHelper::matchSimplifyAddToSub(
3096 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const {
3097 Register LHS = MI.getOperand(i: 1).getReg();
3098 Register RHS = MI.getOperand(i: 2).getReg();
3099 Register &NewLHS = std::get<0>(t&: MatchInfo);
3100 Register &NewRHS = std::get<1>(t&: MatchInfo);
3101
3102 // Helper lambda to check for opportunities for
3103 // ((0-A) + B) -> B - A
3104 // (A + (0-B)) -> A - B
3105 auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) {
3106 if (!mi_match(R: MaybeSub, MRI, P: m_Neg(Src: m_Reg(R&: NewRHS))))
3107 return false;
3108 NewLHS = MaybeNewLHS;
3109 return true;
3110 };
3111
3112 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
3113}
3114
3115bool CombinerHelper::matchCombineInsertVecElts(
3116 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const {
3117 assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT &&
3118 "Invalid opcode");
3119 Register DstReg = MI.getOperand(i: 0).getReg();
3120 LLT DstTy = MRI.getType(Reg: DstReg);
3121 assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?");
3122
3123 if (DstTy.isScalableVector())
3124 return false;
3125
3126 unsigned NumElts = DstTy.getNumElements();
3127 // If this MI is part of a sequence of insert_vec_elts, then
3128 // don't do the combine in the middle of the sequence.
3129 if (MRI.hasOneUse(RegNo: DstReg) && MRI.use_instr_begin(RegNo: DstReg)->getOpcode() ==
3130 TargetOpcode::G_INSERT_VECTOR_ELT)
3131 return false;
3132 MachineInstr *CurrInst = &MI;
3133 MachineInstr *TmpInst;
3134 int64_t IntImm;
3135 Register TmpReg;
3136 MatchInfo.resize(N: NumElts);
3137 while (mi_match(
3138 R: CurrInst->getOperand(i: 0).getReg(), MRI,
3139 P: m_GInsertVecElt(Src0: m_MInstr(MI&: TmpInst), Src1: m_Reg(R&: TmpReg), Src2: m_ICst(Cst&: IntImm)))) {
3140 if (IntImm >= NumElts || IntImm < 0)
3141 return false;
3142 if (!MatchInfo[IntImm])
3143 MatchInfo[IntImm] = TmpReg;
3144 CurrInst = TmpInst;
3145 }
3146 // Variable index.
3147 if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
3148 return false;
3149 if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
3150 for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) {
3151 if (!MatchInfo[I - 1].isValid())
3152 MatchInfo[I - 1] = TmpInst->getOperand(i: I).getReg();
3153 }
3154 return true;
3155 }
3156 // If we didn't end in a G_IMPLICIT_DEF and the source is not fully
3157 // overwritten, bail out.
3158 return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF ||
3159 all_of(Range&: MatchInfo, P: [](Register Reg) { return !!Reg; });
3160}
3161
3162void CombinerHelper::applyCombineInsertVecElts(
3163 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const {
3164 Register UndefReg;
3165 auto GetUndef = [&]() {
3166 if (UndefReg)
3167 return UndefReg;
3168 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3169 UndefReg = Builder.buildUndef(Res: DstTy.getScalarType()).getReg(Idx: 0);
3170 return UndefReg;
3171 };
3172 for (Register &Reg : MatchInfo) {
3173 if (!Reg)
3174 Reg = GetUndef();
3175 }
3176 Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: MatchInfo);
3177 MI.eraseFromParent();
3178}
3179
3180void CombinerHelper::applySimplifyAddToSub(
3181 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const {
3182 Register SubLHS, SubRHS;
3183 std::tie(args&: SubLHS, args&: SubRHS) = MatchInfo;
3184 Builder.buildSub(Dst: MI.getOperand(i: 0).getReg(), Src0: SubLHS, Src1: SubRHS);
3185 MI.eraseFromParent();
3186}
3187
3188bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands(
3189 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const {
3190 // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ...
3191 //
3192 // Creates the new hand + logic instruction (but does not insert them.)
3193 //
3194 // On success, MatchInfo is populated with the new instructions. These are
3195 // inserted in applyHoistLogicOpWithSameOpcodeHands.
3196 unsigned LogicOpcode = MI.getOpcode();
3197 assert(LogicOpcode == TargetOpcode::G_AND ||
3198 LogicOpcode == TargetOpcode::G_OR ||
3199 LogicOpcode == TargetOpcode::G_XOR);
3200 MachineIRBuilder MIB(MI);
3201 Register Dst = MI.getOperand(i: 0).getReg();
3202 Register LHSReg = MI.getOperand(i: 1).getReg();
3203 Register RHSReg = MI.getOperand(i: 2).getReg();
3204
3205 // Don't recompute anything.
3206 if (!MRI.hasOneNonDBGUse(RegNo: LHSReg) || !MRI.hasOneNonDBGUse(RegNo: RHSReg))
3207 return false;
3208
3209 // Make sure we have (hand x, ...), (hand y, ...)
3210 MachineInstr *LeftHandInst = getDefIgnoringCopies(Reg: LHSReg, MRI);
3211 MachineInstr *RightHandInst = getDefIgnoringCopies(Reg: RHSReg, MRI);
3212 if (!LeftHandInst || !RightHandInst)
3213 return false;
3214 unsigned HandOpcode = LeftHandInst->getOpcode();
3215 if (HandOpcode != RightHandInst->getOpcode())
3216 return false;
3217 if (LeftHandInst->getNumOperands() < 2 ||
3218 !LeftHandInst->getOperand(i: 1).isReg() ||
3219 RightHandInst->getNumOperands() < 2 ||
3220 !RightHandInst->getOperand(i: 1).isReg())
3221 return false;
3222
3223 // Make sure the types match up, and if we're doing this post-legalization,
3224 // we end up with legal types.
3225 Register X = LeftHandInst->getOperand(i: 1).getReg();
3226 Register Y = RightHandInst->getOperand(i: 1).getReg();
3227 LLT XTy = MRI.getType(Reg: X);
3228 LLT YTy = MRI.getType(Reg: Y);
3229 if (!XTy.isValid() || XTy != YTy)
3230 return false;
3231
3232 // Optional extra source register.
3233 Register ExtraHandOpSrcReg;
3234 switch (HandOpcode) {
3235 default:
3236 return false;
3237 case TargetOpcode::G_ANYEXT:
3238 case TargetOpcode::G_SEXT:
3239 case TargetOpcode::G_ZEXT: {
3240 // Match: logic (ext X), (ext Y) --> ext (logic X, Y)
3241 break;
3242 }
3243 case TargetOpcode::G_TRUNC: {
3244 // Match: logic (trunc X), (trunc Y) -> trunc (logic X, Y)
3245 const MachineFunction *MF = MI.getMF();
3246 LLVMContext &Ctx = MF->getFunction().getContext();
3247
3248 LLT DstTy = MRI.getType(Reg: Dst);
3249 const TargetLowering &TLI = getTargetLowering();
3250
3251 // Be extra careful sinking truncate. If it's free, there's no benefit in
3252 // widening a binop.
3253 if (TLI.isZExtFree(FromTy: DstTy, ToTy: XTy, Ctx) && TLI.isTruncateFree(FromTy: XTy, ToTy: DstTy, Ctx))
3254 return false;
3255 break;
3256 }
3257 case TargetOpcode::G_AND:
3258 case TargetOpcode::G_ASHR:
3259 case TargetOpcode::G_LSHR:
3260 case TargetOpcode::G_SHL: {
3261 // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z
3262 MachineOperand &ZOp = LeftHandInst->getOperand(i: 2);
3263 if (!matchEqualDefs(MOP1: ZOp, MOP2: RightHandInst->getOperand(i: 2)))
3264 return false;
3265 ExtraHandOpSrcReg = ZOp.getReg();
3266 break;
3267 }
3268 }
3269
3270 if (!isLegalOrBeforeLegalizer(Query: {LogicOpcode, {XTy, YTy}}))
3271 return false;
3272
3273 // Record the steps to build the new instructions.
3274 //
3275 // Steps to build (logic x, y)
3276 auto NewLogicDst = MRI.createGenericVirtualRegister(Ty: XTy);
3277 OperandBuildSteps LogicBuildSteps = {
3278 [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: NewLogicDst); },
3279 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: X); },
3280 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: Y); }};
3281 InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps);
3282
3283 // Steps to build hand (logic x, y), ...z
3284 OperandBuildSteps HandBuildSteps = {
3285 [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: Dst); },
3286 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: NewLogicDst); }};
3287 if (ExtraHandOpSrcReg.isValid())
3288 HandBuildSteps.push_back(
3289 Elt: [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: ExtraHandOpSrcReg); });
3290 InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps);
3291
3292 MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps});
3293 return true;
3294}
3295
3296void CombinerHelper::applyBuildInstructionSteps(
3297 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const {
3298 assert(MatchInfo.InstrsToBuild.size() &&
3299 "Expected at least one instr to build?");
3300 for (auto &InstrToBuild : MatchInfo.InstrsToBuild) {
3301 assert(InstrToBuild.Opcode && "Expected a valid opcode?");
3302 assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?");
3303 MachineInstrBuilder Instr = Builder.buildInstr(Opcode: InstrToBuild.Opcode);
3304 for (auto &OperandFn : InstrToBuild.OperandFns)
3305 OperandFn(Instr);
3306 }
3307 MI.eraseFromParent();
3308}
3309
3310bool CombinerHelper::matchAshrShlToSextInreg(
3311 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const {
3312 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3313 int64_t ShlCst, AshrCst;
3314 Register Src;
3315 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
3316 P: m_GAShr(L: m_GShl(L: m_Reg(R&: Src), R: m_ICstOrSplat(Cst&: ShlCst)),
3317 R: m_ICstOrSplat(Cst&: AshrCst))))
3318 return false;
3319 if (ShlCst != AshrCst)
3320 return false;
3321 if (!isLegalOrBeforeLegalizer(
3322 Query: {TargetOpcode::G_SEXT_INREG, {MRI.getType(Reg: Src)}}))
3323 return false;
3324 MatchInfo = std::make_tuple(args&: Src, args&: ShlCst);
3325 return true;
3326}
3327
3328void CombinerHelper::applyAshShlToSextInreg(
3329 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const {
3330 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3331 Register Src;
3332 int64_t ShiftAmt;
3333 std::tie(args&: Src, args&: ShiftAmt) = MatchInfo;
3334 unsigned Size = MRI.getType(Reg: Src).getScalarSizeInBits();
3335 Builder.buildSExtInReg(Res: MI.getOperand(i: 0).getReg(), Op: Src, ImmOp: Size - ShiftAmt);
3336 MI.eraseFromParent();
3337}
3338
3339/// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0
3340bool CombinerHelper::matchOverlappingAnd(
3341 MachineInstr &MI,
3342 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
3343 assert(MI.getOpcode() == TargetOpcode::G_AND);
3344
3345 Register Dst = MI.getOperand(i: 0).getReg();
3346 LLT Ty = MRI.getType(Reg: Dst);
3347
3348 Register R;
3349 int64_t C1;
3350 int64_t C2;
3351 if (!mi_match(
3352 R: Dst, MRI,
3353 P: m_GAnd(L: m_GAnd(L: m_Reg(R), R: m_ICst(Cst&: C1)), R: m_ICst(Cst&: C2))))
3354 return false;
3355
3356 MatchInfo = [=](MachineIRBuilder &B) {
3357 if (C1 & C2) {
3358 B.buildAnd(Dst, Src0: R, Src1: B.buildConstant(Res: Ty, Val: C1 & C2));
3359 return;
3360 }
3361 auto Zero = B.buildConstant(Res: Ty, Val: 0);
3362 replaceRegWith(MRI, FromReg: Dst, ToReg: Zero->getOperand(i: 0).getReg());
3363 };
3364 return true;
3365}
3366
3367bool CombinerHelper::matchRedundantAnd(MachineInstr &MI,
3368 Register &Replacement) const {
3369 // Given
3370 //
3371 // %y:_(sN) = G_SOMETHING
3372 // %x:_(sN) = G_SOMETHING
3373 // %res:_(sN) = G_AND %x, %y
3374 //
3375 // Eliminate the G_AND when it is known that x & y == x or x & y == y.
3376 //
3377 // Patterns like this can appear as a result of legalization. E.g.
3378 //
3379 // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y
3380 // %one:_(s32) = G_CONSTANT i32 1
3381 // %and:_(s32) = G_AND %cmp, %one
3382 //
3383 // In this case, G_ICMP only produces a single bit, so x & 1 == x.
3384 assert(MI.getOpcode() == TargetOpcode::G_AND);
3385 if (!VT)
3386 return false;
3387
3388 Register AndDst = MI.getOperand(i: 0).getReg();
3389 Register LHS = MI.getOperand(i: 1).getReg();
3390 Register RHS = MI.getOperand(i: 2).getReg();
3391
3392 // Check the RHS (maybe a constant) first, and if we have no KnownBits there,
3393 // we can't do anything. If we do, then it depends on whether we have
3394 // KnownBits on the LHS.
3395 KnownBits RHSBits = VT->getKnownBits(R: RHS);
3396 if (RHSBits.isUnknown())
3397 return false;
3398
3399 KnownBits LHSBits = VT->getKnownBits(R: LHS);
3400
3401 // Check that x & Mask == x.
3402 // x & 1 == x, always
3403 // x & 0 == x, only if x is also 0
3404 // Meaning Mask has no effect if every bit is either one in Mask or zero in x.
3405 //
3406 // Check if we can replace AndDst with the LHS of the G_AND
3407 if (canReplaceReg(DstReg: AndDst, SrcReg: LHS, MRI) &&
3408 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3409 Replacement = LHS;
3410 return true;
3411 }
3412
3413 // Check if we can replace AndDst with the RHS of the G_AND
3414 if (canReplaceReg(DstReg: AndDst, SrcReg: RHS, MRI) &&
3415 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3416 Replacement = RHS;
3417 return true;
3418 }
3419
3420 return false;
3421}
3422
3423bool CombinerHelper::matchRedundantOr(MachineInstr &MI,
3424 Register &Replacement) const {
3425 // Given
3426 //
3427 // %y:_(sN) = G_SOMETHING
3428 // %x:_(sN) = G_SOMETHING
3429 // %res:_(sN) = G_OR %x, %y
3430 //
3431 // Eliminate the G_OR when it is known that x | y == x or x | y == y.
3432 assert(MI.getOpcode() == TargetOpcode::G_OR);
3433 if (!VT)
3434 return false;
3435
3436 Register OrDst = MI.getOperand(i: 0).getReg();
3437 Register LHS = MI.getOperand(i: 1).getReg();
3438 Register RHS = MI.getOperand(i: 2).getReg();
3439
3440 KnownBits LHSBits = VT->getKnownBits(R: LHS);
3441 KnownBits RHSBits = VT->getKnownBits(R: RHS);
3442
3443 // Check that x | Mask == x.
3444 // x | 0 == x, always
3445 // x | 1 == x, only if x is also 1
3446 // Meaning Mask has no effect if every bit is either zero in Mask or one in x.
3447 //
3448 // Check if we can replace OrDst with the LHS of the G_OR
3449 if (canReplaceReg(DstReg: OrDst, SrcReg: LHS, MRI) &&
3450 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3451 Replacement = LHS;
3452 return true;
3453 }
3454
3455 // Check if we can replace OrDst with the RHS of the G_OR
3456 if (canReplaceReg(DstReg: OrDst, SrcReg: RHS, MRI) &&
3457 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3458 Replacement = RHS;
3459 return true;
3460 }
3461
3462 return false;
3463}
3464
3465bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) const {
3466 // If the input is already sign extended, just drop the extension.
3467 Register Src = MI.getOperand(i: 1).getReg();
3468 unsigned ExtBits = MI.getOperand(i: 2).getImm();
3469 unsigned TypeSize = MRI.getType(Reg: Src).getScalarSizeInBits();
3470 return VT->computeNumSignBits(R: Src) >= (TypeSize - ExtBits + 1);
3471}
3472
3473static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits,
3474 int64_t Cst, bool IsVector, bool IsFP) {
3475 // For i1, Cst will always be -1 regardless of boolean contents.
3476 return (ScalarSizeBits == 1 && Cst == -1) ||
3477 isConstTrueVal(TLI, Val: Cst, IsVector, IsFP);
3478}
3479
3480// This pattern aims to match the following shape to avoid extra mov
3481// instructions
3482// G_BUILD_VECTOR(
3483// G_UNMERGE_VALUES(src, 0)
3484// G_UNMERGE_VALUES(src, 1)
3485// G_IMPLICIT_DEF
3486// G_IMPLICIT_DEF
3487// )
3488// ->
3489// G_CONCAT_VECTORS(
3490// src,
3491// undef
3492// )
3493bool CombinerHelper::matchCombineBuildUnmerge(MachineInstr &MI,
3494 MachineRegisterInfo &MRI,
3495 Register &UnmergeSrc) const {
3496 auto &BV = cast<GBuildVector>(Val&: MI);
3497
3498 unsigned BuildUseCount = BV.getNumSources();
3499 if (BuildUseCount % 2 != 0)
3500 return false;
3501
3502 unsigned NumUnmerge = BuildUseCount / 2;
3503
3504 auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: BV.getSourceReg(I: 0), MRI);
3505
3506 // Check the first operand is an unmerge and has the correct number of
3507 // operands
3508 if (!Unmerge || Unmerge->getNumDefs() != NumUnmerge)
3509 return false;
3510
3511 UnmergeSrc = Unmerge->getSourceReg();
3512
3513 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3514 LLT UnmergeSrcTy = MRI.getType(Reg: UnmergeSrc);
3515
3516 if (!UnmergeSrcTy.isVector())
3517 return false;
3518
3519 // Ensure we only generate legal instructions post-legalizer
3520 if (!IsPreLegalize &&
3521 !isLegal(Query: {TargetOpcode::G_CONCAT_VECTORS, {DstTy, UnmergeSrcTy}}))
3522 return false;
3523
3524 // Check that all of the operands before the midpoint come from the same
3525 // unmerge and are in the same order as they are used in the build_vector
3526 for (unsigned I = 0; I < NumUnmerge; ++I) {
3527 auto MaybeUnmergeReg = BV.getSourceReg(I);
3528 auto *LoopUnmerge = getOpcodeDef<GUnmerge>(Reg: MaybeUnmergeReg, MRI);
3529
3530 if (!LoopUnmerge || LoopUnmerge != Unmerge)
3531 return false;
3532
3533 if (LoopUnmerge->getOperand(i: I).getReg() != MaybeUnmergeReg)
3534 return false;
3535 }
3536
3537 // Check that all of the unmerged values are used
3538 if (Unmerge->getNumDefs() != NumUnmerge)
3539 return false;
3540
3541 // Check that all of the operands after the mid point are undefs.
3542 for (unsigned I = NumUnmerge; I < BuildUseCount; ++I) {
3543 auto *Undef = getDefIgnoringCopies(Reg: BV.getSourceReg(I), MRI);
3544
3545 if (Undef->getOpcode() != TargetOpcode::G_IMPLICIT_DEF)
3546 return false;
3547 }
3548
3549 return true;
3550}
3551
3552void CombinerHelper::applyCombineBuildUnmerge(MachineInstr &MI,
3553 MachineRegisterInfo &MRI,
3554 MachineIRBuilder &B,
3555 Register &UnmergeSrc) const {
3556 assert(UnmergeSrc && "Expected there to be one matching G_UNMERGE_VALUES");
3557 B.setInstrAndDebugLoc(MI);
3558
3559 Register UndefVec = B.buildUndef(Res: MRI.getType(Reg: UnmergeSrc)).getReg(Idx: 0);
3560 B.buildConcatVectors(Res: MI.getOperand(i: 0), Ops: {UnmergeSrc, UndefVec});
3561
3562 MI.eraseFromParent();
3563}
3564
3565// This combine tries to reduce the number of scalarised G_TRUNC instructions by
3566// using vector truncates instead
3567//
3568// EXAMPLE:
3569// %a(i32), %b(i32) = G_UNMERGE_VALUES %src(<2 x i32>)
3570// %T_a(i16) = G_TRUNC %a(i32)
3571// %T_b(i16) = G_TRUNC %b(i32)
3572// %Undef(i16) = G_IMPLICIT_DEF(i16)
3573// %dst(v4i16) = G_BUILD_VECTORS %T_a(i16), %T_b(i16), %Undef(i16), %Undef(i16)
3574//
3575// ===>
3576// %Undef(<2 x i32>) = G_IMPLICIT_DEF(<2 x i32>)
3577// %Mid(<4 x s32>) = G_CONCAT_VECTORS %src(<2 x i32>), %Undef(<2 x i32>)
3578// %dst(<4 x s16>) = G_TRUNC %Mid(<4 x s32>)
3579//
3580// Only matches sources made up of G_TRUNCs followed by G_IMPLICIT_DEFs
3581bool CombinerHelper::matchUseVectorTruncate(MachineInstr &MI,
3582 Register &MatchInfo) const {
3583 auto BuildMI = cast<GBuildVector>(Val: &MI);
3584 unsigned NumOperands = BuildMI->getNumSources();
3585 LLT DstTy = MRI.getType(Reg: BuildMI->getReg(Idx: 0));
3586
3587 // Check the G_BUILD_VECTOR sources
3588 unsigned I;
3589 MachineInstr *UnmergeMI = nullptr;
3590
3591 // Check all source TRUNCs come from the same UNMERGE instruction
3592 for (I = 0; I < NumOperands; ++I) {
3593 auto SrcMI = MRI.getVRegDef(Reg: BuildMI->getSourceReg(I));
3594 auto SrcMIOpc = SrcMI->getOpcode();
3595
3596 // Check if the G_TRUNC instructions all come from the same MI
3597 if (SrcMIOpc == TargetOpcode::G_TRUNC) {
3598 if (!UnmergeMI) {
3599 UnmergeMI = MRI.getVRegDef(Reg: SrcMI->getOperand(i: 1).getReg());
3600 if (UnmergeMI->getOpcode() != TargetOpcode::G_UNMERGE_VALUES)
3601 return false;
3602 } else {
3603 auto UnmergeSrcMI = MRI.getVRegDef(Reg: SrcMI->getOperand(i: 1).getReg());
3604 if (UnmergeMI != UnmergeSrcMI)
3605 return false;
3606 }
3607 } else {
3608 break;
3609 }
3610 }
3611 if (I < 2)
3612 return false;
3613
3614 // Check the remaining source elements are only G_IMPLICIT_DEF
3615 for (; I < NumOperands; ++I) {
3616 auto SrcMI = MRI.getVRegDef(Reg: BuildMI->getSourceReg(I));
3617 auto SrcMIOpc = SrcMI->getOpcode();
3618
3619 if (SrcMIOpc != TargetOpcode::G_IMPLICIT_DEF)
3620 return false;
3621 }
3622
3623 // Check the size of unmerge source
3624 MatchInfo = cast<GUnmerge>(Val: UnmergeMI)->getSourceReg();
3625 LLT UnmergeSrcTy = MRI.getType(Reg: MatchInfo);
3626 if (!DstTy.getElementCount().isKnownMultipleOf(RHS: UnmergeSrcTy.getNumElements()))
3627 return false;
3628
3629 // Check the unmerge source and destination element types match
3630 LLT UnmergeSrcEltTy = UnmergeSrcTy.getElementType();
3631 Register UnmergeDstReg = UnmergeMI->getOperand(i: 0).getReg();
3632 LLT UnmergeDstEltTy = MRI.getType(Reg: UnmergeDstReg);
3633 if (UnmergeSrcEltTy != UnmergeDstEltTy)
3634 return false;
3635
3636 // Only generate legal instructions post-legalizer
3637 if (!IsPreLegalize) {
3638 LLT MidTy = DstTy.changeElementType(NewEltTy: UnmergeSrcTy.getScalarType());
3639
3640 if (DstTy.getElementCount() != UnmergeSrcTy.getElementCount() &&
3641 !isLegal(Query: {TargetOpcode::G_CONCAT_VECTORS, {MidTy, UnmergeSrcTy}}))
3642 return false;
3643
3644 if (!isLegal(Query: {TargetOpcode::G_TRUNC, {DstTy, MidTy}}))
3645 return false;
3646 }
3647
3648 return true;
3649}
3650
3651void CombinerHelper::applyUseVectorTruncate(MachineInstr &MI,
3652 Register &MatchInfo) const {
3653 Register MidReg;
3654 auto BuildMI = cast<GBuildVector>(Val: &MI);
3655 Register DstReg = BuildMI->getReg(Idx: 0);
3656 LLT DstTy = MRI.getType(Reg: DstReg);
3657 LLT UnmergeSrcTy = MRI.getType(Reg: MatchInfo);
3658 unsigned DstTyNumElt = DstTy.getNumElements();
3659 unsigned UnmergeSrcTyNumElt = UnmergeSrcTy.getNumElements();
3660
3661 // No need to pad vector if only G_TRUNC is needed
3662 if (DstTyNumElt / UnmergeSrcTyNumElt == 1) {
3663 MidReg = MatchInfo;
3664 } else {
3665 Register UndefReg = Builder.buildUndef(Res: UnmergeSrcTy).getReg(Idx: 0);
3666 SmallVector<Register> ConcatRegs = {MatchInfo};
3667 for (unsigned I = 1; I < DstTyNumElt / UnmergeSrcTyNumElt; ++I)
3668 ConcatRegs.push_back(Elt: UndefReg);
3669
3670 auto MidTy = DstTy.changeElementType(NewEltTy: UnmergeSrcTy.getScalarType());
3671 MidReg = Builder.buildConcatVectors(Res: MidTy, Ops: ConcatRegs).getReg(Idx: 0);
3672 }
3673
3674 Builder.buildTrunc(Res: DstReg, Op: MidReg);
3675 MI.eraseFromParent();
3676}
3677
3678bool CombinerHelper::matchNotCmp(
3679 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const {
3680 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3681 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3682 const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering();
3683 Register XorSrc;
3684 Register CstReg;
3685 // We match xor(src, true) here.
3686 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
3687 P: m_GXor(L: m_Reg(R&: XorSrc), R: m_Reg(R&: CstReg))))
3688 return false;
3689
3690 if (!MRI.hasOneNonDBGUse(RegNo: XorSrc))
3691 return false;
3692
3693 // Check that XorSrc is the root of a tree of comparisons combined with ANDs
3694 // and ORs. The suffix of RegsToNegate starting from index I is used a work
3695 // list of tree nodes to visit.
3696 RegsToNegate.push_back(Elt: XorSrc);
3697 // Remember whether the comparisons are all integer or all floating point.
3698 bool IsInt = false;
3699 bool IsFP = false;
3700 for (unsigned I = 0; I < RegsToNegate.size(); ++I) {
3701 Register Reg = RegsToNegate[I];
3702 if (!MRI.hasOneNonDBGUse(RegNo: Reg))
3703 return false;
3704 MachineInstr *Def = MRI.getVRegDef(Reg);
3705 switch (Def->getOpcode()) {
3706 default:
3707 // Don't match if the tree contains anything other than ANDs, ORs and
3708 // comparisons.
3709 return false;
3710 case TargetOpcode::G_ICMP:
3711 if (IsFP)
3712 return false;
3713 IsInt = true;
3714 // When we apply the combine we will invert the predicate.
3715 break;
3716 case TargetOpcode::G_FCMP:
3717 if (IsInt)
3718 return false;
3719 IsFP = true;
3720 // When we apply the combine we will invert the predicate.
3721 break;
3722 case TargetOpcode::G_AND:
3723 case TargetOpcode::G_OR:
3724 // Implement De Morgan's laws:
3725 // ~(x & y) -> ~x | ~y
3726 // ~(x | y) -> ~x & ~y
3727 // When we apply the combine we will change the opcode and recursively
3728 // negate the operands.
3729 RegsToNegate.push_back(Elt: Def->getOperand(i: 1).getReg());
3730 RegsToNegate.push_back(Elt: Def->getOperand(i: 2).getReg());
3731 break;
3732 }
3733 }
3734
3735 // Now we know whether the comparisons are integer or floating point, check
3736 // the constant in the xor.
3737 int64_t Cst;
3738 if (Ty.isVector()) {
3739 MachineInstr *CstDef = MRI.getVRegDef(Reg: CstReg);
3740 auto MaybeCst = getIConstantSplatSExtVal(MI: *CstDef, MRI);
3741 if (!MaybeCst)
3742 return false;
3743 if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getScalarSizeInBits(), Cst: *MaybeCst, IsVector: true, IsFP))
3744 return false;
3745 } else {
3746 if (!mi_match(R: CstReg, MRI, P: m_ICst(Cst)))
3747 return false;
3748 if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getSizeInBits(), Cst, IsVector: false, IsFP))
3749 return false;
3750 }
3751
3752 return true;
3753}
3754
3755void CombinerHelper::applyNotCmp(
3756 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const {
3757 for (Register Reg : RegsToNegate) {
3758 MachineInstr *Def = MRI.getVRegDef(Reg);
3759 Observer.changingInstr(MI&: *Def);
3760 // For each comparison, invert the opcode. For each AND and OR, change the
3761 // opcode.
3762 switch (Def->getOpcode()) {
3763 default:
3764 llvm_unreachable("Unexpected opcode");
3765 case TargetOpcode::G_ICMP:
3766 case TargetOpcode::G_FCMP: {
3767 MachineOperand &PredOp = Def->getOperand(i: 1);
3768 CmpInst::Predicate NewP = CmpInst::getInversePredicate(
3769 pred: (CmpInst::Predicate)PredOp.getPredicate());
3770 PredOp.setPredicate(NewP);
3771 break;
3772 }
3773 case TargetOpcode::G_AND:
3774 Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_OR));
3775 break;
3776 case TargetOpcode::G_OR:
3777 Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND));
3778 break;
3779 }
3780 Observer.changedInstr(MI&: *Def);
3781 }
3782
3783 replaceRegWith(MRI, FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 1).getReg());
3784 MI.eraseFromParent();
3785}
3786
3787bool CombinerHelper::matchXorOfAndWithSameReg(
3788 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const {
3789 // Match (xor (and x, y), y) (or any of its commuted cases)
3790 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3791 Register &X = MatchInfo.first;
3792 Register &Y = MatchInfo.second;
3793 Register AndReg = MI.getOperand(i: 1).getReg();
3794 Register SharedReg = MI.getOperand(i: 2).getReg();
3795
3796 // Find a G_AND on either side of the G_XOR.
3797 // Look for one of
3798 //
3799 // (xor (and x, y), SharedReg)
3800 // (xor SharedReg, (and x, y))
3801 if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) {
3802 std::swap(a&: AndReg, b&: SharedReg);
3803 if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y))))
3804 return false;
3805 }
3806
3807 // Only do this if we'll eliminate the G_AND.
3808 if (!MRI.hasOneNonDBGUse(RegNo: AndReg))
3809 return false;
3810
3811 // We can combine if SharedReg is the same as either the LHS or RHS of the
3812 // G_AND.
3813 if (Y != SharedReg)
3814 std::swap(a&: X, b&: Y);
3815 return Y == SharedReg;
3816}
3817
3818void CombinerHelper::applyXorOfAndWithSameReg(
3819 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const {
3820 // Fold (xor (and x, y), y) -> (and (not x), y)
3821 Register X, Y;
3822 std::tie(args&: X, args&: Y) = MatchInfo;
3823 auto Not = Builder.buildNot(Dst: MRI.getType(Reg: X), Src0: X);
3824 Observer.changingInstr(MI);
3825 MI.setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND));
3826 MI.getOperand(i: 1).setReg(Not->getOperand(i: 0).getReg());
3827 MI.getOperand(i: 2).setReg(Y);
3828 Observer.changedInstr(MI);
3829}
3830
3831bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) const {
3832 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
3833 Register DstReg = PtrAdd.getReg(Idx: 0);
3834 LLT Ty = MRI.getType(Reg: DstReg);
3835 const DataLayout &DL = Builder.getMF().getDataLayout();
3836
3837 if (DL.isNonIntegralAddressSpace(AddrSpace: Ty.getScalarType().getAddressSpace()))
3838 return false;
3839
3840 if (Ty.isPointer()) {
3841 auto ConstVal = getIConstantVRegVal(VReg: PtrAdd.getBaseReg(), MRI);
3842 return ConstVal && *ConstVal == 0;
3843 }
3844
3845 assert(Ty.isVector() && "Expecting a vector type");
3846 const MachineInstr *VecMI = MRI.getVRegDef(Reg: PtrAdd.getBaseReg());
3847 return isBuildVectorAllZeros(MI: *VecMI, MRI);
3848}
3849
3850void CombinerHelper::applyPtrAddZero(MachineInstr &MI) const {
3851 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
3852 Builder.buildIntToPtr(Dst: PtrAdd.getReg(Idx: 0), Src: PtrAdd.getOffsetReg());
3853 PtrAdd.eraseFromParent();
3854}
3855
3856/// The second source operand is known to be a power of 2.
3857void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) const {
3858 Register DstReg = MI.getOperand(i: 0).getReg();
3859 Register Src0 = MI.getOperand(i: 1).getReg();
3860 Register Pow2Src1 = MI.getOperand(i: 2).getReg();
3861 LLT Ty = MRI.getType(Reg: DstReg);
3862
3863 // Fold (urem x, pow2) -> (and x, pow2-1)
3864 auto NegOne = Builder.buildConstant(Res: Ty, Val: -1);
3865 auto Add = Builder.buildAdd(Dst: Ty, Src0: Pow2Src1, Src1: NegOne);
3866 Builder.buildAnd(Dst: DstReg, Src0, Src1: Add);
3867 MI.eraseFromParent();
3868}
3869
3870bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI,
3871 unsigned &SelectOpNo) const {
3872 Register LHS = MI.getOperand(i: 1).getReg();
3873 Register RHS = MI.getOperand(i: 2).getReg();
3874
3875 Register OtherOperandReg = RHS;
3876 SelectOpNo = 1;
3877 MachineInstr *Select = MRI.getVRegDef(Reg: LHS);
3878
3879 // Don't do this unless the old select is going away. We want to eliminate the
3880 // binary operator, not replace a binop with a select.
3881 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3882 !MRI.hasOneNonDBGUse(RegNo: LHS)) {
3883 OtherOperandReg = LHS;
3884 SelectOpNo = 2;
3885 Select = MRI.getVRegDef(Reg: RHS);
3886 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3887 !MRI.hasOneNonDBGUse(RegNo: RHS))
3888 return false;
3889 }
3890
3891 MachineInstr *SelectLHS = MRI.getVRegDef(Reg: Select->getOperand(i: 2).getReg());
3892 MachineInstr *SelectRHS = MRI.getVRegDef(Reg: Select->getOperand(i: 3).getReg());
3893
3894 if (!isConstantOrConstantVector(MI: *SelectLHS, MRI,
3895 /*AllowFP*/ true,
3896 /*AllowOpaqueConstants*/ false))
3897 return false;
3898 if (!isConstantOrConstantVector(MI: *SelectRHS, MRI,
3899 /*AllowFP*/ true,
3900 /*AllowOpaqueConstants*/ false))
3901 return false;
3902
3903 unsigned BinOpcode = MI.getOpcode();
3904
3905 // We know that one of the operands is a select of constants. Now verify that
3906 // the other binary operator operand is either a constant, or we can handle a
3907 // variable.
3908 bool CanFoldNonConst =
3909 (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) &&
3910 (isNullOrNullSplat(MI: *SelectLHS, MRI) ||
3911 isAllOnesOrAllOnesSplat(MI: *SelectLHS, MRI)) &&
3912 (isNullOrNullSplat(MI: *SelectRHS, MRI) ||
3913 isAllOnesOrAllOnesSplat(MI: *SelectRHS, MRI));
3914 if (CanFoldNonConst)
3915 return true;
3916
3917 return isConstantOrConstantVector(MI: *MRI.getVRegDef(Reg: OtherOperandReg), MRI,
3918 /*AllowFP*/ true,
3919 /*AllowOpaqueConstants*/ false);
3920}
3921
3922/// \p SelectOperand is the operand in binary operator \p MI that is the select
3923/// to fold.
3924void CombinerHelper::applyFoldBinOpIntoSelect(
3925 MachineInstr &MI, const unsigned &SelectOperand) const {
3926 Register Dst = MI.getOperand(i: 0).getReg();
3927 Register LHS = MI.getOperand(i: 1).getReg();
3928 Register RHS = MI.getOperand(i: 2).getReg();
3929 MachineInstr *Select = MRI.getVRegDef(Reg: MI.getOperand(i: SelectOperand).getReg());
3930
3931 Register SelectCond = Select->getOperand(i: 1).getReg();
3932 Register SelectTrue = Select->getOperand(i: 2).getReg();
3933 Register SelectFalse = Select->getOperand(i: 3).getReg();
3934
3935 LLT Ty = MRI.getType(Reg: Dst);
3936 unsigned BinOpcode = MI.getOpcode();
3937
3938 Register FoldTrue, FoldFalse;
3939
3940 // We have a select-of-constants followed by a binary operator with a
3941 // constant. Eliminate the binop by pulling the constant math into the select.
3942 // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
3943 if (SelectOperand == 1) {
3944 // TODO: SelectionDAG verifies this actually constant folds before
3945 // committing to the combine.
3946
3947 FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectTrue, RHS}).getReg(Idx: 0);
3948 FoldFalse =
3949 Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectFalse, RHS}).getReg(Idx: 0);
3950 } else {
3951 FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectTrue}).getReg(Idx: 0);
3952 FoldFalse =
3953 Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectFalse}).getReg(Idx: 0);
3954 }
3955
3956 Builder.buildSelect(Res: Dst, Tst: SelectCond, Op0: FoldTrue, Op1: FoldFalse, Flags: MI.getFlags());
3957 MI.eraseFromParent();
3958}
3959
3960std::optional<SmallVector<Register, 8>>
3961CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const {
3962 assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!");
3963 // We want to detect if Root is part of a tree which represents a bunch
3964 // of loads being merged into a larger load. We'll try to recognize patterns
3965 // like, for example:
3966 //
3967 // Reg Reg
3968 // \ /
3969 // OR_1 Reg
3970 // \ /
3971 // OR_2
3972 // \ Reg
3973 // .. /
3974 // Root
3975 //
3976 // Reg Reg Reg Reg
3977 // \ / \ /
3978 // OR_1 OR_2
3979 // \ /
3980 // \ /
3981 // ...
3982 // Root
3983 //
3984 // Each "Reg" may have been produced by a load + some arithmetic. This
3985 // function will save each of them.
3986 SmallVector<Register, 8> RegsToVisit;
3987 SmallVector<const MachineInstr *, 7> Ors = {Root};
3988
3989 // In the "worst" case, we're dealing with a load for each byte. So, there
3990 // are at most #bytes - 1 ORs.
3991 const unsigned MaxIter =
3992 MRI.getType(Reg: Root->getOperand(i: 0).getReg()).getSizeInBytes() - 1;
3993 for (unsigned Iter = 0; Iter < MaxIter; ++Iter) {
3994 if (Ors.empty())
3995 break;
3996 const MachineInstr *Curr = Ors.pop_back_val();
3997 Register OrLHS = Curr->getOperand(i: 1).getReg();
3998 Register OrRHS = Curr->getOperand(i: 2).getReg();
3999
4000 // In the combine, we want to elimate the entire tree.
4001 if (!MRI.hasOneNonDBGUse(RegNo: OrLHS) || !MRI.hasOneNonDBGUse(RegNo: OrRHS))
4002 return std::nullopt;
4003
4004 // If it's a G_OR, save it and continue to walk. If it's not, then it's
4005 // something that may be a load + arithmetic.
4006 if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrLHS, MRI))
4007 Ors.push_back(Elt: Or);
4008 else
4009 RegsToVisit.push_back(Elt: OrLHS);
4010 if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrRHS, MRI))
4011 Ors.push_back(Elt: Or);
4012 else
4013 RegsToVisit.push_back(Elt: OrRHS);
4014 }
4015
4016 // We're going to try and merge each register into a wider power-of-2 type,
4017 // so we ought to have an even number of registers.
4018 if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0)
4019 return std::nullopt;
4020 return RegsToVisit;
4021}
4022
4023/// Helper function for findLoadOffsetsForLoadOrCombine.
4024///
4025/// Check if \p Reg is the result of loading a \p MemSizeInBits wide value,
4026/// and then moving that value into a specific byte offset.
4027///
4028/// e.g. x[i] << 24
4029///
4030/// \returns The load instruction and the byte offset it is moved into.
4031static std::optional<std::pair<GZExtLoad *, int64_t>>
4032matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits,
4033 const MachineRegisterInfo &MRI) {
4034 assert(MRI.hasOneNonDBGUse(Reg) &&
4035 "Expected Reg to only have one non-debug use?");
4036 Register MaybeLoad;
4037 int64_t Shift;
4038 if (!mi_match(R: Reg, MRI,
4039 P: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: MaybeLoad), R: m_ICst(Cst&: Shift))))) {
4040 Shift = 0;
4041 MaybeLoad = Reg;
4042 }
4043
4044 if (Shift % MemSizeInBits != 0)
4045 return std::nullopt;
4046
4047 // TODO: Handle other types of loads.
4048 auto *Load = getOpcodeDef<GZExtLoad>(Reg: MaybeLoad, MRI);
4049 if (!Load)
4050 return std::nullopt;
4051
4052 if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits)
4053 return std::nullopt;
4054
4055 return std::make_pair(x&: Load, y: Shift / MemSizeInBits);
4056}
4057
4058std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>>
4059CombinerHelper::findLoadOffsetsForLoadOrCombine(
4060 SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
4061 const SmallVector<Register, 8> &RegsToVisit,
4062 const unsigned MemSizeInBits) const {
4063
4064 // Each load found for the pattern. There should be one for each RegsToVisit.
4065 SmallSetVector<const MachineInstr *, 8> Loads;
4066
4067 // The lowest index used in any load. (The lowest "i" for each x[i].)
4068 int64_t LowestIdx = INT64_MAX;
4069
4070 // The load which uses the lowest index.
4071 GZExtLoad *LowestIdxLoad = nullptr;
4072
4073 // Keeps track of the load indices we see. We shouldn't see any indices twice.
4074 SmallSet<int64_t, 8> SeenIdx;
4075
4076 // Ensure each load is in the same MBB.
4077 // TODO: Support multiple MachineBasicBlocks.
4078 MachineBasicBlock *MBB = nullptr;
4079 const MachineMemOperand *MMO = nullptr;
4080
4081 // Earliest instruction-order load in the pattern.
4082 GZExtLoad *EarliestLoad = nullptr;
4083
4084 // Latest instruction-order load in the pattern.
4085 GZExtLoad *LatestLoad = nullptr;
4086
4087 // Base pointer which every load should share.
4088 Register BasePtr;
4089
4090 // We want to find a load for each register. Each load should have some
4091 // appropriate bit twiddling arithmetic. During this loop, we will also keep
4092 // track of the load which uses the lowest index. Later, we will check if we
4093 // can use its pointer in the final, combined load.
4094 for (auto Reg : RegsToVisit) {
4095 // Find the load, and find the position that it will end up in (e.g. a
4096 // shifted) value.
4097 auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI);
4098 if (!LoadAndPos)
4099 return std::nullopt;
4100 GZExtLoad *Load;
4101 int64_t DstPos;
4102 std::tie(args&: Load, args&: DstPos) = *LoadAndPos;
4103
4104 // TODO: Handle multiple MachineBasicBlocks. Currently not handled because
4105 // it is difficult to check for stores/calls/etc between loads.
4106 MachineBasicBlock *LoadMBB = Load->getParent();
4107 if (!MBB)
4108 MBB = LoadMBB;
4109 if (LoadMBB != MBB)
4110 return std::nullopt;
4111
4112 // Make sure that the MachineMemOperands of every seen load are compatible.
4113 auto &LoadMMO = Load->getMMO();
4114 if (!MMO)
4115 MMO = &LoadMMO;
4116 if (MMO->getAddrSpace() != LoadMMO.getAddrSpace())
4117 return std::nullopt;
4118
4119 // Find out what the base pointer and index for the load is.
4120 Register LoadPtr;
4121 int64_t Idx;
4122 if (!mi_match(R: Load->getOperand(i: 1).getReg(), MRI,
4123 P: m_GPtrAdd(L: m_Reg(R&: LoadPtr), R: m_ICst(Cst&: Idx)))) {
4124 LoadPtr = Load->getOperand(i: 1).getReg();
4125 Idx = 0;
4126 }
4127
4128 // Don't combine things like a[i], a[i] -> a bigger load.
4129 if (!SeenIdx.insert(V: Idx).second)
4130 return std::nullopt;
4131
4132 // Every load must share the same base pointer; don't combine things like:
4133 //
4134 // a[i], b[i + 1] -> a bigger load.
4135 if (!BasePtr.isValid())
4136 BasePtr = LoadPtr;
4137 if (BasePtr != LoadPtr)
4138 return std::nullopt;
4139
4140 if (Idx < LowestIdx) {
4141 LowestIdx = Idx;
4142 LowestIdxLoad = Load;
4143 }
4144
4145 // Keep track of the byte offset that this load ends up at. If we have seen
4146 // the byte offset, then stop here. We do not want to combine:
4147 //
4148 // a[i] << 16, a[i + k] << 16 -> a bigger load.
4149 if (!MemOffset2Idx.try_emplace(Key: DstPos, Args&: Idx).second)
4150 return std::nullopt;
4151 Loads.insert(X: Load);
4152
4153 // Keep track of the position of the earliest/latest loads in the pattern.
4154 // We will check that there are no load fold barriers between them later
4155 // on.
4156 //
4157 // FIXME: Is there a better way to check for load fold barriers?
4158 if (!EarliestLoad || dominates(DefMI: *Load, UseMI: *EarliestLoad))
4159 EarliestLoad = Load;
4160 if (!LatestLoad || dominates(DefMI: *LatestLoad, UseMI: *Load))
4161 LatestLoad = Load;
4162 }
4163
4164 // We found a load for each register. Let's check if each load satisfies the
4165 // pattern.
4166 assert(Loads.size() == RegsToVisit.size() &&
4167 "Expected to find a load for each register?");
4168 assert(EarliestLoad != LatestLoad && EarliestLoad &&
4169 LatestLoad && "Expected at least two loads?");
4170
4171 // Check if there are any stores, calls, etc. between any of the loads. If
4172 // there are, then we can't safely perform the combine.
4173 //
4174 // MaxIter is chosen based off the (worst case) number of iterations it
4175 // typically takes to succeed in the LLVM test suite plus some padding.
4176 //
4177 // FIXME: Is there a better way to check for load fold barriers?
4178 const unsigned MaxIter = 20;
4179 unsigned Iter = 0;
4180 for (const auto &MI : instructionsWithoutDebug(It: EarliestLoad->getIterator(),
4181 End: LatestLoad->getIterator())) {
4182 if (Loads.count(key: &MI))
4183 continue;
4184 if (MI.isLoadFoldBarrier())
4185 return std::nullopt;
4186 if (Iter++ == MaxIter)
4187 return std::nullopt;
4188 }
4189
4190 return std::make_tuple(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad);
4191}
4192
4193bool CombinerHelper::matchLoadOrCombine(
4194 MachineInstr &MI,
4195 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4196 assert(MI.getOpcode() == TargetOpcode::G_OR);
4197 MachineFunction &MF = *MI.getMF();
4198 // Assuming a little-endian target, transform:
4199 // s8 *a = ...
4200 // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
4201 // =>
4202 // s32 val = *((i32)a)
4203 //
4204 // s8 *a = ...
4205 // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
4206 // =>
4207 // s32 val = BSWAP(*((s32)a))
4208 Register Dst = MI.getOperand(i: 0).getReg();
4209 LLT Ty = MRI.getType(Reg: Dst);
4210 if (Ty.isVector())
4211 return false;
4212
4213 // We need to combine at least two loads into this type. Since the smallest
4214 // possible load is into a byte, we need at least a 16-bit wide type.
4215 const unsigned WideMemSizeInBits = Ty.getSizeInBits();
4216 if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0)
4217 return false;
4218
4219 // Match a collection of non-OR instructions in the pattern.
4220 auto RegsToVisit = findCandidatesForLoadOrCombine(Root: &MI);
4221 if (!RegsToVisit)
4222 return false;
4223
4224 // We have a collection of non-OR instructions. Figure out how wide each of
4225 // the small loads should be based off of the number of potential loads we
4226 // found.
4227 const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size();
4228 if (NarrowMemSizeInBits % 8 != 0)
4229 return false;
4230
4231 // Check if each register feeding into each OR is a load from the same
4232 // base pointer + some arithmetic.
4233 //
4234 // e.g. a[0], a[1] << 8, a[2] << 16, etc.
4235 //
4236 // Also verify that each of these ends up putting a[i] into the same memory
4237 // offset as a load into a wide type would.
4238 SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx;
4239 GZExtLoad *LowestIdxLoad, *LatestLoad;
4240 int64_t LowestIdx;
4241 auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine(
4242 MemOffset2Idx, RegsToVisit: *RegsToVisit, MemSizeInBits: NarrowMemSizeInBits);
4243 if (!MaybeLoadInfo)
4244 return false;
4245 std::tie(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad) = *MaybeLoadInfo;
4246
4247 // We have a bunch of loads being OR'd together. Using the addresses + offsets
4248 // we found before, check if this corresponds to a big or little endian byte
4249 // pattern. If it does, then we can represent it using a load + possibly a
4250 // BSWAP.
4251 bool IsBigEndianTarget = MF.getDataLayout().isBigEndian();
4252 std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx);
4253 if (!IsBigEndian)
4254 return false;
4255 bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian;
4256 if (NeedsBSwap && !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_BSWAP, {Ty}}))
4257 return false;
4258
4259 // Make sure that the load from the lowest index produces offset 0 in the
4260 // final value.
4261 //
4262 // This ensures that we won't combine something like this:
4263 //
4264 // load x[i] -> byte 2
4265 // load x[i+1] -> byte 0 ---> wide_load x[i]
4266 // load x[i+2] -> byte 1
4267 const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits;
4268 const unsigned ZeroByteOffset =
4269 *IsBigEndian
4270 ? bigEndianByteAt(ByteWidth: NumLoadsInTy, I: 0)
4271 : littleEndianByteAt(ByteWidth: NumLoadsInTy, I: 0);
4272 auto ZeroOffsetIdx = MemOffset2Idx.find(Val: ZeroByteOffset);
4273 if (ZeroOffsetIdx == MemOffset2Idx.end() ||
4274 ZeroOffsetIdx->second != LowestIdx)
4275 return false;
4276
4277 // We wil reuse the pointer from the load which ends up at byte offset 0. It
4278 // may not use index 0.
4279 Register Ptr = LowestIdxLoad->getPointerReg();
4280 const MachineMemOperand &MMO = LowestIdxLoad->getMMO();
4281 LegalityQuery::MemDesc MMDesc(MMO);
4282 MMDesc.MemoryTy = Ty;
4283 if (!isLegalOrBeforeLegalizer(
4284 Query: {TargetOpcode::G_LOAD, {Ty, MRI.getType(Reg: Ptr)}, {MMDesc}}))
4285 return false;
4286 auto PtrInfo = MMO.getPointerInfo();
4287 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: WideMemSizeInBits / 8);
4288
4289 // Load must be allowed and fast on the target.
4290 LLVMContext &C = MF.getFunction().getContext();
4291 auto &DL = MF.getDataLayout();
4292 unsigned Fast = 0;
4293 if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty, MMO: *NewMMO, Fast: &Fast) ||
4294 !Fast)
4295 return false;
4296
4297 MatchInfo = [=](MachineIRBuilder &MIB) {
4298 MIB.setInstrAndDebugLoc(*LatestLoad);
4299 Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(VReg: Dst) : Dst;
4300 MIB.buildLoad(Res: LoadDst, Addr: Ptr, MMO&: *NewMMO);
4301 if (NeedsBSwap)
4302 MIB.buildBSwap(Dst, Src0: LoadDst);
4303 };
4304 return true;
4305}
4306
4307bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI,
4308 MachineInstr *&ExtMI) const {
4309 auto &PHI = cast<GPhi>(Val&: MI);
4310 Register DstReg = PHI.getReg(Idx: 0);
4311
4312 // TODO: Extending a vector may be expensive, don't do this until heuristics
4313 // are better.
4314 if (MRI.getType(Reg: DstReg).isVector())
4315 return false;
4316
4317 // Try to match a phi, whose only use is an extend.
4318 if (!MRI.hasOneNonDBGUse(RegNo: DstReg))
4319 return false;
4320 ExtMI = &*MRI.use_instr_nodbg_begin(RegNo: DstReg);
4321 switch (ExtMI->getOpcode()) {
4322 case TargetOpcode::G_ANYEXT:
4323 return true; // G_ANYEXT is usually free.
4324 case TargetOpcode::G_ZEXT:
4325 case TargetOpcode::G_SEXT:
4326 break;
4327 default:
4328 return false;
4329 }
4330
4331 // If the target is likely to fold this extend away, don't propagate.
4332 if (Builder.getTII().isExtendLikelyToBeFolded(ExtMI&: *ExtMI, MRI))
4333 return false;
4334
4335 // We don't want to propagate the extends unless there's a good chance that
4336 // they'll be optimized in some way.
4337 // Collect the unique incoming values.
4338 SmallPtrSet<MachineInstr *, 4> InSrcs;
4339 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
4340 auto *DefMI = getDefIgnoringCopies(Reg: PHI.getIncomingValue(I), MRI);
4341 switch (DefMI->getOpcode()) {
4342 case TargetOpcode::G_LOAD:
4343 case TargetOpcode::G_TRUNC:
4344 case TargetOpcode::G_SEXT:
4345 case TargetOpcode::G_ZEXT:
4346 case TargetOpcode::G_ANYEXT:
4347 case TargetOpcode::G_CONSTANT:
4348 InSrcs.insert(Ptr: DefMI);
4349 // Don't try to propagate if there are too many places to create new
4350 // extends, chances are it'll increase code size.
4351 if (InSrcs.size() > 2)
4352 return false;
4353 break;
4354 default:
4355 return false;
4356 }
4357 }
4358 return true;
4359}
4360
4361void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI,
4362 MachineInstr *&ExtMI) const {
4363 auto &PHI = cast<GPhi>(Val&: MI);
4364 Register DstReg = ExtMI->getOperand(i: 0).getReg();
4365 LLT ExtTy = MRI.getType(Reg: DstReg);
4366
4367 // Propagate the extension into the block of each incoming reg's block.
4368 // Use a SetVector here because PHIs can have duplicate edges, and we want
4369 // deterministic iteration order.
4370 SmallSetVector<MachineInstr *, 8> SrcMIs;
4371 SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap;
4372 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
4373 auto SrcReg = PHI.getIncomingValue(I);
4374 auto *SrcMI = MRI.getVRegDef(Reg: SrcReg);
4375 if (!SrcMIs.insert(X: SrcMI))
4376 continue;
4377
4378 // Build an extend after each src inst.
4379 auto *MBB = SrcMI->getParent();
4380 MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator();
4381 if (InsertPt != MBB->end() && InsertPt->isPHI())
4382 InsertPt = MBB->getFirstNonPHI();
4383
4384 Builder.setInsertPt(MBB&: *SrcMI->getParent(), II: InsertPt);
4385 Builder.setDebugLoc(MI.getDebugLoc());
4386 auto NewExt = Builder.buildExtOrTrunc(ExtOpc: ExtMI->getOpcode(), Res: ExtTy, Op: SrcReg);
4387 OldToNewSrcMap[SrcMI] = NewExt;
4388 }
4389
4390 // Create a new phi with the extended inputs.
4391 Builder.setInstrAndDebugLoc(MI);
4392 auto NewPhi = Builder.buildInstrNoInsert(Opcode: TargetOpcode::G_PHI);
4393 NewPhi.addDef(RegNo: DstReg);
4394 for (const MachineOperand &MO : llvm::drop_begin(RangeOrContainer: MI.operands())) {
4395 if (!MO.isReg()) {
4396 NewPhi.addMBB(MBB: MO.getMBB());
4397 continue;
4398 }
4399 auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(Reg: MO.getReg())];
4400 NewPhi.addUse(RegNo: NewSrc->getOperand(i: 0).getReg());
4401 }
4402 Builder.insertInstr(MIB: NewPhi);
4403 ExtMI->eraseFromParent();
4404}
4405
4406bool CombinerHelper::matchExtractVecEltBuildVec(MachineInstr &MI,
4407 Register &Reg) const {
4408 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
4409 // If we have a constant index, look for a G_BUILD_VECTOR source
4410 // and find the source register that the index maps to.
4411 Register SrcVec = MI.getOperand(i: 1).getReg();
4412 LLT SrcTy = MRI.getType(Reg: SrcVec);
4413 if (SrcTy.isScalableVector())
4414 return false;
4415
4416 auto Cst = getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
4417 if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements())
4418 return false;
4419
4420 unsigned VecIdx = Cst->Value.getZExtValue();
4421
4422 // Check if we have a build_vector or build_vector_trunc with an optional
4423 // trunc in front.
4424 MachineInstr *SrcVecMI = MRI.getVRegDef(Reg: SrcVec);
4425 if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) {
4426 SrcVecMI = MRI.getVRegDef(Reg: SrcVecMI->getOperand(i: 1).getReg());
4427 }
4428
4429 if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR &&
4430 SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC)
4431 return false;
4432
4433 EVT Ty(getMVTForLLT(Ty: SrcTy));
4434 if (!MRI.hasOneNonDBGUse(RegNo: SrcVec) &&
4435 !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty))
4436 return false;
4437
4438 Reg = SrcVecMI->getOperand(i: VecIdx + 1).getReg();
4439 return true;
4440}
4441
4442void CombinerHelper::applyExtractVecEltBuildVec(MachineInstr &MI,
4443 Register &Reg) const {
4444 // Check the type of the register, since it may have come from a
4445 // G_BUILD_VECTOR_TRUNC.
4446 LLT ScalarTy = MRI.getType(Reg);
4447 Register DstReg = MI.getOperand(i: 0).getReg();
4448 LLT DstTy = MRI.getType(Reg: DstReg);
4449
4450 if (ScalarTy != DstTy) {
4451 assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits());
4452 Builder.buildTrunc(Res: DstReg, Op: Reg);
4453 MI.eraseFromParent();
4454 return;
4455 }
4456 replaceSingleDefInstWithReg(MI, Replacement: Reg);
4457}
4458
4459bool CombinerHelper::matchExtractAllEltsFromBuildVector(
4460 MachineInstr &MI,
4461 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const {
4462 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4463 // This combine tries to find build_vector's which have every source element
4464 // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like
4465 // the masked load scalarization is run late in the pipeline. There's already
4466 // a combine for a similar pattern starting from the extract, but that
4467 // doesn't attempt to do it if there are multiple uses of the build_vector,
4468 // which in this case is true. Starting the combine from the build_vector
4469 // feels more natural than trying to find sibling nodes of extracts.
4470 // E.g.
4471 // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4
4472 // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0
4473 // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1
4474 // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2
4475 // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3
4476 // ==>
4477 // replace ext{1,2,3,4} with %s{1,2,3,4}
4478
4479 Register DstReg = MI.getOperand(i: 0).getReg();
4480 LLT DstTy = MRI.getType(Reg: DstReg);
4481 unsigned NumElts = DstTy.getNumElements();
4482
4483 SmallBitVector ExtractedElts(NumElts);
4484 for (MachineInstr &II : MRI.use_nodbg_instructions(Reg: DstReg)) {
4485 if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT)
4486 return false;
4487 auto Cst = getIConstantVRegVal(VReg: II.getOperand(i: 2).getReg(), MRI);
4488 if (!Cst)
4489 return false;
4490 unsigned Idx = Cst->getZExtValue();
4491 if (Idx >= NumElts)
4492 return false; // Out of range.
4493 ExtractedElts.set(Idx);
4494 SrcDstPairs.emplace_back(
4495 Args: std::make_pair(x: MI.getOperand(i: Idx + 1).getReg(), y: &II));
4496 }
4497 // Match if every element was extracted.
4498 return ExtractedElts.all();
4499}
4500
4501void CombinerHelper::applyExtractAllEltsFromBuildVector(
4502 MachineInstr &MI,
4503 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const {
4504 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4505 for (auto &Pair : SrcDstPairs) {
4506 auto *ExtMI = Pair.second;
4507 replaceRegWith(MRI, FromReg: ExtMI->getOperand(i: 0).getReg(), ToReg: Pair.first);
4508 ExtMI->eraseFromParent();
4509 }
4510 MI.eraseFromParent();
4511}
4512
4513void CombinerHelper::applyBuildFn(
4514 MachineInstr &MI,
4515 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4516 applyBuildFnNoErase(MI, MatchInfo);
4517 MI.eraseFromParent();
4518}
4519
4520void CombinerHelper::applyBuildFnNoErase(
4521 MachineInstr &MI,
4522 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4523 MatchInfo(Builder);
4524}
4525
4526bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
4527 bool AllowScalarConstants,
4528 BuildFnTy &MatchInfo) const {
4529 assert(MI.getOpcode() == TargetOpcode::G_OR);
4530
4531 Register Dst = MI.getOperand(i: 0).getReg();
4532 LLT Ty = MRI.getType(Reg: Dst);
4533 unsigned BitWidth = Ty.getScalarSizeInBits();
4534
4535 Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt;
4536 unsigned FshOpc = 0;
4537
4538 // Match (or (shl ...), (lshr ...)).
4539 if (!mi_match(R: Dst, MRI,
4540 // m_GOr() handles the commuted version as well.
4541 P: m_GOr(L: m_GShl(L: m_Reg(R&: ShlSrc), R: m_Reg(R&: ShlAmt)),
4542 R: m_GLShr(L: m_Reg(R&: LShrSrc), R: m_Reg(R&: LShrAmt)))))
4543 return false;
4544
4545 // Given constants C0 and C1 such that C0 + C1 is bit-width:
4546 // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1)
4547 int64_t CstShlAmt = 0, CstLShrAmt;
4548 if (mi_match(R: ShlAmt, MRI, P: m_ICstOrSplat(Cst&: CstShlAmt)) &&
4549 mi_match(R: LShrAmt, MRI, P: m_ICstOrSplat(Cst&: CstLShrAmt)) &&
4550 CstShlAmt + CstLShrAmt == BitWidth) {
4551 FshOpc = TargetOpcode::G_FSHR;
4552 Amt = LShrAmt;
4553 } else if (mi_match(R: LShrAmt, MRI,
4554 P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) &&
4555 ShlAmt == Amt) {
4556 // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt)
4557 FshOpc = TargetOpcode::G_FSHL;
4558 } else if (mi_match(R: ShlAmt, MRI,
4559 P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) &&
4560 LShrAmt == Amt) {
4561 // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt)
4562 FshOpc = TargetOpcode::G_FSHR;
4563 } else {
4564 return false;
4565 }
4566
4567 LLT AmtTy = MRI.getType(Reg: Amt);
4568 if (!isLegalOrBeforeLegalizer(Query: {FshOpc, {Ty, AmtTy}}) &&
4569 (!AllowScalarConstants || CstShlAmt == 0 || !Ty.isScalar()))
4570 return false;
4571
4572 MatchInfo = [=](MachineIRBuilder &B) {
4573 B.buildInstr(Opc: FshOpc, DstOps: {Dst}, SrcOps: {ShlSrc, LShrSrc, Amt});
4574 };
4575 return true;
4576}
4577
4578/// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate.
4579bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) const {
4580 unsigned Opc = MI.getOpcode();
4581 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4582 Register X = MI.getOperand(i: 1).getReg();
4583 Register Y = MI.getOperand(i: 2).getReg();
4584 if (X != Y)
4585 return false;
4586 unsigned RotateOpc =
4587 Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR;
4588 return isLegalOrBeforeLegalizer(Query: {RotateOpc, {MRI.getType(Reg: X), MRI.getType(Reg: Y)}});
4589}
4590
4591void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) const {
4592 unsigned Opc = MI.getOpcode();
4593 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4594 bool IsFSHL = Opc == TargetOpcode::G_FSHL;
4595 Observer.changingInstr(MI);
4596 MI.setDesc(Builder.getTII().get(Opcode: IsFSHL ? TargetOpcode::G_ROTL
4597 : TargetOpcode::G_ROTR));
4598 MI.removeOperand(OpNo: 2);
4599 Observer.changedInstr(MI);
4600}
4601
4602// Fold (rot x, c) -> (rot x, c % BitSize)
4603bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) const {
4604 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4605 MI.getOpcode() == TargetOpcode::G_ROTR);
4606 unsigned Bitsize =
4607 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits();
4608 Register AmtReg = MI.getOperand(i: 2).getReg();
4609 bool OutOfRange = false;
4610 auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) {
4611 if (auto *CI = dyn_cast<ConstantInt>(Val: C))
4612 OutOfRange |= CI->getValue().uge(RHS: Bitsize);
4613 return true;
4614 };
4615 return matchUnaryPredicate(MRI, Reg: AmtReg, Match: MatchOutOfRange) && OutOfRange;
4616}
4617
4618void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) const {
4619 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4620 MI.getOpcode() == TargetOpcode::G_ROTR);
4621 unsigned Bitsize =
4622 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits();
4623 Register Amt = MI.getOperand(i: 2).getReg();
4624 LLT AmtTy = MRI.getType(Reg: Amt);
4625 auto Bits = Builder.buildConstant(Res: AmtTy, Val: Bitsize);
4626 Amt = Builder.buildURem(Dst: AmtTy, Src0: MI.getOperand(i: 2).getReg(), Src1: Bits).getReg(Idx: 0);
4627 Observer.changingInstr(MI);
4628 MI.getOperand(i: 2).setReg(Amt);
4629 Observer.changedInstr(MI);
4630}
4631
4632bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI,
4633 int64_t &MatchInfo) const {
4634 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4635 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
4636
4637 // We want to avoid calling KnownBits on the LHS if possible, as this combine
4638 // has no filter and runs on every G_ICMP instruction. We can avoid calling
4639 // KnownBits on the LHS in two cases:
4640 //
4641 // - The RHS is unknown: Constants are always on RHS. If the RHS is unknown
4642 // we cannot do any transforms so we can safely bail out early.
4643 // - The RHS is zero: we don't need to know the LHS to do unsigned <0 and
4644 // >=0.
4645 auto KnownRHS = VT->getKnownBits(R: MI.getOperand(i: 3).getReg());
4646 if (KnownRHS.isUnknown())
4647 return false;
4648
4649 std::optional<bool> KnownVal;
4650 if (KnownRHS.isZero()) {
4651 // ? uge 0 -> always true
4652 // ? ult 0 -> always false
4653 if (Pred == CmpInst::ICMP_UGE)
4654 KnownVal = true;
4655 else if (Pred == CmpInst::ICMP_ULT)
4656 KnownVal = false;
4657 }
4658
4659 if (!KnownVal) {
4660 auto KnownLHS = VT->getKnownBits(R: MI.getOperand(i: 2).getReg());
4661 KnownVal = ICmpInst::compare(LHS: KnownLHS, RHS: KnownRHS, Pred);
4662 }
4663
4664 if (!KnownVal)
4665 return false;
4666 MatchInfo =
4667 *KnownVal
4668 ? getICmpTrueVal(TLI: getTargetLowering(),
4669 /*IsVector = */
4670 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector(),
4671 /* IsFP = */ false)
4672 : 0;
4673 return true;
4674}
4675
4676bool CombinerHelper::matchICmpToLHSKnownBits(
4677 MachineInstr &MI,
4678 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4679 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4680 // Given:
4681 //
4682 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4683 // %cmp = G_ICMP ne %x, 0
4684 //
4685 // Or:
4686 //
4687 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4688 // %cmp = G_ICMP eq %x, 1
4689 //
4690 // We can replace %cmp with %x assuming true is 1 on the target.
4691 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
4692 if (!CmpInst::isEquality(pred: Pred))
4693 return false;
4694 Register Dst = MI.getOperand(i: 0).getReg();
4695 LLT DstTy = MRI.getType(Reg: Dst);
4696 if (getICmpTrueVal(TLI: getTargetLowering(), IsVector: DstTy.isVector(),
4697 /* IsFP = */ false) != 1)
4698 return false;
4699 int64_t OneOrZero = Pred == CmpInst::ICMP_EQ;
4700 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICst(RequestedValue: OneOrZero)))
4701 return false;
4702 Register LHS = MI.getOperand(i: 2).getReg();
4703 auto KnownLHS = VT->getKnownBits(R: LHS);
4704 if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1)
4705 return false;
4706 // Make sure replacing Dst with the LHS is a legal operation.
4707 LLT LHSTy = MRI.getType(Reg: LHS);
4708 unsigned LHSSize = LHSTy.getSizeInBits();
4709 unsigned DstSize = DstTy.getSizeInBits();
4710 unsigned Op = TargetOpcode::COPY;
4711 if (DstSize != LHSSize)
4712 Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT;
4713 if (!isLegalOrBeforeLegalizer(Query: {Op, {DstTy, LHSTy}}))
4714 return false;
4715 MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Opc: Op, DstOps: {Dst}, SrcOps: {LHS}); };
4716 return true;
4717}
4718
4719// Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0
4720bool CombinerHelper::matchAndOrDisjointMask(
4721 MachineInstr &MI,
4722 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4723 assert(MI.getOpcode() == TargetOpcode::G_AND);
4724
4725 // Ignore vector types to simplify matching the two constants.
4726 // TODO: do this for vectors and scalars via a demanded bits analysis.
4727 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
4728 if (Ty.isVector())
4729 return false;
4730
4731 Register Src;
4732 Register AndMaskReg;
4733 int64_t AndMaskBits;
4734 int64_t OrMaskBits;
4735 if (!mi_match(MI, MRI,
4736 P: m_GAnd(L: m_GOr(L: m_Reg(R&: Src), R: m_ICst(Cst&: OrMaskBits)),
4737 R: m_all_of(preds: m_ICst(Cst&: AndMaskBits), preds: m_Reg(R&: AndMaskReg)))))
4738 return false;
4739
4740 // Check if OrMask could turn on any bits in Src.
4741 if (AndMaskBits & OrMaskBits)
4742 return false;
4743
4744 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4745 Observer.changingInstr(MI);
4746 // Canonicalize the result to have the constant on the RHS.
4747 if (MI.getOperand(i: 1).getReg() == AndMaskReg)
4748 MI.getOperand(i: 2).setReg(AndMaskReg);
4749 MI.getOperand(i: 1).setReg(Src);
4750 Observer.changedInstr(MI);
4751 };
4752 return true;
4753}
4754
4755/// Form a G_SBFX from a G_SEXT_INREG fed by a right shift.
4756bool CombinerHelper::matchBitfieldExtractFromSExtInReg(
4757 MachineInstr &MI,
4758 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4759 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
4760 Register Dst = MI.getOperand(i: 0).getReg();
4761 Register Src = MI.getOperand(i: 1).getReg();
4762 LLT Ty = MRI.getType(Reg: Src);
4763 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4764 if (!LI || !LI->isLegalOrCustom(Query: {TargetOpcode::G_SBFX, {Ty, ExtractTy}}))
4765 return false;
4766 int64_t Width = MI.getOperand(i: 2).getImm();
4767 Register ShiftSrc;
4768 int64_t ShiftImm;
4769 if (!mi_match(
4770 R: Src, MRI,
4771 P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)),
4772 preds: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm))))))
4773 return false;
4774 if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits())
4775 return false;
4776
4777 MatchInfo = [=](MachineIRBuilder &B) {
4778 auto Cst1 = B.buildConstant(Res: ExtractTy, Val: ShiftImm);
4779 auto Cst2 = B.buildConstant(Res: ExtractTy, Val: Width);
4780 B.buildSbfx(Dst, Src: ShiftSrc, LSB: Cst1, Width: Cst2);
4781 };
4782 return true;
4783}
4784
4785/// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants.
4786bool CombinerHelper::matchBitfieldExtractFromAnd(MachineInstr &MI,
4787 BuildFnTy &MatchInfo) const {
4788 GAnd *And = cast<GAnd>(Val: &MI);
4789 Register Dst = And->getReg(Idx: 0);
4790 LLT Ty = MRI.getType(Reg: Dst);
4791 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4792 // Note that isLegalOrBeforeLegalizer is stricter and does not take custom
4793 // into account.
4794 if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4795 return false;
4796
4797 int64_t AndImm, LSBImm;
4798 Register ShiftSrc;
4799 const unsigned Size = Ty.getScalarSizeInBits();
4800 if (!mi_match(R: And->getReg(Idx: 0), MRI,
4801 P: m_GAnd(L: m_OneNonDBGUse(SP: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: LSBImm))),
4802 R: m_ICst(Cst&: AndImm))))
4803 return false;
4804
4805 // The mask is a mask of the low bits iff imm & (imm+1) == 0.
4806 auto MaybeMask = static_cast<uint64_t>(AndImm);
4807 if (MaybeMask & (MaybeMask + 1))
4808 return false;
4809
4810 // LSB must fit within the register.
4811 if (static_cast<uint64_t>(LSBImm) >= Size)
4812 return false;
4813
4814 uint64_t Width = APInt(Size, AndImm).countr_one();
4815 MatchInfo = [=](MachineIRBuilder &B) {
4816 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4817 auto LSBCst = B.buildConstant(Res: ExtractTy, Val: LSBImm);
4818 B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {ShiftSrc, LSBCst, WidthCst});
4819 };
4820 return true;
4821}
4822
4823bool CombinerHelper::matchBitfieldExtractFromShr(
4824 MachineInstr &MI,
4825 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4826 const unsigned Opcode = MI.getOpcode();
4827 assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR);
4828
4829 const Register Dst = MI.getOperand(i: 0).getReg();
4830
4831 const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR
4832 ? TargetOpcode::G_SBFX
4833 : TargetOpcode::G_UBFX;
4834
4835 // Check if the type we would use for the extract is legal
4836 LLT Ty = MRI.getType(Reg: Dst);
4837 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4838 if (!LI || !LI->isLegalOrCustom(Query: {ExtrOpcode, {Ty, ExtractTy}}))
4839 return false;
4840
4841 Register ShlSrc;
4842 int64_t ShrAmt;
4843 int64_t ShlAmt;
4844 const unsigned Size = Ty.getScalarSizeInBits();
4845
4846 // Try to match shr (shl x, c1), c2
4847 if (!mi_match(R: Dst, MRI,
4848 P: m_BinOp(Opcode,
4849 L: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: ShlSrc), R: m_ICst(Cst&: ShlAmt))),
4850 R: m_ICst(Cst&: ShrAmt))))
4851 return false;
4852
4853 // Make sure that the shift sizes can fit a bitfield extract
4854 if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size)
4855 return false;
4856
4857 // Skip this combine if the G_SEXT_INREG combine could handle it
4858 if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt)
4859 return false;
4860
4861 // Calculate start position and width of the extract
4862 const int64_t Pos = ShrAmt - ShlAmt;
4863 const int64_t Width = Size - ShrAmt;
4864
4865 MatchInfo = [=](MachineIRBuilder &B) {
4866 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4867 auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos);
4868 B.buildInstr(Opc: ExtrOpcode, DstOps: {Dst}, SrcOps: {ShlSrc, PosCst, WidthCst});
4869 };
4870 return true;
4871}
4872
4873bool CombinerHelper::matchBitfieldExtractFromShrAnd(
4874 MachineInstr &MI,
4875 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4876 const unsigned Opcode = MI.getOpcode();
4877 assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR);
4878
4879 const Register Dst = MI.getOperand(i: 0).getReg();
4880 LLT Ty = MRI.getType(Reg: Dst);
4881 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4882 if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4883 return false;
4884
4885 // Try to match shr (and x, c1), c2
4886 Register AndSrc;
4887 int64_t ShrAmt;
4888 int64_t SMask;
4889 if (!mi_match(R: Dst, MRI,
4890 P: m_BinOp(Opcode,
4891 L: m_OneNonDBGUse(SP: m_GAnd(L: m_Reg(R&: AndSrc), R: m_ICst(Cst&: SMask))),
4892 R: m_ICst(Cst&: ShrAmt))))
4893 return false;
4894
4895 const unsigned Size = Ty.getScalarSizeInBits();
4896 if (ShrAmt < 0 || ShrAmt >= Size)
4897 return false;
4898
4899 // If the shift subsumes the mask, emit the 0 directly.
4900 if (0 == (SMask >> ShrAmt)) {
4901 MatchInfo = [=](MachineIRBuilder &B) {
4902 B.buildConstant(Res: Dst, Val: 0);
4903 };
4904 return true;
4905 }
4906
4907 // Check that ubfx can do the extraction, with no holes in the mask.
4908 uint64_t UMask = SMask;
4909 UMask |= maskTrailingOnes<uint64_t>(N: ShrAmt);
4910 UMask &= maskTrailingOnes<uint64_t>(N: Size);
4911 if (!isMask_64(Value: UMask))
4912 return false;
4913
4914 // Calculate start position and width of the extract.
4915 const int64_t Pos = ShrAmt;
4916 const int64_t Width = llvm::countr_one(Value: UMask) - ShrAmt;
4917
4918 // It's preferable to keep the shift, rather than form G_SBFX.
4919 // TODO: remove the G_AND via demanded bits analysis.
4920 if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size)
4921 return false;
4922
4923 MatchInfo = [=](MachineIRBuilder &B) {
4924 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4925 auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos);
4926 B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {AndSrc, PosCst, WidthCst});
4927 };
4928 return true;
4929}
4930
4931bool CombinerHelper::reassociationCanBreakAddressingModePattern(
4932 MachineInstr &MI) const {
4933 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
4934
4935 Register Src1Reg = PtrAdd.getBaseReg();
4936 auto *Src1Def = getOpcodeDef<GPtrAdd>(Reg: Src1Reg, MRI);
4937 if (!Src1Def)
4938 return false;
4939
4940 Register Src2Reg = PtrAdd.getOffsetReg();
4941
4942 if (MRI.hasOneNonDBGUse(RegNo: Src1Reg))
4943 return false;
4944
4945 auto C1 = getIConstantVRegVal(VReg: Src1Def->getOffsetReg(), MRI);
4946 if (!C1)
4947 return false;
4948 auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI);
4949 if (!C2)
4950 return false;
4951
4952 const APInt &C1APIntVal = *C1;
4953 const APInt &C2APIntVal = *C2;
4954 const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue();
4955
4956 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: PtrAdd.getReg(Idx: 0))) {
4957 // This combine may end up running before ptrtoint/inttoptr combines
4958 // manage to eliminate redundant conversions, so try to look through them.
4959 MachineInstr *ConvUseMI = &UseMI;
4960 unsigned ConvUseOpc = ConvUseMI->getOpcode();
4961 while (ConvUseOpc == TargetOpcode::G_INTTOPTR ||
4962 ConvUseOpc == TargetOpcode::G_PTRTOINT) {
4963 Register DefReg = ConvUseMI->getOperand(i: 0).getReg();
4964 if (!MRI.hasOneNonDBGUse(RegNo: DefReg))
4965 break;
4966 ConvUseMI = &*MRI.use_instr_nodbg_begin(RegNo: DefReg);
4967 ConvUseOpc = ConvUseMI->getOpcode();
4968 }
4969 auto *LdStMI = dyn_cast<GLoadStore>(Val: ConvUseMI);
4970 if (!LdStMI)
4971 continue;
4972 // Is x[offset2] already not a legal addressing mode? If so then
4973 // reassociating the constants breaks nothing (we test offset2 because
4974 // that's the one we hope to fold into the load or store).
4975 TargetLoweringBase::AddrMode AM;
4976 AM.HasBaseReg = true;
4977 AM.BaseOffs = C2APIntVal.getSExtValue();
4978 unsigned AS = MRI.getType(Reg: LdStMI->getPointerReg()).getAddressSpace();
4979 Type *AccessTy = getTypeForLLT(Ty: LdStMI->getMMO().getMemoryType(),
4980 C&: PtrAdd.getMF()->getFunction().getContext());
4981 const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering();
4982 if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM,
4983 Ty: AccessTy, AddrSpace: AS))
4984 continue;
4985
4986 // Would x[offset1+offset2] still be a legal addressing mode?
4987 AM.BaseOffs = CombinedValue;
4988 if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM,
4989 Ty: AccessTy, AddrSpace: AS))
4990 return true;
4991 }
4992
4993 return false;
4994}
4995
4996bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI,
4997 MachineInstr *RHS,
4998 BuildFnTy &MatchInfo) const {
4999 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
5000 Register Src1Reg = MI.getOperand(i: 1).getReg();
5001 if (RHS->getOpcode() != TargetOpcode::G_ADD)
5002 return false;
5003 auto C2 = getIConstantVRegVal(VReg: RHS->getOperand(i: 2).getReg(), MRI);
5004 if (!C2)
5005 return false;
5006
5007 // If both additions are nuw, the reassociated additions are also nuw.
5008 // If the original G_PTR_ADD is additionally nusw, X and C are both not
5009 // negative, so BASE+X is between BASE and BASE+(X+C). The new G_PTR_ADDs are
5010 // therefore also nusw.
5011 // If the original G_PTR_ADD is additionally inbounds (which implies nusw),
5012 // the new G_PTR_ADDs are then also inbounds.
5013 unsigned PtrAddFlags = MI.getFlags();
5014 unsigned AddFlags = RHS->getFlags();
5015 bool IsNoUWrap = PtrAddFlags & AddFlags & MachineInstr::MIFlag::NoUWrap;
5016 bool IsNoUSWrap = IsNoUWrap && (PtrAddFlags & MachineInstr::MIFlag::NoUSWrap);
5017 bool IsInBounds = IsNoUWrap && (PtrAddFlags & MachineInstr::MIFlag::InBounds);
5018 unsigned Flags = 0;
5019 if (IsNoUWrap)
5020 Flags |= MachineInstr::MIFlag::NoUWrap;
5021 if (IsNoUSWrap)
5022 Flags |= MachineInstr::MIFlag::NoUSWrap;
5023 if (IsInBounds)
5024 Flags |= MachineInstr::MIFlag::InBounds;
5025
5026 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5027 LLT PtrTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5028
5029 auto NewBase =
5030 Builder.buildPtrAdd(Res: PtrTy, Op0: Src1Reg, Op1: RHS->getOperand(i: 1).getReg(), Flags);
5031 Observer.changingInstr(MI);
5032 MI.getOperand(i: 1).setReg(NewBase.getReg(Idx: 0));
5033 MI.getOperand(i: 2).setReg(RHS->getOperand(i: 2).getReg());
5034 MI.setFlags(Flags);
5035 Observer.changedInstr(MI);
5036 };
5037 return !reassociationCanBreakAddressingModePattern(MI);
5038}
5039
5040bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI,
5041 MachineInstr *LHS,
5042 MachineInstr *RHS,
5043 BuildFnTy &MatchInfo) const {
5044 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C)
5045 // if and only if (G_PTR_ADD X, C) has one use.
5046 Register LHSBase;
5047 std::optional<ValueAndVReg> LHSCstOff;
5048 if (!mi_match(R: MI.getBaseReg(), MRI,
5049 P: m_OneNonDBGUse(SP: m_GPtrAdd(L: m_Reg(R&: LHSBase), R: m_GCst(ValReg&: LHSCstOff)))))
5050 return false;
5051
5052 auto *LHSPtrAdd = cast<GPtrAdd>(Val: LHS);
5053
5054 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
5055 // nuw and inbounds (which implies nusw), the offsets are both non-negative,
5056 // so the new G_PTR_ADDs are also inbounds.
5057 unsigned PtrAddFlags = MI.getFlags();
5058 unsigned LHSPtrAddFlags = LHSPtrAdd->getFlags();
5059 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
5060 bool IsNoUSWrap = IsNoUWrap && (PtrAddFlags & LHSPtrAddFlags &
5061 MachineInstr::MIFlag::NoUSWrap);
5062 bool IsInBounds = IsNoUWrap && (PtrAddFlags & LHSPtrAddFlags &
5063 MachineInstr::MIFlag::InBounds);
5064 unsigned Flags = 0;
5065 if (IsNoUWrap)
5066 Flags |= MachineInstr::MIFlag::NoUWrap;
5067 if (IsNoUSWrap)
5068 Flags |= MachineInstr::MIFlag::NoUSWrap;
5069 if (IsInBounds)
5070 Flags |= MachineInstr::MIFlag::InBounds;
5071
5072 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5073 // When we change LHSPtrAdd's offset register we might cause it to use a reg
5074 // before its def. Sink the instruction so the outer PTR_ADD to ensure this
5075 // doesn't happen.
5076 LHSPtrAdd->moveBefore(MovePos: &MI);
5077 Register RHSReg = MI.getOffsetReg();
5078 // set VReg will cause type mismatch if it comes from extend/trunc
5079 auto NewCst = B.buildConstant(Res: MRI.getType(Reg: RHSReg), Val: LHSCstOff->Value);
5080 Observer.changingInstr(MI);
5081 MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0));
5082 MI.setFlags(Flags);
5083 Observer.changedInstr(MI);
5084 Observer.changingInstr(MI&: *LHSPtrAdd);
5085 LHSPtrAdd->getOperand(i: 2).setReg(RHSReg);
5086 LHSPtrAdd->setFlags(Flags);
5087 Observer.changedInstr(MI&: *LHSPtrAdd);
5088 };
5089 return !reassociationCanBreakAddressingModePattern(MI);
5090}
5091
5092bool CombinerHelper::matchReassocFoldConstantsInSubTree(
5093 GPtrAdd &MI, MachineInstr *LHS, MachineInstr *RHS,
5094 BuildFnTy &MatchInfo) const {
5095 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
5096 auto *LHSPtrAdd = dyn_cast<GPtrAdd>(Val: LHS);
5097 if (!LHSPtrAdd)
5098 return false;
5099
5100 Register Src2Reg = MI.getOperand(i: 2).getReg();
5101 Register LHSSrc1 = LHSPtrAdd->getBaseReg();
5102 Register LHSSrc2 = LHSPtrAdd->getOffsetReg();
5103 auto C1 = getIConstantVRegVal(VReg: LHSSrc2, MRI);
5104 if (!C1)
5105 return false;
5106 auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI);
5107 if (!C2)
5108 return false;
5109
5110 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
5111 // inbounds, reaching the same result in one G_PTR_ADD is also inbounds.
5112 // The nusw constraints are satisfied because imm1+imm2 cannot exceed the
5113 // largest signed integer that fits into the index type, which is the maximum
5114 // size of allocated objects according to the IR Language Reference.
5115 unsigned PtrAddFlags = MI.getFlags();
5116 unsigned LHSPtrAddFlags = LHSPtrAdd->getFlags();
5117 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
5118 bool IsInBounds =
5119 PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::InBounds;
5120 unsigned Flags = 0;
5121 if (IsNoUWrap)
5122 Flags |= MachineInstr::MIFlag::NoUWrap;
5123 if (IsInBounds) {
5124 Flags |= MachineInstr::MIFlag::InBounds;
5125 Flags |= MachineInstr::MIFlag::NoUSWrap;
5126 }
5127
5128 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5129 auto NewCst = B.buildConstant(Res: MRI.getType(Reg: Src2Reg), Val: *C1 + *C2);
5130 Observer.changingInstr(MI);
5131 MI.getOperand(i: 1).setReg(LHSSrc1);
5132 MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0));
5133 MI.setFlags(Flags);
5134 Observer.changedInstr(MI);
5135 };
5136 return !reassociationCanBreakAddressingModePattern(MI);
5137}
5138
5139bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI,
5140 BuildFnTy &MatchInfo) const {
5141 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
5142 // We're trying to match a few pointer computation patterns here for
5143 // re-association opportunities.
5144 // 1) Isolating a constant operand to be on the RHS, e.g.:
5145 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
5146 //
5147 // 2) Folding two constants in each sub-tree as long as such folding
5148 // doesn't break a legal addressing mode.
5149 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
5150 //
5151 // 3) Move a constant from the LHS of an inner op to the RHS of the outer.
5152 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C)
5153 // iif (G_PTR_ADD X, C) has one use.
5154 MachineInstr *LHS = MRI.getVRegDef(Reg: PtrAdd.getBaseReg());
5155 MachineInstr *RHS = MRI.getVRegDef(Reg: PtrAdd.getOffsetReg());
5156
5157 // Try to match example 2.
5158 if (matchReassocFoldConstantsInSubTree(MI&: PtrAdd, LHS, RHS, MatchInfo))
5159 return true;
5160
5161 // Try to match example 3.
5162 if (matchReassocConstantInnerLHS(MI&: PtrAdd, LHS, RHS, MatchInfo))
5163 return true;
5164
5165 // Try to match example 1.
5166 if (matchReassocConstantInnerRHS(MI&: PtrAdd, RHS, MatchInfo))
5167 return true;
5168
5169 return false;
5170}
5171bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg,
5172 Register OpLHS, Register OpRHS,
5173 BuildFnTy &MatchInfo) const {
5174 LLT OpRHSTy = MRI.getType(Reg: OpRHS);
5175 MachineInstr *OpLHSDef = MRI.getVRegDef(Reg: OpLHS);
5176
5177 if (OpLHSDef->getOpcode() != Opc)
5178 return false;
5179
5180 MachineInstr *OpRHSDef = MRI.getVRegDef(Reg: OpRHS);
5181 Register OpLHSLHS = OpLHSDef->getOperand(i: 1).getReg();
5182 Register OpLHSRHS = OpLHSDef->getOperand(i: 2).getReg();
5183
5184 // If the inner op is (X op C), pull the constant out so it can be folded with
5185 // other constants in the expression tree. Folding is not guaranteed so we
5186 // might have (C1 op C2). In that case do not pull a constant out because it
5187 // won't help and can lead to infinite loops.
5188 if (isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSRHS), MRI) &&
5189 !isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSLHS), MRI)) {
5190 if (isConstantOrConstantSplatVector(MI&: *OpRHSDef, MRI)) {
5191 // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2))
5192 MatchInfo = [=](MachineIRBuilder &B) {
5193 auto NewCst = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSRHS, OpRHS});
5194 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {OpLHSLHS, NewCst});
5195 };
5196 return true;
5197 }
5198 if (getTargetLowering().isReassocProfitable(MRI, N0: OpLHS, N1: OpRHS)) {
5199 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
5200 // iff (op x, c1) has one use
5201 MatchInfo = [=](MachineIRBuilder &B) {
5202 auto NewLHSLHS = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSLHS, OpRHS});
5203 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {NewLHSLHS, OpLHSRHS});
5204 };
5205 return true;
5206 }
5207 }
5208
5209 return false;
5210}
5211
5212bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI,
5213 BuildFnTy &MatchInfo) const {
5214 // We don't check if the reassociation will break a legal addressing mode
5215 // here since pointer arithmetic is handled by G_PTR_ADD.
5216 unsigned Opc = MI.getOpcode();
5217 Register DstReg = MI.getOperand(i: 0).getReg();
5218 Register LHSReg = MI.getOperand(i: 1).getReg();
5219 Register RHSReg = MI.getOperand(i: 2).getReg();
5220
5221 if (tryReassocBinOp(Opc, DstReg, OpLHS: LHSReg, OpRHS: RHSReg, MatchInfo))
5222 return true;
5223 if (tryReassocBinOp(Opc, DstReg, OpLHS: RHSReg, OpRHS: LHSReg, MatchInfo))
5224 return true;
5225 return false;
5226}
5227
5228bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI,
5229 APInt &MatchInfo) const {
5230 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5231 Register SrcOp = MI.getOperand(i: 1).getReg();
5232
5233 if (auto MaybeCst = ConstantFoldCastOp(Opcode: MI.getOpcode(), DstTy, Op0: SrcOp, MRI)) {
5234 MatchInfo = *MaybeCst;
5235 return true;
5236 }
5237
5238 return false;
5239}
5240
5241bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI,
5242 APInt &MatchInfo) const {
5243 Register Op1 = MI.getOperand(i: 1).getReg();
5244 Register Op2 = MI.getOperand(i: 2).getReg();
5245 auto MaybeCst = ConstantFoldBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI);
5246 if (!MaybeCst)
5247 return false;
5248 MatchInfo = *MaybeCst;
5249 return true;
5250}
5251
5252bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI,
5253 ConstantFP *&MatchInfo) const {
5254 Register Op1 = MI.getOperand(i: 1).getReg();
5255 Register Op2 = MI.getOperand(i: 2).getReg();
5256 auto MaybeCst = ConstantFoldFPBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI);
5257 if (!MaybeCst)
5258 return false;
5259 MatchInfo =
5260 ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: *MaybeCst);
5261 return true;
5262}
5263
5264bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI,
5265 ConstantFP *&MatchInfo) const {
5266 assert(MI.getOpcode() == TargetOpcode::G_FMA ||
5267 MI.getOpcode() == TargetOpcode::G_FMAD);
5268 auto [_, Op1, Op2, Op3] = MI.getFirst4Regs();
5269
5270 const ConstantFP *Op3Cst = getConstantFPVRegVal(VReg: Op3, MRI);
5271 if (!Op3Cst)
5272 return false;
5273
5274 const ConstantFP *Op2Cst = getConstantFPVRegVal(VReg: Op2, MRI);
5275 if (!Op2Cst)
5276 return false;
5277
5278 const ConstantFP *Op1Cst = getConstantFPVRegVal(VReg: Op1, MRI);
5279 if (!Op1Cst)
5280 return false;
5281
5282 APFloat Op1F = Op1Cst->getValueAPF();
5283 Op1F.fusedMultiplyAdd(Multiplicand: Op2Cst->getValueAPF(), Addend: Op3Cst->getValueAPF(),
5284 RM: APFloat::rmNearestTiesToEven);
5285 MatchInfo = ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: Op1F);
5286 return true;
5287}
5288
5289bool CombinerHelper::matchNarrowBinopFeedingAnd(
5290 MachineInstr &MI,
5291 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
5292 // Look for a binop feeding into an AND with a mask:
5293 //
5294 // %add = G_ADD %lhs, %rhs
5295 // %and = G_AND %add, 000...11111111
5296 //
5297 // Check if it's possible to perform the binop at a narrower width and zext
5298 // back to the original width like so:
5299 //
5300 // %narrow_lhs = G_TRUNC %lhs
5301 // %narrow_rhs = G_TRUNC %rhs
5302 // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs
5303 // %new_add = G_ZEXT %narrow_add
5304 // %and = G_AND %new_add, 000...11111111
5305 //
5306 // This can allow later combines to eliminate the G_AND if it turns out
5307 // that the mask is irrelevant.
5308 assert(MI.getOpcode() == TargetOpcode::G_AND);
5309 Register Dst = MI.getOperand(i: 0).getReg();
5310 Register AndLHS = MI.getOperand(i: 1).getReg();
5311 Register AndRHS = MI.getOperand(i: 2).getReg();
5312 LLT WideTy = MRI.getType(Reg: Dst);
5313
5314 // If the potential binop has more than one use, then it's possible that one
5315 // of those uses will need its full width.
5316 if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(RegNo: AndLHS))
5317 return false;
5318
5319 // Check if the LHS feeding the AND is impacted by the high bits that we're
5320 // masking out.
5321 //
5322 // e.g. for 64-bit x, y:
5323 //
5324 // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535
5325 MachineInstr *LHSInst = getDefIgnoringCopies(Reg: AndLHS, MRI);
5326 if (!LHSInst)
5327 return false;
5328 unsigned LHSOpc = LHSInst->getOpcode();
5329 switch (LHSOpc) {
5330 default:
5331 return false;
5332 case TargetOpcode::G_ADD:
5333 case TargetOpcode::G_SUB:
5334 case TargetOpcode::G_MUL:
5335 case TargetOpcode::G_AND:
5336 case TargetOpcode::G_OR:
5337 case TargetOpcode::G_XOR:
5338 break;
5339 }
5340
5341 // Find the mask on the RHS.
5342 auto Cst = getIConstantVRegValWithLookThrough(VReg: AndRHS, MRI);
5343 if (!Cst)
5344 return false;
5345 auto Mask = Cst->Value;
5346 if (!Mask.isMask())
5347 return false;
5348
5349 // No point in combining if there's nothing to truncate.
5350 unsigned NarrowWidth = Mask.countr_one();
5351 if (NarrowWidth == WideTy.getSizeInBits())
5352 return false;
5353 LLT NarrowTy = LLT::scalar(SizeInBits: NarrowWidth);
5354
5355 // Check if adding the zext + truncates could be harmful.
5356 auto &MF = *MI.getMF();
5357 const auto &TLI = getTargetLowering();
5358 LLVMContext &Ctx = MF.getFunction().getContext();
5359 if (!TLI.isTruncateFree(FromTy: WideTy, ToTy: NarrowTy, Ctx) ||
5360 !TLI.isZExtFree(FromTy: NarrowTy, ToTy: WideTy, Ctx))
5361 return false;
5362 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) ||
5363 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {WideTy, NarrowTy}}))
5364 return false;
5365 Register BinOpLHS = LHSInst->getOperand(i: 1).getReg();
5366 Register BinOpRHS = LHSInst->getOperand(i: 2).getReg();
5367 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5368 auto NarrowLHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpLHS);
5369 auto NarrowRHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpRHS);
5370 auto NarrowBinOp =
5371 Builder.buildInstr(Opc: LHSOpc, DstOps: {NarrowTy}, SrcOps: {NarrowLHS, NarrowRHS});
5372 auto Ext = Builder.buildZExt(Res: WideTy, Op: NarrowBinOp);
5373 Observer.changingInstr(MI);
5374 MI.getOperand(i: 1).setReg(Ext.getReg(Idx: 0));
5375 Observer.changedInstr(MI);
5376 };
5377 return true;
5378}
5379
5380bool CombinerHelper::matchMulOBy2(MachineInstr &MI,
5381 BuildFnTy &MatchInfo) const {
5382 unsigned Opc = MI.getOpcode();
5383 assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO);
5384
5385 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 2)))
5386 return false;
5387
5388 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5389 Observer.changingInstr(MI);
5390 unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO
5391 : TargetOpcode::G_SADDO;
5392 MI.setDesc(Builder.getTII().get(Opcode: NewOpc));
5393 MI.getOperand(i: 3).setReg(MI.getOperand(i: 2).getReg());
5394 Observer.changedInstr(MI);
5395 };
5396 return true;
5397}
5398
5399bool CombinerHelper::matchMulOBy0(MachineInstr &MI,
5400 BuildFnTy &MatchInfo) const {
5401 // (G_*MULO x, 0) -> 0 + no carry out
5402 assert(MI.getOpcode() == TargetOpcode::G_UMULO ||
5403 MI.getOpcode() == TargetOpcode::G_SMULO);
5404 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0)))
5405 return false;
5406 Register Dst = MI.getOperand(i: 0).getReg();
5407 Register Carry = MI.getOperand(i: 1).getReg();
5408 if (!isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Dst)) ||
5409 !isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Carry)))
5410 return false;
5411 MatchInfo = [=](MachineIRBuilder &B) {
5412 B.buildConstant(Res: Dst, Val: 0);
5413 B.buildConstant(Res: Carry, Val: 0);
5414 };
5415 return true;
5416}
5417
5418bool CombinerHelper::matchAddEToAddO(MachineInstr &MI,
5419 BuildFnTy &MatchInfo) const {
5420 // (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
5421 // (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
5422 assert(MI.getOpcode() == TargetOpcode::G_UADDE ||
5423 MI.getOpcode() == TargetOpcode::G_SADDE ||
5424 MI.getOpcode() == TargetOpcode::G_USUBE ||
5425 MI.getOpcode() == TargetOpcode::G_SSUBE);
5426 if (!mi_match(R: MI.getOperand(i: 4).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0)))
5427 return false;
5428 MatchInfo = [&](MachineIRBuilder &B) {
5429 unsigned NewOpcode;
5430 switch (MI.getOpcode()) {
5431 case TargetOpcode::G_UADDE:
5432 NewOpcode = TargetOpcode::G_UADDO;
5433 break;
5434 case TargetOpcode::G_SADDE:
5435 NewOpcode = TargetOpcode::G_SADDO;
5436 break;
5437 case TargetOpcode::G_USUBE:
5438 NewOpcode = TargetOpcode::G_USUBO;
5439 break;
5440 case TargetOpcode::G_SSUBE:
5441 NewOpcode = TargetOpcode::G_SSUBO;
5442 break;
5443 }
5444 Observer.changingInstr(MI);
5445 MI.setDesc(B.getTII().get(Opcode: NewOpcode));
5446 MI.removeOperand(OpNo: 4);
5447 Observer.changedInstr(MI);
5448 };
5449 return true;
5450}
5451
5452bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI,
5453 BuildFnTy &MatchInfo) const {
5454 assert(MI.getOpcode() == TargetOpcode::G_SUB);
5455 Register Dst = MI.getOperand(i: 0).getReg();
5456 // (x + y) - z -> x (if y == z)
5457 // (x + y) - z -> y (if x == z)
5458 Register X, Y, Z;
5459 if (mi_match(R: Dst, MRI, P: m_GSub(L: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: Y)), R: m_Reg(R&: Z)))) {
5460 Register ReplaceReg;
5461 int64_t CstX, CstY;
5462 if (Y == Z || (mi_match(R: Y, MRI, P: m_ICstOrSplat(Cst&: CstY)) &&
5463 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstY))))
5464 ReplaceReg = X;
5465 else if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5466 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5467 ReplaceReg = Y;
5468 if (ReplaceReg) {
5469 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: ReplaceReg); };
5470 return true;
5471 }
5472 }
5473
5474 // x - (y + z) -> 0 - y (if x == z)
5475 // x - (y + z) -> 0 - z (if x == y)
5476 if (mi_match(R: Dst, MRI, P: m_GSub(L: m_Reg(R&: X), R: m_GAdd(L: m_Reg(R&: Y), R: m_Reg(R&: Z))))) {
5477 Register ReplaceReg;
5478 int64_t CstX;
5479 if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5480 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5481 ReplaceReg = Y;
5482 else if (X == Y || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5483 mi_match(R: Y, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5484 ReplaceReg = Z;
5485 if (ReplaceReg) {
5486 MatchInfo = [=](MachineIRBuilder &B) {
5487 auto Zero = B.buildConstant(Res: MRI.getType(Reg: Dst), Val: 0);
5488 B.buildSub(Dst, Src0: Zero, Src1: ReplaceReg);
5489 };
5490 return true;
5491 }
5492 }
5493 return false;
5494}
5495
5496MachineInstr *CombinerHelper::buildUDivOrURemUsingMul(MachineInstr &MI) const {
5497 unsigned Opcode = MI.getOpcode();
5498 assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM);
5499 auto &UDivorRem = cast<GenericMachineInstr>(Val&: MI);
5500 Register Dst = UDivorRem.getReg(Idx: 0);
5501 Register LHS = UDivorRem.getReg(Idx: 1);
5502 Register RHS = UDivorRem.getReg(Idx: 2);
5503 LLT Ty = MRI.getType(Reg: Dst);
5504 LLT ScalarTy = Ty.getScalarType();
5505 const unsigned EltBits = ScalarTy.getScalarSizeInBits();
5506 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5507 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5508
5509 auto &MIB = Builder;
5510
5511 bool UseSRL = false;
5512 SmallVector<Register, 16> Shifts, Factors;
5513 auto *RHSDefInstr = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI));
5514 bool IsSplat = getIConstantSplatVal(MI: *RHSDefInstr, MRI).has_value();
5515
5516 auto BuildExactUDIVPattern = [&](const Constant *C) {
5517 // Don't recompute inverses for each splat element.
5518 if (IsSplat && !Factors.empty()) {
5519 Shifts.push_back(Elt: Shifts[0]);
5520 Factors.push_back(Elt: Factors[0]);
5521 return true;
5522 }
5523
5524 auto *CI = cast<ConstantInt>(Val: C);
5525 APInt Divisor = CI->getValue();
5526 unsigned Shift = Divisor.countr_zero();
5527 if (Shift) {
5528 Divisor.lshrInPlace(ShiftAmt: Shift);
5529 UseSRL = true;
5530 }
5531
5532 // Calculate the multiplicative inverse modulo BW.
5533 APInt Factor = Divisor.multiplicativeInverse();
5534 Shifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0));
5535 Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0));
5536 return true;
5537 };
5538
5539 if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5540 // Collect all magic values from the build vector.
5541 if (!matchUnaryPredicate(MRI, Reg: RHS, Match: BuildExactUDIVPattern))
5542 llvm_unreachable("Expected unary predicate match to succeed");
5543
5544 Register Shift, Factor;
5545 if (Ty.isVector()) {
5546 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0);
5547 Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0);
5548 } else {
5549 Shift = Shifts[0];
5550 Factor = Factors[0];
5551 }
5552
5553 Register Res = LHS;
5554
5555 if (UseSRL)
5556 Res = MIB.buildLShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0);
5557
5558 return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor);
5559 }
5560
5561 unsigned KnownLeadingZeros =
5562 VT ? VT->getKnownBits(R: LHS).countMinLeadingZeros() : 0;
5563
5564 bool UseNPQ = false;
5565 SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
5566 auto BuildUDIVPattern = [&](const Constant *C) {
5567 auto *CI = cast<ConstantInt>(Val: C);
5568 const APInt &Divisor = CI->getValue();
5569
5570 bool SelNPQ = false;
5571 APInt Magic(Divisor.getBitWidth(), 0);
5572 unsigned PreShift = 0, PostShift = 0;
5573
5574 // Magic algorithm doesn't work for division by 1. We need to emit a select
5575 // at the end.
5576 // TODO: Use undef values for divisor of 1.
5577 if (!Divisor.isOne()) {
5578
5579 // UnsignedDivisionByConstantInfo doesn't work correctly if leading zeros
5580 // in the dividend exceeds the leading zeros for the divisor.
5581 UnsignedDivisionByConstantInfo magics =
5582 UnsignedDivisionByConstantInfo::get(
5583 D: Divisor, LeadingZeros: std::min(a: KnownLeadingZeros, b: Divisor.countl_zero()));
5584
5585 Magic = std::move(magics.Magic);
5586
5587 assert(magics.PreShift < Divisor.getBitWidth() &&
5588 "We shouldn't generate an undefined shift!");
5589 assert(magics.PostShift < Divisor.getBitWidth() &&
5590 "We shouldn't generate an undefined shift!");
5591 assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift");
5592 PreShift = magics.PreShift;
5593 PostShift = magics.PostShift;
5594 SelNPQ = magics.IsAdd;
5595 }
5596
5597 PreShifts.push_back(
5598 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PreShift).getReg(Idx: 0));
5599 MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magic).getReg(Idx: 0));
5600 NPQFactors.push_back(
5601 Elt: MIB.buildConstant(Res: ScalarTy,
5602 Val: SelNPQ ? APInt::getOneBitSet(numBits: EltBits, BitNo: EltBits - 1)
5603 : APInt::getZero(numBits: EltBits))
5604 .getReg(Idx: 0));
5605 PostShifts.push_back(
5606 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PostShift).getReg(Idx: 0));
5607 UseNPQ |= SelNPQ;
5608 return true;
5609 };
5610
5611 // Collect the shifts/magic values from each element.
5612 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildUDIVPattern);
5613 (void)Matched;
5614 assert(Matched && "Expected unary predicate match to succeed");
5615
5616 Register PreShift, PostShift, MagicFactor, NPQFactor;
5617 auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI);
5618 if (RHSDef) {
5619 PreShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PreShifts).getReg(Idx: 0);
5620 MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0);
5621 NPQFactor = MIB.buildBuildVector(Res: Ty, Ops: NPQFactors).getReg(Idx: 0);
5622 PostShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PostShifts).getReg(Idx: 0);
5623 } else {
5624 assert(MRI.getType(RHS).isScalar() &&
5625 "Non-build_vector operation should have been a scalar");
5626 PreShift = PreShifts[0];
5627 MagicFactor = MagicFactors[0];
5628 PostShift = PostShifts[0];
5629 }
5630
5631 Register Q = LHS;
5632 Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PreShift).getReg(Idx: 0);
5633
5634 // Multiply the numerator (operand 0) by the magic value.
5635 Q = MIB.buildUMulH(Dst: Ty, Src0: Q, Src1: MagicFactor).getReg(Idx: 0);
5636
5637 if (UseNPQ) {
5638 Register NPQ = MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Q).getReg(Idx: 0);
5639
5640 // For vectors we might have a mix of non-NPQ/NPQ paths, so use
5641 // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero.
5642 if (Ty.isVector())
5643 NPQ = MIB.buildUMulH(Dst: Ty, Src0: NPQ, Src1: NPQFactor).getReg(Idx: 0);
5644 else
5645 NPQ = MIB.buildLShr(Dst: Ty, Src0: NPQ, Src1: MIB.buildConstant(Res: ShiftAmtTy, Val: 1)).getReg(Idx: 0);
5646
5647 Q = MIB.buildAdd(Dst: Ty, Src0: NPQ, Src1: Q).getReg(Idx: 0);
5648 }
5649
5650 Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PostShift).getReg(Idx: 0);
5651 auto One = MIB.buildConstant(Res: Ty, Val: 1);
5652 auto IsOne = MIB.buildICmp(
5653 Pred: CmpInst::Predicate::ICMP_EQ,
5654 Res: Ty.isScalar() ? LLT::scalar(SizeInBits: 1) : Ty.changeElementSize(NewEltSize: 1), Op0: RHS, Op1: One);
5655 auto ret = MIB.buildSelect(Res: Ty, Tst: IsOne, Op0: LHS, Op1: Q);
5656
5657 if (Opcode == TargetOpcode::G_UREM) {
5658 auto Prod = MIB.buildMul(Dst: Ty, Src0: ret, Src1: RHS);
5659 return MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Prod);
5660 }
5661 return ret;
5662}
5663
5664bool CombinerHelper::matchUDivOrURemByConst(MachineInstr &MI) const {
5665 unsigned Opcode = MI.getOpcode();
5666 assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM);
5667 Register Dst = MI.getOperand(i: 0).getReg();
5668 Register RHS = MI.getOperand(i: 2).getReg();
5669 LLT DstTy = MRI.getType(Reg: Dst);
5670
5671 auto &MF = *MI.getMF();
5672 AttributeList Attr = MF.getFunction().getAttributes();
5673 const auto &TLI = getTargetLowering();
5674 LLVMContext &Ctx = MF.getFunction().getContext();
5675 if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, Ctx), Attr))
5676 return false;
5677
5678 // Don't do this for minsize because the instruction sequence is usually
5679 // larger.
5680 if (MF.getFunction().hasMinSize())
5681 return false;
5682
5683 if (Opcode == TargetOpcode::G_UDIV &&
5684 MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5685 return matchUnaryPredicate(
5686 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5687 }
5688
5689 auto *RHSDef = MRI.getVRegDef(Reg: RHS);
5690 if (!isConstantOrConstantVector(MI&: *RHSDef, MRI))
5691 return false;
5692
5693 // Don't do this if the types are not going to be legal.
5694 if (LI) {
5695 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}}))
5696 return false;
5697 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMULH, {DstTy}}))
5698 return false;
5699 if (!isLegalOrBeforeLegalizer(
5700 Query: {TargetOpcode::G_ICMP,
5701 {DstTy.isVector() ? DstTy.changeElementSize(NewEltSize: 1) : LLT::scalar(SizeInBits: 1),
5702 DstTy}}))
5703 return false;
5704 if (Opcode == TargetOpcode::G_UREM &&
5705 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy, DstTy}}))
5706 return false;
5707 }
5708
5709 return matchUnaryPredicate(
5710 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5711}
5712
5713void CombinerHelper::applyUDivOrURemByConst(MachineInstr &MI) const {
5714 auto *NewMI = buildUDivOrURemUsingMul(MI);
5715 replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg());
5716}
5717
5718bool CombinerHelper::matchSDivOrSRemByConst(MachineInstr &MI) const {
5719 unsigned Opcode = MI.getOpcode();
5720 assert(Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM);
5721 Register Dst = MI.getOperand(i: 0).getReg();
5722 Register RHS = MI.getOperand(i: 2).getReg();
5723 LLT DstTy = MRI.getType(Reg: Dst);
5724 auto SizeInBits = DstTy.getScalarSizeInBits();
5725 LLT WideTy = DstTy.changeElementSize(NewEltSize: SizeInBits * 2);
5726
5727 auto &MF = *MI.getMF();
5728 AttributeList Attr = MF.getFunction().getAttributes();
5729 const auto &TLI = getTargetLowering();
5730 LLVMContext &Ctx = MF.getFunction().getContext();
5731 if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, Ctx), Attr))
5732 return false;
5733
5734 // Don't do this for minsize because the instruction sequence is usually
5735 // larger.
5736 if (MF.getFunction().hasMinSize())
5737 return false;
5738
5739 // If the sdiv has an 'exact' flag we can use a simpler lowering.
5740 if (Opcode == TargetOpcode::G_SDIV &&
5741 MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5742 return matchUnaryPredicate(
5743 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5744 }
5745
5746 auto *RHSDef = MRI.getVRegDef(Reg: RHS);
5747 if (!isConstantOrConstantVector(MI&: *RHSDef, MRI))
5748 return false;
5749
5750 // Don't do this if the types are not going to be legal.
5751 if (LI) {
5752 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}}))
5753 return false;
5754 if (!isLegal(Query: {TargetOpcode::G_SMULH, {DstTy}}) &&
5755 !isLegalOrHasWidenScalar(Query: {TargetOpcode::G_MUL, {WideTy, WideTy}}))
5756 return false;
5757 if (Opcode == TargetOpcode::G_SREM &&
5758 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy, DstTy}}))
5759 return false;
5760 }
5761
5762 return matchUnaryPredicate(
5763 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5764}
5765
5766void CombinerHelper::applySDivOrSRemByConst(MachineInstr &MI) const {
5767 auto *NewMI = buildSDivOrSRemUsingMul(MI);
5768 replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg());
5769}
5770
5771MachineInstr *CombinerHelper::buildSDivOrSRemUsingMul(MachineInstr &MI) const {
5772 unsigned Opcode = MI.getOpcode();
5773 assert(MI.getOpcode() == TargetOpcode::G_SDIV ||
5774 Opcode == TargetOpcode::G_SREM);
5775 auto &SDivorRem = cast<GenericMachineInstr>(Val&: MI);
5776 Register Dst = SDivorRem.getReg(Idx: 0);
5777 Register LHS = SDivorRem.getReg(Idx: 1);
5778 Register RHS = SDivorRem.getReg(Idx: 2);
5779 LLT Ty = MRI.getType(Reg: Dst);
5780 LLT ScalarTy = Ty.getScalarType();
5781 const unsigned EltBits = ScalarTy.getScalarSizeInBits();
5782 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5783 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5784 auto &MIB = Builder;
5785
5786 bool UseSRA = false;
5787 SmallVector<Register, 16> ExactShifts, ExactFactors;
5788
5789 auto *RHSDefInstr = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI));
5790 bool IsSplat = getIConstantSplatVal(MI: *RHSDefInstr, MRI).has_value();
5791
5792 auto BuildExactSDIVPattern = [&](const Constant *C) {
5793 // Don't recompute inverses for each splat element.
5794 if (IsSplat && !ExactFactors.empty()) {
5795 ExactShifts.push_back(Elt: ExactShifts[0]);
5796 ExactFactors.push_back(Elt: ExactFactors[0]);
5797 return true;
5798 }
5799
5800 auto *CI = cast<ConstantInt>(Val: C);
5801 APInt Divisor = CI->getValue();
5802 unsigned Shift = Divisor.countr_zero();
5803 if (Shift) {
5804 Divisor.ashrInPlace(ShiftAmt: Shift);
5805 UseSRA = true;
5806 }
5807
5808 // Calculate the multiplicative inverse modulo BW.
5809 // 2^W requires W + 1 bits, so we have to extend and then truncate.
5810 APInt Factor = Divisor.multiplicativeInverse();
5811 ExactShifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0));
5812 ExactFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0));
5813 return true;
5814 };
5815
5816 if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5817 // Collect all magic values from the build vector.
5818 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildExactSDIVPattern);
5819 (void)Matched;
5820 assert(Matched && "Expected unary predicate match to succeed");
5821
5822 Register Shift, Factor;
5823 if (Ty.isVector()) {
5824 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: ExactShifts).getReg(Idx: 0);
5825 Factor = MIB.buildBuildVector(Res: Ty, Ops: ExactFactors).getReg(Idx: 0);
5826 } else {
5827 Shift = ExactShifts[0];
5828 Factor = ExactFactors[0];
5829 }
5830
5831 Register Res = LHS;
5832
5833 if (UseSRA)
5834 Res = MIB.buildAShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0);
5835
5836 return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor);
5837 }
5838
5839 SmallVector<Register, 16> MagicFactors, Factors, Shifts, ShiftMasks;
5840
5841 auto BuildSDIVPattern = [&](const Constant *C) {
5842 auto *CI = cast<ConstantInt>(Val: C);
5843 const APInt &Divisor = CI->getValue();
5844
5845 SignedDivisionByConstantInfo Magics =
5846 SignedDivisionByConstantInfo::get(D: Divisor);
5847 int NumeratorFactor = 0;
5848 int ShiftMask = -1;
5849
5850 if (Divisor.isOne() || Divisor.isAllOnes()) {
5851 // If d is +1/-1, we just multiply the numerator by +1/-1.
5852 NumeratorFactor = Divisor.getSExtValue();
5853 Magics.Magic = 0;
5854 Magics.ShiftAmount = 0;
5855 ShiftMask = 0;
5856 } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) {
5857 // If d > 0 and m < 0, add the numerator.
5858 NumeratorFactor = 1;
5859 } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) {
5860 // If d < 0 and m > 0, subtract the numerator.
5861 NumeratorFactor = -1;
5862 }
5863
5864 MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magics.Magic).getReg(Idx: 0));
5865 Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: NumeratorFactor).getReg(Idx: 0));
5866 Shifts.push_back(
5867 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Magics.ShiftAmount).getReg(Idx: 0));
5868 ShiftMasks.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: ShiftMask).getReg(Idx: 0));
5869
5870 return true;
5871 };
5872
5873 // Collect the shifts/magic values from each element.
5874 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildSDIVPattern);
5875 (void)Matched;
5876 assert(Matched && "Expected unary predicate match to succeed");
5877
5878 Register MagicFactor, Factor, Shift, ShiftMask;
5879 auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI);
5880 if (RHSDef) {
5881 MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0);
5882 Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0);
5883 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0);
5884 ShiftMask = MIB.buildBuildVector(Res: Ty, Ops: ShiftMasks).getReg(Idx: 0);
5885 } else {
5886 assert(MRI.getType(RHS).isScalar() &&
5887 "Non-build_vector operation should have been a scalar");
5888 MagicFactor = MagicFactors[0];
5889 Factor = Factors[0];
5890 Shift = Shifts[0];
5891 ShiftMask = ShiftMasks[0];
5892 }
5893
5894 Register Q = LHS;
5895 Q = MIB.buildSMulH(Dst: Ty, Src0: LHS, Src1: MagicFactor).getReg(Idx: 0);
5896
5897 // (Optionally) Add/subtract the numerator using Factor.
5898 Factor = MIB.buildMul(Dst: Ty, Src0: LHS, Src1: Factor).getReg(Idx: 0);
5899 Q = MIB.buildAdd(Dst: Ty, Src0: Q, Src1: Factor).getReg(Idx: 0);
5900
5901 // Shift right algebraic by shift value.
5902 Q = MIB.buildAShr(Dst: Ty, Src0: Q, Src1: Shift).getReg(Idx: 0);
5903
5904 // Extract the sign bit, mask it and add it to the quotient.
5905 auto SignShift = MIB.buildConstant(Res: ShiftAmtTy, Val: EltBits - 1);
5906 auto T = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: SignShift);
5907 T = MIB.buildAnd(Dst: Ty, Src0: T, Src1: ShiftMask);
5908 auto ret = MIB.buildAdd(Dst: Ty, Src0: Q, Src1: T);
5909
5910 if (Opcode == TargetOpcode::G_SREM) {
5911 auto Prod = MIB.buildMul(Dst: Ty, Src0: ret, Src1: RHS);
5912 return MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Prod);
5913 }
5914 return ret;
5915}
5916
5917bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const {
5918 assert((MI.getOpcode() == TargetOpcode::G_SDIV ||
5919 MI.getOpcode() == TargetOpcode::G_UDIV) &&
5920 "Expected SDIV or UDIV");
5921 auto &Div = cast<GenericMachineInstr>(Val&: MI);
5922 Register RHS = Div.getReg(Idx: 2);
5923 auto MatchPow2 = [&](const Constant *C) {
5924 auto *CI = dyn_cast<ConstantInt>(Val: C);
5925 return CI && (CI->getValue().isPowerOf2() ||
5926 (IsSigned && CI->getValue().isNegatedPowerOf2()));
5927 };
5928 return matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2, /*AllowUndefs=*/false);
5929}
5930
5931void CombinerHelper::applySDivByPow2(MachineInstr &MI) const {
5932 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5933 auto &SDiv = cast<GenericMachineInstr>(Val&: MI);
5934 Register Dst = SDiv.getReg(Idx: 0);
5935 Register LHS = SDiv.getReg(Idx: 1);
5936 Register RHS = SDiv.getReg(Idx: 2);
5937 LLT Ty = MRI.getType(Reg: Dst);
5938 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5939 LLT CCVT =
5940 Ty.isVector() ? LLT::vector(EC: Ty.getElementCount(), ScalarSizeInBits: 1) : LLT::scalar(SizeInBits: 1);
5941
5942 // Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2,
5943 // to the following version:
5944 //
5945 // %c1 = G_CTTZ %rhs
5946 // %inexact = G_SUB $bitwidth, %c1
5947 // %sign = %G_ASHR %lhs, $(bitwidth - 1)
5948 // %lshr = G_LSHR %sign, %inexact
5949 // %add = G_ADD %lhs, %lshr
5950 // %ashr = G_ASHR %add, %c1
5951 // %ashr = G_SELECT, %isoneorallones, %lhs, %ashr
5952 // %zero = G_CONSTANT $0
5953 // %neg = G_NEG %ashr
5954 // %isneg = G_ICMP SLT %rhs, %zero
5955 // %res = G_SELECT %isneg, %neg, %ashr
5956
5957 unsigned BitWidth = Ty.getScalarSizeInBits();
5958 auto Zero = Builder.buildConstant(Res: Ty, Val: 0);
5959
5960 auto Bits = Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth);
5961 auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS);
5962 auto Inexact = Builder.buildSub(Dst: ShiftAmtTy, Src0: Bits, Src1: C1);
5963 // Splat the sign bit into the register
5964 auto Sign = Builder.buildAShr(
5965 Dst: Ty, Src0: LHS, Src1: Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth - 1));
5966
5967 // Add (LHS < 0) ? abs2 - 1 : 0;
5968 auto LSrl = Builder.buildLShr(Dst: Ty, Src0: Sign, Src1: Inexact);
5969 auto Add = Builder.buildAdd(Dst: Ty, Src0: LHS, Src1: LSrl);
5970 auto AShr = Builder.buildAShr(Dst: Ty, Src0: Add, Src1: C1);
5971
5972 // Special case: (sdiv X, 1) -> X
5973 // Special Case: (sdiv X, -1) -> 0-X
5974 auto One = Builder.buildConstant(Res: Ty, Val: 1);
5975 auto MinusOne = Builder.buildConstant(Res: Ty, Val: -1);
5976 auto IsOne = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: One);
5977 auto IsMinusOne =
5978 Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: MinusOne);
5979 auto IsOneOrMinusOne = Builder.buildOr(Dst: CCVT, Src0: IsOne, Src1: IsMinusOne);
5980 AShr = Builder.buildSelect(Res: Ty, Tst: IsOneOrMinusOne, Op0: LHS, Op1: AShr);
5981
5982 // If divided by a positive value, we're done. Otherwise, the result must be
5983 // negated.
5984 auto Neg = Builder.buildNeg(Dst: Ty, Src0: AShr);
5985 auto IsNeg = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_SLT, Res: CCVT, Op0: RHS, Op1: Zero);
5986 Builder.buildSelect(Res: MI.getOperand(i: 0).getReg(), Tst: IsNeg, Op0: Neg, Op1: AShr);
5987 MI.eraseFromParent();
5988}
5989
5990void CombinerHelper::applyUDivByPow2(MachineInstr &MI) const {
5991 assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV");
5992 auto &UDiv = cast<GenericMachineInstr>(Val&: MI);
5993 Register Dst = UDiv.getReg(Idx: 0);
5994 Register LHS = UDiv.getReg(Idx: 1);
5995 Register RHS = UDiv.getReg(Idx: 2);
5996 LLT Ty = MRI.getType(Reg: Dst);
5997 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5998
5999 auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS);
6000 Builder.buildLShr(Dst: MI.getOperand(i: 0).getReg(), Src0: LHS, Src1: C1);
6001 MI.eraseFromParent();
6002}
6003
6004bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) const {
6005 assert(MI.getOpcode() == TargetOpcode::G_UMULH);
6006 Register RHS = MI.getOperand(i: 2).getReg();
6007 Register Dst = MI.getOperand(i: 0).getReg();
6008 LLT Ty = MRI.getType(Reg: Dst);
6009 LLT RHSTy = MRI.getType(Reg: RHS);
6010 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6011 auto MatchPow2ExceptOne = [&](const Constant *C) {
6012 if (auto *CI = dyn_cast<ConstantInt>(Val: C))
6013 return CI->getValue().isPowerOf2() && !CI->getValue().isOne();
6014 return false;
6015 };
6016 if (!matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2ExceptOne, AllowUndefs: false))
6017 return false;
6018 // We need to check both G_LSHR and G_CTLZ because the combine uses G_CTLZ to
6019 // get log base 2, and it is not always legal for on a target.
6020 return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}}) &&
6021 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CTLZ, {RHSTy, RHSTy}});
6022}
6023
6024void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) const {
6025 Register LHS = MI.getOperand(i: 1).getReg();
6026 Register RHS = MI.getOperand(i: 2).getReg();
6027 Register Dst = MI.getOperand(i: 0).getReg();
6028 LLT Ty = MRI.getType(Reg: Dst);
6029 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6030 unsigned NumEltBits = Ty.getScalarSizeInBits();
6031
6032 auto LogBase2 = buildLogBase2(V: RHS, MIB&: Builder);
6033 auto ShiftAmt =
6034 Builder.buildSub(Dst: Ty, Src0: Builder.buildConstant(Res: Ty, Val: NumEltBits), Src1: LogBase2);
6035 auto Trunc = Builder.buildZExtOrTrunc(Res: ShiftAmtTy, Op: ShiftAmt);
6036 Builder.buildLShr(Dst, Src0: LHS, Src1: Trunc);
6037 MI.eraseFromParent();
6038}
6039
6040bool CombinerHelper::matchTruncSSatS(MachineInstr &MI,
6041 Register &MatchInfo) const {
6042 Register Dst = MI.getOperand(i: 0).getReg();
6043 Register Src = MI.getOperand(i: 1).getReg();
6044 LLT DstTy = MRI.getType(Reg: Dst);
6045 LLT SrcTy = MRI.getType(Reg: Src);
6046 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6047 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6048 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6049
6050 if (!LI || !isLegal(Query: {TargetOpcode::G_TRUNC_SSAT_S, {DstTy, SrcTy}}))
6051 return false;
6052
6053 APInt SignedMax = APInt::getSignedMaxValue(numBits: NumDstBits).sext(width: NumSrcBits);
6054 APInt SignedMin = APInt::getSignedMinValue(numBits: NumDstBits).sext(width: NumSrcBits);
6055 return mi_match(R: Src, MRI,
6056 P: m_GSMin(L: m_GSMax(L: m_Reg(R&: MatchInfo),
6057 R: m_SpecificICstOrSplat(RequestedValue: SignedMin)),
6058 R: m_SpecificICstOrSplat(RequestedValue: SignedMax))) ||
6059 mi_match(R: Src, MRI,
6060 P: m_GSMax(L: m_GSMin(L: m_Reg(R&: MatchInfo),
6061 R: m_SpecificICstOrSplat(RequestedValue: SignedMax)),
6062 R: m_SpecificICstOrSplat(RequestedValue: SignedMin)));
6063}
6064
6065void CombinerHelper::applyTruncSSatS(MachineInstr &MI,
6066 Register &MatchInfo) const {
6067 Register Dst = MI.getOperand(i: 0).getReg();
6068 Builder.buildTruncSSatS(Res: Dst, Op: MatchInfo);
6069 MI.eraseFromParent();
6070}
6071
6072bool CombinerHelper::matchTruncSSatU(MachineInstr &MI,
6073 Register &MatchInfo) const {
6074 Register Dst = MI.getOperand(i: 0).getReg();
6075 Register Src = MI.getOperand(i: 1).getReg();
6076 LLT DstTy = MRI.getType(Reg: Dst);
6077 LLT SrcTy = MRI.getType(Reg: Src);
6078 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6079 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6080 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6081
6082 if (!LI || !isLegal(Query: {TargetOpcode::G_TRUNC_SSAT_U, {DstTy, SrcTy}}))
6083 return false;
6084 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
6085 return mi_match(R: Src, MRI,
6086 P: m_GSMin(L: m_GSMax(L: m_Reg(R&: MatchInfo), R: m_SpecificICstOrSplat(RequestedValue: 0)),
6087 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax))) ||
6088 mi_match(R: Src, MRI,
6089 P: m_GSMax(L: m_GSMin(L: m_Reg(R&: MatchInfo),
6090 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)),
6091 R: m_SpecificICstOrSplat(RequestedValue: 0))) ||
6092 mi_match(R: Src, MRI,
6093 P: m_GUMin(L: m_GSMax(L: m_Reg(R&: MatchInfo), R: m_SpecificICstOrSplat(RequestedValue: 0)),
6094 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)));
6095}
6096
6097void CombinerHelper::applyTruncSSatU(MachineInstr &MI,
6098 Register &MatchInfo) const {
6099 Register Dst = MI.getOperand(i: 0).getReg();
6100 Builder.buildTruncSSatU(Res: Dst, Op: MatchInfo);
6101 MI.eraseFromParent();
6102}
6103
6104bool CombinerHelper::matchTruncUSatU(MachineInstr &MI,
6105 MachineInstr &MinMI) const {
6106 Register Min = MinMI.getOperand(i: 2).getReg();
6107 Register Val = MinMI.getOperand(i: 1).getReg();
6108 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6109 LLT SrcTy = MRI.getType(Reg: Val);
6110 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6111 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6112 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6113
6114 if (!LI || !isLegal(Query: {TargetOpcode::G_TRUNC_SSAT_U, {DstTy, SrcTy}}))
6115 return false;
6116 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
6117 return mi_match(R: Min, MRI, P: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)) &&
6118 !mi_match(R: Val, MRI, P: m_GSMax(L: m_Reg(), R: m_Reg()));
6119}
6120
6121bool CombinerHelper::matchTruncUSatUToFPTOUISat(MachineInstr &MI,
6122 MachineInstr &SrcMI) const {
6123 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6124 LLT SrcTy = MRI.getType(Reg: SrcMI.getOperand(i: 1).getReg());
6125
6126 return LI &&
6127 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FPTOUI_SAT, {DstTy, SrcTy}});
6128}
6129
6130bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI,
6131 BuildFnTy &MatchInfo) const {
6132 unsigned Opc = MI.getOpcode();
6133 assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB ||
6134 Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
6135 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA);
6136
6137 Register Dst = MI.getOperand(i: 0).getReg();
6138 Register X = MI.getOperand(i: 1).getReg();
6139 Register Y = MI.getOperand(i: 2).getReg();
6140 LLT Type = MRI.getType(Reg: Dst);
6141
6142 // fold (fadd x, fneg(y)) -> (fsub x, y)
6143 // fold (fadd fneg(y), x) -> (fsub x, y)
6144 // G_ADD is commutative so both cases are checked by m_GFAdd
6145 if (mi_match(R: Dst, MRI, P: m_GFAdd(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) &&
6146 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FSUB, {Type}})) {
6147 Opc = TargetOpcode::G_FSUB;
6148 }
6149 /// fold (fsub x, fneg(y)) -> (fadd x, y)
6150 else if (mi_match(R: Dst, MRI, P: m_GFSub(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) &&
6151 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FADD, {Type}})) {
6152 Opc = TargetOpcode::G_FADD;
6153 }
6154 // fold (fmul fneg(x), fneg(y)) -> (fmul x, y)
6155 // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y)
6156 // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z)
6157 // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z)
6158 else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
6159 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) &&
6160 mi_match(R: X, MRI, P: m_GFNeg(Src: m_Reg(R&: X))) &&
6161 mi_match(R: Y, MRI, P: m_GFNeg(Src: m_Reg(R&: Y)))) {
6162 // no opcode change
6163 } else
6164 return false;
6165
6166 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6167 Observer.changingInstr(MI);
6168 MI.setDesc(B.getTII().get(Opcode: Opc));
6169 MI.getOperand(i: 1).setReg(X);
6170 MI.getOperand(i: 2).setReg(Y);
6171 Observer.changedInstr(MI);
6172 };
6173 return true;
6174}
6175
6176bool CombinerHelper::matchFsubToFneg(MachineInstr &MI,
6177 Register &MatchInfo) const {
6178 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6179
6180 Register LHS = MI.getOperand(i: 1).getReg();
6181 MatchInfo = MI.getOperand(i: 2).getReg();
6182 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6183
6184 const auto LHSCst = Ty.isVector()
6185 ? getFConstantSplat(VReg: LHS, MRI, /* allowUndef */ AllowUndef: true)
6186 : getFConstantVRegValWithLookThrough(VReg: LHS, MRI);
6187 if (!LHSCst)
6188 return false;
6189
6190 // -0.0 is always allowed
6191 if (LHSCst->Value.isNegZero())
6192 return true;
6193
6194 // +0.0 is only allowed if nsz is set.
6195 if (LHSCst->Value.isPosZero())
6196 return MI.getFlag(Flag: MachineInstr::FmNsz);
6197
6198 return false;
6199}
6200
6201void CombinerHelper::applyFsubToFneg(MachineInstr &MI,
6202 Register &MatchInfo) const {
6203 Register Dst = MI.getOperand(i: 0).getReg();
6204 Builder.buildFNeg(
6205 Dst, Src0: Builder.buildFCanonicalize(Dst: MRI.getType(Reg: Dst), Src0: MatchInfo).getReg(Idx: 0));
6206 eraseInst(MI);
6207}
6208
6209/// Checks if \p MI is TargetOpcode::G_FMUL and contractable either
6210/// due to global flags or MachineInstr flags.
6211static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) {
6212 if (MI.getOpcode() != TargetOpcode::G_FMUL)
6213 return false;
6214 return AllowFusionGlobally || MI.getFlag(Flag: MachineInstr::MIFlag::FmContract);
6215}
6216
6217static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1,
6218 const MachineRegisterInfo &MRI) {
6219 return std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI0.getOperand(i: 0).getReg()),
6220 last: MRI.use_instr_nodbg_end()) >
6221 std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI1.getOperand(i: 0).getReg()),
6222 last: MRI.use_instr_nodbg_end());
6223}
6224
6225bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI,
6226 bool &AllowFusionGlobally,
6227 bool &HasFMAD, bool &Aggressive,
6228 bool CanReassociate) const {
6229
6230 auto *MF = MI.getMF();
6231 const auto &TLI = *MF->getSubtarget().getTargetLowering();
6232 const TargetOptions &Options = MF->getTarget().Options;
6233 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6234
6235 if (CanReassociate && !MI.getFlag(Flag: MachineInstr::MIFlag::FmReassoc))
6236 return false;
6237
6238 // Floating-point multiply-add with intermediate rounding.
6239 HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, Ty: DstType));
6240 // Floating-point multiply-add without intermediate rounding.
6241 bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(MF: *MF, DstType) &&
6242 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FMA, {DstType}});
6243 // No valid opcode, do not combine.
6244 if (!HasFMAD && !HasFMA)
6245 return false;
6246
6247 AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD;
6248 // If the addition is not contractable, do not combine.
6249 if (!AllowFusionGlobally && !MI.getFlag(Flag: MachineInstr::MIFlag::FmContract))
6250 return false;
6251
6252 Aggressive = TLI.enableAggressiveFMAFusion(Ty: DstType);
6253 return true;
6254}
6255
6256bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA(
6257 MachineInstr &MI,
6258 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6259 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6260
6261 bool AllowFusionGlobally, HasFMAD, Aggressive;
6262 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6263 return false;
6264
6265 Register Op1 = MI.getOperand(i: 1).getReg();
6266 Register Op2 = MI.getOperand(i: 2).getReg();
6267 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6268 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6269 unsigned PreferredFusedOpcode =
6270 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6271
6272 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6273 // prefer to fold the multiply with fewer uses.
6274 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6275 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6276 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6277 std::swap(a&: LHS, b&: RHS);
6278 }
6279
6280 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
6281 if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6282 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg))) {
6283 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6284 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6285 SrcOps: {LHS.MI->getOperand(i: 1).getReg(),
6286 LHS.MI->getOperand(i: 2).getReg(), RHS.Reg});
6287 };
6288 return true;
6289 }
6290
6291 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
6292 if (isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6293 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg))) {
6294 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6295 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6296 SrcOps: {RHS.MI->getOperand(i: 1).getReg(),
6297 RHS.MI->getOperand(i: 2).getReg(), LHS.Reg});
6298 };
6299 return true;
6300 }
6301
6302 return false;
6303}
6304
6305bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA(
6306 MachineInstr &MI,
6307 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6308 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6309
6310 bool AllowFusionGlobally, HasFMAD, Aggressive;
6311 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6312 return false;
6313
6314 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6315 Register Op1 = MI.getOperand(i: 1).getReg();
6316 Register Op2 = MI.getOperand(i: 2).getReg();
6317 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6318 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6319 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6320
6321 unsigned PreferredFusedOpcode =
6322 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6323
6324 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6325 // prefer to fold the multiply with fewer uses.
6326 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6327 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6328 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6329 std::swap(a&: LHS, b&: RHS);
6330 }
6331
6332 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
6333 MachineInstr *FpExtSrc;
6334 if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) &&
6335 isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) &&
6336 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6337 SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) {
6338 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6339 auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg());
6340 auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg());
6341 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6342 SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), RHS.Reg});
6343 };
6344 return true;
6345 }
6346
6347 // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z)
6348 // Note: Commutes FADD operands.
6349 if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) &&
6350 isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) &&
6351 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6352 SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) {
6353 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6354 auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg());
6355 auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg());
6356 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6357 SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), LHS.Reg});
6358 };
6359 return true;
6360 }
6361
6362 return false;
6363}
6364
6365bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA(
6366 MachineInstr &MI,
6367 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6368 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6369
6370 bool AllowFusionGlobally, HasFMAD, Aggressive;
6371 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, CanReassociate: true))
6372 return false;
6373
6374 Register Op1 = MI.getOperand(i: 1).getReg();
6375 Register Op2 = MI.getOperand(i: 2).getReg();
6376 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6377 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6378 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6379
6380 unsigned PreferredFusedOpcode =
6381 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6382
6383 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6384 // prefer to fold the multiply with fewer uses.
6385 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6386 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6387 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6388 std::swap(a&: LHS, b&: RHS);
6389 }
6390
6391 MachineInstr *FMA = nullptr;
6392 Register Z;
6393 // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z))
6394 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
6395 (MRI.getVRegDef(Reg: LHS.MI->getOperand(i: 3).getReg())->getOpcode() ==
6396 TargetOpcode::G_FMUL) &&
6397 MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 0).getReg()) &&
6398 MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 3).getReg())) {
6399 FMA = LHS.MI;
6400 Z = RHS.Reg;
6401 }
6402 // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z))
6403 else if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
6404 (MRI.getVRegDef(Reg: RHS.MI->getOperand(i: 3).getReg())->getOpcode() ==
6405 TargetOpcode::G_FMUL) &&
6406 MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 0).getReg()) &&
6407 MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 3).getReg())) {
6408 Z = LHS.Reg;
6409 FMA = RHS.MI;
6410 }
6411
6412 if (FMA) {
6413 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMA->getOperand(i: 3).getReg());
6414 Register X = FMA->getOperand(i: 1).getReg();
6415 Register Y = FMA->getOperand(i: 2).getReg();
6416 Register U = FMulMI->getOperand(i: 1).getReg();
6417 Register V = FMulMI->getOperand(i: 2).getReg();
6418
6419 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6420 Register InnerFMA = MRI.createGenericVirtualRegister(Ty: DstTy);
6421 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {InnerFMA}, SrcOps: {U, V, Z});
6422 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6423 SrcOps: {X, Y, InnerFMA});
6424 };
6425 return true;
6426 }
6427
6428 return false;
6429}
6430
6431bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
6432 MachineInstr &MI,
6433 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6434 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6435
6436 bool AllowFusionGlobally, HasFMAD, Aggressive;
6437 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6438 return false;
6439
6440 if (!Aggressive)
6441 return false;
6442
6443 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6444 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6445 Register Op1 = MI.getOperand(i: 1).getReg();
6446 Register Op2 = MI.getOperand(i: 2).getReg();
6447 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6448 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6449
6450 unsigned PreferredFusedOpcode =
6451 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6452
6453 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6454 // prefer to fold the multiply with fewer uses.
6455 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6456 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6457 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6458 std::swap(a&: LHS, b&: RHS);
6459 }
6460
6461 // Builds: (fma x, y, (fma (fpext u), (fpext v), z))
6462 auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X,
6463 Register Y, MachineIRBuilder &B) {
6464 Register FpExtU = B.buildFPExt(Res: DstType, Op: U).getReg(Idx: 0);
6465 Register FpExtV = B.buildFPExt(Res: DstType, Op: V).getReg(Idx: 0);
6466 Register InnerFMA =
6467 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {DstType}, SrcOps: {FpExtU, FpExtV, Z})
6468 .getReg(Idx: 0);
6469 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6470 SrcOps: {X, Y, InnerFMA});
6471 };
6472
6473 MachineInstr *FMulMI, *FMAMI;
6474 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
6475 // -> (fma x, y, (fma (fpext u), (fpext v), z))
6476 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
6477 mi_match(R: LHS.MI->getOperand(i: 3).getReg(), MRI,
6478 P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6479 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6480 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6481 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6482 MatchInfo = [=](MachineIRBuilder &B) {
6483 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6484 FMulMI->getOperand(i: 2).getReg(), RHS.Reg,
6485 LHS.MI->getOperand(i: 1).getReg(),
6486 LHS.MI->getOperand(i: 2).getReg(), B);
6487 };
6488 return true;
6489 }
6490
6491 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
6492 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
6493 // FIXME: This turns two single-precision and one double-precision
6494 // operation into two double-precision operations, which might not be
6495 // interesting for all targets, especially GPUs.
6496 if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) &&
6497 FMAMI->getOpcode() == PreferredFusedOpcode) {
6498 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg());
6499 if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6500 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6501 SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) {
6502 MatchInfo = [=](MachineIRBuilder &B) {
6503 Register X = FMAMI->getOperand(i: 1).getReg();
6504 Register Y = FMAMI->getOperand(i: 2).getReg();
6505 X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0);
6506 Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0);
6507 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6508 FMulMI->getOperand(i: 2).getReg(), RHS.Reg, X, Y, B);
6509 };
6510
6511 return true;
6512 }
6513 }
6514
6515 // fold (fadd z, (fma x, y, (fpext (fmul u, v)))
6516 // -> (fma x, y, (fma (fpext u), (fpext v), z))
6517 if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
6518 mi_match(R: RHS.MI->getOperand(i: 3).getReg(), MRI,
6519 P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6520 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6521 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6522 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6523 MatchInfo = [=](MachineIRBuilder &B) {
6524 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6525 FMulMI->getOperand(i: 2).getReg(), LHS.Reg,
6526 RHS.MI->getOperand(i: 1).getReg(),
6527 RHS.MI->getOperand(i: 2).getReg(), B);
6528 };
6529 return true;
6530 }
6531
6532 // fold (fadd z, (fpext (fma x, y, (fmul u, v)))
6533 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
6534 // FIXME: This turns two single-precision and one double-precision
6535 // operation into two double-precision operations, which might not be
6536 // interesting for all targets, especially GPUs.
6537 if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) &&
6538 FMAMI->getOpcode() == PreferredFusedOpcode) {
6539 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg());
6540 if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6541 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6542 SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) {
6543 MatchInfo = [=](MachineIRBuilder &B) {
6544 Register X = FMAMI->getOperand(i: 1).getReg();
6545 Register Y = FMAMI->getOperand(i: 2).getReg();
6546 X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0);
6547 Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0);
6548 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6549 FMulMI->getOperand(i: 2).getReg(), LHS.Reg, X, Y, B);
6550 };
6551 return true;
6552 }
6553 }
6554
6555 return false;
6556}
6557
6558bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA(
6559 MachineInstr &MI,
6560 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6561 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6562
6563 bool AllowFusionGlobally, HasFMAD, Aggressive;
6564 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6565 return false;
6566
6567 Register Op1 = MI.getOperand(i: 1).getReg();
6568 Register Op2 = MI.getOperand(i: 2).getReg();
6569 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6570 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6571 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6572
6573 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6574 // prefer to fold the multiply with fewer uses.
6575 int FirstMulHasFewerUses = true;
6576 if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6577 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6578 hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6579 FirstMulHasFewerUses = false;
6580
6581 unsigned PreferredFusedOpcode =
6582 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6583
6584 // fold (fsub (fmul x, y), z) -> (fma x, y, -z)
6585 if (FirstMulHasFewerUses &&
6586 (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6587 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg)))) {
6588 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6589 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHS.Reg).getReg(Idx: 0);
6590 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6591 SrcOps: {LHS.MI->getOperand(i: 1).getReg(),
6592 LHS.MI->getOperand(i: 2).getReg(), NegZ});
6593 };
6594 return true;
6595 }
6596 // fold (fsub x, (fmul y, z)) -> (fma -y, z, x)
6597 else if ((isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6598 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg)))) {
6599 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6600 Register NegY =
6601 B.buildFNeg(Dst: DstTy, Src0: RHS.MI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6602 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6603 SrcOps: {NegY, RHS.MI->getOperand(i: 2).getReg(), LHS.Reg});
6604 };
6605 return true;
6606 }
6607
6608 return false;
6609}
6610
6611bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA(
6612 MachineInstr &MI,
6613 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6614 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6615
6616 bool AllowFusionGlobally, HasFMAD, Aggressive;
6617 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6618 return false;
6619
6620 Register LHSReg = MI.getOperand(i: 1).getReg();
6621 Register RHSReg = MI.getOperand(i: 2).getReg();
6622 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6623
6624 unsigned PreferredFusedOpcode =
6625 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6626
6627 MachineInstr *FMulMI;
6628 // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z))
6629 if (mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) &&
6630 (Aggressive || (MRI.hasOneNonDBGUse(RegNo: LHSReg) &&
6631 MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) &&
6632 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) {
6633 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6634 Register NegX =
6635 B.buildFNeg(Dst: DstTy, Src0: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6636 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0);
6637 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6638 SrcOps: {NegX, FMulMI->getOperand(i: 2).getReg(), NegZ});
6639 };
6640 return true;
6641 }
6642
6643 // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x)
6644 if (mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) &&
6645 (Aggressive || (MRI.hasOneNonDBGUse(RegNo: RHSReg) &&
6646 MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) &&
6647 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) {
6648 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6649 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6650 SrcOps: {FMulMI->getOperand(i: 1).getReg(),
6651 FMulMI->getOperand(i: 2).getReg(), LHSReg});
6652 };
6653 return true;
6654 }
6655
6656 return false;
6657}
6658
6659bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA(
6660 MachineInstr &MI,
6661 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6662 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6663
6664 bool AllowFusionGlobally, HasFMAD, Aggressive;
6665 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6666 return false;
6667
6668 Register LHSReg = MI.getOperand(i: 1).getReg();
6669 Register RHSReg = MI.getOperand(i: 2).getReg();
6670 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6671
6672 unsigned PreferredFusedOpcode =
6673 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6674
6675 MachineInstr *FMulMI;
6676 // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
6677 if (mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6678 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6679 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHSReg))) {
6680 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6681 Register FpExtX =
6682 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6683 Register FpExtY =
6684 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0);
6685 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0);
6686 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6687 SrcOps: {FpExtX, FpExtY, NegZ});
6688 };
6689 return true;
6690 }
6691
6692 // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
6693 if (mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6694 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6695 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHSReg))) {
6696 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6697 Register FpExtY =
6698 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6699 Register NegY = B.buildFNeg(Dst: DstTy, Src0: FpExtY).getReg(Idx: 0);
6700 Register FpExtZ =
6701 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0);
6702 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6703 SrcOps: {NegY, FpExtZ, LHSReg});
6704 };
6705 return true;
6706 }
6707
6708 return false;
6709}
6710
6711bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
6712 MachineInstr &MI,
6713 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6714 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6715
6716 bool AllowFusionGlobally, HasFMAD, Aggressive;
6717 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6718 return false;
6719
6720 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6721 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6722 Register LHSReg = MI.getOperand(i: 1).getReg();
6723 Register RHSReg = MI.getOperand(i: 2).getReg();
6724
6725 unsigned PreferredFusedOpcode =
6726 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6727
6728 auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z,
6729 MachineIRBuilder &B) {
6730 Register FpExtX = B.buildFPExt(Res: DstTy, Op: X).getReg(Idx: 0);
6731 Register FpExtY = B.buildFPExt(Res: DstTy, Op: Y).getReg(Idx: 0);
6732 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {Dst}, SrcOps: {FpExtX, FpExtY, Z});
6733 };
6734
6735 MachineInstr *FMulMI;
6736 // fold (fsub (fpext (fneg (fmul x, y))), z) ->
6737 // (fneg (fma (fpext x), (fpext y), z))
6738 // fold (fsub (fneg (fpext (fmul x, y))), z) ->
6739 // (fneg (fma (fpext x), (fpext y), z))
6740 if ((mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) ||
6741 mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) &&
6742 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6743 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy,
6744 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6745 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6746 Register FMAReg = MRI.createGenericVirtualRegister(Ty: DstTy);
6747 buildMatchInfo(FMAReg, FMulMI->getOperand(i: 1).getReg(),
6748 FMulMI->getOperand(i: 2).getReg(), RHSReg, B);
6749 B.buildFNeg(Dst: MI.getOperand(i: 0).getReg(), Src0: FMAReg);
6750 };
6751 return true;
6752 }
6753
6754 // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
6755 // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
6756 if ((mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) ||
6757 mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) &&
6758 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6759 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy,
6760 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6761 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6762 buildMatchInfo(MI.getOperand(i: 0).getReg(), FMulMI->getOperand(i: 1).getReg(),
6763 FMulMI->getOperand(i: 2).getReg(), LHSReg, B);
6764 };
6765 return true;
6766 }
6767
6768 return false;
6769}
6770
6771bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,
6772 unsigned &IdxToPropagate) const {
6773 bool PropagateNaN;
6774 switch (MI.getOpcode()) {
6775 default:
6776 return false;
6777 case TargetOpcode::G_FMINNUM:
6778 case TargetOpcode::G_FMAXNUM:
6779 PropagateNaN = false;
6780 break;
6781 case TargetOpcode::G_FMINIMUM:
6782 case TargetOpcode::G_FMAXIMUM:
6783 PropagateNaN = true;
6784 break;
6785 }
6786
6787 auto MatchNaN = [&](unsigned Idx) {
6788 Register MaybeNaNReg = MI.getOperand(i: Idx).getReg();
6789 const ConstantFP *MaybeCst = getConstantFPVRegVal(VReg: MaybeNaNReg, MRI);
6790 if (!MaybeCst || !MaybeCst->getValueAPF().isNaN())
6791 return false;
6792 IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1);
6793 return true;
6794 };
6795
6796 return MatchNaN(1) || MatchNaN(2);
6797}
6798
6799// Combine multiple FDIVs with the same divisor into multiple FMULs by the
6800// reciprocal.
6801// E.g., (a / Y; b / Y;) -> (recip = 1.0 / Y; a * recip; b * recip)
6802bool CombinerHelper::matchRepeatedFPDivisor(
6803 MachineInstr &MI, SmallVector<MachineInstr *> &MatchInfo) const {
6804 assert(MI.getOpcode() == TargetOpcode::G_FDIV);
6805
6806 Register X = MI.getOperand(i: 1).getReg();
6807 Register Y = MI.getOperand(i: 2).getReg();
6808
6809 if (!MI.getFlag(Flag: MachineInstr::MIFlag::FmArcp))
6810 return false;
6811
6812 // Skip if current node is a reciprocal/fneg-reciprocal.
6813 auto N0CFP = isConstantOrConstantSplatVectorFP(MI&: *MRI.getVRegDef(Reg: X), MRI);
6814 if (N0CFP && (N0CFP->isExactlyValue(V: 1.0) || N0CFP->isExactlyValue(V: -1.0)))
6815 return false;
6816
6817 // Exit early if the target does not want this transform or if there can't
6818 // possibly be enough uses of the divisor to make the transform worthwhile.
6819 unsigned MinUses = getTargetLowering().combineRepeatedFPDivisors();
6820 if (!MinUses)
6821 return false;
6822
6823 // Find all FDIV users of the same divisor. For the moment we limit all
6824 // instructions to a single BB and use the first Instr in MatchInfo as the
6825 // dominating position.
6826 MatchInfo.push_back(Elt: &MI);
6827 for (auto &U : MRI.use_nodbg_instructions(Reg: Y)) {
6828 if (&U == &MI || U.getParent() != MI.getParent())
6829 continue;
6830 if (U.getOpcode() == TargetOpcode::G_FDIV &&
6831 U.getOperand(i: 2).getReg() == Y && U.getOperand(i: 1).getReg() != Y) {
6832 // This division is eligible for optimization only if global unsafe math
6833 // is enabled or if this division allows reciprocal formation.
6834 if (U.getFlag(Flag: MachineInstr::MIFlag::FmArcp)) {
6835 MatchInfo.push_back(Elt: &U);
6836 if (dominates(DefMI: U, UseMI: *MatchInfo[0]))
6837 std::swap(a&: MatchInfo[0], b&: MatchInfo.back());
6838 }
6839 }
6840 }
6841
6842 // Now that we have the actual number of divisor uses, make sure it meets
6843 // the minimum threshold specified by the target.
6844 return MatchInfo.size() >= MinUses;
6845}
6846
6847void CombinerHelper::applyRepeatedFPDivisor(
6848 SmallVector<MachineInstr *> &MatchInfo) const {
6849 // Generate the new div at the position of the first instruction, that we have
6850 // ensured will dominate all other instructions.
6851 Builder.setInsertPt(MBB&: *MatchInfo[0]->getParent(), II: MatchInfo[0]);
6852 LLT Ty = MRI.getType(Reg: MatchInfo[0]->getOperand(i: 0).getReg());
6853 auto Div = Builder.buildFDiv(Dst: Ty, Src0: Builder.buildFConstant(Res: Ty, Val: 1.0),
6854 Src1: MatchInfo[0]->getOperand(i: 2).getReg(),
6855 Flags: MatchInfo[0]->getFlags());
6856
6857 // Replace all found div's with fmul instructions.
6858 for (MachineInstr *MI : MatchInfo) {
6859 Builder.setInsertPt(MBB&: *MI->getParent(), II: MI);
6860 Builder.buildFMul(Dst: MI->getOperand(i: 0).getReg(), Src0: MI->getOperand(i: 1).getReg(),
6861 Src1: Div->getOperand(i: 0).getReg(), Flags: MI->getFlags());
6862 MI->eraseFromParent();
6863 }
6864}
6865
6866bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) const {
6867 assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD");
6868 Register LHS = MI.getOperand(i: 1).getReg();
6869 Register RHS = MI.getOperand(i: 2).getReg();
6870
6871 // Helper lambda to check for opportunities for
6872 // A + (B - A) -> B
6873 // (B - A) + A -> B
6874 auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) {
6875 Register Reg;
6876 return mi_match(R: MaybeSub, MRI, P: m_GSub(L: m_Reg(R&: Src), R: m_Reg(R&: Reg))) &&
6877 Reg == MaybeSameReg;
6878 };
6879 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
6880}
6881
6882bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI,
6883 Register &MatchInfo) const {
6884 // This combine folds the following patterns:
6885 //
6886 // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k))
6887 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k)))
6888 // into
6889 // x
6890 // if
6891 // k == sizeof(VecEltTy)/2
6892 // type(x) == type(dst)
6893 //
6894 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef)
6895 // into
6896 // x
6897 // if
6898 // type(x) == type(dst)
6899
6900 LLT DstVecTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6901 LLT DstEltTy = DstVecTy.getElementType();
6902
6903 Register Lo, Hi;
6904
6905 if (mi_match(
6906 MI, MRI,
6907 P: m_GBuildVector(L: m_GTrunc(Src: m_GBitcast(Src: m_Reg(R&: Lo))), R: m_GImplicitDef()))) {
6908 MatchInfo = Lo;
6909 return MRI.getType(Reg: MatchInfo) == DstVecTy;
6910 }
6911
6912 std::optional<ValueAndVReg> ShiftAmount;
6913 const auto LoPattern = m_GBitcast(Src: m_Reg(R&: Lo));
6914 const auto HiPattern = m_GLShr(L: m_GBitcast(Src: m_Reg(R&: Hi)), R: m_GCst(ValReg&: ShiftAmount));
6915 if (mi_match(
6916 MI, MRI,
6917 P: m_any_of(preds: m_GBuildVectorTrunc(L: LoPattern, R: HiPattern),
6918 preds: m_GBuildVector(L: m_GTrunc(Src: LoPattern), R: m_GTrunc(Src: HiPattern))))) {
6919 if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) {
6920 MatchInfo = Lo;
6921 return MRI.getType(Reg: MatchInfo) == DstVecTy;
6922 }
6923 }
6924
6925 return false;
6926}
6927
6928bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI,
6929 Register &MatchInfo) const {
6930 // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x
6931 // if type(x) == type(G_TRUNC)
6932 if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI,
6933 P: m_GBitcast(Src: m_GBuildVector(L: m_Reg(R&: MatchInfo), R: m_Reg()))))
6934 return false;
6935
6936 return MRI.getType(Reg: MatchInfo) == MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6937}
6938
6939bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI,
6940 Register &MatchInfo) const {
6941 // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with
6942 // y if K == size of vector element type
6943 std::optional<ValueAndVReg> ShiftAmt;
6944 if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI,
6945 P: m_GLShr(L: m_GBitcast(Src: m_GBuildVector(L: m_Reg(), R: m_Reg(R&: MatchInfo))),
6946 R: m_GCst(ValReg&: ShiftAmt))))
6947 return false;
6948
6949 LLT MatchTy = MRI.getType(Reg: MatchInfo);
6950 return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() &&
6951 MatchTy == MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6952}
6953
6954unsigned CombinerHelper::getFPMinMaxOpcForSelect(
6955 CmpInst::Predicate Pred, LLT DstTy,
6956 SelectPatternNaNBehaviour VsNaNRetVal) const {
6957 assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE &&
6958 "Expected a NaN behaviour?");
6959 // Choose an opcode based off of legality or the behaviour when one of the
6960 // LHS/RHS may be NaN.
6961 switch (Pred) {
6962 default:
6963 return 0;
6964 case CmpInst::FCMP_UGT:
6965 case CmpInst::FCMP_UGE:
6966 case CmpInst::FCMP_OGT:
6967 case CmpInst::FCMP_OGE:
6968 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6969 return TargetOpcode::G_FMAXNUM;
6970 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6971 return TargetOpcode::G_FMAXIMUM;
6972 if (isLegal(Query: {TargetOpcode::G_FMAXNUM, {DstTy}}))
6973 return TargetOpcode::G_FMAXNUM;
6974 if (isLegal(Query: {TargetOpcode::G_FMAXIMUM, {DstTy}}))
6975 return TargetOpcode::G_FMAXIMUM;
6976 return 0;
6977 case CmpInst::FCMP_ULT:
6978 case CmpInst::FCMP_ULE:
6979 case CmpInst::FCMP_OLT:
6980 case CmpInst::FCMP_OLE:
6981 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6982 return TargetOpcode::G_FMINNUM;
6983 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6984 return TargetOpcode::G_FMINIMUM;
6985 if (isLegal(Query: {TargetOpcode::G_FMINNUM, {DstTy}}))
6986 return TargetOpcode::G_FMINNUM;
6987 if (!isLegal(Query: {TargetOpcode::G_FMINIMUM, {DstTy}}))
6988 return 0;
6989 return TargetOpcode::G_FMINIMUM;
6990 }
6991}
6992
6993CombinerHelper::SelectPatternNaNBehaviour
6994CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS,
6995 bool IsOrderedComparison) const {
6996 bool LHSSafe = isKnownNeverNaN(Val: LHS, MRI);
6997 bool RHSSafe = isKnownNeverNaN(Val: RHS, MRI);
6998 // Completely unsafe.
6999 if (!LHSSafe && !RHSSafe)
7000 return SelectPatternNaNBehaviour::NOT_APPLICABLE;
7001 if (LHSSafe && RHSSafe)
7002 return SelectPatternNaNBehaviour::RETURNS_ANY;
7003 // An ordered comparison will return false when given a NaN, so it
7004 // returns the RHS.
7005 if (IsOrderedComparison)
7006 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN
7007 : SelectPatternNaNBehaviour::RETURNS_OTHER;
7008 // An unordered comparison will return true when given a NaN, so it
7009 // returns the LHS.
7010 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER
7011 : SelectPatternNaNBehaviour::RETURNS_NAN;
7012}
7013
7014bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond,
7015 Register TrueVal, Register FalseVal,
7016 BuildFnTy &MatchInfo) const {
7017 // Match: select (fcmp cond x, y) x, y
7018 // select (fcmp cond x, y) y, x
7019 // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition.
7020 LLT DstTy = MRI.getType(Reg: Dst);
7021 // Bail out early on pointers, since we'll never want to fold to a min/max.
7022 if (DstTy.isPointer())
7023 return false;
7024 // Match a floating point compare with a less-than/greater-than predicate.
7025 // TODO: Allow multiple users of the compare if they are all selects.
7026 CmpInst::Predicate Pred;
7027 Register CmpLHS, CmpRHS;
7028 if (!mi_match(R: Cond, MRI,
7029 P: m_OneNonDBGUse(
7030 SP: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: CmpLHS), R: m_Reg(R&: CmpRHS)))) ||
7031 CmpInst::isEquality(pred: Pred))
7032 return false;
7033 SelectPatternNaNBehaviour ResWithKnownNaNInfo =
7034 computeRetValAgainstNaN(LHS: CmpLHS, RHS: CmpRHS, IsOrderedComparison: CmpInst::isOrdered(predicate: Pred));
7035 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE)
7036 return false;
7037 if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
7038 std::swap(a&: CmpLHS, b&: CmpRHS);
7039 Pred = CmpInst::getSwappedPredicate(pred: Pred);
7040 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN)
7041 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER;
7042 else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER)
7043 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN;
7044 }
7045 if (TrueVal != CmpLHS || FalseVal != CmpRHS)
7046 return false;
7047 // Decide what type of max/min this should be based off of the predicate.
7048 unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, VsNaNRetVal: ResWithKnownNaNInfo);
7049 if (!Opc || !isLegal(Query: {Opc, {DstTy}}))
7050 return false;
7051 // Comparisons between signed zero and zero may have different results...
7052 // unless we have fmaximum/fminimum. In that case, we know -0 < 0.
7053 if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) {
7054 // We don't know if a comparison between two 0s will give us a consistent
7055 // result. Be conservative and only proceed if at least one side is
7056 // non-zero.
7057 auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpLHS, MRI);
7058 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) {
7059 KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpRHS, MRI);
7060 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero())
7061 return false;
7062 }
7063 }
7064 MatchInfo = [=](MachineIRBuilder &B) {
7065 B.buildInstr(Opc, DstOps: {Dst}, SrcOps: {CmpLHS, CmpRHS});
7066 };
7067 return true;
7068}
7069
7070bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI,
7071 BuildFnTy &MatchInfo) const {
7072 // TODO: Handle integer cases.
7073 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
7074 // Condition may be fed by a truncated compare.
7075 Register Cond = MI.getOperand(i: 1).getReg();
7076 Register MaybeTrunc;
7077 if (mi_match(R: Cond, MRI, P: m_OneNonDBGUse(SP: m_GTrunc(Src: m_Reg(R&: MaybeTrunc)))))
7078 Cond = MaybeTrunc;
7079 Register Dst = MI.getOperand(i: 0).getReg();
7080 Register TrueVal = MI.getOperand(i: 2).getReg();
7081 Register FalseVal = MI.getOperand(i: 3).getReg();
7082 return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo);
7083}
7084
7085bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI,
7086 BuildFnTy &MatchInfo) const {
7087 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
7088 // (X + Y) == X --> Y == 0
7089 // (X + Y) != X --> Y != 0
7090 // (X - Y) == X --> Y == 0
7091 // (X - Y) != X --> Y != 0
7092 // (X ^ Y) == X --> Y == 0
7093 // (X ^ Y) != X --> Y != 0
7094 Register Dst = MI.getOperand(i: 0).getReg();
7095 CmpInst::Predicate Pred;
7096 Register X, Y, OpLHS, OpRHS;
7097 bool MatchedSub = mi_match(
7098 R: Dst, MRI,
7099 P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), R: m_GSub(L: m_Reg(R&: OpLHS), R: m_Reg(R&: Y))));
7100 if (MatchedSub && X != OpLHS)
7101 return false;
7102 if (!MatchedSub) {
7103 if (!mi_match(R: Dst, MRI,
7104 P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X),
7105 R: m_any_of(preds: m_GAdd(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)),
7106 preds: m_GXor(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS))))))
7107 return false;
7108 Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register();
7109 }
7110 MatchInfo = [=](MachineIRBuilder &B) {
7111 auto Zero = B.buildConstant(Res: MRI.getType(Reg: Y), Val: 0);
7112 B.buildICmp(Pred, Res: Dst, Op0: Y, Op1: Zero);
7113 };
7114 return CmpInst::isEquality(pred: Pred) && Y.isValid();
7115}
7116
7117/// Return the minimum useless shift amount that results in complete loss of the
7118/// source value. Return std::nullopt when it cannot determine a value.
7119static std::optional<unsigned>
7120getMinUselessShift(KnownBits ValueKB, unsigned Opcode,
7121 std::optional<int64_t> &Result) {
7122 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR ||
7123 Opcode == TargetOpcode::G_ASHR) &&
7124 "Expect G_SHL, G_LSHR or G_ASHR.");
7125 auto SignificantBits = 0;
7126 switch (Opcode) {
7127 case TargetOpcode::G_SHL:
7128 SignificantBits = ValueKB.countMinTrailingZeros();
7129 Result = 0;
7130 break;
7131 case TargetOpcode::G_LSHR:
7132 Result = 0;
7133 SignificantBits = ValueKB.countMinLeadingZeros();
7134 break;
7135 case TargetOpcode::G_ASHR:
7136 if (ValueKB.isNonNegative()) {
7137 SignificantBits = ValueKB.countMinLeadingZeros();
7138 Result = 0;
7139 } else if (ValueKB.isNegative()) {
7140 SignificantBits = ValueKB.countMinLeadingOnes();
7141 Result = -1;
7142 } else {
7143 // Cannot determine shift result.
7144 Result = std::nullopt;
7145 }
7146 break;
7147 default:
7148 break;
7149 }
7150 return ValueKB.getBitWidth() - SignificantBits;
7151}
7152
7153bool CombinerHelper::matchShiftsTooBig(
7154 MachineInstr &MI, std::optional<int64_t> &MatchInfo) const {
7155 Register ShiftVal = MI.getOperand(i: 1).getReg();
7156 Register ShiftReg = MI.getOperand(i: 2).getReg();
7157 LLT ResTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
7158 auto IsShiftTooBig = [&](const Constant *C) {
7159 auto *CI = dyn_cast<ConstantInt>(Val: C);
7160 if (!CI)
7161 return false;
7162 if (CI->uge(Num: ResTy.getScalarSizeInBits())) {
7163 MatchInfo = std::nullopt;
7164 return true;
7165 }
7166 auto OptMaxUsefulShift = getMinUselessShift(ValueKB: VT->getKnownBits(R: ShiftVal),
7167 Opcode: MI.getOpcode(), Result&: MatchInfo);
7168 return OptMaxUsefulShift && CI->uge(Num: *OptMaxUsefulShift);
7169 };
7170 return matchUnaryPredicate(MRI, Reg: ShiftReg, Match: IsShiftTooBig);
7171}
7172
7173bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) const {
7174 unsigned LHSOpndIdx = 1;
7175 unsigned RHSOpndIdx = 2;
7176 switch (MI.getOpcode()) {
7177 case TargetOpcode::G_UADDO:
7178 case TargetOpcode::G_SADDO:
7179 case TargetOpcode::G_UMULO:
7180 case TargetOpcode::G_SMULO:
7181 LHSOpndIdx = 2;
7182 RHSOpndIdx = 3;
7183 break;
7184 default:
7185 break;
7186 }
7187 Register LHS = MI.getOperand(i: LHSOpndIdx).getReg();
7188 Register RHS = MI.getOperand(i: RHSOpndIdx).getReg();
7189 if (!getIConstantVRegVal(VReg: LHS, MRI)) {
7190 // Skip commuting if LHS is not a constant. But, LHS may be a
7191 // G_CONSTANT_FOLD_BARRIER. If so we commute as long as we don't already
7192 // have a constant on the RHS.
7193 if (MRI.getVRegDef(Reg: LHS)->getOpcode() !=
7194 TargetOpcode::G_CONSTANT_FOLD_BARRIER)
7195 return false;
7196 }
7197 // Commute as long as RHS is not a constant or G_CONSTANT_FOLD_BARRIER.
7198 return MRI.getVRegDef(Reg: RHS)->getOpcode() !=
7199 TargetOpcode::G_CONSTANT_FOLD_BARRIER &&
7200 !getIConstantVRegVal(VReg: RHS, MRI);
7201}
7202
7203bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) const {
7204 Register LHS = MI.getOperand(i: 1).getReg();
7205 Register RHS = MI.getOperand(i: 2).getReg();
7206 std::optional<FPValueAndVReg> ValAndVReg;
7207 if (!mi_match(R: LHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg)))
7208 return false;
7209 return !mi_match(R: RHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg));
7210}
7211
7212void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) const {
7213 Observer.changingInstr(MI);
7214 unsigned LHSOpndIdx = 1;
7215 unsigned RHSOpndIdx = 2;
7216 switch (MI.getOpcode()) {
7217 case TargetOpcode::G_UADDO:
7218 case TargetOpcode::G_SADDO:
7219 case TargetOpcode::G_UMULO:
7220 case TargetOpcode::G_SMULO:
7221 LHSOpndIdx = 2;
7222 RHSOpndIdx = 3;
7223 break;
7224 default:
7225 break;
7226 }
7227 Register LHSReg = MI.getOperand(i: LHSOpndIdx).getReg();
7228 Register RHSReg = MI.getOperand(i: RHSOpndIdx).getReg();
7229 MI.getOperand(i: LHSOpndIdx).setReg(RHSReg);
7230 MI.getOperand(i: RHSOpndIdx).setReg(LHSReg);
7231 Observer.changedInstr(MI);
7232}
7233
7234bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) const {
7235 LLT SrcTy = MRI.getType(Reg: Src);
7236 if (SrcTy.isFixedVector())
7237 return isConstantSplatVector(Src, SplatValue: 1, AllowUndefs);
7238 if (SrcTy.isScalar()) {
7239 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr)
7240 return true;
7241 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7242 return IConstant && IConstant->Value == 1;
7243 }
7244 return false; // scalable vector
7245}
7246
7247bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) const {
7248 LLT SrcTy = MRI.getType(Reg: Src);
7249 if (SrcTy.isFixedVector())
7250 return isConstantSplatVector(Src, SplatValue: 0, AllowUndefs);
7251 if (SrcTy.isScalar()) {
7252 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr)
7253 return true;
7254 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7255 return IConstant && IConstant->Value == 0;
7256 }
7257 return false; // scalable vector
7258}
7259
7260// Ignores COPYs during conformance checks.
7261// FIXME scalable vectors.
7262bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
7263 bool AllowUndefs) const {
7264 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7265 if (!BuildVector)
7266 return false;
7267 unsigned NumSources = BuildVector->getNumSources();
7268
7269 for (unsigned I = 0; I < NumSources; ++I) {
7270 GImplicitDef *ImplicitDef =
7271 getOpcodeDef<GImplicitDef>(Reg: BuildVector->getSourceReg(I), MRI);
7272 if (ImplicitDef && AllowUndefs)
7273 continue;
7274 if (ImplicitDef && !AllowUndefs)
7275 return false;
7276 std::optional<ValueAndVReg> IConstant =
7277 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7278 if (IConstant && IConstant->Value == SplatValue)
7279 continue;
7280 return false;
7281 }
7282 return true;
7283}
7284
7285// Ignores COPYs during lookups.
7286// FIXME scalable vectors
7287std::optional<APInt>
7288CombinerHelper::getConstantOrConstantSplatVector(Register Src) const {
7289 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7290 if (IConstant)
7291 return IConstant->Value;
7292
7293 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7294 if (!BuildVector)
7295 return std::nullopt;
7296 unsigned NumSources = BuildVector->getNumSources();
7297
7298 std::optional<APInt> Value = std::nullopt;
7299 for (unsigned I = 0; I < NumSources; ++I) {
7300 std::optional<ValueAndVReg> IConstant =
7301 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7302 if (!IConstant)
7303 return std::nullopt;
7304 if (!Value)
7305 Value = IConstant->Value;
7306 else if (*Value != IConstant->Value)
7307 return std::nullopt;
7308 }
7309 return Value;
7310}
7311
7312// FIXME G_SPLAT_VECTOR
7313bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const {
7314 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7315 if (IConstant)
7316 return true;
7317
7318 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7319 if (!BuildVector)
7320 return false;
7321
7322 unsigned NumSources = BuildVector->getNumSources();
7323 for (unsigned I = 0; I < NumSources; ++I) {
7324 std::optional<ValueAndVReg> IConstant =
7325 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7326 if (!IConstant)
7327 return false;
7328 }
7329 return true;
7330}
7331
7332// TODO: use knownbits to determine zeros
7333bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
7334 BuildFnTy &MatchInfo) const {
7335 uint32_t Flags = Select->getFlags();
7336 Register Dest = Select->getReg(Idx: 0);
7337 Register Cond = Select->getCondReg();
7338 Register True = Select->getTrueReg();
7339 Register False = Select->getFalseReg();
7340 LLT CondTy = MRI.getType(Reg: Select->getCondReg());
7341 LLT TrueTy = MRI.getType(Reg: Select->getTrueReg());
7342
7343 // We only do this combine for scalar boolean conditions.
7344 if (CondTy != LLT::scalar(SizeInBits: 1))
7345 return false;
7346
7347 if (TrueTy.isPointer())
7348 return false;
7349
7350 // Both are scalars.
7351 std::optional<ValueAndVReg> TrueOpt =
7352 getIConstantVRegValWithLookThrough(VReg: True, MRI);
7353 std::optional<ValueAndVReg> FalseOpt =
7354 getIConstantVRegValWithLookThrough(VReg: False, MRI);
7355
7356 if (!TrueOpt || !FalseOpt)
7357 return false;
7358
7359 APInt TrueValue = TrueOpt->Value;
7360 APInt FalseValue = FalseOpt->Value;
7361
7362 // select Cond, 1, 0 --> zext (Cond)
7363 if (TrueValue.isOne() && FalseValue.isZero()) {
7364 MatchInfo = [=](MachineIRBuilder &B) {
7365 B.setInstrAndDebugLoc(*Select);
7366 B.buildZExtOrTrunc(Res: Dest, Op: Cond);
7367 };
7368 return true;
7369 }
7370
7371 // select Cond, -1, 0 --> sext (Cond)
7372 if (TrueValue.isAllOnes() && FalseValue.isZero()) {
7373 MatchInfo = [=](MachineIRBuilder &B) {
7374 B.setInstrAndDebugLoc(*Select);
7375 B.buildSExtOrTrunc(Res: Dest, Op: Cond);
7376 };
7377 return true;
7378 }
7379
7380 // select Cond, 0, 1 --> zext (!Cond)
7381 if (TrueValue.isZero() && FalseValue.isOne()) {
7382 MatchInfo = [=](MachineIRBuilder &B) {
7383 B.setInstrAndDebugLoc(*Select);
7384 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7385 B.buildNot(Dst: Inner, Src0: Cond);
7386 B.buildZExtOrTrunc(Res: Dest, Op: Inner);
7387 };
7388 return true;
7389 }
7390
7391 // select Cond, 0, -1 --> sext (!Cond)
7392 if (TrueValue.isZero() && FalseValue.isAllOnes()) {
7393 MatchInfo = [=](MachineIRBuilder &B) {
7394 B.setInstrAndDebugLoc(*Select);
7395 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7396 B.buildNot(Dst: Inner, Src0: Cond);
7397 B.buildSExtOrTrunc(Res: Dest, Op: Inner);
7398 };
7399 return true;
7400 }
7401
7402 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
7403 if (TrueValue - 1 == FalseValue) {
7404 MatchInfo = [=](MachineIRBuilder &B) {
7405 B.setInstrAndDebugLoc(*Select);
7406 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7407 B.buildZExtOrTrunc(Res: Inner, Op: Cond);
7408 B.buildAdd(Dst: Dest, Src0: Inner, Src1: False);
7409 };
7410 return true;
7411 }
7412
7413 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
7414 if (TrueValue + 1 == FalseValue) {
7415 MatchInfo = [=](MachineIRBuilder &B) {
7416 B.setInstrAndDebugLoc(*Select);
7417 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7418 B.buildSExtOrTrunc(Res: Inner, Op: Cond);
7419 B.buildAdd(Dst: Dest, Src0: Inner, Src1: False);
7420 };
7421 return true;
7422 }
7423
7424 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
7425 if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
7426 MatchInfo = [=](MachineIRBuilder &B) {
7427 B.setInstrAndDebugLoc(*Select);
7428 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7429 B.buildZExtOrTrunc(Res: Inner, Op: Cond);
7430 // The shift amount must be scalar.
7431 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
7432 auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: TrueValue.exactLogBase2());
7433 B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags);
7434 };
7435 return true;
7436 }
7437
7438 // select Cond, 0, Pow2 --> (zext (!Cond)) << log2(Pow2)
7439 if (FalseValue.isPowerOf2() && TrueValue.isZero()) {
7440 MatchInfo = [=](MachineIRBuilder &B) {
7441 B.setInstrAndDebugLoc(*Select);
7442 Register Not = MRI.createGenericVirtualRegister(Ty: CondTy);
7443 B.buildNot(Dst: Not, Src0: Cond);
7444 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7445 B.buildZExtOrTrunc(Res: Inner, Op: Not);
7446 // The shift amount must be scalar.
7447 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
7448 auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: FalseValue.exactLogBase2());
7449 B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags);
7450 };
7451 return true;
7452 }
7453
7454 // select Cond, -1, C --> or (sext Cond), C
7455 if (TrueValue.isAllOnes()) {
7456 MatchInfo = [=](MachineIRBuilder &B) {
7457 B.setInstrAndDebugLoc(*Select);
7458 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7459 B.buildSExtOrTrunc(Res: Inner, Op: Cond);
7460 B.buildOr(Dst: Dest, Src0: Inner, Src1: False, Flags);
7461 };
7462 return true;
7463 }
7464
7465 // select Cond, C, -1 --> or (sext (not Cond)), C
7466 if (FalseValue.isAllOnes()) {
7467 MatchInfo = [=](MachineIRBuilder &B) {
7468 B.setInstrAndDebugLoc(*Select);
7469 Register Not = MRI.createGenericVirtualRegister(Ty: CondTy);
7470 B.buildNot(Dst: Not, Src0: Cond);
7471 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7472 B.buildSExtOrTrunc(Res: Inner, Op: Not);
7473 B.buildOr(Dst: Dest, Src0: Inner, Src1: True, Flags);
7474 };
7475 return true;
7476 }
7477
7478 return false;
7479}
7480
7481// TODO: use knownbits to determine zeros
7482bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
7483 BuildFnTy &MatchInfo) const {
7484 uint32_t Flags = Select->getFlags();
7485 Register DstReg = Select->getReg(Idx: 0);
7486 Register Cond = Select->getCondReg();
7487 Register True = Select->getTrueReg();
7488 Register False = Select->getFalseReg();
7489 LLT CondTy = MRI.getType(Reg: Select->getCondReg());
7490 LLT TrueTy = MRI.getType(Reg: Select->getTrueReg());
7491
7492 // Boolean or fixed vector of booleans.
7493 if (CondTy.isScalableVector() ||
7494 (CondTy.isFixedVector() &&
7495 CondTy.getElementType().getScalarSizeInBits() != 1) ||
7496 CondTy.getScalarSizeInBits() != 1)
7497 return false;
7498
7499 if (CondTy != TrueTy)
7500 return false;
7501
7502 // select Cond, Cond, F --> or Cond, F
7503 // select Cond, 1, F --> or Cond, F
7504 if ((Cond == True) || isOneOrOneSplat(Src: True, /* AllowUndefs */ true)) {
7505 MatchInfo = [=](MachineIRBuilder &B) {
7506 B.setInstrAndDebugLoc(*Select);
7507 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7508 B.buildZExtOrTrunc(Res: Ext, Op: Cond);
7509 auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False);
7510 B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeFalse, Flags);
7511 };
7512 return true;
7513 }
7514
7515 // select Cond, T, Cond --> and Cond, T
7516 // select Cond, T, 0 --> and Cond, T
7517 if ((Cond == False) || isZeroOrZeroSplat(Src: False, /* AllowUndefs */ true)) {
7518 MatchInfo = [=](MachineIRBuilder &B) {
7519 B.setInstrAndDebugLoc(*Select);
7520 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7521 B.buildZExtOrTrunc(Res: Ext, Op: Cond);
7522 auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True);
7523 B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeTrue);
7524 };
7525 return true;
7526 }
7527
7528 // select Cond, T, 1 --> or (not Cond), T
7529 if (isOneOrOneSplat(Src: False, /* AllowUndefs */ true)) {
7530 MatchInfo = [=](MachineIRBuilder &B) {
7531 B.setInstrAndDebugLoc(*Select);
7532 // First the not.
7533 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7534 B.buildNot(Dst: Inner, Src0: Cond);
7535 // Then an ext to match the destination register.
7536 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7537 B.buildZExtOrTrunc(Res: Ext, Op: Inner);
7538 auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True);
7539 B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeTrue, Flags);
7540 };
7541 return true;
7542 }
7543
7544 // select Cond, 0, F --> and (not Cond), F
7545 if (isZeroOrZeroSplat(Src: True, /* AllowUndefs */ true)) {
7546 MatchInfo = [=](MachineIRBuilder &B) {
7547 B.setInstrAndDebugLoc(*Select);
7548 // First the not.
7549 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7550 B.buildNot(Dst: Inner, Src0: Cond);
7551 // Then an ext to match the destination register.
7552 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7553 B.buildZExtOrTrunc(Res: Ext, Op: Inner);
7554 auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False);
7555 B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeFalse);
7556 };
7557 return true;
7558 }
7559
7560 return false;
7561}
7562
7563bool CombinerHelper::matchSelectIMinMax(const MachineOperand &MO,
7564 BuildFnTy &MatchInfo) const {
7565 GSelect *Select = cast<GSelect>(Val: MRI.getVRegDef(Reg: MO.getReg()));
7566 GICmp *Cmp = cast<GICmp>(Val: MRI.getVRegDef(Reg: Select->getCondReg()));
7567
7568 Register DstReg = Select->getReg(Idx: 0);
7569 Register True = Select->getTrueReg();
7570 Register False = Select->getFalseReg();
7571 LLT DstTy = MRI.getType(Reg: DstReg);
7572
7573 if (DstTy.isPointerOrPointerVector())
7574 return false;
7575
7576 // We want to fold the icmp and replace the select.
7577 if (!MRI.hasOneNonDBGUse(RegNo: Cmp->getReg(Idx: 0)))
7578 return false;
7579
7580 CmpInst::Predicate Pred = Cmp->getCond();
7581 // We need a larger or smaller predicate for
7582 // canonicalization.
7583 if (CmpInst::isEquality(pred: Pred))
7584 return false;
7585
7586 Register CmpLHS = Cmp->getLHSReg();
7587 Register CmpRHS = Cmp->getRHSReg();
7588
7589 // We can swap CmpLHS and CmpRHS for higher hitrate.
7590 if (True == CmpRHS && False == CmpLHS) {
7591 std::swap(a&: CmpLHS, b&: CmpRHS);
7592 Pred = CmpInst::getSwappedPredicate(pred: Pred);
7593 }
7594
7595 // (icmp X, Y) ? X : Y -> integer minmax.
7596 // see matchSelectPattern in ValueTracking.
7597 // Legality between G_SELECT and integer minmax can differ.
7598 if (True != CmpLHS || False != CmpRHS)
7599 return false;
7600
7601 switch (Pred) {
7602 case ICmpInst::ICMP_UGT:
7603 case ICmpInst::ICMP_UGE: {
7604 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMAX, DstTy}))
7605 return false;
7606 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMax(Dst: DstReg, Src0: True, Src1: False); };
7607 return true;
7608 }
7609 case ICmpInst::ICMP_SGT:
7610 case ICmpInst::ICMP_SGE: {
7611 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMAX, DstTy}))
7612 return false;
7613 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMax(Dst: DstReg, Src0: True, Src1: False); };
7614 return true;
7615 }
7616 case ICmpInst::ICMP_ULT:
7617 case ICmpInst::ICMP_ULE: {
7618 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMIN, DstTy}))
7619 return false;
7620 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMin(Dst: DstReg, Src0: True, Src1: False); };
7621 return true;
7622 }
7623 case ICmpInst::ICMP_SLT:
7624 case ICmpInst::ICMP_SLE: {
7625 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMIN, DstTy}))
7626 return false;
7627 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMin(Dst: DstReg, Src0: True, Src1: False); };
7628 return true;
7629 }
7630 default:
7631 return false;
7632 }
7633}
7634
7635// (neg (min/max x, (neg x))) --> (max/min x, (neg x))
7636bool CombinerHelper::matchSimplifyNegMinMax(MachineInstr &MI,
7637 BuildFnTy &MatchInfo) const {
7638 assert(MI.getOpcode() == TargetOpcode::G_SUB);
7639 Register DestReg = MI.getOperand(i: 0).getReg();
7640 LLT DestTy = MRI.getType(Reg: DestReg);
7641
7642 Register X;
7643 Register Sub0;
7644 auto NegPattern = m_all_of(preds: m_Neg(Src: m_DeferredReg(R&: X)), preds: m_Reg(R&: Sub0));
7645 if (mi_match(R: DestReg, MRI,
7646 P: m_Neg(Src: m_OneUse(SP: m_any_of(preds: m_GSMin(L: m_Reg(R&: X), R: NegPattern),
7647 preds: m_GSMax(L: m_Reg(R&: X), R: NegPattern),
7648 preds: m_GUMin(L: m_Reg(R&: X), R: NegPattern),
7649 preds: m_GUMax(L: m_Reg(R&: X), R: NegPattern)))))) {
7650 MachineInstr *MinMaxMI = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
7651 unsigned NewOpc = getInverseGMinMaxOpcode(MinMaxOpc: MinMaxMI->getOpcode());
7652 if (isLegal(Query: {NewOpc, {DestTy}})) {
7653 MatchInfo = [=](MachineIRBuilder &B) {
7654 B.buildInstr(Opc: NewOpc, DstOps: {DestReg}, SrcOps: {X, Sub0});
7655 };
7656 return true;
7657 }
7658 }
7659
7660 return false;
7661}
7662
7663bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7664 GSelect *Select = cast<GSelect>(Val: &MI);
7665
7666 if (tryFoldSelectOfConstants(Select, MatchInfo))
7667 return true;
7668
7669 if (tryFoldBoolSelectToLogic(Select, MatchInfo))
7670 return true;
7671
7672 return false;
7673}
7674
7675/// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
7676/// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
7677/// into a single comparison using range-based reasoning.
7678/// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges.
7679bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(
7680 GLogicalBinOp *Logic, BuildFnTy &MatchInfo) const {
7681 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor");
7682 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
7683 Register DstReg = Logic->getReg(Idx: 0);
7684 Register LHS = Logic->getLHSReg();
7685 Register RHS = Logic->getRHSReg();
7686 unsigned Flags = Logic->getFlags();
7687
7688 // We need an G_ICMP on the LHS register.
7689 GICmp *Cmp1 = getOpcodeDef<GICmp>(Reg: LHS, MRI);
7690 if (!Cmp1)
7691 return false;
7692
7693 // We need an G_ICMP on the RHS register.
7694 GICmp *Cmp2 = getOpcodeDef<GICmp>(Reg: RHS, MRI);
7695 if (!Cmp2)
7696 return false;
7697
7698 // We want to fold the icmps.
7699 if (!MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) ||
7700 !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)))
7701 return false;
7702
7703 APInt C1;
7704 APInt C2;
7705 std::optional<ValueAndVReg> MaybeC1 =
7706 getIConstantVRegValWithLookThrough(VReg: Cmp1->getRHSReg(), MRI);
7707 if (!MaybeC1)
7708 return false;
7709 C1 = MaybeC1->Value;
7710
7711 std::optional<ValueAndVReg> MaybeC2 =
7712 getIConstantVRegValWithLookThrough(VReg: Cmp2->getRHSReg(), MRI);
7713 if (!MaybeC2)
7714 return false;
7715 C2 = MaybeC2->Value;
7716
7717 Register R1 = Cmp1->getLHSReg();
7718 Register R2 = Cmp2->getLHSReg();
7719 CmpInst::Predicate Pred1 = Cmp1->getCond();
7720 CmpInst::Predicate Pred2 = Cmp2->getCond();
7721 LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0));
7722 LLT CmpOperandTy = MRI.getType(Reg: R1);
7723
7724 if (CmpOperandTy.isPointer())
7725 return false;
7726
7727 // We build ands, adds, and constants of type CmpOperandTy.
7728 // They must be legal to build.
7729 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_AND, CmpOperandTy}) ||
7730 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, CmpOperandTy}) ||
7731 !isConstantLegalOrBeforeLegalizer(Ty: CmpOperandTy))
7732 return false;
7733
7734 // Look through add of a constant offset on R1, R2, or both operands. This
7735 // allows us to interpret the R + C' < C'' range idiom into a proper range.
7736 std::optional<APInt> Offset1;
7737 std::optional<APInt> Offset2;
7738 if (R1 != R2) {
7739 if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R1, MRI)) {
7740 std::optional<ValueAndVReg> MaybeOffset1 =
7741 getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI);
7742 if (MaybeOffset1) {
7743 R1 = Add->getLHSReg();
7744 Offset1 = MaybeOffset1->Value;
7745 }
7746 }
7747 if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R2, MRI)) {
7748 std::optional<ValueAndVReg> MaybeOffset2 =
7749 getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI);
7750 if (MaybeOffset2) {
7751 R2 = Add->getLHSReg();
7752 Offset2 = MaybeOffset2->Value;
7753 }
7754 }
7755 }
7756
7757 if (R1 != R2)
7758 return false;
7759
7760 // We calculate the icmp ranges including maybe offsets.
7761 ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
7762 Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred1) : Pred1, Other: C1);
7763 if (Offset1)
7764 CR1 = CR1.subtract(CI: *Offset1);
7765
7766 ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
7767 Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred2) : Pred2, Other: C2);
7768 if (Offset2)
7769 CR2 = CR2.subtract(CI: *Offset2);
7770
7771 bool CreateMask = false;
7772 APInt LowerDiff;
7773 std::optional<ConstantRange> CR = CR1.exactUnionWith(CR: CR2);
7774 if (!CR) {
7775 // We need non-wrapping ranges.
7776 if (CR1.isWrappedSet() || CR2.isWrappedSet())
7777 return false;
7778
7779 // Check whether we have equal-size ranges that only differ by one bit.
7780 // In that case we can apply a mask to map one range onto the other.
7781 LowerDiff = CR1.getLower() ^ CR2.getLower();
7782 APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1);
7783 APInt CR1Size = CR1.getUpper() - CR1.getLower();
7784 if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff ||
7785 CR1Size != CR2.getUpper() - CR2.getLower())
7786 return false;
7787
7788 CR = CR1.getLower().ult(RHS: CR2.getLower()) ? CR1 : CR2;
7789 CreateMask = true;
7790 }
7791
7792 if (IsAnd)
7793 CR = CR->inverse();
7794
7795 CmpInst::Predicate NewPred;
7796 APInt NewC, Offset;
7797 CR->getEquivalentICmp(Pred&: NewPred, RHS&: NewC, Offset);
7798
7799 // We take the result type of one of the original icmps, CmpTy, for
7800 // the to be build icmp. The operand type, CmpOperandTy, is used for
7801 // the other instructions and constants to be build. The types of
7802 // the parameters and output are the same for add and and. CmpTy
7803 // and the type of DstReg might differ. That is why we zext or trunc
7804 // the icmp into the destination register.
7805
7806 MatchInfo = [=](MachineIRBuilder &B) {
7807 if (CreateMask && Offset != 0) {
7808 auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff);
7809 auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask.
7810 auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset);
7811 auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: And, Src1: OffsetC, Flags);
7812 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7813 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon);
7814 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7815 } else if (CreateMask && Offset == 0) {
7816 auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff);
7817 auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask.
7818 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7819 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: And, Op1: NewCon);
7820 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7821 } else if (!CreateMask && Offset != 0) {
7822 auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset);
7823 auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: R1, Src1: OffsetC, Flags);
7824 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7825 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon);
7826 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7827 } else if (!CreateMask && Offset == 0) {
7828 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7829 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: R1, Op1: NewCon);
7830 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7831 } else {
7832 llvm_unreachable("unexpected configuration of CreateMask and Offset");
7833 }
7834 };
7835 return true;
7836}
7837
7838bool CombinerHelper::tryFoldLogicOfFCmps(GLogicalBinOp *Logic,
7839 BuildFnTy &MatchInfo) const {
7840 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpecte xor");
7841 Register DestReg = Logic->getReg(Idx: 0);
7842 Register LHS = Logic->getLHSReg();
7843 Register RHS = Logic->getRHSReg();
7844 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
7845
7846 // We need a compare on the LHS register.
7847 GFCmp *Cmp1 = getOpcodeDef<GFCmp>(Reg: LHS, MRI);
7848 if (!Cmp1)
7849 return false;
7850
7851 // We need a compare on the RHS register.
7852 GFCmp *Cmp2 = getOpcodeDef<GFCmp>(Reg: RHS, MRI);
7853 if (!Cmp2)
7854 return false;
7855
7856 LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0));
7857 LLT CmpOperandTy = MRI.getType(Reg: Cmp1->getLHSReg());
7858
7859 // We build one fcmp, want to fold the fcmps, replace the logic op,
7860 // and the fcmps must have the same shape.
7861 if (!isLegalOrBeforeLegalizer(
7862 Query: {TargetOpcode::G_FCMP, {CmpTy, CmpOperandTy}}) ||
7863 !MRI.hasOneNonDBGUse(RegNo: Logic->getReg(Idx: 0)) ||
7864 !MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) ||
7865 !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)) ||
7866 MRI.getType(Reg: Cmp1->getLHSReg()) != MRI.getType(Reg: Cmp2->getLHSReg()))
7867 return false;
7868
7869 CmpInst::Predicate PredL = Cmp1->getCond();
7870 CmpInst::Predicate PredR = Cmp2->getCond();
7871 Register LHS0 = Cmp1->getLHSReg();
7872 Register LHS1 = Cmp1->getRHSReg();
7873 Register RHS0 = Cmp2->getLHSReg();
7874 Register RHS1 = Cmp2->getRHSReg();
7875
7876 if (LHS0 == RHS1 && LHS1 == RHS0) {
7877 // Swap RHS operands to match LHS.
7878 PredR = CmpInst::getSwappedPredicate(pred: PredR);
7879 std::swap(a&: RHS0, b&: RHS1);
7880 }
7881
7882 if (LHS0 == RHS0 && LHS1 == RHS1) {
7883 // We determine the new predicate.
7884 unsigned CmpCodeL = getFCmpCode(CC: PredL);
7885 unsigned CmpCodeR = getFCmpCode(CC: PredR);
7886 unsigned NewPred = IsAnd ? CmpCodeL & CmpCodeR : CmpCodeL | CmpCodeR;
7887 unsigned Flags = Cmp1->getFlags() | Cmp2->getFlags();
7888 MatchInfo = [=](MachineIRBuilder &B) {
7889 // The fcmp predicates fill the lower part of the enum.
7890 FCmpInst::Predicate Pred = static_cast<FCmpInst::Predicate>(NewPred);
7891 if (Pred == FCmpInst::FCMP_FALSE &&
7892 isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) {
7893 auto False = B.buildConstant(Res: CmpTy, Val: 0);
7894 B.buildZExtOrTrunc(Res: DestReg, Op: False);
7895 } else if (Pred == FCmpInst::FCMP_TRUE &&
7896 isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) {
7897 auto True =
7898 B.buildConstant(Res: CmpTy, Val: getICmpTrueVal(TLI: getTargetLowering(),
7899 IsVector: CmpTy.isVector() /*isVector*/,
7900 IsFP: true /*isFP*/));
7901 B.buildZExtOrTrunc(Res: DestReg, Op: True);
7902 } else { // We take the predicate without predicate optimizations.
7903 auto Cmp = B.buildFCmp(Pred, Res: CmpTy, Op0: LHS0, Op1: LHS1, Flags);
7904 B.buildZExtOrTrunc(Res: DestReg, Op: Cmp);
7905 }
7906 };
7907 return true;
7908 }
7909
7910 return false;
7911}
7912
7913bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7914 GAnd *And = cast<GAnd>(Val: &MI);
7915
7916 if (tryFoldAndOrOrICmpsUsingRanges(Logic: And, MatchInfo))
7917 return true;
7918
7919 if (tryFoldLogicOfFCmps(Logic: And, MatchInfo))
7920 return true;
7921
7922 return false;
7923}
7924
7925bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7926 GOr *Or = cast<GOr>(Val: &MI);
7927
7928 if (tryFoldAndOrOrICmpsUsingRanges(Logic: Or, MatchInfo))
7929 return true;
7930
7931 if (tryFoldLogicOfFCmps(Logic: Or, MatchInfo))
7932 return true;
7933
7934 return false;
7935}
7936
7937bool CombinerHelper::matchAddOverflow(MachineInstr &MI,
7938 BuildFnTy &MatchInfo) const {
7939 GAddCarryOut *Add = cast<GAddCarryOut>(Val: &MI);
7940
7941 // Addo has no flags
7942 Register Dst = Add->getReg(Idx: 0);
7943 Register Carry = Add->getReg(Idx: 1);
7944 Register LHS = Add->getLHSReg();
7945 Register RHS = Add->getRHSReg();
7946 bool IsSigned = Add->isSigned();
7947 LLT DstTy = MRI.getType(Reg: Dst);
7948 LLT CarryTy = MRI.getType(Reg: Carry);
7949
7950 // Fold addo, if the carry is dead -> add, undef.
7951 if (MRI.use_nodbg_empty(RegNo: Carry) &&
7952 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}})) {
7953 MatchInfo = [=](MachineIRBuilder &B) {
7954 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
7955 B.buildUndef(Res: Carry);
7956 };
7957 return true;
7958 }
7959
7960 // Canonicalize constant to RHS.
7961 if (isConstantOrConstantVectorI(Src: LHS) && !isConstantOrConstantVectorI(Src: RHS)) {
7962 if (IsSigned) {
7963 MatchInfo = [=](MachineIRBuilder &B) {
7964 B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS);
7965 };
7966 return true;
7967 }
7968 // !IsSigned
7969 MatchInfo = [=](MachineIRBuilder &B) {
7970 B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS);
7971 };
7972 return true;
7973 }
7974
7975 std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(Src: LHS);
7976 std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(Src: RHS);
7977
7978 // Fold addo(c1, c2) -> c3, carry.
7979 if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(Ty: DstTy) &&
7980 isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) {
7981 bool Overflow;
7982 APInt Result = IsSigned ? MaybeLHS->sadd_ov(RHS: *MaybeRHS, Overflow)
7983 : MaybeLHS->uadd_ov(RHS: *MaybeRHS, Overflow);
7984 MatchInfo = [=](MachineIRBuilder &B) {
7985 B.buildConstant(Res: Dst, Val: Result);
7986 B.buildConstant(Res: Carry, Val: Overflow);
7987 };
7988 return true;
7989 }
7990
7991 // Fold (addo x, 0) -> x, no carry
7992 if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) {
7993 MatchInfo = [=](MachineIRBuilder &B) {
7994 B.buildCopy(Res: Dst, Op: LHS);
7995 B.buildConstant(Res: Carry, Val: 0);
7996 };
7997 return true;
7998 }
7999
8000 // Given 2 constant operands whose sum does not overflow:
8001 // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1
8002 // saddo (X +nsw C0), C1 -> saddo X, C0 + C1
8003 GAdd *AddLHS = getOpcodeDef<GAdd>(Reg: LHS, MRI);
8004 if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)) &&
8005 ((IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoSWrap)) ||
8006 (!IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoUWrap)))) {
8007 std::optional<APInt> MaybeAddRHS =
8008 getConstantOrConstantSplatVector(Src: AddLHS->getRHSReg());
8009 if (MaybeAddRHS) {
8010 bool Overflow;
8011 APInt NewC = IsSigned ? MaybeAddRHS->sadd_ov(RHS: *MaybeRHS, Overflow)
8012 : MaybeAddRHS->uadd_ov(RHS: *MaybeRHS, Overflow);
8013 if (!Overflow && isConstantLegalOrBeforeLegalizer(Ty: DstTy)) {
8014 if (IsSigned) {
8015 MatchInfo = [=](MachineIRBuilder &B) {
8016 auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC);
8017 B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS);
8018 };
8019 return true;
8020 }
8021 // !IsSigned
8022 MatchInfo = [=](MachineIRBuilder &B) {
8023 auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC);
8024 B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS);
8025 };
8026 return true;
8027 }
8028 }
8029 };
8030
8031 // We try to combine addo to non-overflowing add.
8032 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}}) ||
8033 !isConstantLegalOrBeforeLegalizer(Ty: CarryTy))
8034 return false;
8035
8036 // We try to combine uaddo to non-overflowing add.
8037 if (!IsSigned) {
8038 ConstantRange CRLHS =
8039 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS), /*IsSigned=*/false);
8040 ConstantRange CRRHS =
8041 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS), /*IsSigned=*/false);
8042
8043 switch (CRLHS.unsignedAddMayOverflow(Other: CRRHS)) {
8044 case ConstantRange::OverflowResult::MayOverflow:
8045 return false;
8046 case ConstantRange::OverflowResult::NeverOverflows: {
8047 MatchInfo = [=](MachineIRBuilder &B) {
8048 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap);
8049 B.buildConstant(Res: Carry, Val: 0);
8050 };
8051 return true;
8052 }
8053 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8054 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8055 MatchInfo = [=](MachineIRBuilder &B) {
8056 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
8057 B.buildConstant(Res: Carry, Val: 1);
8058 };
8059 return true;
8060 }
8061 }
8062 return false;
8063 }
8064
8065 // We try to combine saddo to non-overflowing add.
8066
8067 // If LHS and RHS each have at least two sign bits, then there is no signed
8068 // overflow.
8069 if (VT->computeNumSignBits(R: RHS) > 1 && VT->computeNumSignBits(R: LHS) > 1) {
8070 MatchInfo = [=](MachineIRBuilder &B) {
8071 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8072 B.buildConstant(Res: Carry, Val: 0);
8073 };
8074 return true;
8075 }
8076
8077 ConstantRange CRLHS =
8078 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS), /*IsSigned=*/true);
8079 ConstantRange CRRHS =
8080 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS), /*IsSigned=*/true);
8081
8082 switch (CRLHS.signedAddMayOverflow(Other: CRRHS)) {
8083 case ConstantRange::OverflowResult::MayOverflow:
8084 return false;
8085 case ConstantRange::OverflowResult::NeverOverflows: {
8086 MatchInfo = [=](MachineIRBuilder &B) {
8087 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8088 B.buildConstant(Res: Carry, Val: 0);
8089 };
8090 return true;
8091 }
8092 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8093 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8094 MatchInfo = [=](MachineIRBuilder &B) {
8095 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
8096 B.buildConstant(Res: Carry, Val: 1);
8097 };
8098 return true;
8099 }
8100 }
8101
8102 return false;
8103}
8104
8105void CombinerHelper::applyBuildFnMO(const MachineOperand &MO,
8106 BuildFnTy &MatchInfo) const {
8107 MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI);
8108 MatchInfo(Builder);
8109 Root->eraseFromParent();
8110}
8111
8112bool CombinerHelper::matchFPowIExpansion(MachineInstr &MI,
8113 int64_t Exponent) const {
8114 bool OptForSize = MI.getMF()->getFunction().hasOptSize();
8115 return getTargetLowering().isBeneficialToExpandPowI(Exponent, OptForSize);
8116}
8117
8118void CombinerHelper::applyExpandFPowI(MachineInstr &MI,
8119 int64_t Exponent) const {
8120 auto [Dst, Base] = MI.getFirst2Regs();
8121 LLT Ty = MRI.getType(Reg: Dst);
8122 int64_t ExpVal = Exponent;
8123
8124 if (ExpVal == 0) {
8125 Builder.buildFConstant(Res: Dst, Val: 1.0);
8126 MI.removeFromParent();
8127 return;
8128 }
8129
8130 if (ExpVal < 0)
8131 ExpVal = -ExpVal;
8132
8133 // We use the simple binary decomposition method from SelectionDAG ExpandPowI
8134 // to generate the multiply sequence. There are more optimal ways to do this
8135 // (for example, powi(x,15) generates one more multiply than it should), but
8136 // this has the benefit of being both really simple and much better than a
8137 // libcall.
8138 std::optional<SrcOp> Res;
8139 SrcOp CurSquare = Base;
8140 while (ExpVal > 0) {
8141 if (ExpVal & 1) {
8142 if (!Res)
8143 Res = CurSquare;
8144 else
8145 Res = Builder.buildFMul(Dst: Ty, Src0: *Res, Src1: CurSquare);
8146 }
8147
8148 CurSquare = Builder.buildFMul(Dst: Ty, Src0: CurSquare, Src1: CurSquare);
8149 ExpVal >>= 1;
8150 }
8151
8152 // If the original exponent was negative, invert the result, producing
8153 // 1/(x*x*x).
8154 if (Exponent < 0)
8155 Res = Builder.buildFDiv(Dst: Ty, Src0: Builder.buildFConstant(Res: Ty, Val: 1.0), Src1: *Res,
8156 Flags: MI.getFlags());
8157
8158 Builder.buildCopy(Res: Dst, Op: *Res);
8159 MI.eraseFromParent();
8160}
8161
8162bool CombinerHelper::matchFoldAPlusC1MinusC2(const MachineInstr &MI,
8163 BuildFnTy &MatchInfo) const {
8164 // fold (A+C1)-C2 -> A+(C1-C2)
8165 const GSub *Sub = cast<GSub>(Val: &MI);
8166 GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: Sub->getLHSReg()));
8167
8168 if (!MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)))
8169 return false;
8170
8171 APInt C2 = getIConstantFromReg(VReg: Sub->getRHSReg(), MRI);
8172 APInt C1 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8173
8174 Register Dst = Sub->getReg(Idx: 0);
8175 LLT DstTy = MRI.getType(Reg: Dst);
8176
8177 MatchInfo = [=](MachineIRBuilder &B) {
8178 auto Const = B.buildConstant(Res: DstTy, Val: C1 - C2);
8179 B.buildAdd(Dst, Src0: Add->getLHSReg(), Src1: Const);
8180 };
8181
8182 return true;
8183}
8184
8185bool CombinerHelper::matchFoldC2MinusAPlusC1(const MachineInstr &MI,
8186 BuildFnTy &MatchInfo) const {
8187 // fold C2-(A+C1) -> (C2-C1)-A
8188 const GSub *Sub = cast<GSub>(Val: &MI);
8189 GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: Sub->getRHSReg()));
8190
8191 if (!MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)))
8192 return false;
8193
8194 APInt C2 = getIConstantFromReg(VReg: Sub->getLHSReg(), MRI);
8195 APInt C1 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8196
8197 Register Dst = Sub->getReg(Idx: 0);
8198 LLT DstTy = MRI.getType(Reg: Dst);
8199
8200 MatchInfo = [=](MachineIRBuilder &B) {
8201 auto Const = B.buildConstant(Res: DstTy, Val: C2 - C1);
8202 B.buildSub(Dst, Src0: Const, Src1: Add->getLHSReg());
8203 };
8204
8205 return true;
8206}
8207
8208bool CombinerHelper::matchFoldAMinusC1MinusC2(const MachineInstr &MI,
8209 BuildFnTy &MatchInfo) const {
8210 // fold (A-C1)-C2 -> A-(C1+C2)
8211 const GSub *Sub1 = cast<GSub>(Val: &MI);
8212 GSub *Sub2 = cast<GSub>(Val: MRI.getVRegDef(Reg: Sub1->getLHSReg()));
8213
8214 if (!MRI.hasOneNonDBGUse(RegNo: Sub2->getReg(Idx: 0)))
8215 return false;
8216
8217 APInt C2 = getIConstantFromReg(VReg: Sub1->getRHSReg(), MRI);
8218 APInt C1 = getIConstantFromReg(VReg: Sub2->getRHSReg(), MRI);
8219
8220 Register Dst = Sub1->getReg(Idx: 0);
8221 LLT DstTy = MRI.getType(Reg: Dst);
8222
8223 MatchInfo = [=](MachineIRBuilder &B) {
8224 auto Const = B.buildConstant(Res: DstTy, Val: C1 + C2);
8225 B.buildSub(Dst, Src0: Sub2->getLHSReg(), Src1: Const);
8226 };
8227
8228 return true;
8229}
8230
8231bool CombinerHelper::matchFoldC1Minus2MinusC2(const MachineInstr &MI,
8232 BuildFnTy &MatchInfo) const {
8233 // fold (C1-A)-C2 -> (C1-C2)-A
8234 const GSub *Sub1 = cast<GSub>(Val: &MI);
8235 GSub *Sub2 = cast<GSub>(Val: MRI.getVRegDef(Reg: Sub1->getLHSReg()));
8236
8237 if (!MRI.hasOneNonDBGUse(RegNo: Sub2->getReg(Idx: 0)))
8238 return false;
8239
8240 APInt C2 = getIConstantFromReg(VReg: Sub1->getRHSReg(), MRI);
8241 APInt C1 = getIConstantFromReg(VReg: Sub2->getLHSReg(), MRI);
8242
8243 Register Dst = Sub1->getReg(Idx: 0);
8244 LLT DstTy = MRI.getType(Reg: Dst);
8245
8246 MatchInfo = [=](MachineIRBuilder &B) {
8247 auto Const = B.buildConstant(Res: DstTy, Val: C1 - C2);
8248 B.buildSub(Dst, Src0: Const, Src1: Sub2->getRHSReg());
8249 };
8250
8251 return true;
8252}
8253
8254bool CombinerHelper::matchFoldAMinusC1PlusC2(const MachineInstr &MI,
8255 BuildFnTy &MatchInfo) const {
8256 // fold ((A-C1)+C2) -> (A+(C2-C1))
8257 const GAdd *Add = cast<GAdd>(Val: &MI);
8258 GSub *Sub = cast<GSub>(Val: MRI.getVRegDef(Reg: Add->getLHSReg()));
8259
8260 if (!MRI.hasOneNonDBGUse(RegNo: Sub->getReg(Idx: 0)))
8261 return false;
8262
8263 APInt C2 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8264 APInt C1 = getIConstantFromReg(VReg: Sub->getRHSReg(), MRI);
8265
8266 Register Dst = Add->getReg(Idx: 0);
8267 LLT DstTy = MRI.getType(Reg: Dst);
8268
8269 MatchInfo = [=](MachineIRBuilder &B) {
8270 auto Const = B.buildConstant(Res: DstTy, Val: C2 - C1);
8271 B.buildAdd(Dst, Src0: Sub->getLHSReg(), Src1: Const);
8272 };
8273
8274 return true;
8275}
8276
8277bool CombinerHelper::matchUnmergeValuesAnyExtBuildVector(
8278 const MachineInstr &MI, BuildFnTy &MatchInfo) const {
8279 const GUnmerge *Unmerge = cast<GUnmerge>(Val: &MI);
8280
8281 if (!MRI.hasOneNonDBGUse(RegNo: Unmerge->getSourceReg()))
8282 return false;
8283
8284 const MachineInstr *Source = MRI.getVRegDef(Reg: Unmerge->getSourceReg());
8285
8286 LLT DstTy = MRI.getType(Reg: Unmerge->getReg(Idx: 0));
8287
8288 // $bv:_(<8 x s8>) = G_BUILD_VECTOR ....
8289 // $any:_(<8 x s16>) = G_ANYEXT $bv
8290 // $uv:_(<4 x s16>), $uv1:_(<4 x s16>) = G_UNMERGE_VALUES $any
8291 //
8292 // ->
8293 //
8294 // $any:_(s16) = G_ANYEXT $bv[0]
8295 // $any1:_(s16) = G_ANYEXT $bv[1]
8296 // $any2:_(s16) = G_ANYEXT $bv[2]
8297 // $any3:_(s16) = G_ANYEXT $bv[3]
8298 // $any4:_(s16) = G_ANYEXT $bv[4]
8299 // $any5:_(s16) = G_ANYEXT $bv[5]
8300 // $any6:_(s16) = G_ANYEXT $bv[6]
8301 // $any7:_(s16) = G_ANYEXT $bv[7]
8302 // $uv:_(<4 x s16>) = G_BUILD_VECTOR $any, $any1, $any2, $any3
8303 // $uv1:_(<4 x s16>) = G_BUILD_VECTOR $any4, $any5, $any6, $any7
8304
8305 // We want to unmerge into vectors.
8306 if (!DstTy.isFixedVector())
8307 return false;
8308
8309 const GAnyExt *Any = dyn_cast<GAnyExt>(Val: Source);
8310 if (!Any)
8311 return false;
8312
8313 const MachineInstr *NextSource = MRI.getVRegDef(Reg: Any->getSrcReg());
8314
8315 if (const GBuildVector *BV = dyn_cast<GBuildVector>(Val: NextSource)) {
8316 // G_UNMERGE_VALUES G_ANYEXT G_BUILD_VECTOR
8317
8318 if (!MRI.hasOneNonDBGUse(RegNo: BV->getReg(Idx: 0)))
8319 return false;
8320
8321 // FIXME: check element types?
8322 if (BV->getNumSources() % Unmerge->getNumDefs() != 0)
8323 return false;
8324
8325 LLT BigBvTy = MRI.getType(Reg: BV->getReg(Idx: 0));
8326 LLT SmallBvTy = DstTy;
8327 LLT SmallBvElemenTy = SmallBvTy.getElementType();
8328
8329 if (!isLegalOrBeforeLegalizer(
8330 Query: {TargetOpcode::G_BUILD_VECTOR, {SmallBvTy, SmallBvElemenTy}}))
8331 return false;
8332
8333 // We check the legality of scalar anyext.
8334 if (!isLegalOrBeforeLegalizer(
8335 Query: {TargetOpcode::G_ANYEXT,
8336 {SmallBvElemenTy, BigBvTy.getElementType()}}))
8337 return false;
8338
8339 MatchInfo = [=](MachineIRBuilder &B) {
8340 // Build into each G_UNMERGE_VALUES def
8341 // a small build vector with anyext from the source build vector.
8342 for (unsigned I = 0; I < Unmerge->getNumDefs(); ++I) {
8343 SmallVector<Register> Ops;
8344 for (unsigned J = 0; J < SmallBvTy.getNumElements(); ++J) {
8345 Register SourceArray =
8346 BV->getSourceReg(I: I * SmallBvTy.getNumElements() + J);
8347 auto AnyExt = B.buildAnyExt(Res: SmallBvElemenTy, Op: SourceArray);
8348 Ops.push_back(Elt: AnyExt.getReg(Idx: 0));
8349 }
8350 B.buildBuildVector(Res: Unmerge->getOperand(i: I).getReg(), Ops);
8351 };
8352 };
8353 return true;
8354 };
8355
8356 return false;
8357}
8358
8359bool CombinerHelper::matchShuffleUndefRHS(MachineInstr &MI,
8360 BuildFnTy &MatchInfo) const {
8361
8362 bool Changed = false;
8363 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
8364 ArrayRef<int> OrigMask = Shuffle.getMask();
8365 SmallVector<int, 16> NewMask;
8366 const LLT SrcTy = MRI.getType(Reg: Shuffle.getSrc1Reg());
8367 const unsigned NumSrcElems = SrcTy.isVector() ? SrcTy.getNumElements() : 1;
8368 const unsigned NumDstElts = OrigMask.size();
8369 for (unsigned i = 0; i != NumDstElts; ++i) {
8370 int Idx = OrigMask[i];
8371 if (Idx >= (int)NumSrcElems) {
8372 Idx = -1;
8373 Changed = true;
8374 }
8375 NewMask.push_back(Elt: Idx);
8376 }
8377
8378 if (!Changed)
8379 return false;
8380
8381 MatchInfo = [&, NewMask = std::move(NewMask)](MachineIRBuilder &B) {
8382 B.buildShuffleVector(Res: MI.getOperand(i: 0), Src1: MI.getOperand(i: 1), Src2: MI.getOperand(i: 2),
8383 Mask: std::move(NewMask));
8384 };
8385
8386 return true;
8387}
8388
8389static void commuteMask(MutableArrayRef<int> Mask, const unsigned NumElems) {
8390 const unsigned MaskSize = Mask.size();
8391 for (unsigned I = 0; I < MaskSize; ++I) {
8392 int Idx = Mask[I];
8393 if (Idx < 0)
8394 continue;
8395
8396 if (Idx < (int)NumElems)
8397 Mask[I] = Idx + NumElems;
8398 else
8399 Mask[I] = Idx - NumElems;
8400 }
8401}
8402
8403bool CombinerHelper::matchShuffleDisjointMask(MachineInstr &MI,
8404 BuildFnTy &MatchInfo) const {
8405
8406 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
8407 // If any of the two inputs is already undef, don't check the mask again to
8408 // prevent infinite loop
8409 if (getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: Shuffle.getSrc1Reg(), MRI))
8410 return false;
8411
8412 if (getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: Shuffle.getSrc2Reg(), MRI))
8413 return false;
8414
8415 const LLT DstTy = MRI.getType(Reg: Shuffle.getReg(Idx: 0));
8416 const LLT Src1Ty = MRI.getType(Reg: Shuffle.getSrc1Reg());
8417 if (!isLegalOrBeforeLegalizer(
8418 Query: {TargetOpcode::G_SHUFFLE_VECTOR, {DstTy, Src1Ty}}))
8419 return false;
8420
8421 ArrayRef<int> Mask = Shuffle.getMask();
8422 const unsigned NumSrcElems = Src1Ty.getNumElements();
8423
8424 bool TouchesSrc1 = false;
8425 bool TouchesSrc2 = false;
8426 const unsigned NumElems = Mask.size();
8427 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
8428 if (Mask[Idx] < 0)
8429 continue;
8430
8431 if (Mask[Idx] < (int)NumSrcElems)
8432 TouchesSrc1 = true;
8433 else
8434 TouchesSrc2 = true;
8435 }
8436
8437 if (TouchesSrc1 == TouchesSrc2)
8438 return false;
8439
8440 Register NewSrc1 = Shuffle.getSrc1Reg();
8441 SmallVector<int, 16> NewMask(Mask);
8442 if (TouchesSrc2) {
8443 NewSrc1 = Shuffle.getSrc2Reg();
8444 commuteMask(Mask: NewMask, NumElems: NumSrcElems);
8445 }
8446
8447 MatchInfo = [=, &Shuffle](MachineIRBuilder &B) {
8448 auto Undef = B.buildUndef(Res: Src1Ty);
8449 B.buildShuffleVector(Res: Shuffle.getReg(Idx: 0), Src1: NewSrc1, Src2: Undef, Mask: NewMask);
8450 };
8451
8452 return true;
8453}
8454
8455bool CombinerHelper::matchSuboCarryOut(const MachineInstr &MI,
8456 BuildFnTy &MatchInfo) const {
8457 const GSubCarryOut *Subo = cast<GSubCarryOut>(Val: &MI);
8458
8459 Register Dst = Subo->getReg(Idx: 0);
8460 Register LHS = Subo->getLHSReg();
8461 Register RHS = Subo->getRHSReg();
8462 Register Carry = Subo->getCarryOutReg();
8463 LLT DstTy = MRI.getType(Reg: Dst);
8464 LLT CarryTy = MRI.getType(Reg: Carry);
8465
8466 // Check legality before known bits.
8467 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy}}) ||
8468 !isConstantLegalOrBeforeLegalizer(Ty: CarryTy))
8469 return false;
8470
8471 ConstantRange KBLHS =
8472 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS),
8473 /* IsSigned=*/Subo->isSigned());
8474 ConstantRange KBRHS =
8475 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS),
8476 /* IsSigned=*/Subo->isSigned());
8477
8478 if (Subo->isSigned()) {
8479 // G_SSUBO
8480 switch (KBLHS.signedSubMayOverflow(Other: KBRHS)) {
8481 case ConstantRange::OverflowResult::MayOverflow:
8482 return false;
8483 case ConstantRange::OverflowResult::NeverOverflows: {
8484 MatchInfo = [=](MachineIRBuilder &B) {
8485 B.buildSub(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8486 B.buildConstant(Res: Carry, Val: 0);
8487 };
8488 return true;
8489 }
8490 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8491 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8492 MatchInfo = [=](MachineIRBuilder &B) {
8493 B.buildSub(Dst, Src0: LHS, Src1: RHS);
8494 B.buildConstant(Res: Carry, Val: getICmpTrueVal(TLI: getTargetLowering(),
8495 /*isVector=*/IsVector: CarryTy.isVector(),
8496 /*isFP=*/IsFP: false));
8497 };
8498 return true;
8499 }
8500 }
8501 return false;
8502 }
8503
8504 // G_USUBO
8505 switch (KBLHS.unsignedSubMayOverflow(Other: KBRHS)) {
8506 case ConstantRange::OverflowResult::MayOverflow:
8507 return false;
8508 case ConstantRange::OverflowResult::NeverOverflows: {
8509 MatchInfo = [=](MachineIRBuilder &B) {
8510 B.buildSub(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap);
8511 B.buildConstant(Res: Carry, Val: 0);
8512 };
8513 return true;
8514 }
8515 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8516 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8517 MatchInfo = [=](MachineIRBuilder &B) {
8518 B.buildSub(Dst, Src0: LHS, Src1: RHS);
8519 B.buildConstant(Res: Carry, Val: getICmpTrueVal(TLI: getTargetLowering(),
8520 /*isVector=*/IsVector: CarryTy.isVector(),
8521 /*isFP=*/IsFP: false));
8522 };
8523 return true;
8524 }
8525 }
8526
8527 return false;
8528}
8529
8530// Fold (ctlz (xor x, (sra x, bitwidth-1))) -> (add (ctls x), 1).
8531// Fold (ctlz (or (shl (xor x, (sra x, bitwidth-1)), 1), 1) -> (ctls x)
8532bool CombinerHelper::matchCtls(MachineInstr &CtlzMI,
8533 BuildFnTy &MatchInfo) const {
8534 assert((CtlzMI.getOpcode() == TargetOpcode::G_CTLZ ||
8535 CtlzMI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) &&
8536 "Expected G_CTLZ variant");
8537
8538 const Register Dst = CtlzMI.getOperand(i: 0).getReg();
8539 Register Src = CtlzMI.getOperand(i: 1).getReg();
8540
8541 LLT Ty = MRI.getType(Reg: Dst);
8542 LLT SrcTy = MRI.getType(Reg: Src);
8543
8544 if (!(Ty.isValid() && Ty.isScalar()))
8545 return false;
8546
8547 if (!LI)
8548 return false;
8549
8550 SmallVector<LLT, 2> QueryTypes = {Ty, SrcTy};
8551 LegalityQuery Query(TargetOpcode::G_CTLS, QueryTypes);
8552
8553 switch (LI->getAction(Query).Action) {
8554 default:
8555 return false;
8556 case LegalizeActions::Legal:
8557 case LegalizeActions::Custom:
8558 case LegalizeActions::WidenScalar:
8559 break;
8560 }
8561
8562 // Src = or(shl(V, 1), 1) -> Src=V; NeedAdd = False
8563 Register V;
8564 bool NeedAdd = true;
8565 if (mi_match(R: Src, MRI,
8566 P: m_OneUse(SP: m_GOr(L: m_OneUse(SP: m_GShl(L: m_Reg(R&: V), R: m_SpecificICst(RequestedValue: 1))),
8567 R: m_SpecificICst(RequestedValue: 1))))) {
8568 NeedAdd = false;
8569 Src = V;
8570 }
8571
8572 unsigned BitWidth = Ty.getScalarSizeInBits();
8573
8574 Register X;
8575 if (!mi_match(R: Src, MRI,
8576 P: m_OneUse(SP: m_GXor(L: m_Reg(R&: X), R: m_OneUse(SP: m_GAShr(
8577 L: m_DeferredReg(R&: X),
8578 R: m_SpecificICst(RequestedValue: BitWidth - 1)))))))
8579 return false;
8580
8581 MatchInfo = [=](MachineIRBuilder &B) {
8582 if (!NeedAdd) {
8583 B.buildCTLS(Dst, Src0: X);
8584 return;
8585 }
8586
8587 auto Ctls = B.buildCTLS(Dst: Ty, Src0: X);
8588 auto One = B.buildConstant(Res: Ty, Val: 1);
8589
8590 B.buildAdd(Dst, Src0: Ctls, Src1: One);
8591 };
8592
8593 return true;
8594}
8595