1//===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
9#include "llvm/ADT/APFloat.h"
10#include "llvm/ADT/STLExtras.h"
11#include "llvm/ADT/SetVector.h"
12#include "llvm/ADT/SmallBitVector.h"
13#include "llvm/Analysis/CmpInstAnalysis.h"
14#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
15#include "llvm/CodeGen/GlobalISel/GISelValueTracking.h"
16#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
17#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
19#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
20#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
21#include "llvm/CodeGen/GlobalISel/Utils.h"
22#include "llvm/CodeGen/LowLevelTypeUtils.h"
23#include "llvm/CodeGen/MachineBasicBlock.h"
24#include "llvm/CodeGen/MachineDominators.h"
25#include "llvm/CodeGen/MachineInstr.h"
26#include "llvm/CodeGen/MachineMemOperand.h"
27#include "llvm/CodeGen/MachineRegisterInfo.h"
28#include "llvm/CodeGen/Register.h"
29#include "llvm/CodeGen/RegisterBankInfo.h"
30#include "llvm/CodeGen/TargetInstrInfo.h"
31#include "llvm/CodeGen/TargetLowering.h"
32#include "llvm/CodeGen/TargetOpcodes.h"
33#include "llvm/IR/ConstantRange.h"
34#include "llvm/IR/DataLayout.h"
35#include "llvm/IR/InstrTypes.h"
36#include "llvm/IR/PatternMatch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/DivisionByConstantInfo.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/MathExtras.h"
41#include "llvm/Target/TargetMachine.h"
42#include <cmath>
43#include <optional>
44#include <tuple>
45
46#define DEBUG_TYPE "gi-combiner"
47
48using namespace llvm;
49using namespace MIPatternMatch;
50
51// Option to allow testing of the combiner while no targets know about indexed
52// addressing.
53static cl::opt<bool>
54 ForceLegalIndexing("force-legal-indexing", cl::Hidden, cl::init(Val: false),
55 cl::desc("Force all indexed operations to be "
56 "legal for the GlobalISel combiner"));
57
58CombinerHelper::CombinerHelper(GISelChangeObserver &Observer,
59 MachineIRBuilder &B, bool IsPreLegalize,
60 GISelValueTracking *VT,
61 MachineDominatorTree *MDT,
62 const LegalizerInfo *LI)
63 : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), VT(VT),
64 MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI),
65 RBI(Builder.getMF().getSubtarget().getRegBankInfo()),
66 TRI(Builder.getMF().getSubtarget().getRegisterInfo()) {
67 (void)this->VT;
68}
69
70const TargetLowering &CombinerHelper::getTargetLowering() const {
71 return *Builder.getMF().getSubtarget().getTargetLowering();
72}
73
74const MachineFunction &CombinerHelper::getMachineFunction() const {
75 return Builder.getMF();
76}
77
78const DataLayout &CombinerHelper::getDataLayout() const {
79 return getMachineFunction().getDataLayout();
80}
81
82LLVMContext &CombinerHelper::getContext() const { return Builder.getContext(); }
83
84/// \returns The little endian in-memory byte position of byte \p I in a
85/// \p ByteWidth bytes wide type.
86///
87/// E.g. Given a 4-byte type x, x[0] -> byte 0
88static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) {
89 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
90 return I;
91}
92
93/// Determines the LogBase2 value for a non-null input value using the
94/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
95static Register buildLogBase2(Register V, MachineIRBuilder &MIB) {
96 auto &MRI = *MIB.getMRI();
97 LLT Ty = MRI.getType(Reg: V);
98 auto Ctlz = MIB.buildCTLZ(Dst: Ty, Src0: V);
99 auto Base = MIB.buildConstant(Res: Ty, Val: Ty.getScalarSizeInBits() - 1);
100 return MIB.buildSub(Dst: Ty, Src0: Base, Src1: Ctlz).getReg(Idx: 0);
101}
102
103/// \returns The big endian in-memory byte position of byte \p I in a
104/// \p ByteWidth bytes wide type.
105///
106/// E.g. Given a 4-byte type x, x[0] -> byte 3
107static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) {
108 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
109 return ByteWidth - I - 1;
110}
111
112/// Given a map from byte offsets in memory to indices in a load/store,
113/// determine if that map corresponds to a little or big endian byte pattern.
114///
115/// \param MemOffset2Idx maps memory offsets to address offsets.
116/// \param LowestIdx is the lowest index in \p MemOffset2Idx.
117///
118/// \returns true if the map corresponds to a big endian byte pattern, false if
119/// it corresponds to a little endian byte pattern, and std::nullopt otherwise.
120///
121/// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns
122/// are as follows:
123///
124/// AddrOffset Little endian Big endian
125/// 0 0 3
126/// 1 1 2
127/// 2 2 1
128/// 3 3 0
129static std::optional<bool>
130isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
131 int64_t LowestIdx) {
132 // Need at least two byte positions to decide on endianness.
133 unsigned Width = MemOffset2Idx.size();
134 if (Width < 2)
135 return std::nullopt;
136 bool BigEndian = true, LittleEndian = true;
137 for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) {
138 auto MemOffsetAndIdx = MemOffset2Idx.find(Val: MemOffset);
139 if (MemOffsetAndIdx == MemOffset2Idx.end())
140 return std::nullopt;
141 const int64_t Idx = MemOffsetAndIdx->second - LowestIdx;
142 assert(Idx >= 0 && "Expected non-negative byte offset?");
143 LittleEndian &= Idx == littleEndianByteAt(ByteWidth: Width, I: MemOffset);
144 BigEndian &= Idx == bigEndianByteAt(ByteWidth: Width, I: MemOffset);
145 if (!BigEndian && !LittleEndian)
146 return std::nullopt;
147 }
148
149 assert((BigEndian != LittleEndian) &&
150 "Pattern cannot be both big and little endian!");
151 return BigEndian;
152}
153
154bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; }
155
156bool CombinerHelper::isLegal(const LegalityQuery &Query) const {
157 assert(LI && "Must have LegalizerInfo to query isLegal!");
158 return LI->getAction(Query).Action == LegalizeActions::Legal;
159}
160
161bool CombinerHelper::isLegalOrBeforeLegalizer(
162 const LegalityQuery &Query) const {
163 return isPreLegalize() || isLegal(Query);
164}
165
166bool CombinerHelper::isLegalOrHasWidenScalar(const LegalityQuery &Query) const {
167 return isLegal(Query) ||
168 LI->getAction(Query).Action == LegalizeActions::WidenScalar;
169}
170
171bool CombinerHelper::isLegalOrHasFewerElements(
172 const LegalityQuery &Query) const {
173 LegalizeAction Action = LI->getAction(Query).Action;
174 return Action == LegalizeActions::Legal ||
175 Action == LegalizeActions::FewerElements;
176}
177
178bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
179 if (!Ty.isVector())
180 return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CONSTANT, {Ty}});
181 // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs.
182 if (isPreLegalize())
183 return true;
184 LLT EltTy = Ty.getElementType();
185 return isLegal(Query: {TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) &&
186 isLegal(Query: {TargetOpcode::G_CONSTANT, {EltTy}});
187}
188
189void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg,
190 Register ToReg) const {
191 Observer.changingAllUsesOfReg(MRI, Reg: FromReg);
192
193 if (MRI.constrainRegAttrs(Reg: ToReg, ConstrainingReg: FromReg))
194 MRI.replaceRegWith(FromReg, ToReg);
195 else
196 Builder.buildCopy(Res: FromReg, Op: ToReg);
197
198 Observer.finishedChangingAllUsesOfReg();
199}
200
201void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI,
202 MachineOperand &FromRegOp,
203 Register ToReg) const {
204 assert(FromRegOp.getParent() && "Expected an operand in an MI");
205 Observer.changingInstr(MI&: *FromRegOp.getParent());
206
207 FromRegOp.setReg(ToReg);
208
209 Observer.changedInstr(MI&: *FromRegOp.getParent());
210}
211
212void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI,
213 unsigned ToOpcode) const {
214 Observer.changingInstr(MI&: FromMI);
215
216 FromMI.setDesc(Builder.getTII().get(Opcode: ToOpcode));
217
218 Observer.changedInstr(MI&: FromMI);
219}
220
221const RegisterBank *CombinerHelper::getRegBank(Register Reg) const {
222 return RBI->getRegBank(Reg, MRI, TRI: *TRI);
223}
224
225void CombinerHelper::setRegBank(Register Reg,
226 const RegisterBank *RegBank) const {
227 if (RegBank)
228 MRI.setRegBank(Reg, RegBank: *RegBank);
229}
230
231bool CombinerHelper::tryCombineCopy(MachineInstr &MI) const {
232 if (matchCombineCopy(MI)) {
233 applyCombineCopy(MI);
234 return true;
235 }
236 return false;
237}
238bool CombinerHelper::matchCombineCopy(MachineInstr &MI) const {
239 if (MI.getOpcode() != TargetOpcode::COPY)
240 return false;
241 Register DstReg = MI.getOperand(i: 0).getReg();
242 Register SrcReg = MI.getOperand(i: 1).getReg();
243 return canReplaceReg(DstReg, SrcReg, MRI);
244}
245void CombinerHelper::applyCombineCopy(MachineInstr &MI) const {
246 Register DstReg = MI.getOperand(i: 0).getReg();
247 Register SrcReg = MI.getOperand(i: 1).getReg();
248 replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg);
249 MI.eraseFromParent();
250}
251
252bool CombinerHelper::matchFreezeOfSingleMaybePoisonOperand(
253 MachineInstr &MI, BuildFnTy &MatchInfo) const {
254 // Ported from InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating.
255 Register DstOp = MI.getOperand(i: 0).getReg();
256 Register OrigOp = MI.getOperand(i: 1).getReg();
257
258 if (!MRI.hasOneNonDBGUse(RegNo: OrigOp))
259 return false;
260
261 MachineInstr *OrigDef = MRI.getUniqueVRegDef(Reg: OrigOp);
262 // Even if only a single operand of the PHI is not guaranteed non-poison,
263 // moving freeze() backwards across a PHI can cause optimization issues for
264 // other users of that operand.
265 //
266 // Moving freeze() from one of the output registers of a G_UNMERGE_VALUES to
267 // the source register is unprofitable because it makes the freeze() more
268 // strict than is necessary (it would affect the whole register instead of
269 // just the subreg being frozen).
270 if (OrigDef->isPHI() || isa<GUnmerge>(Val: OrigDef))
271 return false;
272
273 if (canCreateUndefOrPoison(Reg: OrigOp, MRI,
274 /*ConsiderFlagsAndMetadata=*/false))
275 return false;
276
277 std::optional<MachineOperand> MaybePoisonOperand;
278 for (MachineOperand &Operand : OrigDef->uses()) {
279 if (!Operand.isReg())
280 return false;
281
282 if (isGuaranteedNotToBeUndefOrPoison(Reg: Operand.getReg(), MRI))
283 continue;
284
285 if (!MaybePoisonOperand)
286 MaybePoisonOperand = Operand;
287 else {
288 // We have more than one maybe-poison operand. Moving the freeze is
289 // unsafe.
290 return false;
291 }
292 }
293
294 // Eliminate freeze if all operands are guaranteed non-poison.
295 if (!MaybePoisonOperand) {
296 MatchInfo = [=](MachineIRBuilder &B) {
297 Observer.changingInstr(MI&: *OrigDef);
298 cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags();
299 Observer.changedInstr(MI&: *OrigDef);
300 B.buildCopy(Res: DstOp, Op: OrigOp);
301 };
302 return true;
303 }
304
305 Register MaybePoisonOperandReg = MaybePoisonOperand->getReg();
306 LLT MaybePoisonOperandRegTy = MRI.getType(Reg: MaybePoisonOperandReg);
307
308 MatchInfo = [=](MachineIRBuilder &B) mutable {
309 Observer.changingInstr(MI&: *OrigDef);
310 cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags();
311 Observer.changedInstr(MI&: *OrigDef);
312 B.setInsertPt(MBB&: *OrigDef->getParent(), II: OrigDef->getIterator());
313 auto Freeze = B.buildFreeze(Dst: MaybePoisonOperandRegTy, Src: MaybePoisonOperandReg);
314 replaceRegOpWith(
315 MRI, FromRegOp&: *OrigDef->findRegisterUseOperand(Reg: MaybePoisonOperandReg, TRI),
316 ToReg: Freeze.getReg(Idx: 0));
317 replaceRegWith(MRI, FromReg: DstOp, ToReg: OrigOp);
318 };
319 return true;
320}
321
322bool CombinerHelper::matchCombineConcatVectors(
323 MachineInstr &MI, SmallVector<Register> &Ops) const {
324 assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS &&
325 "Invalid instruction");
326 bool IsUndef = true;
327 MachineInstr *Undef = nullptr;
328
329 // Walk over all the operands of concat vectors and check if they are
330 // build_vector themselves or undef.
331 // Then collect their operands in Ops.
332 for (const MachineOperand &MO : MI.uses()) {
333 Register Reg = MO.getReg();
334 MachineInstr *Def = MRI.getVRegDef(Reg);
335 assert(Def && "Operand not defined");
336 if (!MRI.hasOneNonDBGUse(RegNo: Reg))
337 return false;
338 switch (Def->getOpcode()) {
339 case TargetOpcode::G_BUILD_VECTOR:
340 IsUndef = false;
341 // Remember the operands of the build_vector to fold
342 // them into the yet-to-build flattened concat vectors.
343 for (const MachineOperand &BuildVecMO : Def->uses())
344 Ops.push_back(Elt: BuildVecMO.getReg());
345 break;
346 case TargetOpcode::G_IMPLICIT_DEF: {
347 LLT OpType = MRI.getType(Reg);
348 // Keep one undef value for all the undef operands.
349 if (!Undef) {
350 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
351 Undef = Builder.buildUndef(Res: OpType.getScalarType());
352 }
353 assert(MRI.getType(Undef->getOperand(0).getReg()) ==
354 OpType.getScalarType() &&
355 "All undefs should have the same type");
356 // Break the undef vector in as many scalar elements as needed
357 // for the flattening.
358 for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements();
359 EltIdx != EltEnd; ++EltIdx)
360 Ops.push_back(Elt: Undef->getOperand(i: 0).getReg());
361 break;
362 }
363 default:
364 return false;
365 }
366 }
367
368 // Check if the combine is illegal
369 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
370 if (!isLegalOrBeforeLegalizer(
371 Query: {TargetOpcode::G_BUILD_VECTOR, {DstTy, MRI.getType(Reg: Ops[0])}})) {
372 return false;
373 }
374
375 if (IsUndef)
376 Ops.clear();
377
378 return true;
379}
380void CombinerHelper::applyCombineConcatVectors(
381 MachineInstr &MI, SmallVector<Register> &Ops) const {
382 // We determined that the concat_vectors can be flatten.
383 // Generate the flattened build_vector.
384 Register DstReg = MI.getOperand(i: 0).getReg();
385 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
386 Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg);
387
388 // Note: IsUndef is sort of redundant. We could have determine it by
389 // checking that at all Ops are undef. Alternatively, we could have
390 // generate a build_vector of undefs and rely on another combine to
391 // clean that up. For now, given we already gather this information
392 // in matchCombineConcatVectors, just save compile time and issue the
393 // right thing.
394 if (Ops.empty())
395 Builder.buildUndef(Res: NewDstReg);
396 else
397 Builder.buildBuildVector(Res: NewDstReg, Ops);
398 replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg);
399 MI.eraseFromParent();
400}
401
402void CombinerHelper::applyCombineShuffleToBuildVector(MachineInstr &MI) const {
403 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
404
405 Register SrcVec1 = Shuffle.getSrc1Reg();
406 Register SrcVec2 = Shuffle.getSrc2Reg();
407 LLT EltTy = MRI.getType(Reg: SrcVec1).getElementType();
408 int Width = MRI.getType(Reg: SrcVec1).getNumElements();
409
410 auto Unmerge1 = Builder.buildUnmerge(Res: EltTy, Op: SrcVec1);
411 auto Unmerge2 = Builder.buildUnmerge(Res: EltTy, Op: SrcVec2);
412
413 SmallVector<Register> Extracts;
414 // Select only applicable elements from unmerged values.
415 for (int Val : Shuffle.getMask()) {
416 if (Val == -1)
417 Extracts.push_back(Elt: Builder.buildUndef(Res: EltTy).getReg(Idx: 0));
418 else if (Val < Width)
419 Extracts.push_back(Elt: Unmerge1.getReg(Idx: Val));
420 else
421 Extracts.push_back(Elt: Unmerge2.getReg(Idx: Val - Width));
422 }
423 assert(Extracts.size() > 0 && "Expected at least one element in the shuffle");
424 if (Extracts.size() == 1)
425 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: Extracts[0]);
426 else
427 Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: Extracts);
428 MI.eraseFromParent();
429}
430
431bool CombinerHelper::matchCombineShuffleConcat(
432 MachineInstr &MI, SmallVector<Register> &Ops) const {
433 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
434 auto ConcatMI1 =
435 dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg()));
436 auto ConcatMI2 =
437 dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg()));
438 if (!ConcatMI1 || !ConcatMI2)
439 return false;
440
441 // Check that the sources of the Concat instructions have the same type
442 if (MRI.getType(Reg: ConcatMI1->getSourceReg(I: 0)) !=
443 MRI.getType(Reg: ConcatMI2->getSourceReg(I: 0)))
444 return false;
445
446 LLT ConcatSrcTy = MRI.getType(Reg: ConcatMI1->getReg(Idx: 1));
447 LLT ShuffleSrcTy1 = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
448 unsigned ConcatSrcNumElt = ConcatSrcTy.getNumElements();
449 for (unsigned i = 0; i < Mask.size(); i += ConcatSrcNumElt) {
450 // Check if the index takes a whole source register from G_CONCAT_VECTORS
451 // Assumes that all Sources of G_CONCAT_VECTORS are the same type
452 if (Mask[i] == -1) {
453 for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
454 if (i + j >= Mask.size())
455 return false;
456 if (Mask[i + j] != -1)
457 return false;
458 }
459 if (!isLegalOrBeforeLegalizer(
460 Query: {TargetOpcode::G_IMPLICIT_DEF, {ConcatSrcTy}}))
461 return false;
462 Ops.push_back(Elt: 0);
463 } else if (Mask[i] % ConcatSrcNumElt == 0) {
464 for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
465 if (i + j >= Mask.size())
466 return false;
467 if (Mask[i + j] != Mask[i] + static_cast<int>(j))
468 return false;
469 }
470 // Retrieve the source register from its respective G_CONCAT_VECTORS
471 // instruction
472 if (Mask[i] < ShuffleSrcTy1.getNumElements()) {
473 Ops.push_back(Elt: ConcatMI1->getSourceReg(I: Mask[i] / ConcatSrcNumElt));
474 } else {
475 Ops.push_back(Elt: ConcatMI2->getSourceReg(I: Mask[i] / ConcatSrcNumElt -
476 ConcatMI1->getNumSources()));
477 }
478 } else {
479 return false;
480 }
481 }
482
483 if (!isLegalOrBeforeLegalizer(
484 Query: {TargetOpcode::G_CONCAT_VECTORS,
485 {MRI.getType(Reg: MI.getOperand(i: 0).getReg()), ConcatSrcTy}}))
486 return false;
487
488 return !Ops.empty();
489}
490
491void CombinerHelper::applyCombineShuffleConcat(
492 MachineInstr &MI, SmallVector<Register> &Ops) const {
493 LLT SrcTy;
494 for (Register &Reg : Ops) {
495 if (Reg != 0)
496 SrcTy = MRI.getType(Reg);
497 }
498 assert(SrcTy.isValid() && "Unexpected full undef vector in concat combine");
499
500 Register UndefReg = 0;
501
502 for (Register &Reg : Ops) {
503 if (Reg == 0) {
504 if (UndefReg == 0)
505 UndefReg = Builder.buildUndef(Res: SrcTy).getReg(Idx: 0);
506 Reg = UndefReg;
507 }
508 }
509
510 if (Ops.size() > 1)
511 Builder.buildConcatVectors(Res: MI.getOperand(i: 0).getReg(), Ops);
512 else
513 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: Ops[0]);
514 MI.eraseFromParent();
515}
516
517bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) const {
518 SmallVector<Register, 4> Ops;
519 if (matchCombineShuffleVector(MI, Ops)) {
520 applyCombineShuffleVector(MI, Ops);
521 return true;
522 }
523 return false;
524}
525
526bool CombinerHelper::matchCombineShuffleVector(
527 MachineInstr &MI, SmallVectorImpl<Register> &Ops) const {
528 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
529 "Invalid instruction kind");
530 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
531 Register Src1 = MI.getOperand(i: 1).getReg();
532 LLT SrcType = MRI.getType(Reg: Src1);
533
534 unsigned DstNumElts = DstType.getNumElements();
535 unsigned SrcNumElts = SrcType.getNumElements();
536
537 // If the resulting vector is smaller than the size of the source
538 // vectors being concatenated, we won't be able to replace the
539 // shuffle vector into a concat_vectors.
540 //
541 // Note: We may still be able to produce a concat_vectors fed by
542 // extract_vector_elt and so on. It is less clear that would
543 // be better though, so don't bother for now.
544 //
545 // If the destination is a scalar, the size of the sources doesn't
546 // matter. we will lower the shuffle to a plain copy. This will
547 // work only if the source and destination have the same size. But
548 // that's covered by the next condition.
549 //
550 // TODO: If the size between the source and destination don't match
551 // we could still emit an extract vector element in that case.
552 if (DstNumElts < 2 * SrcNumElts)
553 return false;
554
555 // Check that the shuffle mask can be broken evenly between the
556 // different sources.
557 if (DstNumElts % SrcNumElts != 0)
558 return false;
559
560 // Mask length is a multiple of the source vector length.
561 // Check if the shuffle is some kind of concatenation of the input
562 // vectors.
563 unsigned NumConcat = DstNumElts / SrcNumElts;
564 SmallVector<int, 8> ConcatSrcs(NumConcat, -1);
565 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
566 for (unsigned i = 0; i != DstNumElts; ++i) {
567 int Idx = Mask[i];
568 // Undef value.
569 if (Idx < 0)
570 continue;
571 // Ensure the indices in each SrcType sized piece are sequential and that
572 // the same source is used for the whole piece.
573 if ((Idx % SrcNumElts != (i % SrcNumElts)) ||
574 (ConcatSrcs[i / SrcNumElts] >= 0 &&
575 ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts)))
576 return false;
577 // Remember which source this index came from.
578 ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts;
579 }
580
581 // The shuffle is concatenating multiple vectors together.
582 // Collect the different operands for that.
583 Register UndefReg;
584 Register Src2 = MI.getOperand(i: 2).getReg();
585 for (auto Src : ConcatSrcs) {
586 if (Src < 0) {
587 if (!UndefReg) {
588 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
589 UndefReg = Builder.buildUndef(Res: SrcType).getReg(Idx: 0);
590 }
591 Ops.push_back(Elt: UndefReg);
592 } else if (Src == 0)
593 Ops.push_back(Elt: Src1);
594 else
595 Ops.push_back(Elt: Src2);
596 }
597 return true;
598}
599
600void CombinerHelper::applyCombineShuffleVector(MachineInstr &MI,
601 ArrayRef<Register> Ops) const {
602 Register DstReg = MI.getOperand(i: 0).getReg();
603 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
604 Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg);
605
606 if (Ops.size() == 1)
607 Builder.buildCopy(Res: NewDstReg, Op: Ops[0]);
608 else
609 Builder.buildMergeLikeInstr(Res: NewDstReg, Ops);
610
611 replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg);
612 MI.eraseFromParent();
613}
614
615namespace {
616
617/// Select a preference between two uses. CurrentUse is the current preference
618/// while *ForCandidate is attributes of the candidate under consideration.
619PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI,
620 PreferredTuple &CurrentUse,
621 const LLT TyForCandidate,
622 unsigned OpcodeForCandidate,
623 MachineInstr *MIForCandidate) {
624 if (!CurrentUse.Ty.isValid()) {
625 if (CurrentUse.ExtendOpcode == OpcodeForCandidate ||
626 CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT)
627 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
628 return CurrentUse;
629 }
630
631 // We permit the extend to hoist through basic blocks but this is only
632 // sensible if the target has extending loads. If you end up lowering back
633 // into a load and extend during the legalizer then the end result is
634 // hoisting the extend up to the load.
635
636 // Prefer defined extensions to undefined extensions as these are more
637 // likely to reduce the number of instructions.
638 if (OpcodeForCandidate == TargetOpcode::G_ANYEXT &&
639 CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT)
640 return CurrentUse;
641 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT &&
642 OpcodeForCandidate != TargetOpcode::G_ANYEXT)
643 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
644
645 // Prefer sign extensions to zero extensions as sign-extensions tend to be
646 // more expensive. Don't do this if the load is already a zero-extend load
647 // though, otherwise we'll rewrite a zero-extend load into a sign-extend
648 // later.
649 if (!isa<GZExtLoad>(Val: LoadMI) && CurrentUse.Ty == TyForCandidate) {
650 if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT &&
651 OpcodeForCandidate == TargetOpcode::G_ZEXT)
652 return CurrentUse;
653 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT &&
654 OpcodeForCandidate == TargetOpcode::G_SEXT)
655 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
656 }
657
658 // This is potentially target specific. We've chosen the largest type
659 // because G_TRUNC is usually free. One potential catch with this is that
660 // some targets have a reduced number of larger registers than smaller
661 // registers and this choice potentially increases the live-range for the
662 // larger value.
663 if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) {
664 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
665 }
666 return CurrentUse;
667}
668
669/// Find a suitable place to insert some instructions and insert them. This
670/// function accounts for special cases like inserting before a PHI node.
671/// The current strategy for inserting before PHI's is to duplicate the
672/// instructions for each predecessor. However, while that's ok for G_TRUNC
673/// on most targets since it generally requires no code, other targets/cases may
674/// want to try harder to find a dominating block.
675static void InsertInsnsWithoutSideEffectsBeforeUse(
676 MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO,
677 std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator,
678 MachineOperand &UseMO)>
679 Inserter) {
680 MachineInstr &UseMI = *UseMO.getParent();
681
682 MachineBasicBlock *InsertBB = UseMI.getParent();
683
684 // If the use is a PHI then we want the predecessor block instead.
685 if (UseMI.isPHI()) {
686 MachineOperand *PredBB = std::next(x: &UseMO);
687 InsertBB = PredBB->getMBB();
688 }
689
690 // If the block is the same block as the def then we want to insert just after
691 // the def instead of at the start of the block.
692 if (InsertBB == DefMI.getParent()) {
693 MachineBasicBlock::iterator InsertPt = &DefMI;
694 Inserter(InsertBB, std::next(x: InsertPt), UseMO);
695 return;
696 }
697
698 // Otherwise we want the start of the BB
699 Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO);
700}
701} // end anonymous namespace
702
703bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) const {
704 PreferredTuple Preferred;
705 if (matchCombineExtendingLoads(MI, MatchInfo&: Preferred)) {
706 applyCombineExtendingLoads(MI, MatchInfo&: Preferred);
707 return true;
708 }
709 return false;
710}
711
712static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) {
713 unsigned CandidateLoadOpc;
714 switch (ExtOpc) {
715 case TargetOpcode::G_ANYEXT:
716 CandidateLoadOpc = TargetOpcode::G_LOAD;
717 break;
718 case TargetOpcode::G_SEXT:
719 CandidateLoadOpc = TargetOpcode::G_SEXTLOAD;
720 break;
721 case TargetOpcode::G_ZEXT:
722 CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD;
723 break;
724 default:
725 llvm_unreachable("Unexpected extend opc");
726 }
727 return CandidateLoadOpc;
728}
729
730bool CombinerHelper::matchCombineExtendingLoads(
731 MachineInstr &MI, PreferredTuple &Preferred) const {
732 // We match the loads and follow the uses to the extend instead of matching
733 // the extends and following the def to the load. This is because the load
734 // must remain in the same position for correctness (unless we also add code
735 // to find a safe place to sink it) whereas the extend is freely movable.
736 // It also prevents us from duplicating the load for the volatile case or just
737 // for performance.
738 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: &MI);
739 if (!LoadMI)
740 return false;
741
742 Register LoadReg = LoadMI->getDstReg();
743
744 LLT LoadValueTy = MRI.getType(Reg: LoadReg);
745 if (!LoadValueTy.isScalar())
746 return false;
747
748 // Most architectures are going to legalize <s8 loads into at least a 1 byte
749 // load, and the MMOs can only describe memory accesses in multiples of bytes.
750 // If we try to perform extload combining on those, we can end up with
751 // %a(s8) = extload %ptr (load 1 byte from %ptr)
752 // ... which is an illegal extload instruction.
753 if (LoadValueTy.getSizeInBits() < 8)
754 return false;
755
756 // For non power-of-2 types, they will very likely be legalized into multiple
757 // loads. Don't bother trying to match them into extending loads.
758 if (!llvm::has_single_bit<uint32_t>(Value: LoadValueTy.getSizeInBits()))
759 return false;
760
761 // Find the preferred type aside from the any-extends (unless it's the only
762 // one) and non-extending ops. We'll emit an extending load to that type and
763 // and emit a variant of (extend (trunc X)) for the others according to the
764 // relative type sizes. At the same time, pick an extend to use based on the
765 // extend involved in the chosen type.
766 unsigned PreferredOpcode =
767 isa<GLoad>(Val: &MI)
768 ? TargetOpcode::G_ANYEXT
769 : isa<GSExtLoad>(Val: &MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
770 Preferred = {.Ty: LLT(), .ExtendOpcode: PreferredOpcode, .MI: nullptr};
771 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: LoadReg)) {
772 if (UseMI.getOpcode() == TargetOpcode::G_SEXT ||
773 UseMI.getOpcode() == TargetOpcode::G_ZEXT ||
774 (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) {
775 const auto &MMO = LoadMI->getMMO();
776 // Don't do anything for atomics.
777 if (MMO.isAtomic())
778 continue;
779 // Check for legality.
780 if (!isPreLegalize()) {
781 LegalityQuery::MemDesc MMDesc(MMO);
782 unsigned CandidateLoadOpc = getExtLoadOpcForExtend(ExtOpc: UseMI.getOpcode());
783 LLT UseTy = MRI.getType(Reg: UseMI.getOperand(i: 0).getReg());
784 LLT SrcTy = MRI.getType(Reg: LoadMI->getPointerReg());
785 if (LI->getAction(Query: {CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}})
786 .Action != LegalizeActions::Legal)
787 continue;
788 }
789 Preferred = ChoosePreferredUse(LoadMI&: MI, CurrentUse&: Preferred,
790 TyForCandidate: MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()),
791 OpcodeForCandidate: UseMI.getOpcode(), MIForCandidate: &UseMI);
792 }
793 }
794
795 // There were no extends
796 if (!Preferred.MI)
797 return false;
798 // It should be impossible to chose an extend without selecting a different
799 // type since by definition the result of an extend is larger.
800 assert(Preferred.Ty != LoadValueTy && "Extending to same type?");
801
802 LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI);
803 return true;
804}
805
806void CombinerHelper::applyCombineExtendingLoads(
807 MachineInstr &MI, PreferredTuple &Preferred) const {
808 // Rewrite the load to the chosen extending load.
809 Register ChosenDstReg = Preferred.MI->getOperand(i: 0).getReg();
810
811 // Inserter to insert a truncate back to the original type at a given point
812 // with some basic CSE to limit truncate duplication to one per BB.
813 DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns;
814 auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB,
815 MachineBasicBlock::iterator InsertBefore,
816 MachineOperand &UseMO) {
817 MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(Val: InsertIntoBB);
818 if (PreviouslyEmitted) {
819 Observer.changingInstr(MI&: *UseMO.getParent());
820 UseMO.setReg(PreviouslyEmitted->getOperand(i: 0).getReg());
821 Observer.changedInstr(MI&: *UseMO.getParent());
822 return;
823 }
824
825 Builder.setInsertPt(MBB&: *InsertIntoBB, II: InsertBefore);
826 Register NewDstReg = MRI.cloneVirtualRegister(VReg: MI.getOperand(i: 0).getReg());
827 MachineInstr *NewMI = Builder.buildTrunc(Res: NewDstReg, Op: ChosenDstReg);
828 EmittedInsns[InsertIntoBB] = NewMI;
829 replaceRegOpWith(MRI, FromRegOp&: UseMO, ToReg: NewDstReg);
830 };
831
832 Observer.changingInstr(MI);
833 unsigned LoadOpc = getExtLoadOpcForExtend(ExtOpc: Preferred.ExtendOpcode);
834 MI.setDesc(Builder.getTII().get(Opcode: LoadOpc));
835
836 // Rewrite all the uses to fix up the types.
837 auto &LoadValue = MI.getOperand(i: 0);
838 SmallVector<MachineOperand *, 4> Uses(
839 llvm::make_pointer_range(Range: MRI.use_operands(Reg: LoadValue.getReg())));
840
841 for (auto *UseMO : Uses) {
842 MachineInstr *UseMI = UseMO->getParent();
843
844 // If the extend is compatible with the preferred extend then we should fix
845 // up the type and extend so that it uses the preferred use.
846 if (UseMI->getOpcode() == Preferred.ExtendOpcode ||
847 UseMI->getOpcode() == TargetOpcode::G_ANYEXT) {
848 Register UseDstReg = UseMI->getOperand(i: 0).getReg();
849 MachineOperand &UseSrcMO = UseMI->getOperand(i: 1);
850 const LLT UseDstTy = MRI.getType(Reg: UseDstReg);
851 if (UseDstReg != ChosenDstReg) {
852 if (Preferred.Ty == UseDstTy) {
853 // If the use has the same type as the preferred use, then merge
854 // the vregs and erase the extend. For example:
855 // %1:_(s8) = G_LOAD ...
856 // %2:_(s32) = G_SEXT %1(s8)
857 // %3:_(s32) = G_ANYEXT %1(s8)
858 // ... = ... %3(s32)
859 // rewrites to:
860 // %2:_(s32) = G_SEXTLOAD ...
861 // ... = ... %2(s32)
862 replaceRegWith(MRI, FromReg: UseDstReg, ToReg: ChosenDstReg);
863 Observer.erasingInstr(MI&: *UseMO->getParent());
864 UseMO->getParent()->eraseFromParent();
865 } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) {
866 // If the preferred size is smaller, then keep the extend but extend
867 // from the result of the extending load. For example:
868 // %1:_(s8) = G_LOAD ...
869 // %2:_(s32) = G_SEXT %1(s8)
870 // %3:_(s64) = G_ANYEXT %1(s8)
871 // ... = ... %3(s64)
872 /// rewrites to:
873 // %2:_(s32) = G_SEXTLOAD ...
874 // %3:_(s64) = G_ANYEXT %2:_(s32)
875 // ... = ... %3(s64)
876 replaceRegOpWith(MRI, FromRegOp&: UseSrcMO, ToReg: ChosenDstReg);
877 } else {
878 // If the preferred size is large, then insert a truncate. For
879 // example:
880 // %1:_(s8) = G_LOAD ...
881 // %2:_(s64) = G_SEXT %1(s8)
882 // %3:_(s32) = G_ZEXT %1(s8)
883 // ... = ... %3(s32)
884 /// rewrites to:
885 // %2:_(s64) = G_SEXTLOAD ...
886 // %4:_(s8) = G_TRUNC %2:_(s32)
887 // %3:_(s64) = G_ZEXT %2:_(s8)
888 // ... = ... %3(s64)
889 InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO,
890 Inserter: InsertTruncAt);
891 }
892 continue;
893 }
894 // The use is (one of) the uses of the preferred use we chose earlier.
895 // We're going to update the load to def this value later so just erase
896 // the old extend.
897 Observer.erasingInstr(MI&: *UseMO->getParent());
898 UseMO->getParent()->eraseFromParent();
899 continue;
900 }
901
902 // The use isn't an extend. Truncate back to the type we originally loaded.
903 // This is free on many targets.
904 InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, Inserter: InsertTruncAt);
905 }
906
907 MI.getOperand(i: 0).setReg(ChosenDstReg);
908 Observer.changedInstr(MI);
909}
910
911bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI,
912 BuildFnTy &MatchInfo) const {
913 assert(MI.getOpcode() == TargetOpcode::G_AND);
914
915 // If we have the following code:
916 // %mask = G_CONSTANT 255
917 // %ld = G_LOAD %ptr, (load s16)
918 // %and = G_AND %ld, %mask
919 //
920 // Try to fold it into
921 // %ld = G_ZEXTLOAD %ptr, (load s8)
922
923 Register Dst = MI.getOperand(i: 0).getReg();
924 if (MRI.getType(Reg: Dst).isVector())
925 return false;
926
927 auto MaybeMask =
928 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
929 if (!MaybeMask)
930 return false;
931
932 APInt MaskVal = MaybeMask->Value;
933
934 if (!MaskVal.isMask())
935 return false;
936
937 Register SrcReg = MI.getOperand(i: 1).getReg();
938 // Don't use getOpcodeDef() here since intermediate instructions may have
939 // multiple users.
940 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: MRI.getVRegDef(Reg: SrcReg));
941 if (!LoadMI || !MRI.hasOneNonDBGUse(RegNo: LoadMI->getDstReg()))
942 return false;
943
944 Register LoadReg = LoadMI->getDstReg();
945 LLT RegTy = MRI.getType(Reg: LoadReg);
946 Register PtrReg = LoadMI->getPointerReg();
947 unsigned RegSize = RegTy.getSizeInBits();
948 LocationSize LoadSizeBits = LoadMI->getMemSizeInBits();
949 unsigned MaskSizeBits = MaskVal.countr_one();
950
951 // The mask may not be larger than the in-memory type, as it might cover sign
952 // extended bits
953 if (MaskSizeBits > LoadSizeBits.getValue())
954 return false;
955
956 // If the mask covers the whole destination register, there's nothing to
957 // extend
958 if (MaskSizeBits >= RegSize)
959 return false;
960
961 // Most targets cannot deal with loads of size < 8 and need to re-legalize to
962 // at least byte loads. Avoid creating such loads here
963 if (MaskSizeBits < 8 || !isPowerOf2_32(Value: MaskSizeBits))
964 return false;
965
966 const MachineMemOperand &MMO = LoadMI->getMMO();
967 LegalityQuery::MemDesc MemDesc(MMO);
968
969 // Don't modify the memory access size if this is atomic/volatile, but we can
970 // still adjust the opcode to indicate the high bit behavior.
971 if (LoadMI->isSimple())
972 MemDesc.MemoryTy = LLT::scalar(SizeInBits: MaskSizeBits);
973 else if (LoadSizeBits.getValue() > MaskSizeBits ||
974 LoadSizeBits.getValue() == RegSize)
975 return false;
976
977 // TODO: Could check if it's legal with the reduced or original memory size.
978 if (!isLegalOrBeforeLegalizer(
979 Query: {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(Reg: PtrReg)}, {MemDesc}}))
980 return false;
981
982 MatchInfo = [=](MachineIRBuilder &B) {
983 B.setInstrAndDebugLoc(*LoadMI);
984 auto &MF = B.getMF();
985 auto PtrInfo = MMO.getPointerInfo();
986 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: MemDesc.MemoryTy);
987 B.buildLoadInstr(Opcode: TargetOpcode::G_ZEXTLOAD, Res: Dst, Addr: PtrReg, MMO&: *NewMMO);
988 LoadMI->eraseFromParent();
989 };
990 return true;
991}
992
993bool CombinerHelper::isPredecessor(const MachineInstr &DefMI,
994 const MachineInstr &UseMI) const {
995 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
996 "shouldn't consider debug uses");
997 assert(DefMI.getParent() == UseMI.getParent());
998 if (&DefMI == &UseMI)
999 return true;
1000 const MachineBasicBlock &MBB = *DefMI.getParent();
1001 auto DefOrUse = find_if(Range: MBB, P: [&DefMI, &UseMI](const MachineInstr &MI) {
1002 return &MI == &DefMI || &MI == &UseMI;
1003 });
1004 if (DefOrUse == MBB.end())
1005 llvm_unreachable("Block must contain both DefMI and UseMI!");
1006 return &*DefOrUse == &DefMI;
1007}
1008
1009bool CombinerHelper::dominates(const MachineInstr &DefMI,
1010 const MachineInstr &UseMI) const {
1011 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
1012 "shouldn't consider debug uses");
1013 if (MDT)
1014 return MDT->dominates(A: &DefMI, B: &UseMI);
1015 else if (DefMI.getParent() != UseMI.getParent())
1016 return false;
1017
1018 return isPredecessor(DefMI, UseMI);
1019}
1020
1021bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) const {
1022 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1023 Register SrcReg = MI.getOperand(i: 1).getReg();
1024 Register LoadUser = SrcReg;
1025
1026 if (MRI.getType(Reg: SrcReg).isVector())
1027 return false;
1028
1029 Register TruncSrc;
1030 if (mi_match(R: SrcReg, MRI, P: m_GTrunc(Src: m_Reg(R&: TruncSrc))))
1031 LoadUser = TruncSrc;
1032
1033 uint64_t SizeInBits = MI.getOperand(i: 2).getImm();
1034 // If the source is a G_SEXTLOAD from the same bit width, then we don't
1035 // need any extend at all, just a truncate.
1036 if (auto *LoadMI = getOpcodeDef<GSExtLoad>(Reg: LoadUser, MRI)) {
1037 // If truncating more than the original extended value, abort.
1038 auto LoadSizeBits = LoadMI->getMemSizeInBits();
1039 if (TruncSrc &&
1040 MRI.getType(Reg: TruncSrc).getSizeInBits() < LoadSizeBits.getValue())
1041 return false;
1042 if (LoadSizeBits == SizeInBits)
1043 return true;
1044 }
1045 return false;
1046}
1047
1048void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) const {
1049 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1050 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: MI.getOperand(i: 1).getReg());
1051 MI.eraseFromParent();
1052}
1053
1054bool CombinerHelper::matchSextInRegOfLoad(
1055 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const {
1056 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1057
1058 Register DstReg = MI.getOperand(i: 0).getReg();
1059 LLT RegTy = MRI.getType(Reg: DstReg);
1060
1061 // Only supports scalars for now.
1062 if (RegTy.isVector())
1063 return false;
1064
1065 Register SrcReg = MI.getOperand(i: 1).getReg();
1066 auto *LoadDef = getOpcodeDef<GLoad>(Reg: SrcReg, MRI);
1067 if (!LoadDef || !MRI.hasOneNonDBGUse(RegNo: SrcReg))
1068 return false;
1069
1070 uint64_t MemBits = LoadDef->getMemSizeInBits().getValue();
1071
1072 // If the sign extend extends from a narrower width than the load's width,
1073 // then we can narrow the load width when we combine to a G_SEXTLOAD.
1074 // Avoid widening the load at all.
1075 unsigned NewSizeBits = std::min(a: (uint64_t)MI.getOperand(i: 2).getImm(), b: MemBits);
1076
1077 // Don't generate G_SEXTLOADs with a < 1 byte width.
1078 if (NewSizeBits < 8)
1079 return false;
1080 // Don't bother creating a non-power-2 sextload, it will likely be broken up
1081 // anyway for most targets.
1082 if (!isPowerOf2_32(Value: NewSizeBits))
1083 return false;
1084
1085 const MachineMemOperand &MMO = LoadDef->getMMO();
1086 LegalityQuery::MemDesc MMDesc(MMO);
1087
1088 // Don't modify the memory access size if this is atomic/volatile, but we can
1089 // still adjust the opcode to indicate the high bit behavior.
1090 if (LoadDef->isSimple())
1091 MMDesc.MemoryTy = LLT::scalar(SizeInBits: NewSizeBits);
1092 else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits())
1093 return false;
1094
1095 // TODO: Could check if it's legal with the reduced or original memory size.
1096 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXTLOAD,
1097 {MRI.getType(Reg: LoadDef->getDstReg()),
1098 MRI.getType(Reg: LoadDef->getPointerReg())},
1099 {MMDesc}}))
1100 return false;
1101
1102 MatchInfo = std::make_tuple(args: LoadDef->getDstReg(), args&: NewSizeBits);
1103 return true;
1104}
1105
1106void CombinerHelper::applySextInRegOfLoad(
1107 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const {
1108 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1109 Register LoadReg;
1110 unsigned ScalarSizeBits;
1111 std::tie(args&: LoadReg, args&: ScalarSizeBits) = MatchInfo;
1112 GLoad *LoadDef = cast<GLoad>(Val: MRI.getVRegDef(Reg: LoadReg));
1113
1114 // If we have the following:
1115 // %ld = G_LOAD %ptr, (load 2)
1116 // %ext = G_SEXT_INREG %ld, 8
1117 // ==>
1118 // %ld = G_SEXTLOAD %ptr (load 1)
1119
1120 auto &MMO = LoadDef->getMMO();
1121 Builder.setInstrAndDebugLoc(*LoadDef);
1122 auto &MF = Builder.getMF();
1123 auto PtrInfo = MMO.getPointerInfo();
1124 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: ScalarSizeBits / 8);
1125 Builder.buildLoadInstr(Opcode: TargetOpcode::G_SEXTLOAD, Res: MI.getOperand(i: 0).getReg(),
1126 Addr: LoadDef->getPointerReg(), MMO&: *NewMMO);
1127 MI.eraseFromParent();
1128
1129 // Not all loads can be deleted, so make sure the old one is removed.
1130 LoadDef->eraseFromParent();
1131}
1132
1133/// Return true if 'MI' is a load or a store that may be fold it's address
1134/// operand into the load / store addressing mode.
1135static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI,
1136 MachineRegisterInfo &MRI) {
1137 TargetLowering::AddrMode AM;
1138 auto *MF = MI->getMF();
1139 auto *Addr = getOpcodeDef<GPtrAdd>(Reg: MI->getPointerReg(), MRI);
1140 if (!Addr)
1141 return false;
1142
1143 AM.HasBaseReg = true;
1144 if (auto CstOff = getIConstantVRegVal(VReg: Addr->getOffsetReg(), MRI))
1145 AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm]
1146 else
1147 AM.Scale = 1; // [reg +/- reg]
1148
1149 return TLI.isLegalAddressingMode(
1150 DL: MF->getDataLayout(), AM,
1151 Ty: getTypeForLLT(Ty: MI->getMMO().getMemoryType(),
1152 C&: MF->getFunction().getContext()),
1153 AddrSpace: MI->getMMO().getAddrSpace());
1154}
1155
1156static unsigned getIndexedOpc(unsigned LdStOpc) {
1157 switch (LdStOpc) {
1158 case TargetOpcode::G_LOAD:
1159 return TargetOpcode::G_INDEXED_LOAD;
1160 case TargetOpcode::G_STORE:
1161 return TargetOpcode::G_INDEXED_STORE;
1162 case TargetOpcode::G_ZEXTLOAD:
1163 return TargetOpcode::G_INDEXED_ZEXTLOAD;
1164 case TargetOpcode::G_SEXTLOAD:
1165 return TargetOpcode::G_INDEXED_SEXTLOAD;
1166 default:
1167 llvm_unreachable("Unexpected opcode");
1168 }
1169}
1170
1171bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
1172 // Check for legality.
1173 LLT PtrTy = MRI.getType(Reg: LdSt.getPointerReg());
1174 LLT Ty = MRI.getType(Reg: LdSt.getReg(Idx: 0));
1175 LLT MemTy = LdSt.getMMO().getMemoryType();
1176 SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
1177 {{MemTy, MemTy.getSizeInBits().getKnownMinValue(),
1178 AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic}});
1179 unsigned IndexedOpc = getIndexedOpc(LdStOpc: LdSt.getOpcode());
1180 SmallVector<LLT> OpTys;
1181 if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)
1182 OpTys = {PtrTy, Ty, Ty};
1183 else
1184 OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD
1185
1186 LegalityQuery Q(IndexedOpc, OpTys, MemDescrs);
1187 return isLegal(Query: Q);
1188}
1189
1190static cl::opt<unsigned> PostIndexUseThreshold(
1191 "post-index-use-threshold", cl::Hidden, cl::init(Val: 32),
1192 cl::desc("Number of uses of a base pointer to check before it is no longer "
1193 "considered for post-indexing."));
1194
1195bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr,
1196 Register &Base, Register &Offset,
1197 bool &RematOffset) const {
1198 // We're looking for the following pattern, for either load or store:
1199 // %baseptr:_(p0) = ...
1200 // G_STORE %val(s64), %baseptr(p0)
1201 // %offset:_(s64) = G_CONSTANT i64 -256
1202 // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64)
1203 const auto &TLI = getTargetLowering();
1204
1205 Register Ptr = LdSt.getPointerReg();
1206 // If the store is the only use, don't bother.
1207 if (MRI.hasOneNonDBGUse(RegNo: Ptr))
1208 return false;
1209
1210 if (!isIndexedLoadStoreLegal(LdSt))
1211 return false;
1212
1213 if (getOpcodeDef(Opcode: TargetOpcode::G_FRAME_INDEX, Reg: Ptr, MRI))
1214 return false;
1215
1216 MachineInstr *StoredValDef = getDefIgnoringCopies(Reg: LdSt.getReg(Idx: 0), MRI);
1217 auto *PtrDef = MRI.getVRegDef(Reg: Ptr);
1218
1219 unsigned NumUsesChecked = 0;
1220 for (auto &Use : MRI.use_nodbg_instructions(Reg: Ptr)) {
1221 if (++NumUsesChecked > PostIndexUseThreshold)
1222 return false; // Try to avoid exploding compile time.
1223
1224 auto *PtrAdd = dyn_cast<GPtrAdd>(Val: &Use);
1225 // The use itself might be dead. This can happen during combines if DCE
1226 // hasn't had a chance to run yet. Don't allow it to form an indexed op.
1227 if (!PtrAdd || MRI.use_nodbg_empty(RegNo: PtrAdd->getReg(Idx: 0)))
1228 continue;
1229
1230 // Check the user of this isn't the store, otherwise we'd be generate a
1231 // indexed store defining its own use.
1232 if (StoredValDef == &Use)
1233 continue;
1234
1235 Offset = PtrAdd->getOffsetReg();
1236 if (!ForceLegalIndexing &&
1237 !TLI.isIndexingLegal(MI&: LdSt, Base: PtrAdd->getBaseReg(), Offset,
1238 /*IsPre*/ false, MRI))
1239 continue;
1240
1241 // Make sure the offset calculation is before the potentially indexed op.
1242 MachineInstr *OffsetDef = MRI.getVRegDef(Reg: Offset);
1243 RematOffset = false;
1244 if (!dominates(DefMI: *OffsetDef, UseMI: LdSt)) {
1245 // If the offset however is just a G_CONSTANT, we can always just
1246 // rematerialize it where we need it.
1247 if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT)
1248 continue;
1249 RematOffset = true;
1250 }
1251
1252 for (auto &BasePtrUse : MRI.use_nodbg_instructions(Reg: PtrAdd->getBaseReg())) {
1253 if (&BasePtrUse == PtrDef)
1254 continue;
1255
1256 // If the user is a later load/store that can be post-indexed, then don't
1257 // combine this one.
1258 auto *BasePtrLdSt = dyn_cast<GLoadStore>(Val: &BasePtrUse);
1259 if (BasePtrLdSt && BasePtrLdSt != &LdSt &&
1260 dominates(DefMI: LdSt, UseMI: *BasePtrLdSt) &&
1261 isIndexedLoadStoreLegal(LdSt&: *BasePtrLdSt))
1262 return false;
1263
1264 // Now we're looking for the key G_PTR_ADD instruction, which contains
1265 // the offset add that we want to fold.
1266 if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(Val: &BasePtrUse)) {
1267 Register PtrAddDefReg = BasePtrUseDef->getReg(Idx: 0);
1268 for (auto &BaseUseUse : MRI.use_nodbg_instructions(Reg: PtrAddDefReg)) {
1269 // If the use is in a different block, then we may produce worse code
1270 // due to the extra register pressure.
1271 if (BaseUseUse.getParent() != LdSt.getParent())
1272 return false;
1273
1274 if (auto *UseUseLdSt = dyn_cast<GLoadStore>(Val: &BaseUseUse))
1275 if (canFoldInAddressingMode(MI: UseUseLdSt, TLI, MRI))
1276 return false;
1277 }
1278 if (!dominates(DefMI: LdSt, UseMI: BasePtrUse))
1279 return false; // All use must be dominated by the load/store.
1280 }
1281 }
1282
1283 Addr = PtrAdd->getReg(Idx: 0);
1284 Base = PtrAdd->getBaseReg();
1285 return true;
1286 }
1287
1288 return false;
1289}
1290
1291bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr,
1292 Register &Base,
1293 Register &Offset) const {
1294 auto &MF = *LdSt.getParent()->getParent();
1295 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1296
1297 Addr = LdSt.getPointerReg();
1298 if (!mi_match(R: Addr, MRI, P: m_GPtrAdd(L: m_Reg(R&: Base), R: m_Reg(R&: Offset))) ||
1299 MRI.hasOneNonDBGUse(RegNo: Addr))
1300 return false;
1301
1302 if (!ForceLegalIndexing &&
1303 !TLI.isIndexingLegal(MI&: LdSt, Base, Offset, /*IsPre*/ true, MRI))
1304 return false;
1305
1306 if (!isIndexedLoadStoreLegal(LdSt))
1307 return false;
1308
1309 MachineInstr *BaseDef = getDefIgnoringCopies(Reg: Base, MRI);
1310 if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX)
1311 return false;
1312
1313 if (auto *St = dyn_cast<GStore>(Val: &LdSt)) {
1314 // Would require a copy.
1315 if (Base == St->getValueReg())
1316 return false;
1317
1318 // We're expecting one use of Addr in MI, but it could also be the
1319 // value stored, which isn't actually dominated by the instruction.
1320 if (St->getValueReg() == Addr)
1321 return false;
1322 }
1323
1324 // Avoid increasing cross-block register pressure.
1325 for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr))
1326 if (AddrUse.getParent() != LdSt.getParent())
1327 return false;
1328
1329 // FIXME: check whether all uses of the base pointer are constant PtrAdds.
1330 // That might allow us to end base's liveness here by adjusting the constant.
1331 bool RealUse = false;
1332 for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) {
1333 if (!dominates(DefMI: LdSt, UseMI: AddrUse))
1334 return false; // All use must be dominated by the load/store.
1335
1336 // If Ptr may be folded in addressing mode of other use, then it's
1337 // not profitable to do this transformation.
1338 if (auto *UseLdSt = dyn_cast<GLoadStore>(Val: &AddrUse)) {
1339 if (!canFoldInAddressingMode(MI: UseLdSt, TLI, MRI))
1340 RealUse = true;
1341 } else {
1342 RealUse = true;
1343 }
1344 }
1345 return RealUse;
1346}
1347
1348bool CombinerHelper::matchCombineExtractedVectorLoad(
1349 MachineInstr &MI, BuildFnTy &MatchInfo) const {
1350 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
1351
1352 // Check if there is a load that defines the vector being extracted from.
1353 auto *LoadMI = getOpcodeDef<GLoad>(Reg: MI.getOperand(i: 1).getReg(), MRI);
1354 if (!LoadMI)
1355 return false;
1356
1357 Register Vector = MI.getOperand(i: 1).getReg();
1358 LLT VecEltTy = MRI.getType(Reg: Vector).getElementType();
1359
1360 assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy);
1361
1362 // Checking whether we should reduce the load width.
1363 if (!MRI.hasOneNonDBGUse(RegNo: Vector))
1364 return false;
1365
1366 // Check if the defining load is simple.
1367 if (!LoadMI->isSimple())
1368 return false;
1369
1370 // If the vector element type is not a multiple of a byte then we are unable
1371 // to correctly compute an address to load only the extracted element as a
1372 // scalar.
1373 if (!VecEltTy.isByteSized())
1374 return false;
1375
1376 // Check for load fold barriers between the extraction and the load.
1377 if (MI.getParent() != LoadMI->getParent())
1378 return false;
1379 const unsigned MaxIter = 20;
1380 unsigned Iter = 0;
1381 for (auto II = LoadMI->getIterator(), IE = MI.getIterator(); II != IE; ++II) {
1382 if (II->isLoadFoldBarrier())
1383 return false;
1384 if (Iter++ == MaxIter)
1385 return false;
1386 }
1387
1388 // Check if the new load that we are going to create is legal
1389 // if we are in the post-legalization phase.
1390 MachineMemOperand MMO = LoadMI->getMMO();
1391 Align Alignment = MMO.getAlign();
1392 MachinePointerInfo PtrInfo;
1393 uint64_t Offset;
1394
1395 // Finding the appropriate PtrInfo if offset is a known constant.
1396 // This is required to create the memory operand for the narrowed load.
1397 // This machine memory operand object helps us infer about legality
1398 // before we proceed to combine the instruction.
1399 if (auto CVal = getIConstantVRegVal(VReg: Vector, MRI)) {
1400 int Elt = CVal->getZExtValue();
1401 // FIXME: should be (ABI size)*Elt.
1402 Offset = VecEltTy.getSizeInBits() * Elt / 8;
1403 PtrInfo = MMO.getPointerInfo().getWithOffset(O: Offset);
1404 } else {
1405 // Discard the pointer info except the address space because the memory
1406 // operand can't represent this new access since the offset is variable.
1407 Offset = VecEltTy.getSizeInBits() / 8;
1408 PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace());
1409 }
1410
1411 Alignment = commonAlignment(A: Alignment, Offset);
1412
1413 Register VecPtr = LoadMI->getPointerReg();
1414 LLT PtrTy = MRI.getType(Reg: VecPtr);
1415
1416 MachineFunction &MF = *MI.getMF();
1417 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: VecEltTy);
1418
1419 LegalityQuery::MemDesc MMDesc(*NewMMO);
1420
1421 if (!isLegalOrBeforeLegalizer(
1422 Query: {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}}))
1423 return false;
1424
1425 // Load must be allowed and fast on the target.
1426 LLVMContext &C = MF.getFunction().getContext();
1427 auto &DL = MF.getDataLayout();
1428 unsigned Fast = 0;
1429 if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty: VecEltTy, MMO: *NewMMO,
1430 Fast: &Fast) ||
1431 !Fast)
1432 return false;
1433
1434 Register Result = MI.getOperand(i: 0).getReg();
1435 Register Index = MI.getOperand(i: 2).getReg();
1436
1437 MatchInfo = [=](MachineIRBuilder &B) {
1438 GISelObserverWrapper DummyObserver;
1439 LegalizerHelper Helper(B.getMF(), DummyObserver, B);
1440 //// Get pointer to the vector element.
1441 Register finalPtr = Helper.getVectorElementPointer(
1442 VecPtr: LoadMI->getPointerReg(), VecTy: MRI.getType(Reg: LoadMI->getOperand(i: 0).getReg()),
1443 Index);
1444 // New G_LOAD instruction.
1445 B.buildLoad(Res: Result, Addr: finalPtr, PtrInfo, Alignment);
1446 // Remove original GLOAD instruction.
1447 LoadMI->eraseFromParent();
1448 };
1449
1450 return true;
1451}
1452
1453bool CombinerHelper::matchCombineIndexedLoadStore(
1454 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const {
1455 auto &LdSt = cast<GLoadStore>(Val&: MI);
1456
1457 if (LdSt.isAtomic())
1458 return false;
1459
1460 MatchInfo.IsPre = findPreIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base,
1461 Offset&: MatchInfo.Offset);
1462 if (!MatchInfo.IsPre &&
1463 !findPostIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base,
1464 Offset&: MatchInfo.Offset, RematOffset&: MatchInfo.RematOffset))
1465 return false;
1466
1467 return true;
1468}
1469
1470void CombinerHelper::applyCombineIndexedLoadStore(
1471 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const {
1472 MachineInstr &AddrDef = *MRI.getUniqueVRegDef(Reg: MatchInfo.Addr);
1473 unsigned Opcode = MI.getOpcode();
1474 bool IsStore = Opcode == TargetOpcode::G_STORE;
1475 unsigned NewOpcode = getIndexedOpc(LdStOpc: Opcode);
1476
1477 // If the offset constant didn't happen to dominate the load/store, we can
1478 // just clone it as needed.
1479 if (MatchInfo.RematOffset) {
1480 auto *OldCst = MRI.getVRegDef(Reg: MatchInfo.Offset);
1481 auto NewCst = Builder.buildConstant(Res: MRI.getType(Reg: MatchInfo.Offset),
1482 Val: *OldCst->getOperand(i: 1).getCImm());
1483 MatchInfo.Offset = NewCst.getReg(Idx: 0);
1484 }
1485
1486 auto MIB = Builder.buildInstr(Opcode: NewOpcode);
1487 if (IsStore) {
1488 MIB.addDef(RegNo: MatchInfo.Addr);
1489 MIB.addUse(RegNo: MI.getOperand(i: 0).getReg());
1490 } else {
1491 MIB.addDef(RegNo: MI.getOperand(i: 0).getReg());
1492 MIB.addDef(RegNo: MatchInfo.Addr);
1493 }
1494
1495 MIB.addUse(RegNo: MatchInfo.Base);
1496 MIB.addUse(RegNo: MatchInfo.Offset);
1497 MIB.addImm(Val: MatchInfo.IsPre);
1498 MIB->cloneMemRefs(MF&: *MI.getMF(), MI);
1499 MI.eraseFromParent();
1500 AddrDef.eraseFromParent();
1501
1502 LLVM_DEBUG(dbgs() << " Combinined to indexed operation");
1503}
1504
1505bool CombinerHelper::matchCombineDivRem(MachineInstr &MI,
1506 MachineInstr *&OtherMI) const {
1507 unsigned Opcode = MI.getOpcode();
1508 bool IsDiv, IsSigned;
1509
1510 switch (Opcode) {
1511 default:
1512 llvm_unreachable("Unexpected opcode!");
1513 case TargetOpcode::G_SDIV:
1514 case TargetOpcode::G_UDIV: {
1515 IsDiv = true;
1516 IsSigned = Opcode == TargetOpcode::G_SDIV;
1517 break;
1518 }
1519 case TargetOpcode::G_SREM:
1520 case TargetOpcode::G_UREM: {
1521 IsDiv = false;
1522 IsSigned = Opcode == TargetOpcode::G_SREM;
1523 break;
1524 }
1525 }
1526
1527 Register Src1 = MI.getOperand(i: 1).getReg();
1528 unsigned DivOpcode, RemOpcode, DivremOpcode;
1529 if (IsSigned) {
1530 DivOpcode = TargetOpcode::G_SDIV;
1531 RemOpcode = TargetOpcode::G_SREM;
1532 DivremOpcode = TargetOpcode::G_SDIVREM;
1533 } else {
1534 DivOpcode = TargetOpcode::G_UDIV;
1535 RemOpcode = TargetOpcode::G_UREM;
1536 DivremOpcode = TargetOpcode::G_UDIVREM;
1537 }
1538
1539 if (!isLegalOrBeforeLegalizer(Query: {DivremOpcode, {MRI.getType(Reg: Src1)}}))
1540 return false;
1541
1542 // Combine:
1543 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1544 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1545 // into:
1546 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1547
1548 // Combine:
1549 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1550 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1551 // into:
1552 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1553
1554 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: Src1)) {
1555 if (MI.getParent() == UseMI.getParent() &&
1556 ((IsDiv && UseMI.getOpcode() == RemOpcode) ||
1557 (!IsDiv && UseMI.getOpcode() == DivOpcode)) &&
1558 matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: UseMI.getOperand(i: 2)) &&
1559 matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: UseMI.getOperand(i: 1))) {
1560 OtherMI = &UseMI;
1561 return true;
1562 }
1563 }
1564
1565 return false;
1566}
1567
1568void CombinerHelper::applyCombineDivRem(MachineInstr &MI,
1569 MachineInstr *&OtherMI) const {
1570 unsigned Opcode = MI.getOpcode();
1571 assert(OtherMI && "OtherMI shouldn't be empty.");
1572
1573 Register DestDivReg, DestRemReg;
1574 if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) {
1575 DestDivReg = MI.getOperand(i: 0).getReg();
1576 DestRemReg = OtherMI->getOperand(i: 0).getReg();
1577 } else {
1578 DestDivReg = OtherMI->getOperand(i: 0).getReg();
1579 DestRemReg = MI.getOperand(i: 0).getReg();
1580 }
1581
1582 bool IsSigned =
1583 Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM;
1584
1585 // Check which instruction is first in the block so we don't break def-use
1586 // deps by "moving" the instruction incorrectly. Also keep track of which
1587 // instruction is first so we pick it's operands, avoiding use-before-def
1588 // bugs.
1589 MachineInstr *FirstInst = dominates(DefMI: MI, UseMI: *OtherMI) ? &MI : OtherMI;
1590 Builder.setInstrAndDebugLoc(*FirstInst);
1591
1592 Builder.buildInstr(Opc: IsSigned ? TargetOpcode::G_SDIVREM
1593 : TargetOpcode::G_UDIVREM,
1594 DstOps: {DestDivReg, DestRemReg},
1595 SrcOps: { FirstInst->getOperand(i: 1), FirstInst->getOperand(i: 2) });
1596 MI.eraseFromParent();
1597 OtherMI->eraseFromParent();
1598}
1599
1600bool CombinerHelper::matchOptBrCondByInvertingCond(
1601 MachineInstr &MI, MachineInstr *&BrCond) const {
1602 assert(MI.getOpcode() == TargetOpcode::G_BR);
1603
1604 // Try to match the following:
1605 // bb1:
1606 // G_BRCOND %c1, %bb2
1607 // G_BR %bb3
1608 // bb2:
1609 // ...
1610 // bb3:
1611
1612 // The above pattern does not have a fall through to the successor bb2, always
1613 // resulting in a branch no matter which path is taken. Here we try to find
1614 // and replace that pattern with conditional branch to bb3 and otherwise
1615 // fallthrough to bb2. This is generally better for branch predictors.
1616
1617 MachineBasicBlock *MBB = MI.getParent();
1618 MachineBasicBlock::iterator BrIt(MI);
1619 if (BrIt == MBB->begin())
1620 return false;
1621 assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator");
1622
1623 BrCond = &*std::prev(x: BrIt);
1624 if (BrCond->getOpcode() != TargetOpcode::G_BRCOND)
1625 return false;
1626
1627 // Check that the next block is the conditional branch target. Also make sure
1628 // that it isn't the same as the G_BR's target (otherwise, this will loop.)
1629 MachineBasicBlock *BrCondTarget = BrCond->getOperand(i: 1).getMBB();
1630 return BrCondTarget != MI.getOperand(i: 0).getMBB() &&
1631 MBB->isLayoutSuccessor(MBB: BrCondTarget);
1632}
1633
1634void CombinerHelper::applyOptBrCondByInvertingCond(
1635 MachineInstr &MI, MachineInstr *&BrCond) const {
1636 MachineBasicBlock *BrTarget = MI.getOperand(i: 0).getMBB();
1637 Builder.setInstrAndDebugLoc(*BrCond);
1638 LLT Ty = MRI.getType(Reg: BrCond->getOperand(i: 0).getReg());
1639 // FIXME: Does int/fp matter for this? If so, we might need to restrict
1640 // this to i1 only since we might not know for sure what kind of
1641 // compare generated the condition value.
1642 auto True = Builder.buildConstant(
1643 Res: Ty, Val: getICmpTrueVal(TLI: getTargetLowering(), IsVector: false, IsFP: false));
1644 auto Xor = Builder.buildXor(Dst: Ty, Src0: BrCond->getOperand(i: 0), Src1: True);
1645
1646 auto *FallthroughBB = BrCond->getOperand(i: 1).getMBB();
1647 Observer.changingInstr(MI);
1648 MI.getOperand(i: 0).setMBB(FallthroughBB);
1649 Observer.changedInstr(MI);
1650
1651 // Change the conditional branch to use the inverted condition and
1652 // new target block.
1653 Observer.changingInstr(MI&: *BrCond);
1654 BrCond->getOperand(i: 0).setReg(Xor.getReg(Idx: 0));
1655 BrCond->getOperand(i: 1).setMBB(BrTarget);
1656 Observer.changedInstr(MI&: *BrCond);
1657}
1658
1659bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) const {
1660 MachineIRBuilder HelperBuilder(MI);
1661 GISelObserverWrapper DummyObserver;
1662 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1663 return Helper.lowerMemcpyInline(MI) ==
1664 LegalizerHelper::LegalizeResult::Legalized;
1665}
1666
1667bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI,
1668 unsigned MaxLen) const {
1669 MachineIRBuilder HelperBuilder(MI);
1670 GISelObserverWrapper DummyObserver;
1671 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1672 return Helper.lowerMemCpyFamily(MI, MaxLen) ==
1673 LegalizerHelper::LegalizeResult::Legalized;
1674}
1675
1676static APFloat constantFoldFpUnary(const MachineInstr &MI,
1677 const MachineRegisterInfo &MRI,
1678 const APFloat &Val) {
1679 APFloat Result(Val);
1680 switch (MI.getOpcode()) {
1681 default:
1682 llvm_unreachable("Unexpected opcode!");
1683 case TargetOpcode::G_FNEG: {
1684 Result.changeSign();
1685 return Result;
1686 }
1687 case TargetOpcode::G_FABS: {
1688 Result.clearSign();
1689 return Result;
1690 }
1691 case TargetOpcode::G_FCEIL:
1692 Result.roundToIntegral(RM: APFloat::rmTowardPositive);
1693 return Result;
1694 case TargetOpcode::G_FFLOOR:
1695 Result.roundToIntegral(RM: APFloat::rmTowardNegative);
1696 return Result;
1697 case TargetOpcode::G_INTRINSIC_TRUNC:
1698 Result.roundToIntegral(RM: APFloat::rmTowardZero);
1699 return Result;
1700 case TargetOpcode::G_INTRINSIC_ROUND:
1701 Result.roundToIntegral(RM: APFloat::rmNearestTiesToAway);
1702 return Result;
1703 case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
1704 Result.roundToIntegral(RM: APFloat::rmNearestTiesToEven);
1705 return Result;
1706 case TargetOpcode::G_FRINT:
1707 case TargetOpcode::G_FNEARBYINT:
1708 // Use default rounding mode (round to nearest, ties to even)
1709 Result.roundToIntegral(RM: APFloat::rmNearestTiesToEven);
1710 return Result;
1711 case TargetOpcode::G_FPEXT:
1712 case TargetOpcode::G_FPTRUNC: {
1713 bool Unused;
1714 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
1715 Result.convert(ToSemantics: getFltSemanticForLLT(Ty: DstTy), RM: APFloat::rmNearestTiesToEven,
1716 losesInfo: &Unused);
1717 return Result;
1718 }
1719 case TargetOpcode::G_FSQRT: {
1720 bool Unused;
1721 Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven,
1722 losesInfo: &Unused);
1723 Result = APFloat(sqrt(x: Result.convertToDouble()));
1724 break;
1725 }
1726 case TargetOpcode::G_FLOG2: {
1727 bool Unused;
1728 Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven,
1729 losesInfo: &Unused);
1730 Result = APFloat(log2(x: Result.convertToDouble()));
1731 break;
1732 }
1733 }
1734 // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise,
1735 // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and
1736 // `G_FLOG2` reach here.
1737 bool Unused;
1738 Result.convert(ToSemantics: Val.getSemantics(), RM: APFloat::rmNearestTiesToEven, losesInfo: &Unused);
1739 return Result;
1740}
1741
1742void CombinerHelper::applyCombineConstantFoldFpUnary(
1743 MachineInstr &MI, const ConstantFP *Cst) const {
1744 APFloat Folded = constantFoldFpUnary(MI, MRI, Val: Cst->getValue());
1745 const ConstantFP *NewCst = ConstantFP::get(Context&: Builder.getContext(), V: Folded);
1746 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: *NewCst);
1747 MI.eraseFromParent();
1748}
1749
1750bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI,
1751 PtrAddChain &MatchInfo) const {
1752 // We're trying to match the following pattern:
1753 // %t1 = G_PTR_ADD %base, G_CONSTANT imm1
1754 // %root = G_PTR_ADD %t1, G_CONSTANT imm2
1755 // -->
1756 // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2)
1757
1758 if (MI.getOpcode() != TargetOpcode::G_PTR_ADD)
1759 return false;
1760
1761 Register Add2 = MI.getOperand(i: 1).getReg();
1762 Register Imm1 = MI.getOperand(i: 2).getReg();
1763 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI);
1764 if (!MaybeImmVal)
1765 return false;
1766
1767 MachineInstr *Add2Def = MRI.getVRegDef(Reg: Add2);
1768 if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD)
1769 return false;
1770
1771 Register Base = Add2Def->getOperand(i: 1).getReg();
1772 Register Imm2 = Add2Def->getOperand(i: 2).getReg();
1773 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI);
1774 if (!MaybeImm2Val)
1775 return false;
1776
1777 // Check if the new combined immediate forms an illegal addressing mode.
1778 // Do not combine if it was legal before but would get illegal.
1779 // To do so, we need to find a load/store user of the pointer to get
1780 // the access type.
1781 Type *AccessTy = nullptr;
1782 auto &MF = *MI.getMF();
1783 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: MI.getOperand(i: 0).getReg())) {
1784 if (auto *LdSt = dyn_cast<GLoadStore>(Val: &UseMI)) {
1785 AccessTy = getTypeForLLT(Ty: MRI.getType(Reg: LdSt->getReg(Idx: 0)),
1786 C&: MF.getFunction().getContext());
1787 break;
1788 }
1789 }
1790 TargetLoweringBase::AddrMode AMNew;
1791 APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value;
1792 AMNew.BaseOffs = CombinedImm.getSExtValue();
1793 if (AccessTy) {
1794 AMNew.HasBaseReg = true;
1795 TargetLoweringBase::AddrMode AMOld;
1796 AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue();
1797 AMOld.HasBaseReg = true;
1798 unsigned AS = MRI.getType(Reg: Add2).getAddressSpace();
1799 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1800 if (TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMOld, Ty: AccessTy, AddrSpace: AS) &&
1801 !TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMNew, Ty: AccessTy, AddrSpace: AS))
1802 return false;
1803 }
1804
1805 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
1806 // inbounds, reaching the same result in one G_PTR_ADD is also inbounds.
1807 // The nusw constraints are satisfied because imm1+imm2 cannot exceed the
1808 // largest signed integer that fits into the index type, which is the maximum
1809 // size of allocated objects according to the IR Language Reference.
1810 unsigned PtrAddFlags = MI.getFlags();
1811 unsigned LHSPtrAddFlags = Add2Def->getFlags();
1812 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
1813 bool IsInBounds =
1814 PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::InBounds;
1815 unsigned Flags = 0;
1816 if (IsNoUWrap)
1817 Flags |= MachineInstr::MIFlag::NoUWrap;
1818 if (IsInBounds) {
1819 Flags |= MachineInstr::MIFlag::InBounds;
1820 Flags |= MachineInstr::MIFlag::NoUSWrap;
1821 }
1822
1823 // Pass the combined immediate to the apply function.
1824 MatchInfo.Imm = AMNew.BaseOffs;
1825 MatchInfo.Base = Base;
1826 MatchInfo.Bank = getRegBank(Reg: Imm2);
1827 MatchInfo.Flags = Flags;
1828 return true;
1829}
1830
1831void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI,
1832 PtrAddChain &MatchInfo) const {
1833 assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD");
1834 MachineIRBuilder MIB(MI);
1835 LLT OffsetTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
1836 auto NewOffset = MIB.buildConstant(Res: OffsetTy, Val: MatchInfo.Imm);
1837 setRegBank(Reg: NewOffset.getReg(Idx: 0), RegBank: MatchInfo.Bank);
1838 Observer.changingInstr(MI);
1839 MI.getOperand(i: 1).setReg(MatchInfo.Base);
1840 MI.getOperand(i: 2).setReg(NewOffset.getReg(Idx: 0));
1841 MI.setFlags(MatchInfo.Flags);
1842 Observer.changedInstr(MI);
1843}
1844
1845bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI,
1846 RegisterImmPair &MatchInfo) const {
1847 // We're trying to match the following pattern with any of
1848 // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions:
1849 // %t1 = SHIFT %base, G_CONSTANT imm1
1850 // %root = SHIFT %t1, G_CONSTANT imm2
1851 // -->
1852 // %root = SHIFT %base, G_CONSTANT (imm1 + imm2)
1853
1854 unsigned Opcode = MI.getOpcode();
1855 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1856 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1857 Opcode == TargetOpcode::G_USHLSAT) &&
1858 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1859
1860 Register Shl2 = MI.getOperand(i: 1).getReg();
1861 Register Imm1 = MI.getOperand(i: 2).getReg();
1862 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI);
1863 if (!MaybeImmVal)
1864 return false;
1865
1866 MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Reg: Shl2);
1867 if (Shl2Def->getOpcode() != Opcode)
1868 return false;
1869
1870 Register Base = Shl2Def->getOperand(i: 1).getReg();
1871 Register Imm2 = Shl2Def->getOperand(i: 2).getReg();
1872 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI);
1873 if (!MaybeImm2Val)
1874 return false;
1875
1876 // Pass the combined immediate to the apply function.
1877 MatchInfo.Imm =
1878 (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue();
1879 MatchInfo.Reg = Base;
1880
1881 // There is no simple replacement for a saturating unsigned left shift that
1882 // exceeds the scalar size.
1883 if (Opcode == TargetOpcode::G_USHLSAT &&
1884 MatchInfo.Imm >= MRI.getType(Reg: Shl2).getScalarSizeInBits())
1885 return false;
1886
1887 return true;
1888}
1889
1890void CombinerHelper::applyShiftImmedChain(MachineInstr &MI,
1891 RegisterImmPair &MatchInfo) const {
1892 unsigned Opcode = MI.getOpcode();
1893 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1894 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1895 Opcode == TargetOpcode::G_USHLSAT) &&
1896 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1897
1898 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
1899 unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits();
1900 auto Imm = MatchInfo.Imm;
1901
1902 if (Imm >= ScalarSizeInBits) {
1903 // Any logical shift that exceeds scalar size will produce zero.
1904 if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) {
1905 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: 0);
1906 MI.eraseFromParent();
1907 return;
1908 }
1909 // Arithmetic shift and saturating signed left shift have no effect beyond
1910 // scalar size.
1911 Imm = ScalarSizeInBits - 1;
1912 }
1913
1914 LLT ImmTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
1915 Register NewImm = Builder.buildConstant(Res: ImmTy, Val: Imm).getReg(Idx: 0);
1916 Observer.changingInstr(MI);
1917 MI.getOperand(i: 1).setReg(MatchInfo.Reg);
1918 MI.getOperand(i: 2).setReg(NewImm);
1919 Observer.changedInstr(MI);
1920}
1921
1922bool CombinerHelper::matchShiftOfShiftedLogic(
1923 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const {
1924 // We're trying to match the following pattern with any of
1925 // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination
1926 // with any of G_AND/G_OR/G_XOR logic instructions.
1927 // %t1 = SHIFT %X, G_CONSTANT C0
1928 // %t2 = LOGIC %t1, %Y
1929 // %root = SHIFT %t2, G_CONSTANT C1
1930 // -->
1931 // %t3 = SHIFT %X, G_CONSTANT (C0+C1)
1932 // %t4 = SHIFT %Y, G_CONSTANT C1
1933 // %root = LOGIC %t3, %t4
1934 unsigned ShiftOpcode = MI.getOpcode();
1935 assert((ShiftOpcode == TargetOpcode::G_SHL ||
1936 ShiftOpcode == TargetOpcode::G_ASHR ||
1937 ShiftOpcode == TargetOpcode::G_LSHR ||
1938 ShiftOpcode == TargetOpcode::G_USHLSAT ||
1939 ShiftOpcode == TargetOpcode::G_SSHLSAT) &&
1940 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1941
1942 // Match a one-use bitwise logic op.
1943 Register LogicDest = MI.getOperand(i: 1).getReg();
1944 if (!MRI.hasOneNonDBGUse(RegNo: LogicDest))
1945 return false;
1946
1947 MachineInstr *LogicMI = MRI.getUniqueVRegDef(Reg: LogicDest);
1948 unsigned LogicOpcode = LogicMI->getOpcode();
1949 if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR &&
1950 LogicOpcode != TargetOpcode::G_XOR)
1951 return false;
1952
1953 // Find a matching one-use shift by constant.
1954 const Register C1 = MI.getOperand(i: 2).getReg();
1955 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: C1, MRI);
1956 if (!MaybeImmVal || MaybeImmVal->Value == 0)
1957 return false;
1958
1959 const uint64_t C1Val = MaybeImmVal->Value.getZExtValue();
1960
1961 auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) {
1962 // Shift should match previous one and should be a one-use.
1963 if (MI->getOpcode() != ShiftOpcode ||
1964 !MRI.hasOneNonDBGUse(RegNo: MI->getOperand(i: 0).getReg()))
1965 return false;
1966
1967 // Must be a constant.
1968 auto MaybeImmVal =
1969 getIConstantVRegValWithLookThrough(VReg: MI->getOperand(i: 2).getReg(), MRI);
1970 if (!MaybeImmVal)
1971 return false;
1972
1973 ShiftVal = MaybeImmVal->Value.getSExtValue();
1974 return true;
1975 };
1976
1977 // Logic ops are commutative, so check each operand for a match.
1978 Register LogicMIReg1 = LogicMI->getOperand(i: 1).getReg();
1979 MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(Reg: LogicMIReg1);
1980 Register LogicMIReg2 = LogicMI->getOperand(i: 2).getReg();
1981 MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(Reg: LogicMIReg2);
1982 uint64_t C0Val;
1983
1984 if (matchFirstShift(LogicMIOp1, C0Val)) {
1985 MatchInfo.LogicNonShiftReg = LogicMIReg2;
1986 MatchInfo.Shift2 = LogicMIOp1;
1987 } else if (matchFirstShift(LogicMIOp2, C0Val)) {
1988 MatchInfo.LogicNonShiftReg = LogicMIReg1;
1989 MatchInfo.Shift2 = LogicMIOp2;
1990 } else
1991 return false;
1992
1993 MatchInfo.ValSum = C0Val + C1Val;
1994
1995 // The fold is not valid if the sum of the shift values exceeds bitwidth.
1996 if (MatchInfo.ValSum >= MRI.getType(Reg: LogicDest).getScalarSizeInBits())
1997 return false;
1998
1999 MatchInfo.Logic = LogicMI;
2000 return true;
2001}
2002
2003void CombinerHelper::applyShiftOfShiftedLogic(
2004 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const {
2005 unsigned Opcode = MI.getOpcode();
2006 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
2007 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT ||
2008 Opcode == TargetOpcode::G_SSHLSAT) &&
2009 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
2010
2011 LLT ShlType = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
2012 LLT DestType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2013
2014 Register Const = Builder.buildConstant(Res: ShlType, Val: MatchInfo.ValSum).getReg(Idx: 0);
2015
2016 Register Shift1Base = MatchInfo.Shift2->getOperand(i: 1).getReg();
2017 Register Shift1 =
2018 Builder.buildInstr(Opc: Opcode, DstOps: {DestType}, SrcOps: {Shift1Base, Const}).getReg(Idx: 0);
2019
2020 // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same
2021 // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when
2022 // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we
2023 // remove old shift1. And it will cause crash later. So erase it earlier to
2024 // avoid the crash.
2025 MatchInfo.Shift2->eraseFromParent();
2026
2027 Register Shift2Const = MI.getOperand(i: 2).getReg();
2028 Register Shift2 = Builder
2029 .buildInstr(Opc: Opcode, DstOps: {DestType},
2030 SrcOps: {MatchInfo.LogicNonShiftReg, Shift2Const})
2031 .getReg(Idx: 0);
2032
2033 Register Dest = MI.getOperand(i: 0).getReg();
2034 Builder.buildInstr(Opc: MatchInfo.Logic->getOpcode(), DstOps: {Dest}, SrcOps: {Shift1, Shift2});
2035
2036 // This was one use so it's safe to remove it.
2037 MatchInfo.Logic->eraseFromParent();
2038
2039 MI.eraseFromParent();
2040}
2041
2042bool CombinerHelper::matchCommuteShift(MachineInstr &MI,
2043 BuildFnTy &MatchInfo) const {
2044 assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL");
2045 // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
2046 // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
2047 auto &Shl = cast<GenericMachineInstr>(Val&: MI);
2048 Register DstReg = Shl.getReg(Idx: 0);
2049 Register SrcReg = Shl.getReg(Idx: 1);
2050 Register ShiftReg = Shl.getReg(Idx: 2);
2051 Register X, C1;
2052
2053 if (!getTargetLowering().isDesirableToCommuteWithShift(MI, IsAfterLegal: !isPreLegalize()))
2054 return false;
2055
2056 if (!mi_match(R: SrcReg, MRI,
2057 P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: C1)),
2058 preds: m_GOr(L: m_Reg(R&: X), R: m_Reg(R&: C1))))))
2059 return false;
2060
2061 APInt C1Val, C2Val;
2062 if (!mi_match(R: C1, MRI, P: m_ICstOrSplat(Cst&: C1Val)) ||
2063 !mi_match(R: ShiftReg, MRI, P: m_ICstOrSplat(Cst&: C2Val)))
2064 return false;
2065
2066 auto *SrcDef = MRI.getVRegDef(Reg: SrcReg);
2067 assert((SrcDef->getOpcode() == TargetOpcode::G_ADD ||
2068 SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op");
2069 LLT SrcTy = MRI.getType(Reg: SrcReg);
2070 MatchInfo = [=](MachineIRBuilder &B) {
2071 auto S1 = B.buildShl(Dst: SrcTy, Src0: X, Src1: ShiftReg);
2072 auto S2 = B.buildShl(Dst: SrcTy, Src0: C1, Src1: ShiftReg);
2073 B.buildInstr(Opc: SrcDef->getOpcode(), DstOps: {DstReg}, SrcOps: {S1, S2});
2074 };
2075 return true;
2076}
2077
2078bool CombinerHelper::matchLshrOfTruncOfLshr(MachineInstr &MI,
2079 LshrOfTruncOfLshr &MatchInfo,
2080 MachineInstr &ShiftMI) const {
2081 assert(MI.getOpcode() == TargetOpcode::G_LSHR && "Expected a G_LSHR");
2082
2083 Register N0 = MI.getOperand(i: 1).getReg();
2084 Register N1 = MI.getOperand(i: 2).getReg();
2085 unsigned OpSizeInBits = MRI.getType(Reg: N0).getScalarSizeInBits();
2086
2087 APInt N1C, N001C;
2088 if (!mi_match(R: N1, MRI, P: m_ICstOrSplat(Cst&: N1C)))
2089 return false;
2090 auto N001 = ShiftMI.getOperand(i: 2).getReg();
2091 if (!mi_match(R: N001, MRI, P: m_ICstOrSplat(Cst&: N001C)))
2092 return false;
2093
2094 if (N001C.getBitWidth() > N1C.getBitWidth())
2095 N1C = N1C.zext(width: N001C.getBitWidth());
2096 else
2097 N001C = N001C.zext(width: N1C.getBitWidth());
2098
2099 Register InnerShift = ShiftMI.getOperand(i: 0).getReg();
2100 LLT InnerShiftTy = MRI.getType(Reg: InnerShift);
2101 uint64_t InnerShiftSize = InnerShiftTy.getScalarSizeInBits();
2102 if ((N1C + N001C).ult(RHS: InnerShiftSize)) {
2103 MatchInfo.Src = ShiftMI.getOperand(i: 1).getReg();
2104 MatchInfo.ShiftAmt = N1C + N001C;
2105 MatchInfo.ShiftAmtTy = MRI.getType(Reg: N001);
2106 MatchInfo.InnerShiftTy = InnerShiftTy;
2107
2108 if ((N001C + OpSizeInBits) == InnerShiftSize)
2109 return true;
2110 if (MRI.hasOneUse(RegNo: N0) && MRI.hasOneUse(RegNo: InnerShift)) {
2111 MatchInfo.Mask = true;
2112 MatchInfo.MaskVal = APInt(N1C.getBitWidth(), OpSizeInBits) - N1C;
2113 return true;
2114 }
2115 }
2116 return false;
2117}
2118
2119void CombinerHelper::applyLshrOfTruncOfLshr(
2120 MachineInstr &MI, LshrOfTruncOfLshr &MatchInfo) const {
2121 assert(MI.getOpcode() == TargetOpcode::G_LSHR && "Expected a G_LSHR");
2122
2123 Register Dst = MI.getOperand(i: 0).getReg();
2124 auto ShiftAmt =
2125 Builder.buildConstant(Res: MatchInfo.ShiftAmtTy, Val: MatchInfo.ShiftAmt);
2126 auto Shift =
2127 Builder.buildLShr(Dst: MatchInfo.InnerShiftTy, Src0: MatchInfo.Src, Src1: ShiftAmt);
2128 if (MatchInfo.Mask == true) {
2129 APInt MaskVal =
2130 APInt::getLowBitsSet(numBits: MatchInfo.InnerShiftTy.getScalarSizeInBits(),
2131 loBitsSet: MatchInfo.MaskVal.getZExtValue());
2132 auto Mask = Builder.buildConstant(Res: MatchInfo.InnerShiftTy, Val: MaskVal);
2133 auto And = Builder.buildAnd(Dst: MatchInfo.InnerShiftTy, Src0: Shift, Src1: Mask);
2134 Builder.buildTrunc(Res: Dst, Op: And);
2135 } else
2136 Builder.buildTrunc(Res: Dst, Op: Shift);
2137 MI.eraseFromParent();
2138}
2139
2140bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI,
2141 unsigned &ShiftVal) const {
2142 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2143 auto MaybeImmVal =
2144 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
2145 if (!MaybeImmVal)
2146 return false;
2147
2148 ShiftVal = MaybeImmVal->Value.exactLogBase2();
2149 return (static_cast<int32_t>(ShiftVal) != -1);
2150}
2151
2152void CombinerHelper::applyCombineMulToShl(MachineInstr &MI,
2153 unsigned &ShiftVal) const {
2154 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2155 MachineIRBuilder MIB(MI);
2156 LLT ShiftTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2157 auto ShiftCst = MIB.buildConstant(Res: ShiftTy, Val: ShiftVal);
2158 Observer.changingInstr(MI);
2159 MI.setDesc(MIB.getTII().get(Opcode: TargetOpcode::G_SHL));
2160 MI.getOperand(i: 2).setReg(ShiftCst.getReg(Idx: 0));
2161 if (ShiftVal == ShiftTy.getScalarSizeInBits() - 1)
2162 MI.clearFlag(Flag: MachineInstr::MIFlag::NoSWrap);
2163 Observer.changedInstr(MI);
2164}
2165
2166bool CombinerHelper::matchCombineSubToAdd(MachineInstr &MI,
2167 BuildFnTy &MatchInfo) const {
2168 GSub &Sub = cast<GSub>(Val&: MI);
2169
2170 LLT Ty = MRI.getType(Reg: Sub.getReg(Idx: 0));
2171
2172 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {Ty}}))
2173 return false;
2174
2175 if (!isConstantLegalOrBeforeLegalizer(Ty))
2176 return false;
2177
2178 APInt Imm = getIConstantFromReg(VReg: Sub.getRHSReg(), MRI);
2179
2180 MatchInfo = [=, &MI](MachineIRBuilder &B) {
2181 auto NegCst = B.buildConstant(Res: Ty, Val: -Imm);
2182 Observer.changingInstr(MI);
2183 MI.setDesc(B.getTII().get(Opcode: TargetOpcode::G_ADD));
2184 MI.getOperand(i: 2).setReg(NegCst.getReg(Idx: 0));
2185 MI.clearFlag(Flag: MachineInstr::MIFlag::NoUWrap);
2186 if (Imm.isMinSignedValue())
2187 MI.clearFlags(flags: MachineInstr::MIFlag::NoSWrap);
2188 Observer.changedInstr(MI);
2189 };
2190 return true;
2191}
2192
2193// shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source
2194bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI,
2195 RegisterImmPair &MatchData) const {
2196 assert(MI.getOpcode() == TargetOpcode::G_SHL && VT);
2197 if (!getTargetLowering().isDesirableToPullExtFromShl(MI))
2198 return false;
2199
2200 Register LHS = MI.getOperand(i: 1).getReg();
2201
2202 Register ExtSrc;
2203 if (!mi_match(R: LHS, MRI, P: m_GAnyExt(Src: m_Reg(R&: ExtSrc))) &&
2204 !mi_match(R: LHS, MRI, P: m_GZExt(Src: m_Reg(R&: ExtSrc))) &&
2205 !mi_match(R: LHS, MRI, P: m_GSExt(Src: m_Reg(R&: ExtSrc))))
2206 return false;
2207
2208 Register RHS = MI.getOperand(i: 2).getReg();
2209 MachineInstr *MIShiftAmt = MRI.getVRegDef(Reg: RHS);
2210 auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(MI&: *MIShiftAmt, MRI);
2211 if (!MaybeShiftAmtVal)
2212 return false;
2213
2214 if (LI) {
2215 LLT SrcTy = MRI.getType(Reg: ExtSrc);
2216
2217 // We only really care about the legality with the shifted value. We can
2218 // pick any type the constant shift amount, so ask the target what to
2219 // use. Otherwise we would have to guess and hope it is reported as legal.
2220 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: SrcTy);
2221 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}}))
2222 return false;
2223 }
2224
2225 int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue();
2226 MatchData.Reg = ExtSrc;
2227 MatchData.Imm = ShiftAmt;
2228
2229 unsigned MinLeadingZeros = VT->getKnownZeroes(R: ExtSrc).countl_one();
2230 unsigned SrcTySize = MRI.getType(Reg: ExtSrc).getScalarSizeInBits();
2231 return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize;
2232}
2233
2234void CombinerHelper::applyCombineShlOfExtend(
2235 MachineInstr &MI, const RegisterImmPair &MatchData) const {
2236 Register ExtSrcReg = MatchData.Reg;
2237 int64_t ShiftAmtVal = MatchData.Imm;
2238
2239 LLT ExtSrcTy = MRI.getType(Reg: ExtSrcReg);
2240 auto ShiftAmt = Builder.buildConstant(Res: ExtSrcTy, Val: ShiftAmtVal);
2241 auto NarrowShift =
2242 Builder.buildShl(Dst: ExtSrcTy, Src0: ExtSrcReg, Src1: ShiftAmt, Flags: MI.getFlags());
2243 Builder.buildZExt(Res: MI.getOperand(i: 0), Op: NarrowShift);
2244 MI.eraseFromParent();
2245}
2246
2247bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI,
2248 Register &MatchInfo) const {
2249 GMerge &Merge = cast<GMerge>(Val&: MI);
2250 SmallVector<Register, 16> MergedValues;
2251 for (unsigned I = 0; I < Merge.getNumSources(); ++I)
2252 MergedValues.emplace_back(Args: Merge.getSourceReg(I));
2253
2254 auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: MergedValues[0], MRI);
2255 if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources())
2256 return false;
2257
2258 for (unsigned I = 0; I < MergedValues.size(); ++I)
2259 if (MergedValues[I] != Unmerge->getReg(Idx: I))
2260 return false;
2261
2262 MatchInfo = Unmerge->getSourceReg();
2263 return true;
2264}
2265
2266static Register peekThroughBitcast(Register Reg,
2267 const MachineRegisterInfo &MRI) {
2268 while (mi_match(R: Reg, MRI, P: m_GBitcast(Src: m_Reg(R&: Reg))))
2269 ;
2270
2271 return Reg;
2272}
2273
2274bool CombinerHelper::matchCombineUnmergeMergeToPlainValues(
2275 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const {
2276 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2277 "Expected an unmerge");
2278 auto &Unmerge = cast<GUnmerge>(Val&: MI);
2279 Register SrcReg = peekThroughBitcast(Reg: Unmerge.getSourceReg(), MRI);
2280
2281 auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(Reg: SrcReg, MRI);
2282 if (!SrcInstr)
2283 return false;
2284
2285 // Check the source type of the merge.
2286 LLT SrcMergeTy = MRI.getType(Reg: SrcInstr->getSourceReg(I: 0));
2287 LLT Dst0Ty = MRI.getType(Reg: Unmerge.getReg(Idx: 0));
2288 bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits();
2289 if (SrcMergeTy != Dst0Ty && !SameSize)
2290 return false;
2291 // They are the same now (modulo a bitcast).
2292 // We can collect all the src registers.
2293 for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx)
2294 Operands.push_back(Elt: SrcInstr->getSourceReg(I: Idx));
2295 return true;
2296}
2297
2298void CombinerHelper::applyCombineUnmergeMergeToPlainValues(
2299 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const {
2300 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2301 "Expected an unmerge");
2302 assert((MI.getNumOperands() - 1 == Operands.size()) &&
2303 "Not enough operands to replace all defs");
2304 unsigned NumElems = MI.getNumOperands() - 1;
2305
2306 LLT SrcTy = MRI.getType(Reg: Operands[0]);
2307 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2308 bool CanReuseInputDirectly = DstTy == SrcTy;
2309 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2310 Register DstReg = MI.getOperand(i: Idx).getReg();
2311 Register SrcReg = Operands[Idx];
2312
2313 // This combine may run after RegBankSelect, so we need to be aware of
2314 // register banks.
2315 const auto &DstCB = MRI.getRegClassOrRegBank(Reg: DstReg);
2316 if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(Reg: SrcReg)) {
2317 SrcReg = Builder.buildCopy(Res: MRI.getType(Reg: SrcReg), Op: SrcReg).getReg(Idx: 0);
2318 MRI.setRegClassOrRegBank(Reg: SrcReg, RCOrRB: DstCB);
2319 }
2320
2321 if (CanReuseInputDirectly)
2322 replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg);
2323 else
2324 Builder.buildCast(Dst: DstReg, Src: SrcReg);
2325 }
2326 MI.eraseFromParent();
2327}
2328
2329bool CombinerHelper::matchCombineUnmergeConstant(
2330 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const {
2331 unsigned SrcIdx = MI.getNumOperands() - 1;
2332 Register SrcReg = MI.getOperand(i: SrcIdx).getReg();
2333 MachineInstr *SrcInstr = MRI.getVRegDef(Reg: SrcReg);
2334 if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT &&
2335 SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT)
2336 return false;
2337 // Break down the big constant in smaller ones.
2338 const MachineOperand &CstVal = SrcInstr->getOperand(i: 1);
2339 APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT
2340 ? CstVal.getCImm()->getValue()
2341 : CstVal.getFPImm()->getValueAPF().bitcastToAPInt();
2342
2343 LLT Dst0Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2344 unsigned ShiftAmt = Dst0Ty.getSizeInBits();
2345 // Unmerge a constant.
2346 for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) {
2347 Csts.emplace_back(Args: Val.trunc(width: ShiftAmt));
2348 Val = Val.lshr(shiftAmt: ShiftAmt);
2349 }
2350
2351 return true;
2352}
2353
2354void CombinerHelper::applyCombineUnmergeConstant(
2355 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const {
2356 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2357 "Expected an unmerge");
2358 assert((MI.getNumOperands() - 1 == Csts.size()) &&
2359 "Not enough operands to replace all defs");
2360 unsigned NumElems = MI.getNumOperands() - 1;
2361 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2362 Register DstReg = MI.getOperand(i: Idx).getReg();
2363 Builder.buildConstant(Res: DstReg, Val: Csts[Idx]);
2364 }
2365
2366 MI.eraseFromParent();
2367}
2368
2369bool CombinerHelper::matchCombineUnmergeUndef(
2370 MachineInstr &MI,
2371 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
2372 unsigned SrcIdx = MI.getNumOperands() - 1;
2373 Register SrcReg = MI.getOperand(i: SrcIdx).getReg();
2374 MatchInfo = [&MI](MachineIRBuilder &B) {
2375 unsigned NumElems = MI.getNumOperands() - 1;
2376 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2377 Register DstReg = MI.getOperand(i: Idx).getReg();
2378 B.buildUndef(Res: DstReg);
2379 }
2380 };
2381 return isa<GImplicitDef>(Val: MRI.getVRegDef(Reg: SrcReg));
2382}
2383
2384bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc(
2385 MachineInstr &MI) const {
2386 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2387 "Expected an unmerge");
2388 if (MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector() ||
2389 MRI.getType(Reg: MI.getOperand(i: MI.getNumDefs()).getReg()).isVector())
2390 return false;
2391 // Check that all the lanes are dead except the first one.
2392 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2393 if (!MRI.use_nodbg_empty(RegNo: MI.getOperand(i: Idx).getReg()))
2394 return false;
2395 }
2396 return true;
2397}
2398
2399void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc(
2400 MachineInstr &MI) const {
2401 Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg();
2402 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2403 Builder.buildTrunc(Res: Dst0Reg, Op: SrcReg);
2404 MI.eraseFromParent();
2405}
2406
2407bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) const {
2408 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2409 "Expected an unmerge");
2410 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2411 LLT Dst0Ty = MRI.getType(Reg: Dst0Reg);
2412 // G_ZEXT on vector applies to each lane, so it will
2413 // affect all destinations. Therefore we won't be able
2414 // to simplify the unmerge to just the first definition.
2415 if (Dst0Ty.isVector())
2416 return false;
2417 Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg();
2418 LLT SrcTy = MRI.getType(Reg: SrcReg);
2419 if (SrcTy.isVector())
2420 return false;
2421
2422 Register ZExtSrcReg;
2423 if (!mi_match(R: SrcReg, MRI, P: m_GZExt(Src: m_Reg(R&: ZExtSrcReg))))
2424 return false;
2425
2426 // Finally we can replace the first definition with
2427 // a zext of the source if the definition is big enough to hold
2428 // all of ZExtSrc bits.
2429 LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg);
2430 return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits();
2431}
2432
2433void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) const {
2434 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2435 "Expected an unmerge");
2436
2437 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2438
2439 MachineInstr *ZExtInstr =
2440 MRI.getVRegDef(Reg: MI.getOperand(i: MI.getNumDefs()).getReg());
2441 assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT &&
2442 "Expecting a G_ZEXT");
2443
2444 Register ZExtSrcReg = ZExtInstr->getOperand(i: 1).getReg();
2445 LLT Dst0Ty = MRI.getType(Reg: Dst0Reg);
2446 LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg);
2447
2448 if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) {
2449 Builder.buildZExt(Res: Dst0Reg, Op: ZExtSrcReg);
2450 } else {
2451 assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() &&
2452 "ZExt src doesn't fit in destination");
2453 replaceRegWith(MRI, FromReg: Dst0Reg, ToReg: ZExtSrcReg);
2454 }
2455
2456 Register ZeroReg;
2457 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2458 if (!ZeroReg)
2459 ZeroReg = Builder.buildConstant(Res: Dst0Ty, Val: 0).getReg(Idx: 0);
2460 replaceRegWith(MRI, FromReg: MI.getOperand(i: Idx).getReg(), ToReg: ZeroReg);
2461 }
2462 MI.eraseFromParent();
2463}
2464
2465bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI,
2466 unsigned TargetShiftSize,
2467 unsigned &ShiftVal) const {
2468 assert((MI.getOpcode() == TargetOpcode::G_SHL ||
2469 MI.getOpcode() == TargetOpcode::G_LSHR ||
2470 MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift");
2471
2472 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2473 if (Ty.isVector()) // TODO:
2474 return false;
2475
2476 // Don't narrow further than the requested size.
2477 unsigned Size = Ty.getSizeInBits();
2478 if (Size <= TargetShiftSize)
2479 return false;
2480
2481 auto MaybeImmVal =
2482 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
2483 if (!MaybeImmVal)
2484 return false;
2485
2486 ShiftVal = MaybeImmVal->Value.getSExtValue();
2487 return ShiftVal >= Size / 2 && ShiftVal < Size;
2488}
2489
2490void CombinerHelper::applyCombineShiftToUnmerge(
2491 MachineInstr &MI, const unsigned &ShiftVal) const {
2492 Register DstReg = MI.getOperand(i: 0).getReg();
2493 Register SrcReg = MI.getOperand(i: 1).getReg();
2494 LLT Ty = MRI.getType(Reg: SrcReg);
2495 unsigned Size = Ty.getSizeInBits();
2496 unsigned HalfSize = Size / 2;
2497 assert(ShiftVal >= HalfSize);
2498
2499 LLT HalfTy = LLT::scalar(SizeInBits: HalfSize);
2500
2501 auto Unmerge = Builder.buildUnmerge(Res: HalfTy, Op: SrcReg);
2502 unsigned NarrowShiftAmt = ShiftVal - HalfSize;
2503
2504 if (MI.getOpcode() == TargetOpcode::G_LSHR) {
2505 Register Narrowed = Unmerge.getReg(Idx: 1);
2506
2507 // dst = G_LSHR s64:x, C for C >= 32
2508 // =>
2509 // lo, hi = G_UNMERGE_VALUES x
2510 // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0
2511
2512 if (NarrowShiftAmt != 0) {
2513 Narrowed = Builder.buildLShr(Dst: HalfTy, Src0: Narrowed,
2514 Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0);
2515 }
2516
2517 auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0);
2518 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Narrowed, Zero});
2519 } else if (MI.getOpcode() == TargetOpcode::G_SHL) {
2520 Register Narrowed = Unmerge.getReg(Idx: 0);
2521 // dst = G_SHL s64:x, C for C >= 32
2522 // =>
2523 // lo, hi = G_UNMERGE_VALUES x
2524 // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32)
2525 if (NarrowShiftAmt != 0) {
2526 Narrowed = Builder.buildShl(Dst: HalfTy, Src0: Narrowed,
2527 Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0);
2528 }
2529
2530 auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0);
2531 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Zero, Narrowed});
2532 } else {
2533 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
2534 auto Hi = Builder.buildAShr(
2535 Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1),
2536 Src1: Builder.buildConstant(Res: HalfTy, Val: HalfSize - 1));
2537
2538 if (ShiftVal == HalfSize) {
2539 // (G_ASHR i64:x, 32) ->
2540 // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31)
2541 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Unmerge.getReg(Idx: 1), Hi});
2542 } else if (ShiftVal == Size - 1) {
2543 // Don't need a second shift.
2544 // (G_ASHR i64:x, 63) ->
2545 // %narrowed = (G_ASHR hi_32(x), 31)
2546 // G_MERGE_VALUES %narrowed, %narrowed
2547 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Hi, Hi});
2548 } else {
2549 auto Lo = Builder.buildAShr(
2550 Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1),
2551 Src1: Builder.buildConstant(Res: HalfTy, Val: ShiftVal - HalfSize));
2552
2553 // (G_ASHR i64:x, C) ->, for C >= 32
2554 // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31)
2555 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Lo, Hi});
2556 }
2557 }
2558
2559 MI.eraseFromParent();
2560}
2561
2562bool CombinerHelper::tryCombineShiftToUnmerge(
2563 MachineInstr &MI, unsigned TargetShiftAmount) const {
2564 unsigned ShiftAmt;
2565 if (matchCombineShiftToUnmerge(MI, TargetShiftSize: TargetShiftAmount, ShiftVal&: ShiftAmt)) {
2566 applyCombineShiftToUnmerge(MI, ShiftVal: ShiftAmt);
2567 return true;
2568 }
2569
2570 return false;
2571}
2572
2573bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI,
2574 Register &Reg) const {
2575 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2576 Register DstReg = MI.getOperand(i: 0).getReg();
2577 LLT DstTy = MRI.getType(Reg: DstReg);
2578 Register SrcReg = MI.getOperand(i: 1).getReg();
2579 return mi_match(R: SrcReg, MRI,
2580 P: m_GPtrToInt(Src: m_all_of(preds: m_SpecificType(Ty: DstTy), preds: m_Reg(R&: Reg))));
2581}
2582
2583void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI,
2584 Register &Reg) const {
2585 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2586 Register DstReg = MI.getOperand(i: 0).getReg();
2587 Builder.buildCopy(Res: DstReg, Op: Reg);
2588 MI.eraseFromParent();
2589}
2590
2591void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI,
2592 Register &Reg) const {
2593 assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT");
2594 Register DstReg = MI.getOperand(i: 0).getReg();
2595 Builder.buildZExtOrTrunc(Res: DstReg, Op: Reg);
2596 MI.eraseFromParent();
2597}
2598
2599bool CombinerHelper::matchCombineAddP2IToPtrAdd(
2600 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const {
2601 assert(MI.getOpcode() == TargetOpcode::G_ADD);
2602 Register LHS = MI.getOperand(i: 1).getReg();
2603 Register RHS = MI.getOperand(i: 2).getReg();
2604 LLT IntTy = MRI.getType(Reg: LHS);
2605
2606 // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the
2607 // instruction.
2608 PtrReg.second = false;
2609 for (Register SrcReg : {LHS, RHS}) {
2610 if (mi_match(R: SrcReg, MRI, P: m_GPtrToInt(Src: m_Reg(R&: PtrReg.first)))) {
2611 // Don't handle cases where the integer is implicitly converted to the
2612 // pointer width.
2613 LLT PtrTy = MRI.getType(Reg: PtrReg.first);
2614 if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits())
2615 return true;
2616 }
2617
2618 PtrReg.second = true;
2619 }
2620
2621 return false;
2622}
2623
2624void CombinerHelper::applyCombineAddP2IToPtrAdd(
2625 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const {
2626 Register Dst = MI.getOperand(i: 0).getReg();
2627 Register LHS = MI.getOperand(i: 1).getReg();
2628 Register RHS = MI.getOperand(i: 2).getReg();
2629
2630 const bool DoCommute = PtrReg.second;
2631 if (DoCommute)
2632 std::swap(a&: LHS, b&: RHS);
2633 LHS = PtrReg.first;
2634
2635 LLT PtrTy = MRI.getType(Reg: LHS);
2636
2637 auto PtrAdd = Builder.buildPtrAdd(Res: PtrTy, Op0: LHS, Op1: RHS);
2638 Builder.buildPtrToInt(Dst, Src: PtrAdd);
2639 MI.eraseFromParent();
2640}
2641
2642bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI,
2643 APInt &NewCst) const {
2644 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
2645 Register LHS = PtrAdd.getBaseReg();
2646 Register RHS = PtrAdd.getOffsetReg();
2647 MachineRegisterInfo &MRI = Builder.getMF().getRegInfo();
2648
2649 if (auto RHSCst = getIConstantVRegVal(VReg: RHS, MRI)) {
2650 APInt Cst;
2651 if (mi_match(R: LHS, MRI, P: m_GIntToPtr(Src: m_ICst(Cst)))) {
2652 auto DstTy = MRI.getType(Reg: PtrAdd.getReg(Idx: 0));
2653 // G_INTTOPTR uses zero-extension
2654 NewCst = Cst.zextOrTrunc(width: DstTy.getSizeInBits());
2655 NewCst += RHSCst->sextOrTrunc(width: DstTy.getSizeInBits());
2656 return true;
2657 }
2658 }
2659
2660 return false;
2661}
2662
2663void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI,
2664 APInt &NewCst) const {
2665 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
2666 Register Dst = PtrAdd.getReg(Idx: 0);
2667
2668 Builder.buildConstant(Res: Dst, Val: NewCst);
2669 PtrAdd.eraseFromParent();
2670}
2671
2672bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI,
2673 Register &Reg) const {
2674 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT");
2675 Register DstReg = MI.getOperand(i: 0).getReg();
2676 Register SrcReg = MI.getOperand(i: 1).getReg();
2677 Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI);
2678 if (OriginalSrcReg.isValid())
2679 SrcReg = OriginalSrcReg;
2680 LLT DstTy = MRI.getType(Reg: DstReg);
2681 return mi_match(R: SrcReg, MRI,
2682 P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))) &&
2683 canReplaceReg(DstReg, SrcReg: Reg, MRI);
2684}
2685
2686bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI,
2687 Register &Reg) const {
2688 assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT");
2689 Register DstReg = MI.getOperand(i: 0).getReg();
2690 Register SrcReg = MI.getOperand(i: 1).getReg();
2691 LLT DstTy = MRI.getType(Reg: DstReg);
2692 if (mi_match(R: SrcReg, MRI,
2693 P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))) &&
2694 canReplaceReg(DstReg, SrcReg: Reg, MRI)) {
2695 unsigned DstSize = DstTy.getScalarSizeInBits();
2696 unsigned SrcSize = MRI.getType(Reg: SrcReg).getScalarSizeInBits();
2697 return VT->getKnownBits(R: Reg).countMinLeadingZeros() >= DstSize - SrcSize;
2698 }
2699 return false;
2700}
2701
2702static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) {
2703 const unsigned ShiftSize = ShiftTy.getScalarSizeInBits();
2704 const unsigned TruncSize = TruncTy.getScalarSizeInBits();
2705
2706 // ShiftTy > 32 > TruncTy -> 32
2707 if (ShiftSize > 32 && TruncSize < 32)
2708 return ShiftTy.changeElementSize(NewEltSize: 32);
2709
2710 // TODO: We could also reduce to 16 bits, but that's more target-dependent.
2711 // Some targets like it, some don't, some only like it under certain
2712 // conditions/processor versions, etc.
2713 // A TL hook might be needed for this.
2714
2715 // Don't combine
2716 return ShiftTy;
2717}
2718
2719bool CombinerHelper::matchCombineTruncOfShift(
2720 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const {
2721 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2722 Register DstReg = MI.getOperand(i: 0).getReg();
2723 Register SrcReg = MI.getOperand(i: 1).getReg();
2724
2725 if (!MRI.hasOneNonDBGUse(RegNo: SrcReg))
2726 return false;
2727
2728 LLT SrcTy = MRI.getType(Reg: SrcReg);
2729 LLT DstTy = MRI.getType(Reg: DstReg);
2730
2731 MachineInstr *SrcMI = getDefIgnoringCopies(Reg: SrcReg, MRI);
2732 const auto &TL = getTargetLowering();
2733
2734 LLT NewShiftTy;
2735 switch (SrcMI->getOpcode()) {
2736 default:
2737 return false;
2738 case TargetOpcode::G_SHL: {
2739 NewShiftTy = DstTy;
2740
2741 // Make sure new shift amount is legal.
2742 KnownBits Known = VT->getKnownBits(R: SrcMI->getOperand(i: 2).getReg());
2743 if (Known.getMaxValue().uge(RHS: NewShiftTy.getScalarSizeInBits()))
2744 return false;
2745 break;
2746 }
2747 case TargetOpcode::G_LSHR:
2748 case TargetOpcode::G_ASHR: {
2749 // For right shifts, we conservatively do not do the transform if the TRUNC
2750 // has any STORE users. The reason is that if we change the type of the
2751 // shift, we may break the truncstore combine.
2752 //
2753 // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)).
2754 for (auto &User : MRI.use_instructions(Reg: DstReg))
2755 if (User.getOpcode() == TargetOpcode::G_STORE)
2756 return false;
2757
2758 NewShiftTy = getMidVTForTruncRightShiftCombine(ShiftTy: SrcTy, TruncTy: DstTy);
2759 if (NewShiftTy == SrcTy)
2760 return false;
2761
2762 // Make sure we won't lose information by truncating the high bits.
2763 KnownBits Known = VT->getKnownBits(R: SrcMI->getOperand(i: 2).getReg());
2764 if (Known.getMaxValue().ugt(RHS: NewShiftTy.getScalarSizeInBits() -
2765 DstTy.getScalarSizeInBits()))
2766 return false;
2767 break;
2768 }
2769 }
2770
2771 if (!isLegalOrBeforeLegalizer(
2772 Query: {SrcMI->getOpcode(),
2773 {NewShiftTy, TL.getPreferredShiftAmountTy(ShiftValueTy: NewShiftTy)}}))
2774 return false;
2775
2776 MatchInfo = std::make_pair(x&: SrcMI, y&: NewShiftTy);
2777 return true;
2778}
2779
2780void CombinerHelper::applyCombineTruncOfShift(
2781 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const {
2782 MachineInstr *ShiftMI = MatchInfo.first;
2783 LLT NewShiftTy = MatchInfo.second;
2784
2785 Register Dst = MI.getOperand(i: 0).getReg();
2786 LLT DstTy = MRI.getType(Reg: Dst);
2787
2788 Register ShiftAmt = ShiftMI->getOperand(i: 2).getReg();
2789 Register ShiftSrc = ShiftMI->getOperand(i: 1).getReg();
2790 ShiftSrc = Builder.buildTrunc(Res: NewShiftTy, Op: ShiftSrc).getReg(Idx: 0);
2791
2792 Register NewShift =
2793 Builder
2794 .buildInstr(Opc: ShiftMI->getOpcode(), DstOps: {NewShiftTy}, SrcOps: {ShiftSrc, ShiftAmt})
2795 .getReg(Idx: 0);
2796
2797 if (NewShiftTy == DstTy)
2798 replaceRegWith(MRI, FromReg: Dst, ToReg: NewShift);
2799 else
2800 Builder.buildTrunc(Res: Dst, Op: NewShift);
2801
2802 eraseInst(MI);
2803}
2804
2805bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) const {
2806 return any_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) {
2807 return MO.isReg() &&
2808 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
2809 });
2810}
2811
2812bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) const {
2813 return all_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) {
2814 return !MO.isReg() ||
2815 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
2816 });
2817}
2818
2819bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) const {
2820 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
2821 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
2822 return all_of(Range&: Mask, P: [](int Elt) { return Elt < 0; });
2823}
2824
2825bool CombinerHelper::matchUndefStore(MachineInstr &MI) const {
2826 assert(MI.getOpcode() == TargetOpcode::G_STORE);
2827 return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 0).getReg(),
2828 MRI);
2829}
2830
2831bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) const {
2832 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2833 return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 1).getReg(),
2834 MRI);
2835}
2836
2837bool CombinerHelper::matchInsertExtractVecEltOutOfBounds(
2838 MachineInstr &MI) const {
2839 assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT ||
2840 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) &&
2841 "Expected an insert/extract element op");
2842 LLT VecTy = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
2843 if (VecTy.isScalableVector())
2844 return false;
2845
2846 unsigned IdxIdx =
2847 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3;
2848 auto Idx = getIConstantVRegVal(VReg: MI.getOperand(i: IdxIdx).getReg(), MRI);
2849 if (!Idx)
2850 return false;
2851 return Idx->getZExtValue() >= VecTy.getNumElements();
2852}
2853
2854bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI,
2855 unsigned &OpIdx) const {
2856 GSelect &SelMI = cast<GSelect>(Val&: MI);
2857 auto Cst =
2858 isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: SelMI.getCondReg()), MRI);
2859 if (!Cst)
2860 return false;
2861 OpIdx = Cst->isZero() ? 3 : 2;
2862 return true;
2863}
2864
2865void CombinerHelper::eraseInst(MachineInstr &MI) const { MI.eraseFromParent(); }
2866
2867bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1,
2868 const MachineOperand &MOP2) const {
2869 if (!MOP1.isReg() || !MOP2.isReg())
2870 return false;
2871 auto InstAndDef1 = getDefSrcRegIgnoringCopies(Reg: MOP1.getReg(), MRI);
2872 if (!InstAndDef1)
2873 return false;
2874 auto InstAndDef2 = getDefSrcRegIgnoringCopies(Reg: MOP2.getReg(), MRI);
2875 if (!InstAndDef2)
2876 return false;
2877 MachineInstr *I1 = InstAndDef1->MI;
2878 MachineInstr *I2 = InstAndDef2->MI;
2879
2880 // Handle a case like this:
2881 //
2882 // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>)
2883 //
2884 // Even though %0 and %1 are produced by the same instruction they are not
2885 // the same values.
2886 if (I1 == I2)
2887 return MOP1.getReg() == MOP2.getReg();
2888
2889 // If we have an instruction which loads or stores, we can't guarantee that
2890 // it is identical.
2891 //
2892 // For example, we may have
2893 //
2894 // %x1 = G_LOAD %addr (load N from @somewhere)
2895 // ...
2896 // call @foo
2897 // ...
2898 // %x2 = G_LOAD %addr (load N from @somewhere)
2899 // ...
2900 // %or = G_OR %x1, %x2
2901 //
2902 // It's possible that @foo will modify whatever lives at the address we're
2903 // loading from. To be safe, let's just assume that all loads and stores
2904 // are different (unless we have something which is guaranteed to not
2905 // change.)
2906 if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad())
2907 return false;
2908
2909 // If both instructions are loads or stores, they are equal only if both
2910 // are dereferenceable invariant loads with the same number of bits.
2911 if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) {
2912 GLoadStore *LS1 = dyn_cast<GLoadStore>(Val: I1);
2913 GLoadStore *LS2 = dyn_cast<GLoadStore>(Val: I2);
2914 if (!LS1 || !LS2)
2915 return false;
2916
2917 if (!I2->isDereferenceableInvariantLoad() ||
2918 (LS1->getMemSizeInBits() != LS2->getMemSizeInBits()))
2919 return false;
2920 }
2921
2922 // Check for physical registers on the instructions first to avoid cases
2923 // like this:
2924 //
2925 // %a = COPY $physreg
2926 // ...
2927 // SOMETHING implicit-def $physreg
2928 // ...
2929 // %b = COPY $physreg
2930 //
2931 // These copies are not equivalent.
2932 if (any_of(Range: I1->uses(), P: [](const MachineOperand &MO) {
2933 return MO.isReg() && MO.getReg().isPhysical();
2934 })) {
2935 // Check if we have a case like this:
2936 //
2937 // %a = COPY $physreg
2938 // %b = COPY %a
2939 //
2940 // In this case, I1 and I2 will both be equal to %a = COPY $physreg.
2941 // From that, we know that they must have the same value, since they must
2942 // have come from the same COPY.
2943 return I1->isIdenticalTo(Other: *I2);
2944 }
2945
2946 // We don't have any physical registers, so we don't necessarily need the
2947 // same vreg defs.
2948 //
2949 // On the off-chance that there's some target instruction feeding into the
2950 // instruction, let's use produceSameValue instead of isIdenticalTo.
2951 if (Builder.getTII().produceSameValue(MI0: *I1, MI1: *I2, MRI: &MRI)) {
2952 // Handle instructions with multiple defs that produce same values. Values
2953 // are same for operands with same index.
2954 // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2955 // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2956 // I1 and I2 are different instructions but produce same values,
2957 // %1 and %6 are same, %1 and %7 are not the same value.
2958 return I1->findRegisterDefOperandIdx(Reg: InstAndDef1->Reg, /*TRI=*/nullptr) ==
2959 I2->findRegisterDefOperandIdx(Reg: InstAndDef2->Reg, /*TRI=*/nullptr);
2960 }
2961 return false;
2962}
2963
2964bool CombinerHelper::matchConstantOp(const MachineOperand &MOP,
2965 int64_t C) const {
2966 if (!MOP.isReg())
2967 return false;
2968 auto *MI = MRI.getVRegDef(Reg: MOP.getReg());
2969 auto MaybeCst = isConstantOrConstantSplatVector(MI&: *MI, MRI);
2970 return MaybeCst && MaybeCst->getBitWidth() <= 64 &&
2971 MaybeCst->getSExtValue() == C;
2972}
2973
2974bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP,
2975 double C) const {
2976 if (!MOP.isReg())
2977 return false;
2978 std::optional<FPValueAndVReg> MaybeCst;
2979 if (!mi_match(R: MOP.getReg(), MRI, P: m_GFCstOrSplat(FPValReg&: MaybeCst)))
2980 return false;
2981
2982 return MaybeCst->Value.isExactlyValue(V: C);
2983}
2984
2985void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI,
2986 unsigned OpIdx) const {
2987 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2988 Register OldReg = MI.getOperand(i: 0).getReg();
2989 Register Replacement = MI.getOperand(i: OpIdx).getReg();
2990 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2991 replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement);
2992 MI.eraseFromParent();
2993}
2994
2995void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI,
2996 Register Replacement) const {
2997 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2998 Register OldReg = MI.getOperand(i: 0).getReg();
2999 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
3000 replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement);
3001 MI.eraseFromParent();
3002}
3003
3004bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI,
3005 unsigned ConstIdx) const {
3006 Register ConstReg = MI.getOperand(i: ConstIdx).getReg();
3007 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3008
3009 // Get the shift amount
3010 auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI);
3011 if (!VRegAndVal)
3012 return false;
3013
3014 // Return true of shift amount >= Bitwidth
3015 return (VRegAndVal->Value.uge(RHS: DstTy.getSizeInBits()));
3016}
3017
3018void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) const {
3019 assert((MI.getOpcode() == TargetOpcode::G_FSHL ||
3020 MI.getOpcode() == TargetOpcode::G_FSHR) &&
3021 "This is not a funnel shift operation");
3022
3023 Register ConstReg = MI.getOperand(i: 3).getReg();
3024 LLT ConstTy = MRI.getType(Reg: ConstReg);
3025 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3026
3027 auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI);
3028 assert((VRegAndVal) && "Value is not a constant");
3029
3030 // Calculate the new Shift Amount = Old Shift Amount % BitWidth
3031 APInt NewConst = VRegAndVal->Value.urem(
3032 RHS: APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits()));
3033
3034 auto NewConstInstr = Builder.buildConstant(Res: ConstTy, Val: NewConst.getZExtValue());
3035 Builder.buildInstr(
3036 Opc: MI.getOpcode(), DstOps: {MI.getOperand(i: 0)},
3037 SrcOps: {MI.getOperand(i: 1), MI.getOperand(i: 2), NewConstInstr.getReg(Idx: 0)});
3038
3039 MI.eraseFromParent();
3040}
3041
3042bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) const {
3043 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
3044 // Match (cond ? x : x)
3045 return matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: MI.getOperand(i: 3)) &&
3046 canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 2).getReg(),
3047 MRI);
3048}
3049
3050bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) const {
3051 return matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: MI.getOperand(i: 2)) &&
3052 canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 1).getReg(),
3053 MRI);
3054}
3055
3056bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI,
3057 unsigned OpIdx) const {
3058 MachineOperand &MO = MI.getOperand(i: OpIdx);
3059 return MO.isReg() &&
3060 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
3061}
3062
3063bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI,
3064 unsigned OpIdx) const {
3065 MachineOperand &MO = MI.getOperand(i: OpIdx);
3066 return isKnownToBeAPowerOfTwo(Val: MO.getReg(), MRI, ValueTracking: VT);
3067}
3068
3069void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI,
3070 double C) const {
3071 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3072 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: C);
3073 MI.eraseFromParent();
3074}
3075
3076void CombinerHelper::replaceInstWithConstant(MachineInstr &MI,
3077 int64_t C) const {
3078 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3079 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C);
3080 MI.eraseFromParent();
3081}
3082
3083void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) const {
3084 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3085 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C);
3086 MI.eraseFromParent();
3087}
3088
3089void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI,
3090 ConstantFP *CFP) const {
3091 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3092 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: CFP->getValueAPF());
3093 MI.eraseFromParent();
3094}
3095
3096void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) const {
3097 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3098 Builder.buildUndef(Res: MI.getOperand(i: 0));
3099 MI.eraseFromParent();
3100}
3101
3102bool CombinerHelper::matchSimplifyAddToSub(
3103 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const {
3104 Register LHS = MI.getOperand(i: 1).getReg();
3105 Register RHS = MI.getOperand(i: 2).getReg();
3106 Register &NewLHS = std::get<0>(t&: MatchInfo);
3107 Register &NewRHS = std::get<1>(t&: MatchInfo);
3108
3109 // Helper lambda to check for opportunities for
3110 // ((0-A) + B) -> B - A
3111 // (A + (0-B)) -> A - B
3112 auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) {
3113 if (!mi_match(R: MaybeSub, MRI, P: m_Neg(Src: m_Reg(R&: NewRHS))))
3114 return false;
3115 NewLHS = MaybeNewLHS;
3116 return true;
3117 };
3118
3119 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
3120}
3121
3122bool CombinerHelper::matchCombineInsertVecElts(
3123 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const {
3124 assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT &&
3125 "Invalid opcode");
3126 Register DstReg = MI.getOperand(i: 0).getReg();
3127 LLT DstTy = MRI.getType(Reg: DstReg);
3128 assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?");
3129
3130 if (DstTy.isScalableVector())
3131 return false;
3132
3133 unsigned NumElts = DstTy.getNumElements();
3134 // If this MI is part of a sequence of insert_vec_elts, then
3135 // don't do the combine in the middle of the sequence.
3136 if (MRI.hasOneUse(RegNo: DstReg) && MRI.use_instr_begin(RegNo: DstReg)->getOpcode() ==
3137 TargetOpcode::G_INSERT_VECTOR_ELT)
3138 return false;
3139 MachineInstr *CurrInst = &MI;
3140 MachineInstr *TmpInst;
3141 int64_t IntImm;
3142 Register TmpReg;
3143 MatchInfo.resize(N: NumElts);
3144 while (mi_match(
3145 R: CurrInst->getOperand(i: 0).getReg(), MRI,
3146 P: m_GInsertVecElt(Src0: m_MInstr(MI&: TmpInst), Src1: m_Reg(R&: TmpReg), Src2: m_ICst(Cst&: IntImm)))) {
3147 if (IntImm >= NumElts || IntImm < 0)
3148 return false;
3149 if (!MatchInfo[IntImm])
3150 MatchInfo[IntImm] = TmpReg;
3151 CurrInst = TmpInst;
3152 }
3153 // Variable index.
3154 if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
3155 return false;
3156 if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
3157 for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) {
3158 if (!MatchInfo[I - 1].isValid())
3159 MatchInfo[I - 1] = TmpInst->getOperand(i: I).getReg();
3160 }
3161 return true;
3162 }
3163 // If we didn't end in a G_IMPLICIT_DEF and the source is not fully
3164 // overwritten, bail out.
3165 return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF ||
3166 all_of(Range&: MatchInfo, P: [](Register Reg) { return !!Reg; });
3167}
3168
3169void CombinerHelper::applyCombineInsertVecElts(
3170 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const {
3171 Register UndefReg;
3172 auto GetUndef = [&]() {
3173 if (UndefReg)
3174 return UndefReg;
3175 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3176 UndefReg = Builder.buildUndef(Res: DstTy.getScalarType()).getReg(Idx: 0);
3177 return UndefReg;
3178 };
3179 for (Register &Reg : MatchInfo) {
3180 if (!Reg)
3181 Reg = GetUndef();
3182 }
3183 Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: MatchInfo);
3184 MI.eraseFromParent();
3185}
3186
3187void CombinerHelper::applySimplifyAddToSub(
3188 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const {
3189 Register SubLHS, SubRHS;
3190 std::tie(args&: SubLHS, args&: SubRHS) = MatchInfo;
3191 Builder.buildSub(Dst: MI.getOperand(i: 0).getReg(), Src0: SubLHS, Src1: SubRHS);
3192 MI.eraseFromParent();
3193}
3194
3195bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands(
3196 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const {
3197 // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ...
3198 //
3199 // Creates the new hand + logic instruction (but does not insert them.)
3200 //
3201 // On success, MatchInfo is populated with the new instructions. These are
3202 // inserted in applyHoistLogicOpWithSameOpcodeHands.
3203 unsigned LogicOpcode = MI.getOpcode();
3204 assert(LogicOpcode == TargetOpcode::G_AND ||
3205 LogicOpcode == TargetOpcode::G_OR ||
3206 LogicOpcode == TargetOpcode::G_XOR);
3207 MachineIRBuilder MIB(MI);
3208 Register Dst = MI.getOperand(i: 0).getReg();
3209 Register LHSReg = MI.getOperand(i: 1).getReg();
3210 Register RHSReg = MI.getOperand(i: 2).getReg();
3211
3212 // Don't recompute anything.
3213 if (!MRI.hasOneNonDBGUse(RegNo: LHSReg) || !MRI.hasOneNonDBGUse(RegNo: RHSReg))
3214 return false;
3215
3216 // Make sure we have (hand x, ...), (hand y, ...)
3217 MachineInstr *LeftHandInst = getDefIgnoringCopies(Reg: LHSReg, MRI);
3218 MachineInstr *RightHandInst = getDefIgnoringCopies(Reg: RHSReg, MRI);
3219 if (!LeftHandInst || !RightHandInst)
3220 return false;
3221 unsigned HandOpcode = LeftHandInst->getOpcode();
3222 if (HandOpcode != RightHandInst->getOpcode())
3223 return false;
3224 if (LeftHandInst->getNumOperands() < 2 ||
3225 !LeftHandInst->getOperand(i: 1).isReg() ||
3226 RightHandInst->getNumOperands() < 2 ||
3227 !RightHandInst->getOperand(i: 1).isReg())
3228 return false;
3229
3230 // Make sure the types match up, and if we're doing this post-legalization,
3231 // we end up with legal types.
3232 Register X = LeftHandInst->getOperand(i: 1).getReg();
3233 Register Y = RightHandInst->getOperand(i: 1).getReg();
3234 LLT XTy = MRI.getType(Reg: X);
3235 LLT YTy = MRI.getType(Reg: Y);
3236 if (!XTy.isValid() || XTy != YTy)
3237 return false;
3238
3239 // Optional extra source register.
3240 Register ExtraHandOpSrcReg;
3241 switch (HandOpcode) {
3242 default:
3243 return false;
3244 case TargetOpcode::G_ANYEXT:
3245 case TargetOpcode::G_SEXT:
3246 case TargetOpcode::G_ZEXT: {
3247 // Match: logic (ext X), (ext Y) --> ext (logic X, Y)
3248 break;
3249 }
3250 case TargetOpcode::G_TRUNC: {
3251 // Match: logic (trunc X), (trunc Y) -> trunc (logic X, Y)
3252 const MachineFunction *MF = MI.getMF();
3253 LLVMContext &Ctx = MF->getFunction().getContext();
3254
3255 LLT DstTy = MRI.getType(Reg: Dst);
3256 const TargetLowering &TLI = getTargetLowering();
3257
3258 // Be extra careful sinking truncate. If it's free, there's no benefit in
3259 // widening a binop.
3260 if (TLI.isZExtFree(FromTy: DstTy, ToTy: XTy, Ctx) && TLI.isTruncateFree(FromTy: XTy, ToTy: DstTy, Ctx))
3261 return false;
3262 break;
3263 }
3264 case TargetOpcode::G_AND:
3265 case TargetOpcode::G_ASHR:
3266 case TargetOpcode::G_LSHR:
3267 case TargetOpcode::G_SHL: {
3268 // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z
3269 MachineOperand &ZOp = LeftHandInst->getOperand(i: 2);
3270 if (!matchEqualDefs(MOP1: ZOp, MOP2: RightHandInst->getOperand(i: 2)))
3271 return false;
3272 ExtraHandOpSrcReg = ZOp.getReg();
3273 break;
3274 }
3275 }
3276
3277 if (!isLegalOrBeforeLegalizer(Query: {LogicOpcode, {XTy, YTy}}))
3278 return false;
3279
3280 // Record the steps to build the new instructions.
3281 //
3282 // Steps to build (logic x, y)
3283 auto NewLogicDst = MRI.createGenericVirtualRegister(Ty: XTy);
3284 OperandBuildSteps LogicBuildSteps = {
3285 [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: NewLogicDst); },
3286 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: X); },
3287 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: Y); }};
3288 InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps);
3289
3290 // Steps to build hand (logic x, y), ...z
3291 OperandBuildSteps HandBuildSteps = {
3292 [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: Dst); },
3293 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: NewLogicDst); }};
3294 if (ExtraHandOpSrcReg.isValid())
3295 HandBuildSteps.push_back(
3296 Elt: [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: ExtraHandOpSrcReg); });
3297 InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps);
3298
3299 MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps});
3300 return true;
3301}
3302
3303void CombinerHelper::applyBuildInstructionSteps(
3304 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const {
3305 assert(MatchInfo.InstrsToBuild.size() &&
3306 "Expected at least one instr to build?");
3307 for (auto &InstrToBuild : MatchInfo.InstrsToBuild) {
3308 assert(InstrToBuild.Opcode && "Expected a valid opcode?");
3309 assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?");
3310 MachineInstrBuilder Instr = Builder.buildInstr(Opcode: InstrToBuild.Opcode);
3311 for (auto &OperandFn : InstrToBuild.OperandFns)
3312 OperandFn(Instr);
3313 }
3314 MI.eraseFromParent();
3315}
3316
3317bool CombinerHelper::matchAshrShlToSextInreg(
3318 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const {
3319 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3320 int64_t ShlCst, AshrCst;
3321 Register Src;
3322 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
3323 P: m_GAShr(L: m_GShl(L: m_Reg(R&: Src), R: m_ICstOrSplat(Cst&: ShlCst)),
3324 R: m_ICstOrSplat(Cst&: AshrCst))))
3325 return false;
3326 if (ShlCst != AshrCst)
3327 return false;
3328 if (!isLegalOrBeforeLegalizer(
3329 Query: {TargetOpcode::G_SEXT_INREG, {MRI.getType(Reg: Src)}}))
3330 return false;
3331 MatchInfo = std::make_tuple(args&: Src, args&: ShlCst);
3332 return true;
3333}
3334
3335void CombinerHelper::applyAshShlToSextInreg(
3336 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const {
3337 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3338 Register Src;
3339 int64_t ShiftAmt;
3340 std::tie(args&: Src, args&: ShiftAmt) = MatchInfo;
3341 unsigned Size = MRI.getType(Reg: Src).getScalarSizeInBits();
3342 Builder.buildSExtInReg(Res: MI.getOperand(i: 0).getReg(), Op: Src, ImmOp: Size - ShiftAmt);
3343 MI.eraseFromParent();
3344}
3345
3346/// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0
3347bool CombinerHelper::matchOverlappingAnd(
3348 MachineInstr &MI,
3349 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
3350 assert(MI.getOpcode() == TargetOpcode::G_AND);
3351
3352 Register Dst = MI.getOperand(i: 0).getReg();
3353 LLT Ty = MRI.getType(Reg: Dst);
3354
3355 Register R;
3356 int64_t C1;
3357 int64_t C2;
3358 if (!mi_match(
3359 R: Dst, MRI,
3360 P: m_GAnd(L: m_GAnd(L: m_Reg(R), R: m_ICst(Cst&: C1)), R: m_ICst(Cst&: C2))))
3361 return false;
3362
3363 MatchInfo = [=](MachineIRBuilder &B) {
3364 if (C1 & C2) {
3365 B.buildAnd(Dst, Src0: R, Src1: B.buildConstant(Res: Ty, Val: C1 & C2));
3366 return;
3367 }
3368 auto Zero = B.buildConstant(Res: Ty, Val: 0);
3369 replaceRegWith(MRI, FromReg: Dst, ToReg: Zero->getOperand(i: 0).getReg());
3370 };
3371 return true;
3372}
3373
3374bool CombinerHelper::matchRedundantAnd(MachineInstr &MI,
3375 Register &Replacement) const {
3376 // Given
3377 //
3378 // %y:_(sN) = G_SOMETHING
3379 // %x:_(sN) = G_SOMETHING
3380 // %res:_(sN) = G_AND %x, %y
3381 //
3382 // Eliminate the G_AND when it is known that x & y == x or x & y == y.
3383 //
3384 // Patterns like this can appear as a result of legalization. E.g.
3385 //
3386 // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y
3387 // %one:_(s32) = G_CONSTANT i32 1
3388 // %and:_(s32) = G_AND %cmp, %one
3389 //
3390 // In this case, G_ICMP only produces a single bit, so x & 1 == x.
3391 assert(MI.getOpcode() == TargetOpcode::G_AND);
3392 if (!VT)
3393 return false;
3394
3395 Register AndDst = MI.getOperand(i: 0).getReg();
3396 Register LHS = MI.getOperand(i: 1).getReg();
3397 Register RHS = MI.getOperand(i: 2).getReg();
3398
3399 // Check the RHS (maybe a constant) first, and if we have no KnownBits there,
3400 // we can't do anything. If we do, then it depends on whether we have
3401 // KnownBits on the LHS.
3402 KnownBits RHSBits = VT->getKnownBits(R: RHS);
3403 if (RHSBits.isUnknown())
3404 return false;
3405
3406 KnownBits LHSBits = VT->getKnownBits(R: LHS);
3407
3408 // Check that x & Mask == x.
3409 // x & 1 == x, always
3410 // x & 0 == x, only if x is also 0
3411 // Meaning Mask has no effect if every bit is either one in Mask or zero in x.
3412 //
3413 // Check if we can replace AndDst with the LHS of the G_AND
3414 if (canReplaceReg(DstReg: AndDst, SrcReg: LHS, MRI) &&
3415 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3416 Replacement = LHS;
3417 return true;
3418 }
3419
3420 // Check if we can replace AndDst with the RHS of the G_AND
3421 if (canReplaceReg(DstReg: AndDst, SrcReg: RHS, MRI) &&
3422 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3423 Replacement = RHS;
3424 return true;
3425 }
3426
3427 return false;
3428}
3429
3430bool CombinerHelper::matchRedundantOr(MachineInstr &MI,
3431 Register &Replacement) const {
3432 // Given
3433 //
3434 // %y:_(sN) = G_SOMETHING
3435 // %x:_(sN) = G_SOMETHING
3436 // %res:_(sN) = G_OR %x, %y
3437 //
3438 // Eliminate the G_OR when it is known that x | y == x or x | y == y.
3439 assert(MI.getOpcode() == TargetOpcode::G_OR);
3440 if (!VT)
3441 return false;
3442
3443 Register OrDst = MI.getOperand(i: 0).getReg();
3444 Register LHS = MI.getOperand(i: 1).getReg();
3445 Register RHS = MI.getOperand(i: 2).getReg();
3446
3447 KnownBits LHSBits = VT->getKnownBits(R: LHS);
3448 KnownBits RHSBits = VT->getKnownBits(R: RHS);
3449
3450 // Check that x | Mask == x.
3451 // x | 0 == x, always
3452 // x | 1 == x, only if x is also 1
3453 // Meaning Mask has no effect if every bit is either zero in Mask or one in x.
3454 //
3455 // Check if we can replace OrDst with the LHS of the G_OR
3456 if (canReplaceReg(DstReg: OrDst, SrcReg: LHS, MRI) &&
3457 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3458 Replacement = LHS;
3459 return true;
3460 }
3461
3462 // Check if we can replace OrDst with the RHS of the G_OR
3463 if (canReplaceReg(DstReg: OrDst, SrcReg: RHS, MRI) &&
3464 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3465 Replacement = RHS;
3466 return true;
3467 }
3468
3469 return false;
3470}
3471
3472bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) const {
3473 // If the input is already sign extended, just drop the extension.
3474 Register Src = MI.getOperand(i: 1).getReg();
3475 unsigned ExtBits = MI.getOperand(i: 2).getImm();
3476 unsigned TypeSize = MRI.getType(Reg: Src).getScalarSizeInBits();
3477 return VT->computeNumSignBits(R: Src) >= (TypeSize - ExtBits + 1);
3478}
3479
3480static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits,
3481 int64_t Cst, bool IsVector, bool IsFP) {
3482 // For i1, Cst will always be -1 regardless of boolean contents.
3483 return (ScalarSizeBits == 1 && Cst == -1) ||
3484 isConstTrueVal(TLI, Val: Cst, IsVector, IsFP);
3485}
3486
3487// This pattern aims to match the following shape to avoid extra mov
3488// instructions
3489// G_BUILD_VECTOR(
3490// G_UNMERGE_VALUES(src, 0)
3491// G_UNMERGE_VALUES(src, 1)
3492// G_IMPLICIT_DEF
3493// G_IMPLICIT_DEF
3494// )
3495// ->
3496// G_CONCAT_VECTORS(
3497// src,
3498// undef
3499// )
3500bool CombinerHelper::matchCombineBuildUnmerge(MachineInstr &MI,
3501 MachineRegisterInfo &MRI,
3502 Register &UnmergeSrc) const {
3503 auto &BV = cast<GBuildVector>(Val&: MI);
3504
3505 unsigned BuildUseCount = BV.getNumSources();
3506 if (BuildUseCount % 2 != 0)
3507 return false;
3508
3509 unsigned NumUnmerge = BuildUseCount / 2;
3510
3511 auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: BV.getSourceReg(I: 0), MRI);
3512
3513 // Check the first operand is an unmerge and has the correct number of
3514 // operands
3515 if (!Unmerge || Unmerge->getNumDefs() != NumUnmerge)
3516 return false;
3517
3518 UnmergeSrc = Unmerge->getSourceReg();
3519
3520 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3521 LLT UnmergeSrcTy = MRI.getType(Reg: UnmergeSrc);
3522
3523 if (!UnmergeSrcTy.isVector())
3524 return false;
3525
3526 // Ensure we only generate legal instructions post-legalizer
3527 if (!IsPreLegalize &&
3528 !isLegal(Query: {TargetOpcode::G_CONCAT_VECTORS, {DstTy, UnmergeSrcTy}}))
3529 return false;
3530
3531 // Check that all of the operands before the midpoint come from the same
3532 // unmerge and are in the same order as they are used in the build_vector
3533 for (unsigned I = 0; I < NumUnmerge; ++I) {
3534 auto MaybeUnmergeReg = BV.getSourceReg(I);
3535 auto *LoopUnmerge = getOpcodeDef<GUnmerge>(Reg: MaybeUnmergeReg, MRI);
3536
3537 if (!LoopUnmerge || LoopUnmerge != Unmerge)
3538 return false;
3539
3540 if (LoopUnmerge->getOperand(i: I).getReg() != MaybeUnmergeReg)
3541 return false;
3542 }
3543
3544 // Check that all of the unmerged values are used
3545 if (Unmerge->getNumDefs() != NumUnmerge)
3546 return false;
3547
3548 // Check that all of the operands after the mid point are undefs.
3549 for (unsigned I = NumUnmerge; I < BuildUseCount; ++I) {
3550 auto *Undef = getDefIgnoringCopies(Reg: BV.getSourceReg(I), MRI);
3551
3552 if (Undef->getOpcode() != TargetOpcode::G_IMPLICIT_DEF)
3553 return false;
3554 }
3555
3556 return true;
3557}
3558
3559void CombinerHelper::applyCombineBuildUnmerge(MachineInstr &MI,
3560 MachineRegisterInfo &MRI,
3561 MachineIRBuilder &B,
3562 Register &UnmergeSrc) const {
3563 assert(UnmergeSrc && "Expected there to be one matching G_UNMERGE_VALUES");
3564 B.setInstrAndDebugLoc(MI);
3565
3566 Register UndefVec = B.buildUndef(Res: MRI.getType(Reg: UnmergeSrc)).getReg(Idx: 0);
3567 B.buildConcatVectors(Res: MI.getOperand(i: 0), Ops: {UnmergeSrc, UndefVec});
3568
3569 MI.eraseFromParent();
3570}
3571
3572// This combine tries to reduce the number of scalarised G_TRUNC instructions by
3573// using vector truncates instead
3574//
3575// EXAMPLE:
3576// %a(i32), %b(i32) = G_UNMERGE_VALUES %src(<2 x i32>)
3577// %T_a(i16) = G_TRUNC %a(i32)
3578// %T_b(i16) = G_TRUNC %b(i32)
3579// %Undef(i16) = G_IMPLICIT_DEF(i16)
3580// %dst(v4i16) = G_BUILD_VECTORS %T_a(i16), %T_b(i16), %Undef(i16), %Undef(i16)
3581//
3582// ===>
3583// %Undef(<2 x i32>) = G_IMPLICIT_DEF(<2 x i32>)
3584// %Mid(<4 x s32>) = G_CONCAT_VECTORS %src(<2 x i32>), %Undef(<2 x i32>)
3585// %dst(<4 x s16>) = G_TRUNC %Mid(<4 x s32>)
3586//
3587// Only matches sources made up of G_TRUNCs followed by G_IMPLICIT_DEFs
3588bool CombinerHelper::matchUseVectorTruncate(MachineInstr &MI,
3589 Register &MatchInfo) const {
3590 auto BuildMI = cast<GBuildVector>(Val: &MI);
3591 unsigned NumOperands = BuildMI->getNumSources();
3592 LLT DstTy = MRI.getType(Reg: BuildMI->getReg(Idx: 0));
3593
3594 // Check the G_BUILD_VECTOR sources
3595 unsigned I;
3596 MachineInstr *UnmergeMI = nullptr;
3597
3598 // Check all source TRUNCs come from the same UNMERGE instruction
3599 for (I = 0; I < NumOperands; ++I) {
3600 auto SrcMI = MRI.getVRegDef(Reg: BuildMI->getSourceReg(I));
3601 auto SrcMIOpc = SrcMI->getOpcode();
3602
3603 // Check if the G_TRUNC instructions all come from the same MI
3604 if (SrcMIOpc == TargetOpcode::G_TRUNC) {
3605 if (!UnmergeMI) {
3606 UnmergeMI = MRI.getVRegDef(Reg: SrcMI->getOperand(i: 1).getReg());
3607 if (UnmergeMI->getOpcode() != TargetOpcode::G_UNMERGE_VALUES)
3608 return false;
3609 } else {
3610 auto UnmergeSrcMI = MRI.getVRegDef(Reg: SrcMI->getOperand(i: 1).getReg());
3611 if (UnmergeMI != UnmergeSrcMI)
3612 return false;
3613 }
3614 } else {
3615 break;
3616 }
3617 }
3618 if (I < 2)
3619 return false;
3620
3621 // Check the remaining source elements are only G_IMPLICIT_DEF
3622 for (; I < NumOperands; ++I) {
3623 auto SrcMI = MRI.getVRegDef(Reg: BuildMI->getSourceReg(I));
3624 auto SrcMIOpc = SrcMI->getOpcode();
3625
3626 if (SrcMIOpc != TargetOpcode::G_IMPLICIT_DEF)
3627 return false;
3628 }
3629
3630 // Check the size of unmerge source
3631 MatchInfo = cast<GUnmerge>(Val: UnmergeMI)->getSourceReg();
3632 LLT UnmergeSrcTy = MRI.getType(Reg: MatchInfo);
3633 if (!DstTy.getElementCount().isKnownMultipleOf(RHS: UnmergeSrcTy.getNumElements()))
3634 return false;
3635
3636 // Check the unmerge source and destination element types match
3637 LLT UnmergeSrcEltTy = UnmergeSrcTy.getElementType();
3638 Register UnmergeDstReg = UnmergeMI->getOperand(i: 0).getReg();
3639 LLT UnmergeDstEltTy = MRI.getType(Reg: UnmergeDstReg);
3640 if (UnmergeSrcEltTy != UnmergeDstEltTy)
3641 return false;
3642
3643 // Only generate legal instructions post-legalizer
3644 if (!IsPreLegalize) {
3645 LLT MidTy = DstTy.changeElementType(NewEltTy: UnmergeSrcTy.getScalarType());
3646
3647 if (DstTy.getElementCount() != UnmergeSrcTy.getElementCount() &&
3648 !isLegal(Query: {TargetOpcode::G_CONCAT_VECTORS, {MidTy, UnmergeSrcTy}}))
3649 return false;
3650
3651 if (!isLegal(Query: {TargetOpcode::G_TRUNC, {DstTy, MidTy}}))
3652 return false;
3653 }
3654
3655 return true;
3656}
3657
3658void CombinerHelper::applyUseVectorTruncate(MachineInstr &MI,
3659 Register &MatchInfo) const {
3660 Register MidReg;
3661 auto BuildMI = cast<GBuildVector>(Val: &MI);
3662 Register DstReg = BuildMI->getReg(Idx: 0);
3663 LLT DstTy = MRI.getType(Reg: DstReg);
3664 LLT UnmergeSrcTy = MRI.getType(Reg: MatchInfo);
3665 unsigned DstTyNumElt = DstTy.getNumElements();
3666 unsigned UnmergeSrcTyNumElt = UnmergeSrcTy.getNumElements();
3667
3668 // No need to pad vector if only G_TRUNC is needed
3669 if (DstTyNumElt / UnmergeSrcTyNumElt == 1) {
3670 MidReg = MatchInfo;
3671 } else {
3672 Register UndefReg = Builder.buildUndef(Res: UnmergeSrcTy).getReg(Idx: 0);
3673 SmallVector<Register> ConcatRegs = {MatchInfo};
3674 for (unsigned I = 1; I < DstTyNumElt / UnmergeSrcTyNumElt; ++I)
3675 ConcatRegs.push_back(Elt: UndefReg);
3676
3677 auto MidTy = DstTy.changeElementType(NewEltTy: UnmergeSrcTy.getScalarType());
3678 MidReg = Builder.buildConcatVectors(Res: MidTy, Ops: ConcatRegs).getReg(Idx: 0);
3679 }
3680
3681 Builder.buildTrunc(Res: DstReg, Op: MidReg);
3682 MI.eraseFromParent();
3683}
3684
3685bool CombinerHelper::matchNotCmp(
3686 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const {
3687 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3688 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3689 const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering();
3690 Register XorSrc;
3691 Register CstReg;
3692 // We match xor(src, true) here.
3693 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
3694 P: m_GXor(L: m_Reg(R&: XorSrc), R: m_Reg(R&: CstReg))))
3695 return false;
3696
3697 if (!MRI.hasOneNonDBGUse(RegNo: XorSrc))
3698 return false;
3699
3700 // Check that XorSrc is the root of a tree of comparisons combined with ANDs
3701 // and ORs. The suffix of RegsToNegate starting from index I is used a work
3702 // list of tree nodes to visit.
3703 RegsToNegate.push_back(Elt: XorSrc);
3704 // Remember whether the comparisons are all integer or all floating point.
3705 bool IsInt = false;
3706 bool IsFP = false;
3707 for (unsigned I = 0; I < RegsToNegate.size(); ++I) {
3708 Register Reg = RegsToNegate[I];
3709 if (!MRI.hasOneNonDBGUse(RegNo: Reg))
3710 return false;
3711 MachineInstr *Def = MRI.getVRegDef(Reg);
3712 switch (Def->getOpcode()) {
3713 default:
3714 // Don't match if the tree contains anything other than ANDs, ORs and
3715 // comparisons.
3716 return false;
3717 case TargetOpcode::G_ICMP:
3718 if (IsFP)
3719 return false;
3720 IsInt = true;
3721 // When we apply the combine we will invert the predicate.
3722 break;
3723 case TargetOpcode::G_FCMP:
3724 if (IsInt)
3725 return false;
3726 IsFP = true;
3727 // When we apply the combine we will invert the predicate.
3728 break;
3729 case TargetOpcode::G_AND:
3730 case TargetOpcode::G_OR:
3731 // Implement De Morgan's laws:
3732 // ~(x & y) -> ~x | ~y
3733 // ~(x | y) -> ~x & ~y
3734 // When we apply the combine we will change the opcode and recursively
3735 // negate the operands.
3736 RegsToNegate.push_back(Elt: Def->getOperand(i: 1).getReg());
3737 RegsToNegate.push_back(Elt: Def->getOperand(i: 2).getReg());
3738 break;
3739 }
3740 }
3741
3742 // Now we know whether the comparisons are integer or floating point, check
3743 // the constant in the xor.
3744 int64_t Cst;
3745 if (Ty.isVector()) {
3746 MachineInstr *CstDef = MRI.getVRegDef(Reg: CstReg);
3747 auto MaybeCst = getIConstantSplatSExtVal(MI: *CstDef, MRI);
3748 if (!MaybeCst)
3749 return false;
3750 if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getScalarSizeInBits(), Cst: *MaybeCst, IsVector: true, IsFP))
3751 return false;
3752 } else {
3753 if (!mi_match(R: CstReg, MRI, P: m_ICst(Cst)))
3754 return false;
3755 if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getSizeInBits(), Cst, IsVector: false, IsFP))
3756 return false;
3757 }
3758
3759 return true;
3760}
3761
3762void CombinerHelper::applyNotCmp(
3763 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const {
3764 for (Register Reg : RegsToNegate) {
3765 MachineInstr *Def = MRI.getVRegDef(Reg);
3766 Observer.changingInstr(MI&: *Def);
3767 // For each comparison, invert the opcode. For each AND and OR, change the
3768 // opcode.
3769 switch (Def->getOpcode()) {
3770 default:
3771 llvm_unreachable("Unexpected opcode");
3772 case TargetOpcode::G_ICMP:
3773 case TargetOpcode::G_FCMP: {
3774 MachineOperand &PredOp = Def->getOperand(i: 1);
3775 CmpInst::Predicate NewP = CmpInst::getInversePredicate(
3776 pred: (CmpInst::Predicate)PredOp.getPredicate());
3777 PredOp.setPredicate(NewP);
3778 break;
3779 }
3780 case TargetOpcode::G_AND:
3781 Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_OR));
3782 break;
3783 case TargetOpcode::G_OR:
3784 Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND));
3785 break;
3786 }
3787 Observer.changedInstr(MI&: *Def);
3788 }
3789
3790 replaceRegWith(MRI, FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 1).getReg());
3791 MI.eraseFromParent();
3792}
3793
3794bool CombinerHelper::matchXorOfAndWithSameReg(
3795 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const {
3796 // Match (xor (and x, y), y) (or any of its commuted cases)
3797 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3798 Register &X = MatchInfo.first;
3799 Register &Y = MatchInfo.second;
3800 Register AndReg = MI.getOperand(i: 1).getReg();
3801 Register SharedReg = MI.getOperand(i: 2).getReg();
3802
3803 // Find a G_AND on either side of the G_XOR.
3804 // Look for one of
3805 //
3806 // (xor (and x, y), SharedReg)
3807 // (xor SharedReg, (and x, y))
3808 if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) {
3809 std::swap(a&: AndReg, b&: SharedReg);
3810 if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y))))
3811 return false;
3812 }
3813
3814 // Only do this if we'll eliminate the G_AND.
3815 if (!MRI.hasOneNonDBGUse(RegNo: AndReg))
3816 return false;
3817
3818 // We can combine if SharedReg is the same as either the LHS or RHS of the
3819 // G_AND.
3820 if (Y != SharedReg)
3821 std::swap(a&: X, b&: Y);
3822 return Y == SharedReg;
3823}
3824
3825void CombinerHelper::applyXorOfAndWithSameReg(
3826 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const {
3827 // Fold (xor (and x, y), y) -> (and (not x), y)
3828 Register X, Y;
3829 std::tie(args&: X, args&: Y) = MatchInfo;
3830 auto Not = Builder.buildNot(Dst: MRI.getType(Reg: X), Src0: X);
3831 Observer.changingInstr(MI);
3832 MI.setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND));
3833 MI.getOperand(i: 1).setReg(Not->getOperand(i: 0).getReg());
3834 MI.getOperand(i: 2).setReg(Y);
3835 Observer.changedInstr(MI);
3836}
3837
3838bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) const {
3839 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
3840 Register DstReg = PtrAdd.getReg(Idx: 0);
3841 LLT Ty = MRI.getType(Reg: DstReg);
3842 const DataLayout &DL = Builder.getMF().getDataLayout();
3843
3844 if (DL.isNonIntegralAddressSpace(AddrSpace: Ty.getScalarType().getAddressSpace()))
3845 return false;
3846
3847 if (Ty.isPointer()) {
3848 auto ConstVal = getIConstantVRegVal(VReg: PtrAdd.getBaseReg(), MRI);
3849 return ConstVal && *ConstVal == 0;
3850 }
3851
3852 assert(Ty.isVector() && "Expecting a vector type");
3853 const MachineInstr *VecMI = MRI.getVRegDef(Reg: PtrAdd.getBaseReg());
3854 return isBuildVectorAllZeros(MI: *VecMI, MRI);
3855}
3856
3857void CombinerHelper::applyPtrAddZero(MachineInstr &MI) const {
3858 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
3859 Builder.buildIntToPtr(Dst: PtrAdd.getReg(Idx: 0), Src: PtrAdd.getOffsetReg());
3860 PtrAdd.eraseFromParent();
3861}
3862
3863/// The second source operand is known to be a power of 2.
3864void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) const {
3865 Register DstReg = MI.getOperand(i: 0).getReg();
3866 Register Src0 = MI.getOperand(i: 1).getReg();
3867 Register Pow2Src1 = MI.getOperand(i: 2).getReg();
3868 LLT Ty = MRI.getType(Reg: DstReg);
3869
3870 // Fold (urem x, pow2) -> (and x, pow2-1)
3871 auto NegOne = Builder.buildConstant(Res: Ty, Val: -1);
3872 auto Add = Builder.buildAdd(Dst: Ty, Src0: Pow2Src1, Src1: NegOne);
3873 Builder.buildAnd(Dst: DstReg, Src0, Src1: Add);
3874 MI.eraseFromParent();
3875}
3876
3877bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI,
3878 unsigned &SelectOpNo) const {
3879 Register LHS = MI.getOperand(i: 1).getReg();
3880 Register RHS = MI.getOperand(i: 2).getReg();
3881
3882 Register OtherOperandReg = RHS;
3883 SelectOpNo = 1;
3884 MachineInstr *Select = MRI.getVRegDef(Reg: LHS);
3885
3886 // Don't do this unless the old select is going away. We want to eliminate the
3887 // binary operator, not replace a binop with a select.
3888 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3889 !MRI.hasOneNonDBGUse(RegNo: LHS)) {
3890 OtherOperandReg = LHS;
3891 SelectOpNo = 2;
3892 Select = MRI.getVRegDef(Reg: RHS);
3893 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3894 !MRI.hasOneNonDBGUse(RegNo: RHS))
3895 return false;
3896 }
3897
3898 MachineInstr *SelectLHS = MRI.getVRegDef(Reg: Select->getOperand(i: 2).getReg());
3899 MachineInstr *SelectRHS = MRI.getVRegDef(Reg: Select->getOperand(i: 3).getReg());
3900
3901 if (!isConstantOrConstantVector(MI: *SelectLHS, MRI,
3902 /*AllowFP*/ true,
3903 /*AllowOpaqueConstants*/ false))
3904 return false;
3905 if (!isConstantOrConstantVector(MI: *SelectRHS, MRI,
3906 /*AllowFP*/ true,
3907 /*AllowOpaqueConstants*/ false))
3908 return false;
3909
3910 unsigned BinOpcode = MI.getOpcode();
3911
3912 // We know that one of the operands is a select of constants. Now verify that
3913 // the other binary operator operand is either a constant, or we can handle a
3914 // variable.
3915 bool CanFoldNonConst =
3916 (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) &&
3917 (isNullOrNullSplat(MI: *SelectLHS, MRI) ||
3918 isAllOnesOrAllOnesSplat(MI: *SelectLHS, MRI)) &&
3919 (isNullOrNullSplat(MI: *SelectRHS, MRI) ||
3920 isAllOnesOrAllOnesSplat(MI: *SelectRHS, MRI));
3921 if (CanFoldNonConst)
3922 return true;
3923
3924 return isConstantOrConstantVector(MI: *MRI.getVRegDef(Reg: OtherOperandReg), MRI,
3925 /*AllowFP*/ true,
3926 /*AllowOpaqueConstants*/ false);
3927}
3928
3929/// \p SelectOperand is the operand in binary operator \p MI that is the select
3930/// to fold.
3931void CombinerHelper::applyFoldBinOpIntoSelect(
3932 MachineInstr &MI, const unsigned &SelectOperand) const {
3933 Register Dst = MI.getOperand(i: 0).getReg();
3934 Register LHS = MI.getOperand(i: 1).getReg();
3935 Register RHS = MI.getOperand(i: 2).getReg();
3936 MachineInstr *Select = MRI.getVRegDef(Reg: MI.getOperand(i: SelectOperand).getReg());
3937
3938 Register SelectCond = Select->getOperand(i: 1).getReg();
3939 Register SelectTrue = Select->getOperand(i: 2).getReg();
3940 Register SelectFalse = Select->getOperand(i: 3).getReg();
3941
3942 LLT Ty = MRI.getType(Reg: Dst);
3943 unsigned BinOpcode = MI.getOpcode();
3944
3945 Register FoldTrue, FoldFalse;
3946
3947 // We have a select-of-constants followed by a binary operator with a
3948 // constant. Eliminate the binop by pulling the constant math into the select.
3949 // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
3950 if (SelectOperand == 1) {
3951 // TODO: SelectionDAG verifies this actually constant folds before
3952 // committing to the combine.
3953
3954 FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectTrue, RHS}).getReg(Idx: 0);
3955 FoldFalse =
3956 Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectFalse, RHS}).getReg(Idx: 0);
3957 } else {
3958 FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectTrue}).getReg(Idx: 0);
3959 FoldFalse =
3960 Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectFalse}).getReg(Idx: 0);
3961 }
3962
3963 Builder.buildSelect(Res: Dst, Tst: SelectCond, Op0: FoldTrue, Op1: FoldFalse, Flags: MI.getFlags());
3964 MI.eraseFromParent();
3965}
3966
3967std::optional<SmallVector<Register, 8>>
3968CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const {
3969 assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!");
3970 // We want to detect if Root is part of a tree which represents a bunch
3971 // of loads being merged into a larger load. We'll try to recognize patterns
3972 // like, for example:
3973 //
3974 // Reg Reg
3975 // \ /
3976 // OR_1 Reg
3977 // \ /
3978 // OR_2
3979 // \ Reg
3980 // .. /
3981 // Root
3982 //
3983 // Reg Reg Reg Reg
3984 // \ / \ /
3985 // OR_1 OR_2
3986 // \ /
3987 // \ /
3988 // ...
3989 // Root
3990 //
3991 // Each "Reg" may have been produced by a load + some arithmetic. This
3992 // function will save each of them.
3993 SmallVector<Register, 8> RegsToVisit;
3994 SmallVector<const MachineInstr *, 7> Ors = {Root};
3995
3996 // In the "worst" case, we're dealing with a load for each byte. So, there
3997 // are at most #bytes - 1 ORs.
3998 const unsigned MaxIter =
3999 MRI.getType(Reg: Root->getOperand(i: 0).getReg()).getSizeInBytes() - 1;
4000 for (unsigned Iter = 0; Iter < MaxIter; ++Iter) {
4001 if (Ors.empty())
4002 break;
4003 const MachineInstr *Curr = Ors.pop_back_val();
4004 Register OrLHS = Curr->getOperand(i: 1).getReg();
4005 Register OrRHS = Curr->getOperand(i: 2).getReg();
4006
4007 // In the combine, we want to elimate the entire tree.
4008 if (!MRI.hasOneNonDBGUse(RegNo: OrLHS) || !MRI.hasOneNonDBGUse(RegNo: OrRHS))
4009 return std::nullopt;
4010
4011 // If it's a G_OR, save it and continue to walk. If it's not, then it's
4012 // something that may be a load + arithmetic.
4013 if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrLHS, MRI))
4014 Ors.push_back(Elt: Or);
4015 else
4016 RegsToVisit.push_back(Elt: OrLHS);
4017 if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrRHS, MRI))
4018 Ors.push_back(Elt: Or);
4019 else
4020 RegsToVisit.push_back(Elt: OrRHS);
4021 }
4022
4023 // We're going to try and merge each register into a wider power-of-2 type,
4024 // so we ought to have an even number of registers.
4025 if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0)
4026 return std::nullopt;
4027 return RegsToVisit;
4028}
4029
4030/// Helper function for findLoadOffsetsForLoadOrCombine.
4031///
4032/// Check if \p Reg is the result of loading a \p MemSizeInBits wide value,
4033/// and then moving that value into a specific byte offset.
4034///
4035/// e.g. x[i] << 24
4036///
4037/// \returns The load instruction and the byte offset it is moved into.
4038static std::optional<std::pair<GZExtLoad *, int64_t>>
4039matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits,
4040 const MachineRegisterInfo &MRI) {
4041 assert(MRI.hasOneNonDBGUse(Reg) &&
4042 "Expected Reg to only have one non-debug use?");
4043 Register MaybeLoad;
4044 int64_t Shift;
4045 if (!mi_match(R: Reg, MRI,
4046 P: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: MaybeLoad), R: m_ICst(Cst&: Shift))))) {
4047 Shift = 0;
4048 MaybeLoad = Reg;
4049 }
4050
4051 if (Shift % MemSizeInBits != 0)
4052 return std::nullopt;
4053
4054 // TODO: Handle other types of loads.
4055 auto *Load = getOpcodeDef<GZExtLoad>(Reg: MaybeLoad, MRI);
4056 if (!Load)
4057 return std::nullopt;
4058
4059 if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits)
4060 return std::nullopt;
4061
4062 return std::make_pair(x&: Load, y: Shift / MemSizeInBits);
4063}
4064
4065std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>>
4066CombinerHelper::findLoadOffsetsForLoadOrCombine(
4067 SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
4068 const SmallVector<Register, 8> &RegsToVisit,
4069 const unsigned MemSizeInBits) const {
4070
4071 // Each load found for the pattern. There should be one for each RegsToVisit.
4072 SmallSetVector<const MachineInstr *, 8> Loads;
4073
4074 // The lowest index used in any load. (The lowest "i" for each x[i].)
4075 int64_t LowestIdx = INT64_MAX;
4076
4077 // The load which uses the lowest index.
4078 GZExtLoad *LowestIdxLoad = nullptr;
4079
4080 // Keeps track of the load indices we see. We shouldn't see any indices twice.
4081 SmallSet<int64_t, 8> SeenIdx;
4082
4083 // Ensure each load is in the same MBB.
4084 // TODO: Support multiple MachineBasicBlocks.
4085 MachineBasicBlock *MBB = nullptr;
4086 const MachineMemOperand *MMO = nullptr;
4087
4088 // Earliest instruction-order load in the pattern.
4089 GZExtLoad *EarliestLoad = nullptr;
4090
4091 // Latest instruction-order load in the pattern.
4092 GZExtLoad *LatestLoad = nullptr;
4093
4094 // Base pointer which every load should share.
4095 Register BasePtr;
4096
4097 // We want to find a load for each register. Each load should have some
4098 // appropriate bit twiddling arithmetic. During this loop, we will also keep
4099 // track of the load which uses the lowest index. Later, we will check if we
4100 // can use its pointer in the final, combined load.
4101 for (auto Reg : RegsToVisit) {
4102 // Find the load, and find the position that it will end up in (e.g. a
4103 // shifted) value.
4104 auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI);
4105 if (!LoadAndPos)
4106 return std::nullopt;
4107 GZExtLoad *Load;
4108 int64_t DstPos;
4109 std::tie(args&: Load, args&: DstPos) = *LoadAndPos;
4110
4111 // TODO: Handle multiple MachineBasicBlocks. Currently not handled because
4112 // it is difficult to check for stores/calls/etc between loads.
4113 MachineBasicBlock *LoadMBB = Load->getParent();
4114 if (!MBB)
4115 MBB = LoadMBB;
4116 if (LoadMBB != MBB)
4117 return std::nullopt;
4118
4119 // Make sure that the MachineMemOperands of every seen load are compatible.
4120 auto &LoadMMO = Load->getMMO();
4121 if (!MMO)
4122 MMO = &LoadMMO;
4123 if (MMO->getAddrSpace() != LoadMMO.getAddrSpace())
4124 return std::nullopt;
4125
4126 // Find out what the base pointer and index for the load is.
4127 Register LoadPtr;
4128 int64_t Idx;
4129 if (!mi_match(R: Load->getOperand(i: 1).getReg(), MRI,
4130 P: m_GPtrAdd(L: m_Reg(R&: LoadPtr), R: m_ICst(Cst&: Idx)))) {
4131 LoadPtr = Load->getOperand(i: 1).getReg();
4132 Idx = 0;
4133 }
4134
4135 // Don't combine things like a[i], a[i] -> a bigger load.
4136 if (!SeenIdx.insert(V: Idx).second)
4137 return std::nullopt;
4138
4139 // Every load must share the same base pointer; don't combine things like:
4140 //
4141 // a[i], b[i + 1] -> a bigger load.
4142 if (!BasePtr.isValid())
4143 BasePtr = LoadPtr;
4144 if (BasePtr != LoadPtr)
4145 return std::nullopt;
4146
4147 if (Idx < LowestIdx) {
4148 LowestIdx = Idx;
4149 LowestIdxLoad = Load;
4150 }
4151
4152 // Keep track of the byte offset that this load ends up at. If we have seen
4153 // the byte offset, then stop here. We do not want to combine:
4154 //
4155 // a[i] << 16, a[i + k] << 16 -> a bigger load.
4156 if (!MemOffset2Idx.try_emplace(Key: DstPos, Args&: Idx).second)
4157 return std::nullopt;
4158 Loads.insert(X: Load);
4159
4160 // Keep track of the position of the earliest/latest loads in the pattern.
4161 // We will check that there are no load fold barriers between them later
4162 // on.
4163 //
4164 // FIXME: Is there a better way to check for load fold barriers?
4165 if (!EarliestLoad || dominates(DefMI: *Load, UseMI: *EarliestLoad))
4166 EarliestLoad = Load;
4167 if (!LatestLoad || dominates(DefMI: *LatestLoad, UseMI: *Load))
4168 LatestLoad = Load;
4169 }
4170
4171 // We found a load for each register. Let's check if each load satisfies the
4172 // pattern.
4173 assert(Loads.size() == RegsToVisit.size() &&
4174 "Expected to find a load for each register?");
4175 assert(EarliestLoad != LatestLoad && EarliestLoad &&
4176 LatestLoad && "Expected at least two loads?");
4177
4178 // Check if there are any stores, calls, etc. between any of the loads. If
4179 // there are, then we can't safely perform the combine.
4180 //
4181 // MaxIter is chosen based off the (worst case) number of iterations it
4182 // typically takes to succeed in the LLVM test suite plus some padding.
4183 //
4184 // FIXME: Is there a better way to check for load fold barriers?
4185 const unsigned MaxIter = 20;
4186 unsigned Iter = 0;
4187 for (const auto &MI : instructionsWithoutDebug(It: EarliestLoad->getIterator(),
4188 End: LatestLoad->getIterator())) {
4189 if (Loads.count(key: &MI))
4190 continue;
4191 if (MI.isLoadFoldBarrier())
4192 return std::nullopt;
4193 if (Iter++ == MaxIter)
4194 return std::nullopt;
4195 }
4196
4197 return std::make_tuple(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad);
4198}
4199
4200bool CombinerHelper::matchLoadOrCombine(
4201 MachineInstr &MI,
4202 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4203 assert(MI.getOpcode() == TargetOpcode::G_OR);
4204 MachineFunction &MF = *MI.getMF();
4205 // Assuming a little-endian target, transform:
4206 // s8 *a = ...
4207 // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
4208 // =>
4209 // s32 val = *((i32)a)
4210 //
4211 // s8 *a = ...
4212 // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
4213 // =>
4214 // s32 val = BSWAP(*((s32)a))
4215 Register Dst = MI.getOperand(i: 0).getReg();
4216 LLT Ty = MRI.getType(Reg: Dst);
4217 if (Ty.isVector())
4218 return false;
4219
4220 // We need to combine at least two loads into this type. Since the smallest
4221 // possible load is into a byte, we need at least a 16-bit wide type.
4222 const unsigned WideMemSizeInBits = Ty.getSizeInBits();
4223 if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0)
4224 return false;
4225
4226 // Match a collection of non-OR instructions in the pattern.
4227 auto RegsToVisit = findCandidatesForLoadOrCombine(Root: &MI);
4228 if (!RegsToVisit)
4229 return false;
4230
4231 // We have a collection of non-OR instructions. Figure out how wide each of
4232 // the small loads should be based off of the number of potential loads we
4233 // found.
4234 const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size();
4235 if (NarrowMemSizeInBits % 8 != 0)
4236 return false;
4237
4238 // Check if each register feeding into each OR is a load from the same
4239 // base pointer + some arithmetic.
4240 //
4241 // e.g. a[0], a[1] << 8, a[2] << 16, etc.
4242 //
4243 // Also verify that each of these ends up putting a[i] into the same memory
4244 // offset as a load into a wide type would.
4245 SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx;
4246 GZExtLoad *LowestIdxLoad, *LatestLoad;
4247 int64_t LowestIdx;
4248 auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine(
4249 MemOffset2Idx, RegsToVisit: *RegsToVisit, MemSizeInBits: NarrowMemSizeInBits);
4250 if (!MaybeLoadInfo)
4251 return false;
4252 std::tie(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad) = *MaybeLoadInfo;
4253
4254 // We have a bunch of loads being OR'd together. Using the addresses + offsets
4255 // we found before, check if this corresponds to a big or little endian byte
4256 // pattern. If it does, then we can represent it using a load + possibly a
4257 // BSWAP.
4258 bool IsBigEndianTarget = MF.getDataLayout().isBigEndian();
4259 std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx);
4260 if (!IsBigEndian)
4261 return false;
4262 bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian;
4263 if (NeedsBSwap && !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_BSWAP, {Ty}}))
4264 return false;
4265
4266 // Make sure that the load from the lowest index produces offset 0 in the
4267 // final value.
4268 //
4269 // This ensures that we won't combine something like this:
4270 //
4271 // load x[i] -> byte 2
4272 // load x[i+1] -> byte 0 ---> wide_load x[i]
4273 // load x[i+2] -> byte 1
4274 const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits;
4275 const unsigned ZeroByteOffset =
4276 *IsBigEndian
4277 ? bigEndianByteAt(ByteWidth: NumLoadsInTy, I: 0)
4278 : littleEndianByteAt(ByteWidth: NumLoadsInTy, I: 0);
4279 auto ZeroOffsetIdx = MemOffset2Idx.find(Val: ZeroByteOffset);
4280 if (ZeroOffsetIdx == MemOffset2Idx.end() ||
4281 ZeroOffsetIdx->second != LowestIdx)
4282 return false;
4283
4284 // We wil reuse the pointer from the load which ends up at byte offset 0. It
4285 // may not use index 0.
4286 Register Ptr = LowestIdxLoad->getPointerReg();
4287 const MachineMemOperand &MMO = LowestIdxLoad->getMMO();
4288 LegalityQuery::MemDesc MMDesc(MMO);
4289 MMDesc.MemoryTy = Ty;
4290 if (!isLegalOrBeforeLegalizer(
4291 Query: {TargetOpcode::G_LOAD, {Ty, MRI.getType(Reg: Ptr)}, {MMDesc}}))
4292 return false;
4293 auto PtrInfo = MMO.getPointerInfo();
4294 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: WideMemSizeInBits / 8);
4295
4296 // Load must be allowed and fast on the target.
4297 LLVMContext &C = MF.getFunction().getContext();
4298 auto &DL = MF.getDataLayout();
4299 unsigned Fast = 0;
4300 if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty, MMO: *NewMMO, Fast: &Fast) ||
4301 !Fast)
4302 return false;
4303
4304 MatchInfo = [=](MachineIRBuilder &MIB) {
4305 MIB.setInstrAndDebugLoc(*LatestLoad);
4306 Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(VReg: Dst) : Dst;
4307 MIB.buildLoad(Res: LoadDst, Addr: Ptr, MMO&: *NewMMO);
4308 if (NeedsBSwap)
4309 MIB.buildBSwap(Dst, Src0: LoadDst);
4310 };
4311 return true;
4312}
4313
4314bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI,
4315 MachineInstr *&ExtMI) const {
4316 auto &PHI = cast<GPhi>(Val&: MI);
4317 Register DstReg = PHI.getReg(Idx: 0);
4318
4319 // TODO: Extending a vector may be expensive, don't do this until heuristics
4320 // are better.
4321 if (MRI.getType(Reg: DstReg).isVector())
4322 return false;
4323
4324 // Try to match a phi, whose only use is an extend.
4325 if (!MRI.hasOneNonDBGUse(RegNo: DstReg))
4326 return false;
4327 ExtMI = &*MRI.use_instr_nodbg_begin(RegNo: DstReg);
4328 switch (ExtMI->getOpcode()) {
4329 case TargetOpcode::G_ANYEXT:
4330 return true; // G_ANYEXT is usually free.
4331 case TargetOpcode::G_ZEXT:
4332 case TargetOpcode::G_SEXT:
4333 break;
4334 default:
4335 return false;
4336 }
4337
4338 // If the target is likely to fold this extend away, don't propagate.
4339 if (Builder.getTII().isExtendLikelyToBeFolded(ExtMI&: *ExtMI, MRI))
4340 return false;
4341
4342 // We don't want to propagate the extends unless there's a good chance that
4343 // they'll be optimized in some way.
4344 // Collect the unique incoming values.
4345 SmallPtrSet<MachineInstr *, 4> InSrcs;
4346 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
4347 auto *DefMI = getDefIgnoringCopies(Reg: PHI.getIncomingValue(I), MRI);
4348 switch (DefMI->getOpcode()) {
4349 case TargetOpcode::G_LOAD:
4350 case TargetOpcode::G_TRUNC:
4351 case TargetOpcode::G_SEXT:
4352 case TargetOpcode::G_ZEXT:
4353 case TargetOpcode::G_ANYEXT:
4354 case TargetOpcode::G_CONSTANT:
4355 InSrcs.insert(Ptr: DefMI);
4356 // Don't try to propagate if there are too many places to create new
4357 // extends, chances are it'll increase code size.
4358 if (InSrcs.size() > 2)
4359 return false;
4360 break;
4361 default:
4362 return false;
4363 }
4364 }
4365 return true;
4366}
4367
4368void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI,
4369 MachineInstr *&ExtMI) const {
4370 auto &PHI = cast<GPhi>(Val&: MI);
4371 Register DstReg = ExtMI->getOperand(i: 0).getReg();
4372 LLT ExtTy = MRI.getType(Reg: DstReg);
4373
4374 // Propagate the extension into the block of each incoming reg's block.
4375 // Use a SetVector here because PHIs can have duplicate edges, and we want
4376 // deterministic iteration order.
4377 SmallSetVector<MachineInstr *, 8> SrcMIs;
4378 SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap;
4379 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
4380 auto SrcReg = PHI.getIncomingValue(I);
4381 auto *SrcMI = MRI.getVRegDef(Reg: SrcReg);
4382 if (!SrcMIs.insert(X: SrcMI))
4383 continue;
4384
4385 // Build an extend after each src inst.
4386 auto *MBB = SrcMI->getParent();
4387 MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator();
4388 if (InsertPt != MBB->end() && InsertPt->isPHI())
4389 InsertPt = MBB->getFirstNonPHI();
4390
4391 Builder.setInsertPt(MBB&: *SrcMI->getParent(), II: InsertPt);
4392 Builder.setDebugLoc(MI.getDebugLoc());
4393 auto NewExt = Builder.buildExtOrTrunc(ExtOpc: ExtMI->getOpcode(), Res: ExtTy, Op: SrcReg);
4394 OldToNewSrcMap[SrcMI] = NewExt;
4395 }
4396
4397 // Create a new phi with the extended inputs.
4398 Builder.setInstrAndDebugLoc(MI);
4399 auto NewPhi = Builder.buildInstrNoInsert(Opcode: TargetOpcode::G_PHI);
4400 NewPhi.addDef(RegNo: DstReg);
4401 for (const MachineOperand &MO : llvm::drop_begin(RangeOrContainer: MI.operands())) {
4402 if (!MO.isReg()) {
4403 NewPhi.addMBB(MBB: MO.getMBB());
4404 continue;
4405 }
4406 auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(Reg: MO.getReg())];
4407 NewPhi.addUse(RegNo: NewSrc->getOperand(i: 0).getReg());
4408 }
4409 Builder.insertInstr(MIB: NewPhi);
4410 ExtMI->eraseFromParent();
4411}
4412
4413bool CombinerHelper::matchExtractVecEltBuildVec(MachineInstr &MI,
4414 Register &Reg) const {
4415 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
4416 // If we have a constant index, look for a G_BUILD_VECTOR source
4417 // and find the source register that the index maps to.
4418 Register SrcVec = MI.getOperand(i: 1).getReg();
4419 LLT SrcTy = MRI.getType(Reg: SrcVec);
4420 if (SrcTy.isScalableVector())
4421 return false;
4422
4423 auto Cst = getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
4424 if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements())
4425 return false;
4426
4427 unsigned VecIdx = Cst->Value.getZExtValue();
4428
4429 // Check if we have a build_vector or build_vector_trunc with an optional
4430 // trunc in front.
4431 MachineInstr *SrcVecMI = MRI.getVRegDef(Reg: SrcVec);
4432 if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) {
4433 SrcVecMI = MRI.getVRegDef(Reg: SrcVecMI->getOperand(i: 1).getReg());
4434 }
4435
4436 if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR &&
4437 SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC)
4438 return false;
4439
4440 EVT Ty(getMVTForLLT(Ty: SrcTy));
4441 if (!MRI.hasOneNonDBGUse(RegNo: SrcVec) &&
4442 !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty))
4443 return false;
4444
4445 Reg = SrcVecMI->getOperand(i: VecIdx + 1).getReg();
4446 return true;
4447}
4448
4449void CombinerHelper::applyExtractVecEltBuildVec(MachineInstr &MI,
4450 Register &Reg) const {
4451 // Check the type of the register, since it may have come from a
4452 // G_BUILD_VECTOR_TRUNC.
4453 LLT ScalarTy = MRI.getType(Reg);
4454 Register DstReg = MI.getOperand(i: 0).getReg();
4455 LLT DstTy = MRI.getType(Reg: DstReg);
4456
4457 if (ScalarTy != DstTy) {
4458 assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits());
4459 Builder.buildTrunc(Res: DstReg, Op: Reg);
4460 MI.eraseFromParent();
4461 return;
4462 }
4463 replaceSingleDefInstWithReg(MI, Replacement: Reg);
4464}
4465
4466bool CombinerHelper::matchExtractAllEltsFromBuildVector(
4467 MachineInstr &MI,
4468 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const {
4469 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4470 // This combine tries to find build_vector's which have every source element
4471 // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like
4472 // the masked load scalarization is run late in the pipeline. There's already
4473 // a combine for a similar pattern starting from the extract, but that
4474 // doesn't attempt to do it if there are multiple uses of the build_vector,
4475 // which in this case is true. Starting the combine from the build_vector
4476 // feels more natural than trying to find sibling nodes of extracts.
4477 // E.g.
4478 // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4
4479 // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0
4480 // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1
4481 // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2
4482 // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3
4483 // ==>
4484 // replace ext{1,2,3,4} with %s{1,2,3,4}
4485
4486 Register DstReg = MI.getOperand(i: 0).getReg();
4487 LLT DstTy = MRI.getType(Reg: DstReg);
4488 unsigned NumElts = DstTy.getNumElements();
4489
4490 SmallBitVector ExtractedElts(NumElts);
4491 for (MachineInstr &II : MRI.use_nodbg_instructions(Reg: DstReg)) {
4492 if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT)
4493 return false;
4494 auto Cst = getIConstantVRegVal(VReg: II.getOperand(i: 2).getReg(), MRI);
4495 if (!Cst)
4496 return false;
4497 unsigned Idx = Cst->getZExtValue();
4498 if (Idx >= NumElts)
4499 return false; // Out of range.
4500 ExtractedElts.set(Idx);
4501 SrcDstPairs.emplace_back(
4502 Args: std::make_pair(x: MI.getOperand(i: Idx + 1).getReg(), y: &II));
4503 }
4504 // Match if every element was extracted.
4505 return ExtractedElts.all();
4506}
4507
4508void CombinerHelper::applyExtractAllEltsFromBuildVector(
4509 MachineInstr &MI,
4510 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const {
4511 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4512 for (auto &Pair : SrcDstPairs) {
4513 auto *ExtMI = Pair.second;
4514 replaceRegWith(MRI, FromReg: ExtMI->getOperand(i: 0).getReg(), ToReg: Pair.first);
4515 ExtMI->eraseFromParent();
4516 }
4517 MI.eraseFromParent();
4518}
4519
4520void CombinerHelper::applyBuildFn(
4521 MachineInstr &MI,
4522 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4523 applyBuildFnNoErase(MI, MatchInfo);
4524 MI.eraseFromParent();
4525}
4526
4527void CombinerHelper::applyBuildFnNoErase(
4528 MachineInstr &MI,
4529 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4530 MatchInfo(Builder);
4531}
4532
4533bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
4534 bool AllowScalarConstants,
4535 BuildFnTy &MatchInfo) const {
4536 assert(MI.getOpcode() == TargetOpcode::G_OR);
4537
4538 Register Dst = MI.getOperand(i: 0).getReg();
4539 LLT Ty = MRI.getType(Reg: Dst);
4540 unsigned BitWidth = Ty.getScalarSizeInBits();
4541
4542 Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt;
4543 unsigned FshOpc = 0;
4544
4545 // Match (or (shl ...), (lshr ...)).
4546 if (!mi_match(R: Dst, MRI,
4547 // m_GOr() handles the commuted version as well.
4548 P: m_GOr(L: m_GShl(L: m_Reg(R&: ShlSrc), R: m_Reg(R&: ShlAmt)),
4549 R: m_GLShr(L: m_Reg(R&: LShrSrc), R: m_Reg(R&: LShrAmt)))))
4550 return false;
4551
4552 // Given constants C0 and C1 such that C0 + C1 is bit-width:
4553 // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1)
4554 int64_t CstShlAmt = 0, CstLShrAmt;
4555 if (mi_match(R: ShlAmt, MRI, P: m_ICstOrSplat(Cst&: CstShlAmt)) &&
4556 mi_match(R: LShrAmt, MRI, P: m_ICstOrSplat(Cst&: CstLShrAmt)) &&
4557 CstShlAmt + CstLShrAmt == BitWidth) {
4558 FshOpc = TargetOpcode::G_FSHR;
4559 Amt = LShrAmt;
4560 } else if (mi_match(R: LShrAmt, MRI,
4561 P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) &&
4562 ShlAmt == Amt) {
4563 // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt)
4564 FshOpc = TargetOpcode::G_FSHL;
4565 } else if (mi_match(R: ShlAmt, MRI,
4566 P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) &&
4567 LShrAmt == Amt) {
4568 // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt)
4569 FshOpc = TargetOpcode::G_FSHR;
4570 } else {
4571 return false;
4572 }
4573
4574 LLT AmtTy = MRI.getType(Reg: Amt);
4575 if (!isLegalOrBeforeLegalizer(Query: {FshOpc, {Ty, AmtTy}}) &&
4576 (!AllowScalarConstants || CstShlAmt == 0 || !Ty.isScalar()))
4577 return false;
4578
4579 MatchInfo = [=](MachineIRBuilder &B) {
4580 B.buildInstr(Opc: FshOpc, DstOps: {Dst}, SrcOps: {ShlSrc, LShrSrc, Amt});
4581 };
4582 return true;
4583}
4584
4585/// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate.
4586bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) const {
4587 unsigned Opc = MI.getOpcode();
4588 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4589 Register X = MI.getOperand(i: 1).getReg();
4590 Register Y = MI.getOperand(i: 2).getReg();
4591 if (X != Y)
4592 return false;
4593 unsigned RotateOpc =
4594 Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR;
4595 return isLegalOrBeforeLegalizer(Query: {RotateOpc, {MRI.getType(Reg: X), MRI.getType(Reg: Y)}});
4596}
4597
4598void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) const {
4599 unsigned Opc = MI.getOpcode();
4600 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4601 bool IsFSHL = Opc == TargetOpcode::G_FSHL;
4602 Observer.changingInstr(MI);
4603 MI.setDesc(Builder.getTII().get(Opcode: IsFSHL ? TargetOpcode::G_ROTL
4604 : TargetOpcode::G_ROTR));
4605 MI.removeOperand(OpNo: 2);
4606 Observer.changedInstr(MI);
4607}
4608
4609// Fold (rot x, c) -> (rot x, c % BitSize)
4610bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) const {
4611 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4612 MI.getOpcode() == TargetOpcode::G_ROTR);
4613 unsigned Bitsize =
4614 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits();
4615 Register AmtReg = MI.getOperand(i: 2).getReg();
4616 bool OutOfRange = false;
4617 auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) {
4618 if (auto *CI = dyn_cast<ConstantInt>(Val: C))
4619 OutOfRange |= CI->getValue().uge(RHS: Bitsize);
4620 return true;
4621 };
4622 return matchUnaryPredicate(MRI, Reg: AmtReg, Match: MatchOutOfRange) && OutOfRange;
4623}
4624
4625void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) const {
4626 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4627 MI.getOpcode() == TargetOpcode::G_ROTR);
4628 unsigned Bitsize =
4629 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits();
4630 Register Amt = MI.getOperand(i: 2).getReg();
4631 LLT AmtTy = MRI.getType(Reg: Amt);
4632 auto Bits = Builder.buildConstant(Res: AmtTy, Val: Bitsize);
4633 Amt = Builder.buildURem(Dst: AmtTy, Src0: MI.getOperand(i: 2).getReg(), Src1: Bits).getReg(Idx: 0);
4634 Observer.changingInstr(MI);
4635 MI.getOperand(i: 2).setReg(Amt);
4636 Observer.changedInstr(MI);
4637}
4638
4639bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI,
4640 int64_t &MatchInfo) const {
4641 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4642 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
4643
4644 // We want to avoid calling KnownBits on the LHS if possible, as this combine
4645 // has no filter and runs on every G_ICMP instruction. We can avoid calling
4646 // KnownBits on the LHS in two cases:
4647 //
4648 // - The RHS is unknown: Constants are always on RHS. If the RHS is unknown
4649 // we cannot do any transforms so we can safely bail out early.
4650 // - The RHS is zero: we don't need to know the LHS to do unsigned <0 and
4651 // >=0.
4652 auto KnownRHS = VT->getKnownBits(R: MI.getOperand(i: 3).getReg());
4653 if (KnownRHS.isUnknown())
4654 return false;
4655
4656 std::optional<bool> KnownVal;
4657 if (KnownRHS.isZero()) {
4658 // ? uge 0 -> always true
4659 // ? ult 0 -> always false
4660 if (Pred == CmpInst::ICMP_UGE)
4661 KnownVal = true;
4662 else if (Pred == CmpInst::ICMP_ULT)
4663 KnownVal = false;
4664 }
4665
4666 if (!KnownVal) {
4667 auto KnownLHS = VT->getKnownBits(R: MI.getOperand(i: 2).getReg());
4668 KnownVal = ICmpInst::compare(LHS: KnownLHS, RHS: KnownRHS, Pred);
4669 }
4670
4671 if (!KnownVal)
4672 return false;
4673 MatchInfo =
4674 *KnownVal
4675 ? getICmpTrueVal(TLI: getTargetLowering(),
4676 /*IsVector = */
4677 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector(),
4678 /* IsFP = */ false)
4679 : 0;
4680 return true;
4681}
4682
4683bool CombinerHelper::matchICmpToLHSKnownBits(
4684 MachineInstr &MI,
4685 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4686 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4687 // Given:
4688 //
4689 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4690 // %cmp = G_ICMP ne %x, 0
4691 //
4692 // Or:
4693 //
4694 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4695 // %cmp = G_ICMP eq %x, 1
4696 //
4697 // We can replace %cmp with %x assuming true is 1 on the target.
4698 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
4699 if (!CmpInst::isEquality(pred: Pred))
4700 return false;
4701 Register Dst = MI.getOperand(i: 0).getReg();
4702 LLT DstTy = MRI.getType(Reg: Dst);
4703 if (getICmpTrueVal(TLI: getTargetLowering(), IsVector: DstTy.isVector(),
4704 /* IsFP = */ false) != 1)
4705 return false;
4706 int64_t OneOrZero = Pred == CmpInst::ICMP_EQ;
4707 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICst(RequestedValue: OneOrZero)))
4708 return false;
4709 Register LHS = MI.getOperand(i: 2).getReg();
4710 auto KnownLHS = VT->getKnownBits(R: LHS);
4711 if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1)
4712 return false;
4713 // Make sure replacing Dst with the LHS is a legal operation.
4714 LLT LHSTy = MRI.getType(Reg: LHS);
4715 unsigned LHSSize = LHSTy.getSizeInBits();
4716 unsigned DstSize = DstTy.getSizeInBits();
4717 unsigned Op = TargetOpcode::COPY;
4718 if (DstSize != LHSSize)
4719 Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT;
4720 if (!isLegalOrBeforeLegalizer(Query: {Op, {DstTy, LHSTy}}))
4721 return false;
4722 MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Opc: Op, DstOps: {Dst}, SrcOps: {LHS}); };
4723 return true;
4724}
4725
4726// Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0
4727bool CombinerHelper::matchAndOrDisjointMask(
4728 MachineInstr &MI,
4729 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4730 assert(MI.getOpcode() == TargetOpcode::G_AND);
4731
4732 // Ignore vector types to simplify matching the two constants.
4733 // TODO: do this for vectors and scalars via a demanded bits analysis.
4734 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
4735 if (Ty.isVector())
4736 return false;
4737
4738 Register Src;
4739 Register AndMaskReg;
4740 int64_t AndMaskBits;
4741 int64_t OrMaskBits;
4742 if (!mi_match(MI, MRI,
4743 P: m_GAnd(L: m_GOr(L: m_Reg(R&: Src), R: m_ICst(Cst&: OrMaskBits)),
4744 R: m_all_of(preds: m_ICst(Cst&: AndMaskBits), preds: m_Reg(R&: AndMaskReg)))))
4745 return false;
4746
4747 // Check if OrMask could turn on any bits in Src.
4748 if (AndMaskBits & OrMaskBits)
4749 return false;
4750
4751 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4752 Observer.changingInstr(MI);
4753 // Canonicalize the result to have the constant on the RHS.
4754 if (MI.getOperand(i: 1).getReg() == AndMaskReg)
4755 MI.getOperand(i: 2).setReg(AndMaskReg);
4756 MI.getOperand(i: 1).setReg(Src);
4757 Observer.changedInstr(MI);
4758 };
4759 return true;
4760}
4761
4762/// Form a G_SBFX from a G_SEXT_INREG fed by a right shift.
4763bool CombinerHelper::matchBitfieldExtractFromSExtInReg(
4764 MachineInstr &MI,
4765 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4766 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
4767 Register Dst = MI.getOperand(i: 0).getReg();
4768 Register Src = MI.getOperand(i: 1).getReg();
4769 LLT Ty = MRI.getType(Reg: Src);
4770 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4771 if (!LI || !LI->isLegalOrCustom(Query: {TargetOpcode::G_SBFX, {Ty, ExtractTy}}))
4772 return false;
4773 int64_t Width = MI.getOperand(i: 2).getImm();
4774 Register ShiftSrc;
4775 int64_t ShiftImm;
4776 if (!mi_match(
4777 R: Src, MRI,
4778 P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)),
4779 preds: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm))))))
4780 return false;
4781 if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits())
4782 return false;
4783
4784 MatchInfo = [=](MachineIRBuilder &B) {
4785 auto Cst1 = B.buildConstant(Res: ExtractTy, Val: ShiftImm);
4786 auto Cst2 = B.buildConstant(Res: ExtractTy, Val: Width);
4787 B.buildSbfx(Dst, Src: ShiftSrc, LSB: Cst1, Width: Cst2);
4788 };
4789 return true;
4790}
4791
4792/// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants.
4793bool CombinerHelper::matchBitfieldExtractFromAnd(MachineInstr &MI,
4794 BuildFnTy &MatchInfo) const {
4795 GAnd *And = cast<GAnd>(Val: &MI);
4796 Register Dst = And->getReg(Idx: 0);
4797 LLT Ty = MRI.getType(Reg: Dst);
4798 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4799 // Note that isLegalOrBeforeLegalizer is stricter and does not take custom
4800 // into account.
4801 if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4802 return false;
4803
4804 int64_t AndImm, LSBImm;
4805 Register ShiftSrc;
4806 const unsigned Size = Ty.getScalarSizeInBits();
4807 if (!mi_match(R: And->getReg(Idx: 0), MRI,
4808 P: m_GAnd(L: m_OneNonDBGUse(SP: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: LSBImm))),
4809 R: m_ICst(Cst&: AndImm))))
4810 return false;
4811
4812 // The mask is a mask of the low bits iff imm & (imm+1) == 0.
4813 auto MaybeMask = static_cast<uint64_t>(AndImm);
4814 if (MaybeMask & (MaybeMask + 1))
4815 return false;
4816
4817 // LSB must fit within the register.
4818 if (static_cast<uint64_t>(LSBImm) >= Size)
4819 return false;
4820
4821 uint64_t Width = APInt(Size, AndImm).countr_one();
4822 MatchInfo = [=](MachineIRBuilder &B) {
4823 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4824 auto LSBCst = B.buildConstant(Res: ExtractTy, Val: LSBImm);
4825 B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {ShiftSrc, LSBCst, WidthCst});
4826 };
4827 return true;
4828}
4829
4830bool CombinerHelper::matchBitfieldExtractFromShr(
4831 MachineInstr &MI,
4832 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4833 const unsigned Opcode = MI.getOpcode();
4834 assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR);
4835
4836 const Register Dst = MI.getOperand(i: 0).getReg();
4837
4838 const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR
4839 ? TargetOpcode::G_SBFX
4840 : TargetOpcode::G_UBFX;
4841
4842 // Check if the type we would use for the extract is legal
4843 LLT Ty = MRI.getType(Reg: Dst);
4844 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4845 if (!LI || !LI->isLegalOrCustom(Query: {ExtrOpcode, {Ty, ExtractTy}}))
4846 return false;
4847
4848 Register ShlSrc;
4849 int64_t ShrAmt;
4850 int64_t ShlAmt;
4851 const unsigned Size = Ty.getScalarSizeInBits();
4852
4853 // Try to match shr (shl x, c1), c2
4854 if (!mi_match(R: Dst, MRI,
4855 P: m_BinOp(Opcode,
4856 L: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: ShlSrc), R: m_ICst(Cst&: ShlAmt))),
4857 R: m_ICst(Cst&: ShrAmt))))
4858 return false;
4859
4860 // Make sure that the shift sizes can fit a bitfield extract
4861 if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size)
4862 return false;
4863
4864 // Skip this combine if the G_SEXT_INREG combine could handle it
4865 if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt)
4866 return false;
4867
4868 // Calculate start position and width of the extract
4869 const int64_t Pos = ShrAmt - ShlAmt;
4870 const int64_t Width = Size - ShrAmt;
4871
4872 MatchInfo = [=](MachineIRBuilder &B) {
4873 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4874 auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos);
4875 B.buildInstr(Opc: ExtrOpcode, DstOps: {Dst}, SrcOps: {ShlSrc, PosCst, WidthCst});
4876 };
4877 return true;
4878}
4879
4880bool CombinerHelper::matchBitfieldExtractFromShrAnd(
4881 MachineInstr &MI,
4882 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4883 const unsigned Opcode = MI.getOpcode();
4884 assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR);
4885
4886 const Register Dst = MI.getOperand(i: 0).getReg();
4887 LLT Ty = MRI.getType(Reg: Dst);
4888 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4889 if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4890 return false;
4891
4892 // Try to match shr (and x, c1), c2
4893 Register AndSrc;
4894 int64_t ShrAmt;
4895 int64_t SMask;
4896 if (!mi_match(R: Dst, MRI,
4897 P: m_BinOp(Opcode,
4898 L: m_OneNonDBGUse(SP: m_GAnd(L: m_Reg(R&: AndSrc), R: m_ICst(Cst&: SMask))),
4899 R: m_ICst(Cst&: ShrAmt))))
4900 return false;
4901
4902 const unsigned Size = Ty.getScalarSizeInBits();
4903 if (ShrAmt < 0 || ShrAmt >= Size)
4904 return false;
4905
4906 // If the shift subsumes the mask, emit the 0 directly.
4907 if (0 == (SMask >> ShrAmt)) {
4908 MatchInfo = [=](MachineIRBuilder &B) {
4909 B.buildConstant(Res: Dst, Val: 0);
4910 };
4911 return true;
4912 }
4913
4914 // Check that ubfx can do the extraction, with no holes in the mask.
4915 uint64_t UMask = SMask;
4916 UMask |= maskTrailingOnes<uint64_t>(N: ShrAmt);
4917 UMask &= maskTrailingOnes<uint64_t>(N: Size);
4918 if (!isMask_64(Value: UMask))
4919 return false;
4920
4921 // Calculate start position and width of the extract.
4922 const int64_t Pos = ShrAmt;
4923 const int64_t Width = llvm::countr_one(Value: UMask) - ShrAmt;
4924
4925 // It's preferable to keep the shift, rather than form G_SBFX.
4926 // TODO: remove the G_AND via demanded bits analysis.
4927 if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size)
4928 return false;
4929
4930 MatchInfo = [=](MachineIRBuilder &B) {
4931 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4932 auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos);
4933 B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {AndSrc, PosCst, WidthCst});
4934 };
4935 return true;
4936}
4937
4938bool CombinerHelper::reassociationCanBreakAddressingModePattern(
4939 MachineInstr &MI) const {
4940 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
4941
4942 Register Src1Reg = PtrAdd.getBaseReg();
4943 auto *Src1Def = getOpcodeDef<GPtrAdd>(Reg: Src1Reg, MRI);
4944 if (!Src1Def)
4945 return false;
4946
4947 Register Src2Reg = PtrAdd.getOffsetReg();
4948
4949 if (MRI.hasOneNonDBGUse(RegNo: Src1Reg))
4950 return false;
4951
4952 auto C1 = getIConstantVRegVal(VReg: Src1Def->getOffsetReg(), MRI);
4953 if (!C1)
4954 return false;
4955 auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI);
4956 if (!C2)
4957 return false;
4958
4959 const APInt &C1APIntVal = *C1;
4960 const APInt &C2APIntVal = *C2;
4961 const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue();
4962
4963 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: PtrAdd.getReg(Idx: 0))) {
4964 // This combine may end up running before ptrtoint/inttoptr combines
4965 // manage to eliminate redundant conversions, so try to look through them.
4966 MachineInstr *ConvUseMI = &UseMI;
4967 unsigned ConvUseOpc = ConvUseMI->getOpcode();
4968 while (ConvUseOpc == TargetOpcode::G_INTTOPTR ||
4969 ConvUseOpc == TargetOpcode::G_PTRTOINT) {
4970 Register DefReg = ConvUseMI->getOperand(i: 0).getReg();
4971 if (!MRI.hasOneNonDBGUse(RegNo: DefReg))
4972 break;
4973 ConvUseMI = &*MRI.use_instr_nodbg_begin(RegNo: DefReg);
4974 ConvUseOpc = ConvUseMI->getOpcode();
4975 }
4976 auto *LdStMI = dyn_cast<GLoadStore>(Val: ConvUseMI);
4977 if (!LdStMI)
4978 continue;
4979 // Is x[offset2] already not a legal addressing mode? If so then
4980 // reassociating the constants breaks nothing (we test offset2 because
4981 // that's the one we hope to fold into the load or store).
4982 TargetLoweringBase::AddrMode AM;
4983 AM.HasBaseReg = true;
4984 AM.BaseOffs = C2APIntVal.getSExtValue();
4985 unsigned AS = MRI.getType(Reg: LdStMI->getPointerReg()).getAddressSpace();
4986 Type *AccessTy = getTypeForLLT(Ty: LdStMI->getMMO().getMemoryType(),
4987 C&: PtrAdd.getMF()->getFunction().getContext());
4988 const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering();
4989 if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM,
4990 Ty: AccessTy, AddrSpace: AS))
4991 continue;
4992
4993 // Would x[offset1+offset2] still be a legal addressing mode?
4994 AM.BaseOffs = CombinedValue;
4995 if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM,
4996 Ty: AccessTy, AddrSpace: AS))
4997 return true;
4998 }
4999
5000 return false;
5001}
5002
5003bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI,
5004 MachineInstr *RHS,
5005 BuildFnTy &MatchInfo) const {
5006 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
5007 Register Src1Reg = MI.getOperand(i: 1).getReg();
5008 if (RHS->getOpcode() != TargetOpcode::G_ADD)
5009 return false;
5010 auto C2 = getIConstantVRegVal(VReg: RHS->getOperand(i: 2).getReg(), MRI);
5011 if (!C2)
5012 return false;
5013
5014 // If both additions are nuw, the reassociated additions are also nuw.
5015 // If the original G_PTR_ADD is additionally nusw, X and C are both not
5016 // negative, so BASE+X is between BASE and BASE+(X+C). The new G_PTR_ADDs are
5017 // therefore also nusw.
5018 // If the original G_PTR_ADD is additionally inbounds (which implies nusw),
5019 // the new G_PTR_ADDs are then also inbounds.
5020 unsigned PtrAddFlags = MI.getFlags();
5021 unsigned AddFlags = RHS->getFlags();
5022 bool IsNoUWrap = PtrAddFlags & AddFlags & MachineInstr::MIFlag::NoUWrap;
5023 bool IsNoUSWrap = IsNoUWrap && (PtrAddFlags & MachineInstr::MIFlag::NoUSWrap);
5024 bool IsInBounds = IsNoUWrap && (PtrAddFlags & MachineInstr::MIFlag::InBounds);
5025 unsigned Flags = 0;
5026 if (IsNoUWrap)
5027 Flags |= MachineInstr::MIFlag::NoUWrap;
5028 if (IsNoUSWrap)
5029 Flags |= MachineInstr::MIFlag::NoUSWrap;
5030 if (IsInBounds)
5031 Flags |= MachineInstr::MIFlag::InBounds;
5032
5033 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5034 LLT PtrTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5035
5036 auto NewBase =
5037 Builder.buildPtrAdd(Res: PtrTy, Op0: Src1Reg, Op1: RHS->getOperand(i: 1).getReg(), Flags);
5038 Observer.changingInstr(MI);
5039 MI.getOperand(i: 1).setReg(NewBase.getReg(Idx: 0));
5040 MI.getOperand(i: 2).setReg(RHS->getOperand(i: 2).getReg());
5041 MI.setFlags(Flags);
5042 Observer.changedInstr(MI);
5043 };
5044 return !reassociationCanBreakAddressingModePattern(MI);
5045}
5046
5047bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI,
5048 MachineInstr *LHS,
5049 MachineInstr *RHS,
5050 BuildFnTy &MatchInfo) const {
5051 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C)
5052 // if and only if (G_PTR_ADD X, C) has one use.
5053 Register LHSBase;
5054 std::optional<ValueAndVReg> LHSCstOff;
5055 if (!mi_match(R: MI.getBaseReg(), MRI,
5056 P: m_OneNonDBGUse(SP: m_GPtrAdd(L: m_Reg(R&: LHSBase), R: m_GCst(ValReg&: LHSCstOff)))))
5057 return false;
5058
5059 auto *LHSPtrAdd = cast<GPtrAdd>(Val: LHS);
5060
5061 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
5062 // nuw and inbounds (which implies nusw), the offsets are both non-negative,
5063 // so the new G_PTR_ADDs are also inbounds.
5064 unsigned PtrAddFlags = MI.getFlags();
5065 unsigned LHSPtrAddFlags = LHSPtrAdd->getFlags();
5066 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
5067 bool IsNoUSWrap = IsNoUWrap && (PtrAddFlags & LHSPtrAddFlags &
5068 MachineInstr::MIFlag::NoUSWrap);
5069 bool IsInBounds = IsNoUWrap && (PtrAddFlags & LHSPtrAddFlags &
5070 MachineInstr::MIFlag::InBounds);
5071 unsigned Flags = 0;
5072 if (IsNoUWrap)
5073 Flags |= MachineInstr::MIFlag::NoUWrap;
5074 if (IsNoUSWrap)
5075 Flags |= MachineInstr::MIFlag::NoUSWrap;
5076 if (IsInBounds)
5077 Flags |= MachineInstr::MIFlag::InBounds;
5078
5079 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5080 // When we change LHSPtrAdd's offset register we might cause it to use a reg
5081 // before its def. Sink the instruction so the outer PTR_ADD to ensure this
5082 // doesn't happen.
5083 LHSPtrAdd->moveBefore(MovePos: &MI);
5084 Register RHSReg = MI.getOffsetReg();
5085 // set VReg will cause type mismatch if it comes from extend/trunc
5086 auto NewCst = B.buildConstant(Res: MRI.getType(Reg: RHSReg), Val: LHSCstOff->Value);
5087 Observer.changingInstr(MI);
5088 MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0));
5089 MI.setFlags(Flags);
5090 Observer.changedInstr(MI);
5091 Observer.changingInstr(MI&: *LHSPtrAdd);
5092 LHSPtrAdd->getOperand(i: 2).setReg(RHSReg);
5093 LHSPtrAdd->setFlags(Flags);
5094 Observer.changedInstr(MI&: *LHSPtrAdd);
5095 };
5096 return !reassociationCanBreakAddressingModePattern(MI);
5097}
5098
5099bool CombinerHelper::matchReassocFoldConstantsInSubTree(
5100 GPtrAdd &MI, MachineInstr *LHS, MachineInstr *RHS,
5101 BuildFnTy &MatchInfo) const {
5102 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
5103 auto *LHSPtrAdd = dyn_cast<GPtrAdd>(Val: LHS);
5104 if (!LHSPtrAdd)
5105 return false;
5106
5107 Register Src2Reg = MI.getOperand(i: 2).getReg();
5108 Register LHSSrc1 = LHSPtrAdd->getBaseReg();
5109 Register LHSSrc2 = LHSPtrAdd->getOffsetReg();
5110 auto C1 = getIConstantVRegVal(VReg: LHSSrc2, MRI);
5111 if (!C1)
5112 return false;
5113 auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI);
5114 if (!C2)
5115 return false;
5116
5117 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
5118 // inbounds, reaching the same result in one G_PTR_ADD is also inbounds.
5119 // The nusw constraints are satisfied because imm1+imm2 cannot exceed the
5120 // largest signed integer that fits into the index type, which is the maximum
5121 // size of allocated objects according to the IR Language Reference.
5122 unsigned PtrAddFlags = MI.getFlags();
5123 unsigned LHSPtrAddFlags = LHSPtrAdd->getFlags();
5124 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
5125 bool IsInBounds =
5126 PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::InBounds;
5127 unsigned Flags = 0;
5128 if (IsNoUWrap)
5129 Flags |= MachineInstr::MIFlag::NoUWrap;
5130 if (IsInBounds) {
5131 Flags |= MachineInstr::MIFlag::InBounds;
5132 Flags |= MachineInstr::MIFlag::NoUSWrap;
5133 }
5134
5135 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5136 auto NewCst = B.buildConstant(Res: MRI.getType(Reg: Src2Reg), Val: *C1 + *C2);
5137 Observer.changingInstr(MI);
5138 MI.getOperand(i: 1).setReg(LHSSrc1);
5139 MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0));
5140 MI.setFlags(Flags);
5141 Observer.changedInstr(MI);
5142 };
5143 return !reassociationCanBreakAddressingModePattern(MI);
5144}
5145
5146bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI,
5147 BuildFnTy &MatchInfo) const {
5148 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
5149 // We're trying to match a few pointer computation patterns here for
5150 // re-association opportunities.
5151 // 1) Isolating a constant operand to be on the RHS, e.g.:
5152 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
5153 //
5154 // 2) Folding two constants in each sub-tree as long as such folding
5155 // doesn't break a legal addressing mode.
5156 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
5157 //
5158 // 3) Move a constant from the LHS of an inner op to the RHS of the outer.
5159 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C)
5160 // iif (G_PTR_ADD X, C) has one use.
5161 MachineInstr *LHS = MRI.getVRegDef(Reg: PtrAdd.getBaseReg());
5162 MachineInstr *RHS = MRI.getVRegDef(Reg: PtrAdd.getOffsetReg());
5163
5164 // Try to match example 2.
5165 if (matchReassocFoldConstantsInSubTree(MI&: PtrAdd, LHS, RHS, MatchInfo))
5166 return true;
5167
5168 // Try to match example 3.
5169 if (matchReassocConstantInnerLHS(MI&: PtrAdd, LHS, RHS, MatchInfo))
5170 return true;
5171
5172 // Try to match example 1.
5173 if (matchReassocConstantInnerRHS(MI&: PtrAdd, RHS, MatchInfo))
5174 return true;
5175
5176 return false;
5177}
5178bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg,
5179 Register OpLHS, Register OpRHS,
5180 BuildFnTy &MatchInfo) const {
5181 LLT OpRHSTy = MRI.getType(Reg: OpRHS);
5182 MachineInstr *OpLHSDef = MRI.getVRegDef(Reg: OpLHS);
5183
5184 if (OpLHSDef->getOpcode() != Opc)
5185 return false;
5186
5187 MachineInstr *OpRHSDef = MRI.getVRegDef(Reg: OpRHS);
5188 Register OpLHSLHS = OpLHSDef->getOperand(i: 1).getReg();
5189 Register OpLHSRHS = OpLHSDef->getOperand(i: 2).getReg();
5190
5191 // If the inner op is (X op C), pull the constant out so it can be folded with
5192 // other constants in the expression tree. Folding is not guaranteed so we
5193 // might have (C1 op C2). In that case do not pull a constant out because it
5194 // won't help and can lead to infinite loops.
5195 if (isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSRHS), MRI) &&
5196 !isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSLHS), MRI)) {
5197 if (isConstantOrConstantSplatVector(MI&: *OpRHSDef, MRI)) {
5198 // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2))
5199 MatchInfo = [=](MachineIRBuilder &B) {
5200 auto NewCst = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSRHS, OpRHS});
5201 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {OpLHSLHS, NewCst});
5202 };
5203 return true;
5204 }
5205 if (getTargetLowering().isReassocProfitable(MRI, N0: OpLHS, N1: OpRHS)) {
5206 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
5207 // iff (op x, c1) has one use
5208 MatchInfo = [=](MachineIRBuilder &B) {
5209 auto NewLHSLHS = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSLHS, OpRHS});
5210 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {NewLHSLHS, OpLHSRHS});
5211 };
5212 return true;
5213 }
5214 }
5215
5216 return false;
5217}
5218
5219bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI,
5220 BuildFnTy &MatchInfo) const {
5221 // We don't check if the reassociation will break a legal addressing mode
5222 // here since pointer arithmetic is handled by G_PTR_ADD.
5223 unsigned Opc = MI.getOpcode();
5224 Register DstReg = MI.getOperand(i: 0).getReg();
5225 Register LHSReg = MI.getOperand(i: 1).getReg();
5226 Register RHSReg = MI.getOperand(i: 2).getReg();
5227
5228 if (tryReassocBinOp(Opc, DstReg, OpLHS: LHSReg, OpRHS: RHSReg, MatchInfo))
5229 return true;
5230 if (tryReassocBinOp(Opc, DstReg, OpLHS: RHSReg, OpRHS: LHSReg, MatchInfo))
5231 return true;
5232 return false;
5233}
5234
5235bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI,
5236 APInt &MatchInfo) const {
5237 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5238 Register SrcOp = MI.getOperand(i: 1).getReg();
5239
5240 if (auto MaybeCst = ConstantFoldCastOp(Opcode: MI.getOpcode(), DstTy, Op0: SrcOp, MRI)) {
5241 MatchInfo = *MaybeCst;
5242 return true;
5243 }
5244
5245 return false;
5246}
5247
5248bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI,
5249 APInt &MatchInfo) const {
5250 Register Op1 = MI.getOperand(i: 1).getReg();
5251 Register Op2 = MI.getOperand(i: 2).getReg();
5252 auto MaybeCst = ConstantFoldBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI);
5253 if (!MaybeCst)
5254 return false;
5255 MatchInfo = *MaybeCst;
5256 return true;
5257}
5258
5259bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI,
5260 ConstantFP *&MatchInfo) const {
5261 Register Op1 = MI.getOperand(i: 1).getReg();
5262 Register Op2 = MI.getOperand(i: 2).getReg();
5263 auto MaybeCst = ConstantFoldFPBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI);
5264 if (!MaybeCst)
5265 return false;
5266 MatchInfo =
5267 ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: *MaybeCst);
5268 return true;
5269}
5270
5271bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI,
5272 ConstantFP *&MatchInfo) const {
5273 assert(MI.getOpcode() == TargetOpcode::G_FMA ||
5274 MI.getOpcode() == TargetOpcode::G_FMAD);
5275 auto [_, Op1, Op2, Op3] = MI.getFirst4Regs();
5276
5277 const ConstantFP *Op3Cst = getConstantFPVRegVal(VReg: Op3, MRI);
5278 if (!Op3Cst)
5279 return false;
5280
5281 const ConstantFP *Op2Cst = getConstantFPVRegVal(VReg: Op2, MRI);
5282 if (!Op2Cst)
5283 return false;
5284
5285 const ConstantFP *Op1Cst = getConstantFPVRegVal(VReg: Op1, MRI);
5286 if (!Op1Cst)
5287 return false;
5288
5289 APFloat Op1F = Op1Cst->getValueAPF();
5290 Op1F.fusedMultiplyAdd(Multiplicand: Op2Cst->getValueAPF(), Addend: Op3Cst->getValueAPF(),
5291 RM: APFloat::rmNearestTiesToEven);
5292 MatchInfo = ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: Op1F);
5293 return true;
5294}
5295
5296bool CombinerHelper::matchNarrowBinopFeedingAnd(
5297 MachineInstr &MI,
5298 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
5299 // Look for a binop feeding into an AND with a mask:
5300 //
5301 // %add = G_ADD %lhs, %rhs
5302 // %and = G_AND %add, 000...11111111
5303 //
5304 // Check if it's possible to perform the binop at a narrower width and zext
5305 // back to the original width like so:
5306 //
5307 // %narrow_lhs = G_TRUNC %lhs
5308 // %narrow_rhs = G_TRUNC %rhs
5309 // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs
5310 // %new_add = G_ZEXT %narrow_add
5311 // %and = G_AND %new_add, 000...11111111
5312 //
5313 // This can allow later combines to eliminate the G_AND if it turns out
5314 // that the mask is irrelevant.
5315 assert(MI.getOpcode() == TargetOpcode::G_AND);
5316 Register Dst = MI.getOperand(i: 0).getReg();
5317 Register AndLHS = MI.getOperand(i: 1).getReg();
5318 Register AndRHS = MI.getOperand(i: 2).getReg();
5319 LLT WideTy = MRI.getType(Reg: Dst);
5320
5321 // If the potential binop has more than one use, then it's possible that one
5322 // of those uses will need its full width.
5323 if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(RegNo: AndLHS))
5324 return false;
5325
5326 // Check if the LHS feeding the AND is impacted by the high bits that we're
5327 // masking out.
5328 //
5329 // e.g. for 64-bit x, y:
5330 //
5331 // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535
5332 MachineInstr *LHSInst = getDefIgnoringCopies(Reg: AndLHS, MRI);
5333 if (!LHSInst)
5334 return false;
5335 unsigned LHSOpc = LHSInst->getOpcode();
5336 switch (LHSOpc) {
5337 default:
5338 return false;
5339 case TargetOpcode::G_ADD:
5340 case TargetOpcode::G_SUB:
5341 case TargetOpcode::G_MUL:
5342 case TargetOpcode::G_AND:
5343 case TargetOpcode::G_OR:
5344 case TargetOpcode::G_XOR:
5345 break;
5346 }
5347
5348 // Find the mask on the RHS.
5349 auto Cst = getIConstantVRegValWithLookThrough(VReg: AndRHS, MRI);
5350 if (!Cst)
5351 return false;
5352 auto Mask = Cst->Value;
5353 if (!Mask.isMask())
5354 return false;
5355
5356 // No point in combining if there's nothing to truncate.
5357 unsigned NarrowWidth = Mask.countr_one();
5358 if (NarrowWidth == WideTy.getSizeInBits())
5359 return false;
5360 LLT NarrowTy = LLT::scalar(SizeInBits: NarrowWidth);
5361
5362 // Check if adding the zext + truncates could be harmful.
5363 auto &MF = *MI.getMF();
5364 const auto &TLI = getTargetLowering();
5365 LLVMContext &Ctx = MF.getFunction().getContext();
5366 if (!TLI.isTruncateFree(FromTy: WideTy, ToTy: NarrowTy, Ctx) ||
5367 !TLI.isZExtFree(FromTy: NarrowTy, ToTy: WideTy, Ctx))
5368 return false;
5369 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) ||
5370 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {WideTy, NarrowTy}}))
5371 return false;
5372 Register BinOpLHS = LHSInst->getOperand(i: 1).getReg();
5373 Register BinOpRHS = LHSInst->getOperand(i: 2).getReg();
5374 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5375 auto NarrowLHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpLHS);
5376 auto NarrowRHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpRHS);
5377 auto NarrowBinOp =
5378 Builder.buildInstr(Opc: LHSOpc, DstOps: {NarrowTy}, SrcOps: {NarrowLHS, NarrowRHS});
5379 auto Ext = Builder.buildZExt(Res: WideTy, Op: NarrowBinOp);
5380 Observer.changingInstr(MI);
5381 MI.getOperand(i: 1).setReg(Ext.getReg(Idx: 0));
5382 Observer.changedInstr(MI);
5383 };
5384 return true;
5385}
5386
5387bool CombinerHelper::matchMulOBy2(MachineInstr &MI,
5388 BuildFnTy &MatchInfo) const {
5389 unsigned Opc = MI.getOpcode();
5390 assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO);
5391
5392 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 2)))
5393 return false;
5394
5395 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5396 Observer.changingInstr(MI);
5397 unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO
5398 : TargetOpcode::G_SADDO;
5399 MI.setDesc(Builder.getTII().get(Opcode: NewOpc));
5400 MI.getOperand(i: 3).setReg(MI.getOperand(i: 2).getReg());
5401 Observer.changedInstr(MI);
5402 };
5403 return true;
5404}
5405
5406bool CombinerHelper::matchMulOBy0(MachineInstr &MI,
5407 BuildFnTy &MatchInfo) const {
5408 // (G_*MULO x, 0) -> 0 + no carry out
5409 assert(MI.getOpcode() == TargetOpcode::G_UMULO ||
5410 MI.getOpcode() == TargetOpcode::G_SMULO);
5411 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0)))
5412 return false;
5413 Register Dst = MI.getOperand(i: 0).getReg();
5414 Register Carry = MI.getOperand(i: 1).getReg();
5415 if (!isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Dst)) ||
5416 !isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Carry)))
5417 return false;
5418 MatchInfo = [=](MachineIRBuilder &B) {
5419 B.buildConstant(Res: Dst, Val: 0);
5420 B.buildConstant(Res: Carry, Val: 0);
5421 };
5422 return true;
5423}
5424
5425bool CombinerHelper::matchAddEToAddO(MachineInstr &MI,
5426 BuildFnTy &MatchInfo) const {
5427 // (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
5428 // (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
5429 assert(MI.getOpcode() == TargetOpcode::G_UADDE ||
5430 MI.getOpcode() == TargetOpcode::G_SADDE ||
5431 MI.getOpcode() == TargetOpcode::G_USUBE ||
5432 MI.getOpcode() == TargetOpcode::G_SSUBE);
5433 if (!mi_match(R: MI.getOperand(i: 4).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0)))
5434 return false;
5435 MatchInfo = [&](MachineIRBuilder &B) {
5436 unsigned NewOpcode;
5437 switch (MI.getOpcode()) {
5438 case TargetOpcode::G_UADDE:
5439 NewOpcode = TargetOpcode::G_UADDO;
5440 break;
5441 case TargetOpcode::G_SADDE:
5442 NewOpcode = TargetOpcode::G_SADDO;
5443 break;
5444 case TargetOpcode::G_USUBE:
5445 NewOpcode = TargetOpcode::G_USUBO;
5446 break;
5447 case TargetOpcode::G_SSUBE:
5448 NewOpcode = TargetOpcode::G_SSUBO;
5449 break;
5450 }
5451 Observer.changingInstr(MI);
5452 MI.setDesc(B.getTII().get(Opcode: NewOpcode));
5453 MI.removeOperand(OpNo: 4);
5454 Observer.changedInstr(MI);
5455 };
5456 return true;
5457}
5458
5459bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI,
5460 BuildFnTy &MatchInfo) const {
5461 assert(MI.getOpcode() == TargetOpcode::G_SUB);
5462 Register Dst = MI.getOperand(i: 0).getReg();
5463 // (x + y) - z -> x (if y == z)
5464 // (x + y) - z -> y (if x == z)
5465 Register X, Y, Z;
5466 if (mi_match(R: Dst, MRI, P: m_GSub(L: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: Y)), R: m_Reg(R&: Z)))) {
5467 Register ReplaceReg;
5468 int64_t CstX, CstY;
5469 if (Y == Z || (mi_match(R: Y, MRI, P: m_ICstOrSplat(Cst&: CstY)) &&
5470 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstY))))
5471 ReplaceReg = X;
5472 else if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5473 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5474 ReplaceReg = Y;
5475 if (ReplaceReg) {
5476 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: ReplaceReg); };
5477 return true;
5478 }
5479 }
5480
5481 // x - (y + z) -> 0 - y (if x == z)
5482 // x - (y + z) -> 0 - z (if x == y)
5483 if (mi_match(R: Dst, MRI, P: m_GSub(L: m_Reg(R&: X), R: m_GAdd(L: m_Reg(R&: Y), R: m_Reg(R&: Z))))) {
5484 Register ReplaceReg;
5485 int64_t CstX;
5486 if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5487 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5488 ReplaceReg = Y;
5489 else if (X == Y || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5490 mi_match(R: Y, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5491 ReplaceReg = Z;
5492 if (ReplaceReg) {
5493 MatchInfo = [=](MachineIRBuilder &B) {
5494 auto Zero = B.buildConstant(Res: MRI.getType(Reg: Dst), Val: 0);
5495 B.buildSub(Dst, Src0: Zero, Src1: ReplaceReg);
5496 };
5497 return true;
5498 }
5499 }
5500 return false;
5501}
5502
5503MachineInstr *CombinerHelper::buildUDivOrURemUsingMul(MachineInstr &MI) const {
5504 unsigned Opcode = MI.getOpcode();
5505 assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM);
5506 auto &UDivorRem = cast<GenericMachineInstr>(Val&: MI);
5507 Register Dst = UDivorRem.getReg(Idx: 0);
5508 Register LHS = UDivorRem.getReg(Idx: 1);
5509 Register RHS = UDivorRem.getReg(Idx: 2);
5510 LLT Ty = MRI.getType(Reg: Dst);
5511 LLT ScalarTy = Ty.getScalarType();
5512 const unsigned EltBits = ScalarTy.getScalarSizeInBits();
5513 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5514 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5515
5516 auto &MIB = Builder;
5517
5518 bool UseSRL = false;
5519 SmallVector<Register, 16> Shifts, Factors;
5520 auto *RHSDefInstr = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI));
5521 bool IsSplat = getIConstantSplatVal(MI: *RHSDefInstr, MRI).has_value();
5522
5523 auto BuildExactUDIVPattern = [&](const Constant *C) {
5524 // Don't recompute inverses for each splat element.
5525 if (IsSplat && !Factors.empty()) {
5526 Shifts.push_back(Elt: Shifts[0]);
5527 Factors.push_back(Elt: Factors[0]);
5528 return true;
5529 }
5530
5531 auto *CI = cast<ConstantInt>(Val: C);
5532 APInt Divisor = CI->getValue();
5533 unsigned Shift = Divisor.countr_zero();
5534 if (Shift) {
5535 Divisor.lshrInPlace(ShiftAmt: Shift);
5536 UseSRL = true;
5537 }
5538
5539 // Calculate the multiplicative inverse modulo BW.
5540 APInt Factor = Divisor.multiplicativeInverse();
5541 Shifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0));
5542 Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0));
5543 return true;
5544 };
5545
5546 if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5547 // Collect all magic values from the build vector.
5548 if (!matchUnaryPredicate(MRI, Reg: RHS, Match: BuildExactUDIVPattern))
5549 llvm_unreachable("Expected unary predicate match to succeed");
5550
5551 Register Shift, Factor;
5552 if (Ty.isVector()) {
5553 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0);
5554 Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0);
5555 } else {
5556 Shift = Shifts[0];
5557 Factor = Factors[0];
5558 }
5559
5560 Register Res = LHS;
5561
5562 if (UseSRL)
5563 Res = MIB.buildLShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0);
5564
5565 return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor);
5566 }
5567
5568 unsigned KnownLeadingZeros =
5569 VT ? VT->getKnownBits(R: LHS).countMinLeadingZeros() : 0;
5570
5571 bool UseNPQ = false;
5572 SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
5573 auto BuildUDIVPattern = [&](const Constant *C) {
5574 auto *CI = cast<ConstantInt>(Val: C);
5575 const APInt &Divisor = CI->getValue();
5576
5577 bool SelNPQ = false;
5578 APInt Magic(Divisor.getBitWidth(), 0);
5579 unsigned PreShift = 0, PostShift = 0;
5580
5581 // Magic algorithm doesn't work for division by 1. We need to emit a select
5582 // at the end.
5583 // TODO: Use undef values for divisor of 1.
5584 if (!Divisor.isOne()) {
5585
5586 // UnsignedDivisionByConstantInfo doesn't work correctly if leading zeros
5587 // in the dividend exceeds the leading zeros for the divisor.
5588 UnsignedDivisionByConstantInfo magics =
5589 UnsignedDivisionByConstantInfo::get(
5590 D: Divisor, LeadingZeros: std::min(a: KnownLeadingZeros, b: Divisor.countl_zero()));
5591
5592 Magic = std::move(magics.Magic);
5593
5594 assert(magics.PreShift < Divisor.getBitWidth() &&
5595 "We shouldn't generate an undefined shift!");
5596 assert(magics.PostShift < Divisor.getBitWidth() &&
5597 "We shouldn't generate an undefined shift!");
5598 assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift");
5599 PreShift = magics.PreShift;
5600 PostShift = magics.PostShift;
5601 SelNPQ = magics.IsAdd;
5602 }
5603
5604 PreShifts.push_back(
5605 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PreShift).getReg(Idx: 0));
5606 MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magic).getReg(Idx: 0));
5607 NPQFactors.push_back(
5608 Elt: MIB.buildConstant(Res: ScalarTy,
5609 Val: SelNPQ ? APInt::getOneBitSet(numBits: EltBits, BitNo: EltBits - 1)
5610 : APInt::getZero(numBits: EltBits))
5611 .getReg(Idx: 0));
5612 PostShifts.push_back(
5613 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PostShift).getReg(Idx: 0));
5614 UseNPQ |= SelNPQ;
5615 return true;
5616 };
5617
5618 // Collect the shifts/magic values from each element.
5619 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildUDIVPattern);
5620 (void)Matched;
5621 assert(Matched && "Expected unary predicate match to succeed");
5622
5623 Register PreShift, PostShift, MagicFactor, NPQFactor;
5624 auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI);
5625 if (RHSDef) {
5626 PreShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PreShifts).getReg(Idx: 0);
5627 MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0);
5628 NPQFactor = MIB.buildBuildVector(Res: Ty, Ops: NPQFactors).getReg(Idx: 0);
5629 PostShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PostShifts).getReg(Idx: 0);
5630 } else {
5631 assert(MRI.getType(RHS).isScalar() &&
5632 "Non-build_vector operation should have been a scalar");
5633 PreShift = PreShifts[0];
5634 MagicFactor = MagicFactors[0];
5635 PostShift = PostShifts[0];
5636 }
5637
5638 Register Q = LHS;
5639 Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PreShift).getReg(Idx: 0);
5640
5641 // Multiply the numerator (operand 0) by the magic value.
5642 Q = MIB.buildUMulH(Dst: Ty, Src0: Q, Src1: MagicFactor).getReg(Idx: 0);
5643
5644 if (UseNPQ) {
5645 Register NPQ = MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Q).getReg(Idx: 0);
5646
5647 // For vectors we might have a mix of non-NPQ/NPQ paths, so use
5648 // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero.
5649 if (Ty.isVector())
5650 NPQ = MIB.buildUMulH(Dst: Ty, Src0: NPQ, Src1: NPQFactor).getReg(Idx: 0);
5651 else
5652 NPQ = MIB.buildLShr(Dst: Ty, Src0: NPQ, Src1: MIB.buildConstant(Res: ShiftAmtTy, Val: 1)).getReg(Idx: 0);
5653
5654 Q = MIB.buildAdd(Dst: Ty, Src0: NPQ, Src1: Q).getReg(Idx: 0);
5655 }
5656
5657 Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PostShift).getReg(Idx: 0);
5658 auto One = MIB.buildConstant(Res: Ty, Val: 1);
5659 auto IsOne = MIB.buildICmp(
5660 Pred: CmpInst::Predicate::ICMP_EQ,
5661 Res: Ty.isScalar() ? LLT::scalar(SizeInBits: 1) : Ty.changeElementSize(NewEltSize: 1), Op0: RHS, Op1: One);
5662 auto ret = MIB.buildSelect(Res: Ty, Tst: IsOne, Op0: LHS, Op1: Q);
5663
5664 if (Opcode == TargetOpcode::G_UREM) {
5665 auto Prod = MIB.buildMul(Dst: Ty, Src0: ret, Src1: RHS);
5666 return MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Prod);
5667 }
5668 return ret;
5669}
5670
5671bool CombinerHelper::matchUDivOrURemByConst(MachineInstr &MI) const {
5672 unsigned Opcode = MI.getOpcode();
5673 assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM);
5674 Register Dst = MI.getOperand(i: 0).getReg();
5675 Register RHS = MI.getOperand(i: 2).getReg();
5676 LLT DstTy = MRI.getType(Reg: Dst);
5677
5678 auto &MF = *MI.getMF();
5679 AttributeList Attr = MF.getFunction().getAttributes();
5680 const auto &TLI = getTargetLowering();
5681 LLVMContext &Ctx = MF.getFunction().getContext();
5682 if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, Ctx), Attr))
5683 return false;
5684
5685 // Don't do this for minsize because the instruction sequence is usually
5686 // larger.
5687 if (MF.getFunction().hasMinSize())
5688 return false;
5689
5690 if (Opcode == TargetOpcode::G_UDIV &&
5691 MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5692 return matchUnaryPredicate(
5693 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5694 }
5695
5696 auto *RHSDef = MRI.getVRegDef(Reg: RHS);
5697 if (!isConstantOrConstantVector(MI&: *RHSDef, MRI))
5698 return false;
5699
5700 // Don't do this if the types are not going to be legal.
5701 if (LI) {
5702 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}}))
5703 return false;
5704 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMULH, {DstTy}}))
5705 return false;
5706 if (!isLegalOrBeforeLegalizer(
5707 Query: {TargetOpcode::G_ICMP,
5708 {DstTy.isVector() ? DstTy.changeElementSize(NewEltSize: 1) : LLT::scalar(SizeInBits: 1),
5709 DstTy}}))
5710 return false;
5711 if (Opcode == TargetOpcode::G_UREM &&
5712 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy, DstTy}}))
5713 return false;
5714 }
5715
5716 return matchUnaryPredicate(
5717 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5718}
5719
5720void CombinerHelper::applyUDivOrURemByConst(MachineInstr &MI) const {
5721 auto *NewMI = buildUDivOrURemUsingMul(MI);
5722 replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg());
5723}
5724
5725bool CombinerHelper::matchSDivOrSRemByConst(MachineInstr &MI) const {
5726 unsigned Opcode = MI.getOpcode();
5727 assert(Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM);
5728 Register Dst = MI.getOperand(i: 0).getReg();
5729 Register RHS = MI.getOperand(i: 2).getReg();
5730 LLT DstTy = MRI.getType(Reg: Dst);
5731 auto SizeInBits = DstTy.getScalarSizeInBits();
5732 LLT WideTy = DstTy.changeElementSize(NewEltSize: SizeInBits * 2);
5733
5734 auto &MF = *MI.getMF();
5735 AttributeList Attr = MF.getFunction().getAttributes();
5736 const auto &TLI = getTargetLowering();
5737 LLVMContext &Ctx = MF.getFunction().getContext();
5738 if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, Ctx), Attr))
5739 return false;
5740
5741 // Don't do this for minsize because the instruction sequence is usually
5742 // larger.
5743 if (MF.getFunction().hasMinSize())
5744 return false;
5745
5746 // If the sdiv has an 'exact' flag we can use a simpler lowering.
5747 if (Opcode == TargetOpcode::G_SDIV &&
5748 MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5749 return matchUnaryPredicate(
5750 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5751 }
5752
5753 auto *RHSDef = MRI.getVRegDef(Reg: RHS);
5754 if (!isConstantOrConstantVector(MI&: *RHSDef, MRI))
5755 return false;
5756
5757 // Don't do this if the types are not going to be legal.
5758 if (LI) {
5759 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}}))
5760 return false;
5761 if (!isLegal(Query: {TargetOpcode::G_SMULH, {DstTy}}) &&
5762 !isLegalOrHasWidenScalar(Query: {TargetOpcode::G_MUL, {WideTy, WideTy}}))
5763 return false;
5764 if (Opcode == TargetOpcode::G_SREM &&
5765 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy, DstTy}}))
5766 return false;
5767 }
5768
5769 return matchUnaryPredicate(
5770 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5771}
5772
5773void CombinerHelper::applySDivOrSRemByConst(MachineInstr &MI) const {
5774 auto *NewMI = buildSDivOrSRemUsingMul(MI);
5775 replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg());
5776}
5777
5778MachineInstr *CombinerHelper::buildSDivOrSRemUsingMul(MachineInstr &MI) const {
5779 unsigned Opcode = MI.getOpcode();
5780 assert(MI.getOpcode() == TargetOpcode::G_SDIV ||
5781 Opcode == TargetOpcode::G_SREM);
5782 auto &SDivorRem = cast<GenericMachineInstr>(Val&: MI);
5783 Register Dst = SDivorRem.getReg(Idx: 0);
5784 Register LHS = SDivorRem.getReg(Idx: 1);
5785 Register RHS = SDivorRem.getReg(Idx: 2);
5786 LLT Ty = MRI.getType(Reg: Dst);
5787 LLT ScalarTy = Ty.getScalarType();
5788 const unsigned EltBits = ScalarTy.getScalarSizeInBits();
5789 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5790 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5791 auto &MIB = Builder;
5792
5793 bool UseSRA = false;
5794 SmallVector<Register, 16> ExactShifts, ExactFactors;
5795
5796 auto *RHSDefInstr = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI));
5797 bool IsSplat = getIConstantSplatVal(MI: *RHSDefInstr, MRI).has_value();
5798
5799 auto BuildExactSDIVPattern = [&](const Constant *C) {
5800 // Don't recompute inverses for each splat element.
5801 if (IsSplat && !ExactFactors.empty()) {
5802 ExactShifts.push_back(Elt: ExactShifts[0]);
5803 ExactFactors.push_back(Elt: ExactFactors[0]);
5804 return true;
5805 }
5806
5807 auto *CI = cast<ConstantInt>(Val: C);
5808 APInt Divisor = CI->getValue();
5809 unsigned Shift = Divisor.countr_zero();
5810 if (Shift) {
5811 Divisor.ashrInPlace(ShiftAmt: Shift);
5812 UseSRA = true;
5813 }
5814
5815 // Calculate the multiplicative inverse modulo BW.
5816 // 2^W requires W + 1 bits, so we have to extend and then truncate.
5817 APInt Factor = Divisor.multiplicativeInverse();
5818 ExactShifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0));
5819 ExactFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0));
5820 return true;
5821 };
5822
5823 if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5824 // Collect all magic values from the build vector.
5825 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildExactSDIVPattern);
5826 (void)Matched;
5827 assert(Matched && "Expected unary predicate match to succeed");
5828
5829 Register Shift, Factor;
5830 if (Ty.isVector()) {
5831 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: ExactShifts).getReg(Idx: 0);
5832 Factor = MIB.buildBuildVector(Res: Ty, Ops: ExactFactors).getReg(Idx: 0);
5833 } else {
5834 Shift = ExactShifts[0];
5835 Factor = ExactFactors[0];
5836 }
5837
5838 Register Res = LHS;
5839
5840 if (UseSRA)
5841 Res = MIB.buildAShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0);
5842
5843 return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor);
5844 }
5845
5846 SmallVector<Register, 16> MagicFactors, Factors, Shifts, ShiftMasks;
5847
5848 auto BuildSDIVPattern = [&](const Constant *C) {
5849 auto *CI = cast<ConstantInt>(Val: C);
5850 const APInt &Divisor = CI->getValue();
5851
5852 SignedDivisionByConstantInfo Magics =
5853 SignedDivisionByConstantInfo::get(D: Divisor);
5854 int NumeratorFactor = 0;
5855 int ShiftMask = -1;
5856
5857 if (Divisor.isOne() || Divisor.isAllOnes()) {
5858 // If d is +1/-1, we just multiply the numerator by +1/-1.
5859 NumeratorFactor = Divisor.getSExtValue();
5860 Magics.Magic = 0;
5861 Magics.ShiftAmount = 0;
5862 ShiftMask = 0;
5863 } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) {
5864 // If d > 0 and m < 0, add the numerator.
5865 NumeratorFactor = 1;
5866 } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) {
5867 // If d < 0 and m > 0, subtract the numerator.
5868 NumeratorFactor = -1;
5869 }
5870
5871 MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magics.Magic).getReg(Idx: 0));
5872 Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: NumeratorFactor).getReg(Idx: 0));
5873 Shifts.push_back(
5874 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Magics.ShiftAmount).getReg(Idx: 0));
5875 ShiftMasks.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: ShiftMask).getReg(Idx: 0));
5876
5877 return true;
5878 };
5879
5880 // Collect the shifts/magic values from each element.
5881 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildSDIVPattern);
5882 (void)Matched;
5883 assert(Matched && "Expected unary predicate match to succeed");
5884
5885 Register MagicFactor, Factor, Shift, ShiftMask;
5886 auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI);
5887 if (RHSDef) {
5888 MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0);
5889 Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0);
5890 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0);
5891 ShiftMask = MIB.buildBuildVector(Res: Ty, Ops: ShiftMasks).getReg(Idx: 0);
5892 } else {
5893 assert(MRI.getType(RHS).isScalar() &&
5894 "Non-build_vector operation should have been a scalar");
5895 MagicFactor = MagicFactors[0];
5896 Factor = Factors[0];
5897 Shift = Shifts[0];
5898 ShiftMask = ShiftMasks[0];
5899 }
5900
5901 Register Q = LHS;
5902 Q = MIB.buildSMulH(Dst: Ty, Src0: LHS, Src1: MagicFactor).getReg(Idx: 0);
5903
5904 // (Optionally) Add/subtract the numerator using Factor.
5905 Factor = MIB.buildMul(Dst: Ty, Src0: LHS, Src1: Factor).getReg(Idx: 0);
5906 Q = MIB.buildAdd(Dst: Ty, Src0: Q, Src1: Factor).getReg(Idx: 0);
5907
5908 // Shift right algebraic by shift value.
5909 Q = MIB.buildAShr(Dst: Ty, Src0: Q, Src1: Shift).getReg(Idx: 0);
5910
5911 // Extract the sign bit, mask it and add it to the quotient.
5912 auto SignShift = MIB.buildConstant(Res: ShiftAmtTy, Val: EltBits - 1);
5913 auto T = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: SignShift);
5914 T = MIB.buildAnd(Dst: Ty, Src0: T, Src1: ShiftMask);
5915 auto ret = MIB.buildAdd(Dst: Ty, Src0: Q, Src1: T);
5916
5917 if (Opcode == TargetOpcode::G_SREM) {
5918 auto Prod = MIB.buildMul(Dst: Ty, Src0: ret, Src1: RHS);
5919 return MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Prod);
5920 }
5921 return ret;
5922}
5923
5924bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const {
5925 assert((MI.getOpcode() == TargetOpcode::G_SDIV ||
5926 MI.getOpcode() == TargetOpcode::G_UDIV) &&
5927 "Expected SDIV or UDIV");
5928 auto &Div = cast<GenericMachineInstr>(Val&: MI);
5929 Register RHS = Div.getReg(Idx: 2);
5930 auto MatchPow2 = [&](const Constant *C) {
5931 auto *CI = dyn_cast<ConstantInt>(Val: C);
5932 return CI && (CI->getValue().isPowerOf2() ||
5933 (IsSigned && CI->getValue().isNegatedPowerOf2()));
5934 };
5935 return matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2, /*AllowUndefs=*/false);
5936}
5937
5938void CombinerHelper::applySDivByPow2(MachineInstr &MI) const {
5939 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5940 auto &SDiv = cast<GenericMachineInstr>(Val&: MI);
5941 Register Dst = SDiv.getReg(Idx: 0);
5942 Register LHS = SDiv.getReg(Idx: 1);
5943 Register RHS = SDiv.getReg(Idx: 2);
5944 LLT Ty = MRI.getType(Reg: Dst);
5945 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5946 LLT CCVT =
5947 Ty.isVector() ? LLT::vector(EC: Ty.getElementCount(), ScalarSizeInBits: 1) : LLT::scalar(SizeInBits: 1);
5948
5949 // Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2,
5950 // to the following version:
5951 //
5952 // %c1 = G_CTTZ %rhs
5953 // %inexact = G_SUB $bitwidth, %c1
5954 // %sign = %G_ASHR %lhs, $(bitwidth - 1)
5955 // %lshr = G_LSHR %sign, %inexact
5956 // %add = G_ADD %lhs, %lshr
5957 // %ashr = G_ASHR %add, %c1
5958 // %ashr = G_SELECT, %isoneorallones, %lhs, %ashr
5959 // %zero = G_CONSTANT $0
5960 // %neg = G_NEG %ashr
5961 // %isneg = G_ICMP SLT %rhs, %zero
5962 // %res = G_SELECT %isneg, %neg, %ashr
5963
5964 unsigned BitWidth = Ty.getScalarSizeInBits();
5965 auto Zero = Builder.buildConstant(Res: Ty, Val: 0);
5966
5967 auto Bits = Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth);
5968 auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS);
5969 auto Inexact = Builder.buildSub(Dst: ShiftAmtTy, Src0: Bits, Src1: C1);
5970 // Splat the sign bit into the register
5971 auto Sign = Builder.buildAShr(
5972 Dst: Ty, Src0: LHS, Src1: Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth - 1));
5973
5974 // Add (LHS < 0) ? abs2 - 1 : 0;
5975 auto LSrl = Builder.buildLShr(Dst: Ty, Src0: Sign, Src1: Inexact);
5976 auto Add = Builder.buildAdd(Dst: Ty, Src0: LHS, Src1: LSrl);
5977 auto AShr = Builder.buildAShr(Dst: Ty, Src0: Add, Src1: C1);
5978
5979 // Special case: (sdiv X, 1) -> X
5980 // Special Case: (sdiv X, -1) -> 0-X
5981 auto One = Builder.buildConstant(Res: Ty, Val: 1);
5982 auto MinusOne = Builder.buildConstant(Res: Ty, Val: -1);
5983 auto IsOne = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: One);
5984 auto IsMinusOne =
5985 Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: MinusOne);
5986 auto IsOneOrMinusOne = Builder.buildOr(Dst: CCVT, Src0: IsOne, Src1: IsMinusOne);
5987 AShr = Builder.buildSelect(Res: Ty, Tst: IsOneOrMinusOne, Op0: LHS, Op1: AShr);
5988
5989 // If divided by a positive value, we're done. Otherwise, the result must be
5990 // negated.
5991 auto Neg = Builder.buildNeg(Dst: Ty, Src0: AShr);
5992 auto IsNeg = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_SLT, Res: CCVT, Op0: RHS, Op1: Zero);
5993 Builder.buildSelect(Res: MI.getOperand(i: 0).getReg(), Tst: IsNeg, Op0: Neg, Op1: AShr);
5994 MI.eraseFromParent();
5995}
5996
5997void CombinerHelper::applyUDivByPow2(MachineInstr &MI) const {
5998 assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV");
5999 auto &UDiv = cast<GenericMachineInstr>(Val&: MI);
6000 Register Dst = UDiv.getReg(Idx: 0);
6001 Register LHS = UDiv.getReg(Idx: 1);
6002 Register RHS = UDiv.getReg(Idx: 2);
6003 LLT Ty = MRI.getType(Reg: Dst);
6004 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6005
6006 auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS);
6007 Builder.buildLShr(Dst: MI.getOperand(i: 0).getReg(), Src0: LHS, Src1: C1);
6008 MI.eraseFromParent();
6009}
6010
6011bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) const {
6012 assert(MI.getOpcode() == TargetOpcode::G_UMULH);
6013 Register RHS = MI.getOperand(i: 2).getReg();
6014 Register Dst = MI.getOperand(i: 0).getReg();
6015 LLT Ty = MRI.getType(Reg: Dst);
6016 LLT RHSTy = MRI.getType(Reg: RHS);
6017 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6018 auto MatchPow2ExceptOne = [&](const Constant *C) {
6019 if (auto *CI = dyn_cast<ConstantInt>(Val: C))
6020 return CI->getValue().isPowerOf2() && !CI->getValue().isOne();
6021 return false;
6022 };
6023 if (!matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2ExceptOne, AllowUndefs: false))
6024 return false;
6025 // We need to check both G_LSHR and G_CTLZ because the combine uses G_CTLZ to
6026 // get log base 2, and it is not always legal for on a target.
6027 return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}}) &&
6028 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CTLZ, {RHSTy, RHSTy}});
6029}
6030
6031void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) const {
6032 Register LHS = MI.getOperand(i: 1).getReg();
6033 Register RHS = MI.getOperand(i: 2).getReg();
6034 Register Dst = MI.getOperand(i: 0).getReg();
6035 LLT Ty = MRI.getType(Reg: Dst);
6036 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6037 unsigned NumEltBits = Ty.getScalarSizeInBits();
6038
6039 auto LogBase2 = buildLogBase2(V: RHS, MIB&: Builder);
6040 auto ShiftAmt =
6041 Builder.buildSub(Dst: Ty, Src0: Builder.buildConstant(Res: Ty, Val: NumEltBits), Src1: LogBase2);
6042 auto Trunc = Builder.buildZExtOrTrunc(Res: ShiftAmtTy, Op: ShiftAmt);
6043 Builder.buildLShr(Dst, Src0: LHS, Src1: Trunc);
6044 MI.eraseFromParent();
6045}
6046
6047bool CombinerHelper::matchTruncSSatS(MachineInstr &MI,
6048 Register &MatchInfo) const {
6049 Register Dst = MI.getOperand(i: 0).getReg();
6050 Register Src = MI.getOperand(i: 1).getReg();
6051 LLT DstTy = MRI.getType(Reg: Dst);
6052 LLT SrcTy = MRI.getType(Reg: Src);
6053 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6054 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6055 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6056
6057 if (!LI || !isLegalOrHasFewerElements(
6058 Query: {TargetOpcode::G_TRUNC_SSAT_S, {DstTy, SrcTy}}))
6059 return false;
6060
6061 APInt SignedMax = APInt::getSignedMaxValue(numBits: NumDstBits).sext(width: NumSrcBits);
6062 APInt SignedMin = APInt::getSignedMinValue(numBits: NumDstBits).sext(width: NumSrcBits);
6063 return mi_match(R: Src, MRI,
6064 P: m_GSMin(L: m_GSMax(L: m_Reg(R&: MatchInfo),
6065 R: m_SpecificICstOrSplat(RequestedValue: SignedMin)),
6066 R: m_SpecificICstOrSplat(RequestedValue: SignedMax))) ||
6067 mi_match(R: Src, MRI,
6068 P: m_GSMax(L: m_GSMin(L: m_Reg(R&: MatchInfo),
6069 R: m_SpecificICstOrSplat(RequestedValue: SignedMax)),
6070 R: m_SpecificICstOrSplat(RequestedValue: SignedMin)));
6071}
6072
6073void CombinerHelper::applyTruncSSatS(MachineInstr &MI,
6074 Register &MatchInfo) const {
6075 Register Dst = MI.getOperand(i: 0).getReg();
6076 Builder.buildTruncSSatS(Res: Dst, Op: MatchInfo);
6077 MI.eraseFromParent();
6078}
6079
6080bool CombinerHelper::matchTruncSSatU(MachineInstr &MI,
6081 Register &MatchInfo) const {
6082 Register Dst = MI.getOperand(i: 0).getReg();
6083 Register Src = MI.getOperand(i: 1).getReg();
6084 LLT DstTy = MRI.getType(Reg: Dst);
6085 LLT SrcTy = MRI.getType(Reg: Src);
6086 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6087 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6088 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6089
6090 if (!LI || !isLegalOrHasFewerElements(
6091 Query: {TargetOpcode::G_TRUNC_SSAT_U, {DstTy, SrcTy}}))
6092 return false;
6093 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
6094 return mi_match(R: Src, MRI,
6095 P: m_GSMin(L: m_GSMax(L: m_Reg(R&: MatchInfo), R: m_SpecificICstOrSplat(RequestedValue: 0)),
6096 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax))) ||
6097 mi_match(R: Src, MRI,
6098 P: m_GSMax(L: m_GSMin(L: m_Reg(R&: MatchInfo),
6099 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)),
6100 R: m_SpecificICstOrSplat(RequestedValue: 0))) ||
6101 mi_match(R: Src, MRI,
6102 P: m_GUMin(L: m_GSMax(L: m_Reg(R&: MatchInfo), R: m_SpecificICstOrSplat(RequestedValue: 0)),
6103 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)));
6104}
6105
6106void CombinerHelper::applyTruncSSatU(MachineInstr &MI,
6107 Register &MatchInfo) const {
6108 Register Dst = MI.getOperand(i: 0).getReg();
6109 Builder.buildTruncSSatU(Res: Dst, Op: MatchInfo);
6110 MI.eraseFromParent();
6111}
6112
6113bool CombinerHelper::matchTruncUSatU(MachineInstr &MI,
6114 MachineInstr &MinMI) const {
6115 Register Min = MinMI.getOperand(i: 2).getReg();
6116 Register Val = MinMI.getOperand(i: 1).getReg();
6117 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6118 LLT SrcTy = MRI.getType(Reg: Val);
6119 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6120 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6121 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6122
6123 if (!LI || !isLegalOrHasFewerElements(
6124 Query: {TargetOpcode::G_TRUNC_SSAT_U, {DstTy, SrcTy}}))
6125 return false;
6126 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
6127 return mi_match(R: Min, MRI, P: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)) &&
6128 !mi_match(R: Val, MRI, P: m_GSMax(L: m_Reg(), R: m_Reg()));
6129}
6130
6131bool CombinerHelper::matchTruncUSatUToFPTOUISat(MachineInstr &MI,
6132 MachineInstr &SrcMI) const {
6133 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6134 LLT SrcTy = MRI.getType(Reg: SrcMI.getOperand(i: 1).getReg());
6135
6136 return LI &&
6137 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FPTOUI_SAT, {DstTy, SrcTy}});
6138}
6139
6140bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI,
6141 BuildFnTy &MatchInfo) const {
6142 unsigned Opc = MI.getOpcode();
6143 assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB ||
6144 Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
6145 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA);
6146
6147 Register Dst = MI.getOperand(i: 0).getReg();
6148 Register X = MI.getOperand(i: 1).getReg();
6149 Register Y = MI.getOperand(i: 2).getReg();
6150 LLT Type = MRI.getType(Reg: Dst);
6151
6152 // fold (fadd x, fneg(y)) -> (fsub x, y)
6153 // fold (fadd fneg(y), x) -> (fsub x, y)
6154 // G_ADD is commutative so both cases are checked by m_GFAdd
6155 if (mi_match(R: Dst, MRI, P: m_GFAdd(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) &&
6156 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FSUB, {Type}})) {
6157 Opc = TargetOpcode::G_FSUB;
6158 }
6159 /// fold (fsub x, fneg(y)) -> (fadd x, y)
6160 else if (mi_match(R: Dst, MRI, P: m_GFSub(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) &&
6161 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FADD, {Type}})) {
6162 Opc = TargetOpcode::G_FADD;
6163 }
6164 // fold (fmul fneg(x), fneg(y)) -> (fmul x, y)
6165 // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y)
6166 // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z)
6167 // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z)
6168 else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
6169 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) &&
6170 mi_match(R: X, MRI, P: m_GFNeg(Src: m_Reg(R&: X))) &&
6171 mi_match(R: Y, MRI, P: m_GFNeg(Src: m_Reg(R&: Y)))) {
6172 // no opcode change
6173 } else
6174 return false;
6175
6176 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6177 Observer.changingInstr(MI);
6178 MI.setDesc(B.getTII().get(Opcode: Opc));
6179 MI.getOperand(i: 1).setReg(X);
6180 MI.getOperand(i: 2).setReg(Y);
6181 Observer.changedInstr(MI);
6182 };
6183 return true;
6184}
6185
6186bool CombinerHelper::matchFsubToFneg(MachineInstr &MI,
6187 Register &MatchInfo) const {
6188 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6189
6190 Register LHS = MI.getOperand(i: 1).getReg();
6191 MatchInfo = MI.getOperand(i: 2).getReg();
6192 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6193
6194 const auto LHSCst = Ty.isVector()
6195 ? getFConstantSplat(VReg: LHS, MRI, /* allowUndef */ AllowUndef: true)
6196 : getFConstantVRegValWithLookThrough(VReg: LHS, MRI);
6197 if (!LHSCst)
6198 return false;
6199
6200 // -0.0 is always allowed
6201 if (LHSCst->Value.isNegZero())
6202 return true;
6203
6204 // +0.0 is only allowed if nsz is set.
6205 if (LHSCst->Value.isPosZero())
6206 return MI.getFlag(Flag: MachineInstr::FmNsz);
6207
6208 return false;
6209}
6210
6211void CombinerHelper::applyFsubToFneg(MachineInstr &MI,
6212 Register &MatchInfo) const {
6213 Register Dst = MI.getOperand(i: 0).getReg();
6214 Builder.buildFNeg(
6215 Dst, Src0: Builder.buildFCanonicalize(Dst: MRI.getType(Reg: Dst), Src0: MatchInfo).getReg(Idx: 0));
6216 eraseInst(MI);
6217}
6218
6219/// Checks if \p MI is TargetOpcode::G_FMUL and contractable either
6220/// due to global flags or MachineInstr flags.
6221static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) {
6222 if (MI.getOpcode() != TargetOpcode::G_FMUL)
6223 return false;
6224 return AllowFusionGlobally || MI.getFlag(Flag: MachineInstr::MIFlag::FmContract);
6225}
6226
6227static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1,
6228 const MachineRegisterInfo &MRI) {
6229 return std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI0.getOperand(i: 0).getReg()),
6230 last: MRI.use_instr_nodbg_end()) >
6231 std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI1.getOperand(i: 0).getReg()),
6232 last: MRI.use_instr_nodbg_end());
6233}
6234
6235bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI,
6236 bool &AllowFusionGlobally,
6237 bool &HasFMAD, bool &Aggressive,
6238 bool CanReassociate) const {
6239
6240 auto *MF = MI.getMF();
6241 const auto &TLI = *MF->getSubtarget().getTargetLowering();
6242 const TargetOptions &Options = MF->getTarget().Options;
6243 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6244
6245 if (CanReassociate && !MI.getFlag(Flag: MachineInstr::MIFlag::FmReassoc))
6246 return false;
6247
6248 // Floating-point multiply-add with intermediate rounding.
6249 HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, Ty: DstType));
6250 // Floating-point multiply-add without intermediate rounding.
6251 bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(MF: *MF, DstType) &&
6252 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FMA, {DstType}});
6253 // No valid opcode, do not combine.
6254 if (!HasFMAD && !HasFMA)
6255 return false;
6256
6257 AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD;
6258 // If the addition is not contractable, do not combine.
6259 if (!AllowFusionGlobally && !MI.getFlag(Flag: MachineInstr::MIFlag::FmContract))
6260 return false;
6261
6262 Aggressive = TLI.enableAggressiveFMAFusion(Ty: DstType);
6263 return true;
6264}
6265
6266bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA(
6267 MachineInstr &MI,
6268 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6269 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6270
6271 bool AllowFusionGlobally, HasFMAD, Aggressive;
6272 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6273 return false;
6274
6275 Register Op1 = MI.getOperand(i: 1).getReg();
6276 Register Op2 = MI.getOperand(i: 2).getReg();
6277 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6278 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6279 unsigned PreferredFusedOpcode =
6280 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6281
6282 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6283 // prefer to fold the multiply with fewer uses.
6284 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6285 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6286 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6287 std::swap(a&: LHS, b&: RHS);
6288 }
6289
6290 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
6291 if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6292 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg))) {
6293 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6294 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6295 SrcOps: {LHS.MI->getOperand(i: 1).getReg(),
6296 LHS.MI->getOperand(i: 2).getReg(), RHS.Reg});
6297 };
6298 return true;
6299 }
6300
6301 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
6302 if (isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6303 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg))) {
6304 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6305 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6306 SrcOps: {RHS.MI->getOperand(i: 1).getReg(),
6307 RHS.MI->getOperand(i: 2).getReg(), LHS.Reg});
6308 };
6309 return true;
6310 }
6311
6312 return false;
6313}
6314
6315bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA(
6316 MachineInstr &MI,
6317 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6318 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6319
6320 bool AllowFusionGlobally, HasFMAD, Aggressive;
6321 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6322 return false;
6323
6324 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6325 Register Op1 = MI.getOperand(i: 1).getReg();
6326 Register Op2 = MI.getOperand(i: 2).getReg();
6327 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6328 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6329 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6330
6331 unsigned PreferredFusedOpcode =
6332 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6333
6334 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6335 // prefer to fold the multiply with fewer uses.
6336 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6337 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6338 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6339 std::swap(a&: LHS, b&: RHS);
6340 }
6341
6342 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
6343 MachineInstr *FpExtSrc;
6344 if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) &&
6345 isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) &&
6346 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6347 SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) {
6348 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6349 auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg());
6350 auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg());
6351 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6352 SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), RHS.Reg});
6353 };
6354 return true;
6355 }
6356
6357 // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z)
6358 // Note: Commutes FADD operands.
6359 if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) &&
6360 isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) &&
6361 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6362 SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) {
6363 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6364 auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg());
6365 auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg());
6366 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6367 SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), LHS.Reg});
6368 };
6369 return true;
6370 }
6371
6372 return false;
6373}
6374
6375bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA(
6376 MachineInstr &MI,
6377 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6378 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6379
6380 bool AllowFusionGlobally, HasFMAD, Aggressive;
6381 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, CanReassociate: true))
6382 return false;
6383
6384 Register Op1 = MI.getOperand(i: 1).getReg();
6385 Register Op2 = MI.getOperand(i: 2).getReg();
6386 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6387 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6388 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6389
6390 unsigned PreferredFusedOpcode =
6391 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6392
6393 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6394 // prefer to fold the multiply with fewer uses.
6395 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6396 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6397 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6398 std::swap(a&: LHS, b&: RHS);
6399 }
6400
6401 MachineInstr *FMA = nullptr;
6402 Register Z;
6403 // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z))
6404 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
6405 (MRI.getVRegDef(Reg: LHS.MI->getOperand(i: 3).getReg())->getOpcode() ==
6406 TargetOpcode::G_FMUL) &&
6407 MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 0).getReg()) &&
6408 MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 3).getReg())) {
6409 FMA = LHS.MI;
6410 Z = RHS.Reg;
6411 }
6412 // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z))
6413 else if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
6414 (MRI.getVRegDef(Reg: RHS.MI->getOperand(i: 3).getReg())->getOpcode() ==
6415 TargetOpcode::G_FMUL) &&
6416 MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 0).getReg()) &&
6417 MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 3).getReg())) {
6418 Z = LHS.Reg;
6419 FMA = RHS.MI;
6420 }
6421
6422 if (FMA) {
6423 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMA->getOperand(i: 3).getReg());
6424 Register X = FMA->getOperand(i: 1).getReg();
6425 Register Y = FMA->getOperand(i: 2).getReg();
6426 Register U = FMulMI->getOperand(i: 1).getReg();
6427 Register V = FMulMI->getOperand(i: 2).getReg();
6428
6429 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6430 Register InnerFMA = MRI.createGenericVirtualRegister(Ty: DstTy);
6431 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {InnerFMA}, SrcOps: {U, V, Z});
6432 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6433 SrcOps: {X, Y, InnerFMA});
6434 };
6435 return true;
6436 }
6437
6438 return false;
6439}
6440
6441bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
6442 MachineInstr &MI,
6443 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6444 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6445
6446 bool AllowFusionGlobally, HasFMAD, Aggressive;
6447 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6448 return false;
6449
6450 if (!Aggressive)
6451 return false;
6452
6453 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6454 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6455 Register Op1 = MI.getOperand(i: 1).getReg();
6456 Register Op2 = MI.getOperand(i: 2).getReg();
6457 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6458 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6459
6460 unsigned PreferredFusedOpcode =
6461 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6462
6463 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6464 // prefer to fold the multiply with fewer uses.
6465 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6466 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6467 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6468 std::swap(a&: LHS, b&: RHS);
6469 }
6470
6471 // Builds: (fma x, y, (fma (fpext u), (fpext v), z))
6472 auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X,
6473 Register Y, MachineIRBuilder &B) {
6474 Register FpExtU = B.buildFPExt(Res: DstType, Op: U).getReg(Idx: 0);
6475 Register FpExtV = B.buildFPExt(Res: DstType, Op: V).getReg(Idx: 0);
6476 Register InnerFMA =
6477 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {DstType}, SrcOps: {FpExtU, FpExtV, Z})
6478 .getReg(Idx: 0);
6479 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6480 SrcOps: {X, Y, InnerFMA});
6481 };
6482
6483 MachineInstr *FMulMI, *FMAMI;
6484 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
6485 // -> (fma x, y, (fma (fpext u), (fpext v), z))
6486 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
6487 mi_match(R: LHS.MI->getOperand(i: 3).getReg(), MRI,
6488 P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6489 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6490 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6491 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6492 MatchInfo = [=](MachineIRBuilder &B) {
6493 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6494 FMulMI->getOperand(i: 2).getReg(), RHS.Reg,
6495 LHS.MI->getOperand(i: 1).getReg(),
6496 LHS.MI->getOperand(i: 2).getReg(), B);
6497 };
6498 return true;
6499 }
6500
6501 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
6502 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
6503 // FIXME: This turns two single-precision and one double-precision
6504 // operation into two double-precision operations, which might not be
6505 // interesting for all targets, especially GPUs.
6506 if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) &&
6507 FMAMI->getOpcode() == PreferredFusedOpcode) {
6508 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg());
6509 if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6510 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6511 SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) {
6512 MatchInfo = [=](MachineIRBuilder &B) {
6513 Register X = FMAMI->getOperand(i: 1).getReg();
6514 Register Y = FMAMI->getOperand(i: 2).getReg();
6515 X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0);
6516 Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0);
6517 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6518 FMulMI->getOperand(i: 2).getReg(), RHS.Reg, X, Y, B);
6519 };
6520
6521 return true;
6522 }
6523 }
6524
6525 // fold (fadd z, (fma x, y, (fpext (fmul u, v)))
6526 // -> (fma x, y, (fma (fpext u), (fpext v), z))
6527 if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
6528 mi_match(R: RHS.MI->getOperand(i: 3).getReg(), MRI,
6529 P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6530 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6531 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6532 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6533 MatchInfo = [=](MachineIRBuilder &B) {
6534 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6535 FMulMI->getOperand(i: 2).getReg(), LHS.Reg,
6536 RHS.MI->getOperand(i: 1).getReg(),
6537 RHS.MI->getOperand(i: 2).getReg(), B);
6538 };
6539 return true;
6540 }
6541
6542 // fold (fadd z, (fpext (fma x, y, (fmul u, v)))
6543 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
6544 // FIXME: This turns two single-precision and one double-precision
6545 // operation into two double-precision operations, which might not be
6546 // interesting for all targets, especially GPUs.
6547 if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) &&
6548 FMAMI->getOpcode() == PreferredFusedOpcode) {
6549 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg());
6550 if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6551 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6552 SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) {
6553 MatchInfo = [=](MachineIRBuilder &B) {
6554 Register X = FMAMI->getOperand(i: 1).getReg();
6555 Register Y = FMAMI->getOperand(i: 2).getReg();
6556 X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0);
6557 Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0);
6558 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6559 FMulMI->getOperand(i: 2).getReg(), LHS.Reg, X, Y, B);
6560 };
6561 return true;
6562 }
6563 }
6564
6565 return false;
6566}
6567
6568bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA(
6569 MachineInstr &MI,
6570 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6571 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6572
6573 bool AllowFusionGlobally, HasFMAD, Aggressive;
6574 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6575 return false;
6576
6577 Register Op1 = MI.getOperand(i: 1).getReg();
6578 Register Op2 = MI.getOperand(i: 2).getReg();
6579 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6580 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6581 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6582
6583 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6584 // prefer to fold the multiply with fewer uses.
6585 int FirstMulHasFewerUses = true;
6586 if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6587 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6588 hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6589 FirstMulHasFewerUses = false;
6590
6591 unsigned PreferredFusedOpcode =
6592 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6593
6594 // fold (fsub (fmul x, y), z) -> (fma x, y, -z)
6595 if (FirstMulHasFewerUses &&
6596 (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6597 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg)))) {
6598 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6599 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHS.Reg).getReg(Idx: 0);
6600 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6601 SrcOps: {LHS.MI->getOperand(i: 1).getReg(),
6602 LHS.MI->getOperand(i: 2).getReg(), NegZ});
6603 };
6604 return true;
6605 }
6606 // fold (fsub x, (fmul y, z)) -> (fma -y, z, x)
6607 else if ((isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6608 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg)))) {
6609 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6610 Register NegY =
6611 B.buildFNeg(Dst: DstTy, Src0: RHS.MI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6612 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6613 SrcOps: {NegY, RHS.MI->getOperand(i: 2).getReg(), LHS.Reg});
6614 };
6615 return true;
6616 }
6617
6618 return false;
6619}
6620
6621bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA(
6622 MachineInstr &MI,
6623 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6624 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6625
6626 bool AllowFusionGlobally, HasFMAD, Aggressive;
6627 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6628 return false;
6629
6630 Register LHSReg = MI.getOperand(i: 1).getReg();
6631 Register RHSReg = MI.getOperand(i: 2).getReg();
6632 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6633
6634 unsigned PreferredFusedOpcode =
6635 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6636
6637 MachineInstr *FMulMI;
6638 // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z))
6639 if (mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) &&
6640 (Aggressive || (MRI.hasOneNonDBGUse(RegNo: LHSReg) &&
6641 MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) &&
6642 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) {
6643 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6644 Register NegX =
6645 B.buildFNeg(Dst: DstTy, Src0: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6646 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0);
6647 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6648 SrcOps: {NegX, FMulMI->getOperand(i: 2).getReg(), NegZ});
6649 };
6650 return true;
6651 }
6652
6653 // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x)
6654 if (mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) &&
6655 (Aggressive || (MRI.hasOneNonDBGUse(RegNo: RHSReg) &&
6656 MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) &&
6657 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) {
6658 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6659 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6660 SrcOps: {FMulMI->getOperand(i: 1).getReg(),
6661 FMulMI->getOperand(i: 2).getReg(), LHSReg});
6662 };
6663 return true;
6664 }
6665
6666 return false;
6667}
6668
6669bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA(
6670 MachineInstr &MI,
6671 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6672 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6673
6674 bool AllowFusionGlobally, HasFMAD, Aggressive;
6675 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6676 return false;
6677
6678 Register LHSReg = MI.getOperand(i: 1).getReg();
6679 Register RHSReg = MI.getOperand(i: 2).getReg();
6680 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6681
6682 unsigned PreferredFusedOpcode =
6683 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6684
6685 MachineInstr *FMulMI;
6686 // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
6687 if (mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6688 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6689 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHSReg))) {
6690 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6691 Register FpExtX =
6692 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6693 Register FpExtY =
6694 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0);
6695 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0);
6696 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6697 SrcOps: {FpExtX, FpExtY, NegZ});
6698 };
6699 return true;
6700 }
6701
6702 // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
6703 if (mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6704 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6705 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHSReg))) {
6706 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6707 Register FpExtY =
6708 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6709 Register NegY = B.buildFNeg(Dst: DstTy, Src0: FpExtY).getReg(Idx: 0);
6710 Register FpExtZ =
6711 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0);
6712 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6713 SrcOps: {NegY, FpExtZ, LHSReg});
6714 };
6715 return true;
6716 }
6717
6718 return false;
6719}
6720
6721bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
6722 MachineInstr &MI,
6723 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6724 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6725
6726 bool AllowFusionGlobally, HasFMAD, Aggressive;
6727 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6728 return false;
6729
6730 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6731 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6732 Register LHSReg = MI.getOperand(i: 1).getReg();
6733 Register RHSReg = MI.getOperand(i: 2).getReg();
6734
6735 unsigned PreferredFusedOpcode =
6736 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6737
6738 auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z,
6739 MachineIRBuilder &B) {
6740 Register FpExtX = B.buildFPExt(Res: DstTy, Op: X).getReg(Idx: 0);
6741 Register FpExtY = B.buildFPExt(Res: DstTy, Op: Y).getReg(Idx: 0);
6742 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {Dst}, SrcOps: {FpExtX, FpExtY, Z});
6743 };
6744
6745 MachineInstr *FMulMI;
6746 // fold (fsub (fpext (fneg (fmul x, y))), z) ->
6747 // (fneg (fma (fpext x), (fpext y), z))
6748 // fold (fsub (fneg (fpext (fmul x, y))), z) ->
6749 // (fneg (fma (fpext x), (fpext y), z))
6750 if ((mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) ||
6751 mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) &&
6752 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6753 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy,
6754 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6755 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6756 Register FMAReg = MRI.createGenericVirtualRegister(Ty: DstTy);
6757 buildMatchInfo(FMAReg, FMulMI->getOperand(i: 1).getReg(),
6758 FMulMI->getOperand(i: 2).getReg(), RHSReg, B);
6759 B.buildFNeg(Dst: MI.getOperand(i: 0).getReg(), Src0: FMAReg);
6760 };
6761 return true;
6762 }
6763
6764 // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
6765 // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
6766 if ((mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) ||
6767 mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) &&
6768 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6769 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy,
6770 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6771 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6772 buildMatchInfo(MI.getOperand(i: 0).getReg(), FMulMI->getOperand(i: 1).getReg(),
6773 FMulMI->getOperand(i: 2).getReg(), LHSReg, B);
6774 };
6775 return true;
6776 }
6777
6778 return false;
6779}
6780
6781bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,
6782 unsigned &IdxToPropagate) const {
6783 bool PropagateNaN;
6784 switch (MI.getOpcode()) {
6785 default:
6786 return false;
6787 case TargetOpcode::G_FMINNUM:
6788 case TargetOpcode::G_FMAXNUM:
6789 PropagateNaN = false;
6790 break;
6791 case TargetOpcode::G_FMINIMUM:
6792 case TargetOpcode::G_FMAXIMUM:
6793 PropagateNaN = true;
6794 break;
6795 }
6796
6797 auto MatchNaN = [&](unsigned Idx) {
6798 Register MaybeNaNReg = MI.getOperand(i: Idx).getReg();
6799 const ConstantFP *MaybeCst = getConstantFPVRegVal(VReg: MaybeNaNReg, MRI);
6800 if (!MaybeCst || !MaybeCst->getValueAPF().isNaN())
6801 return false;
6802 IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1);
6803 return true;
6804 };
6805
6806 return MatchNaN(1) || MatchNaN(2);
6807}
6808
6809// Combine multiple FDIVs with the same divisor into multiple FMULs by the
6810// reciprocal.
6811// E.g., (a / Y; b / Y;) -> (recip = 1.0 / Y; a * recip; b * recip)
6812bool CombinerHelper::matchRepeatedFPDivisor(
6813 MachineInstr &MI, SmallVector<MachineInstr *> &MatchInfo) const {
6814 assert(MI.getOpcode() == TargetOpcode::G_FDIV);
6815
6816 Register X = MI.getOperand(i: 1).getReg();
6817 Register Y = MI.getOperand(i: 2).getReg();
6818
6819 if (!MI.getFlag(Flag: MachineInstr::MIFlag::FmArcp))
6820 return false;
6821
6822 // Skip if current node is a reciprocal/fneg-reciprocal.
6823 auto N0CFP = isConstantOrConstantSplatVectorFP(MI&: *MRI.getVRegDef(Reg: X), MRI);
6824 if (N0CFP && (N0CFP->isExactlyValue(V: 1.0) || N0CFP->isExactlyValue(V: -1.0)))
6825 return false;
6826
6827 // Exit early if the target does not want this transform or if there can't
6828 // possibly be enough uses of the divisor to make the transform worthwhile.
6829 unsigned MinUses = getTargetLowering().combineRepeatedFPDivisors();
6830 if (!MinUses)
6831 return false;
6832
6833 // Find all FDIV users of the same divisor. For the moment we limit all
6834 // instructions to a single BB and use the first Instr in MatchInfo as the
6835 // dominating position.
6836 MatchInfo.push_back(Elt: &MI);
6837 for (auto &U : MRI.use_nodbg_instructions(Reg: Y)) {
6838 if (&U == &MI || U.getParent() != MI.getParent())
6839 continue;
6840 if (U.getOpcode() == TargetOpcode::G_FDIV &&
6841 U.getOperand(i: 2).getReg() == Y && U.getOperand(i: 1).getReg() != Y) {
6842 // This division is eligible for optimization only if global unsafe math
6843 // is enabled or if this division allows reciprocal formation.
6844 if (U.getFlag(Flag: MachineInstr::MIFlag::FmArcp)) {
6845 MatchInfo.push_back(Elt: &U);
6846 if (dominates(DefMI: U, UseMI: *MatchInfo[0]))
6847 std::swap(a&: MatchInfo[0], b&: MatchInfo.back());
6848 }
6849 }
6850 }
6851
6852 // Now that we have the actual number of divisor uses, make sure it meets
6853 // the minimum threshold specified by the target.
6854 return MatchInfo.size() >= MinUses;
6855}
6856
6857void CombinerHelper::applyRepeatedFPDivisor(
6858 SmallVector<MachineInstr *> &MatchInfo) const {
6859 // Generate the new div at the position of the first instruction, that we have
6860 // ensured will dominate all other instructions.
6861 Builder.setInsertPt(MBB&: *MatchInfo[0]->getParent(), II: MatchInfo[0]);
6862 LLT Ty = MRI.getType(Reg: MatchInfo[0]->getOperand(i: 0).getReg());
6863 auto Div = Builder.buildFDiv(Dst: Ty, Src0: Builder.buildFConstant(Res: Ty, Val: 1.0),
6864 Src1: MatchInfo[0]->getOperand(i: 2).getReg(),
6865 Flags: MatchInfo[0]->getFlags());
6866
6867 // Replace all found div's with fmul instructions.
6868 for (MachineInstr *MI : MatchInfo) {
6869 Builder.setInsertPt(MBB&: *MI->getParent(), II: MI);
6870 Builder.buildFMul(Dst: MI->getOperand(i: 0).getReg(), Src0: MI->getOperand(i: 1).getReg(),
6871 Src1: Div->getOperand(i: 0).getReg(), Flags: MI->getFlags());
6872 MI->eraseFromParent();
6873 }
6874}
6875
6876bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) const {
6877 assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD");
6878 Register LHS = MI.getOperand(i: 1).getReg();
6879 Register RHS = MI.getOperand(i: 2).getReg();
6880
6881 // Helper lambda to check for opportunities for
6882 // A + (B - A) -> B
6883 // (B - A) + A -> B
6884 auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) {
6885 Register Reg;
6886 return mi_match(R: MaybeSub, MRI, P: m_GSub(L: m_Reg(R&: Src), R: m_Reg(R&: Reg))) &&
6887 Reg == MaybeSameReg;
6888 };
6889 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
6890}
6891
6892bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI,
6893 Register &MatchInfo) const {
6894 // This combine folds the following patterns:
6895 //
6896 // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k))
6897 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k)))
6898 // into
6899 // x
6900 // if
6901 // k == sizeof(VecEltTy)/2
6902 // type(x) == type(dst)
6903 //
6904 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef)
6905 // into
6906 // x
6907 // if
6908 // type(x) == type(dst)
6909
6910 LLT DstVecTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6911 LLT DstEltTy = DstVecTy.getElementType();
6912
6913 Register Lo, Hi;
6914
6915 if (mi_match(
6916 MI, MRI,
6917 P: m_GBuildVector(L: m_GTrunc(Src: m_GBitcast(Src: m_Reg(R&: Lo))), R: m_GImplicitDef()))) {
6918 MatchInfo = Lo;
6919 return MRI.getType(Reg: MatchInfo) == DstVecTy;
6920 }
6921
6922 std::optional<ValueAndVReg> ShiftAmount;
6923 const auto LoPattern = m_GBitcast(Src: m_Reg(R&: Lo));
6924 const auto HiPattern = m_GLShr(L: m_GBitcast(Src: m_Reg(R&: Hi)), R: m_GCst(ValReg&: ShiftAmount));
6925 if (mi_match(
6926 MI, MRI,
6927 P: m_any_of(preds: m_GBuildVectorTrunc(L: LoPattern, R: HiPattern),
6928 preds: m_GBuildVector(L: m_GTrunc(Src: LoPattern), R: m_GTrunc(Src: HiPattern))))) {
6929 if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) {
6930 MatchInfo = Lo;
6931 return MRI.getType(Reg: MatchInfo) == DstVecTy;
6932 }
6933 }
6934
6935 return false;
6936}
6937
6938bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI,
6939 Register &MatchInfo) const {
6940 // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x
6941 // if type(x) == type(G_TRUNC)
6942 if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI,
6943 P: m_GBitcast(Src: m_GBuildVector(L: m_Reg(R&: MatchInfo), R: m_Reg()))))
6944 return false;
6945
6946 return MRI.getType(Reg: MatchInfo) == MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6947}
6948
6949bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI,
6950 Register &MatchInfo) const {
6951 // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with
6952 // y if K == size of vector element type
6953 std::optional<ValueAndVReg> ShiftAmt;
6954 if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI,
6955 P: m_GLShr(L: m_GBitcast(Src: m_GBuildVector(L: m_Reg(), R: m_Reg(R&: MatchInfo))),
6956 R: m_GCst(ValReg&: ShiftAmt))))
6957 return false;
6958
6959 LLT MatchTy = MRI.getType(Reg: MatchInfo);
6960 return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() &&
6961 MatchTy == MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6962}
6963
6964unsigned CombinerHelper::getFPMinMaxOpcForSelect(
6965 CmpInst::Predicate Pred, LLT DstTy,
6966 SelectPatternNaNBehaviour VsNaNRetVal) const {
6967 assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE &&
6968 "Expected a NaN behaviour?");
6969 // Choose an opcode based off of legality or the behaviour when one of the
6970 // LHS/RHS may be NaN.
6971 switch (Pred) {
6972 default:
6973 return 0;
6974 case CmpInst::FCMP_UGT:
6975 case CmpInst::FCMP_UGE:
6976 case CmpInst::FCMP_OGT:
6977 case CmpInst::FCMP_OGE:
6978 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6979 return TargetOpcode::G_FMAXNUM;
6980 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6981 return TargetOpcode::G_FMAXIMUM;
6982 if (isLegal(Query: {TargetOpcode::G_FMAXNUM, {DstTy}}))
6983 return TargetOpcode::G_FMAXNUM;
6984 if (isLegal(Query: {TargetOpcode::G_FMAXIMUM, {DstTy}}))
6985 return TargetOpcode::G_FMAXIMUM;
6986 return 0;
6987 case CmpInst::FCMP_ULT:
6988 case CmpInst::FCMP_ULE:
6989 case CmpInst::FCMP_OLT:
6990 case CmpInst::FCMP_OLE:
6991 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6992 return TargetOpcode::G_FMINNUM;
6993 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6994 return TargetOpcode::G_FMINIMUM;
6995 if (isLegal(Query: {TargetOpcode::G_FMINNUM, {DstTy}}))
6996 return TargetOpcode::G_FMINNUM;
6997 if (!isLegal(Query: {TargetOpcode::G_FMINIMUM, {DstTy}}))
6998 return 0;
6999 return TargetOpcode::G_FMINIMUM;
7000 }
7001}
7002
7003CombinerHelper::SelectPatternNaNBehaviour
7004CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS,
7005 bool IsOrderedComparison) const {
7006 bool LHSSafe = isKnownNeverNaN(Val: LHS, MRI);
7007 bool RHSSafe = isKnownNeverNaN(Val: RHS, MRI);
7008 // Completely unsafe.
7009 if (!LHSSafe && !RHSSafe)
7010 return SelectPatternNaNBehaviour::NOT_APPLICABLE;
7011 if (LHSSafe && RHSSafe)
7012 return SelectPatternNaNBehaviour::RETURNS_ANY;
7013 // An ordered comparison will return false when given a NaN, so it
7014 // returns the RHS.
7015 if (IsOrderedComparison)
7016 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN
7017 : SelectPatternNaNBehaviour::RETURNS_OTHER;
7018 // An unordered comparison will return true when given a NaN, so it
7019 // returns the LHS.
7020 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER
7021 : SelectPatternNaNBehaviour::RETURNS_NAN;
7022}
7023
7024bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond,
7025 Register TrueVal, Register FalseVal,
7026 BuildFnTy &MatchInfo) const {
7027 // Match: select (fcmp cond x, y) x, y
7028 // select (fcmp cond x, y) y, x
7029 // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition.
7030 LLT DstTy = MRI.getType(Reg: Dst);
7031 // Bail out early on pointers, since we'll never want to fold to a min/max.
7032 if (DstTy.isPointer())
7033 return false;
7034 // Match a floating point compare with a less-than/greater-than predicate.
7035 // TODO: Allow multiple users of the compare if they are all selects.
7036 CmpInst::Predicate Pred;
7037 Register CmpLHS, CmpRHS;
7038 if (!mi_match(R: Cond, MRI,
7039 P: m_OneNonDBGUse(
7040 SP: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: CmpLHS), R: m_Reg(R&: CmpRHS)))) ||
7041 CmpInst::isEquality(pred: Pred))
7042 return false;
7043 SelectPatternNaNBehaviour ResWithKnownNaNInfo =
7044 computeRetValAgainstNaN(LHS: CmpLHS, RHS: CmpRHS, IsOrderedComparison: CmpInst::isOrdered(predicate: Pred));
7045 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE)
7046 return false;
7047 if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
7048 std::swap(a&: CmpLHS, b&: CmpRHS);
7049 Pred = CmpInst::getSwappedPredicate(pred: Pred);
7050 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN)
7051 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER;
7052 else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER)
7053 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN;
7054 }
7055 if (TrueVal != CmpLHS || FalseVal != CmpRHS)
7056 return false;
7057 // Decide what type of max/min this should be based off of the predicate.
7058 unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, VsNaNRetVal: ResWithKnownNaNInfo);
7059 if (!Opc || !isLegal(Query: {Opc, {DstTy}}))
7060 return false;
7061 // Comparisons between signed zero and zero may have different results...
7062 // unless we have fmaximum/fminimum. In that case, we know -0 < 0.
7063 if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) {
7064 // We don't know if a comparison between two 0s will give us a consistent
7065 // result. Be conservative and only proceed if at least one side is
7066 // non-zero.
7067 auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpLHS, MRI);
7068 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) {
7069 KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpRHS, MRI);
7070 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero())
7071 return false;
7072 }
7073 }
7074 MatchInfo = [=](MachineIRBuilder &B) {
7075 B.buildInstr(Opc, DstOps: {Dst}, SrcOps: {CmpLHS, CmpRHS});
7076 };
7077 return true;
7078}
7079
7080bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI,
7081 BuildFnTy &MatchInfo) const {
7082 // TODO: Handle integer cases.
7083 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
7084 // Condition may be fed by a truncated compare.
7085 Register Cond = MI.getOperand(i: 1).getReg();
7086 Register MaybeTrunc;
7087 if (mi_match(R: Cond, MRI, P: m_OneNonDBGUse(SP: m_GTrunc(Src: m_Reg(R&: MaybeTrunc)))))
7088 Cond = MaybeTrunc;
7089 Register Dst = MI.getOperand(i: 0).getReg();
7090 Register TrueVal = MI.getOperand(i: 2).getReg();
7091 Register FalseVal = MI.getOperand(i: 3).getReg();
7092 return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo);
7093}
7094
7095bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI,
7096 BuildFnTy &MatchInfo) const {
7097 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
7098 // (X + Y) == X --> Y == 0
7099 // (X + Y) != X --> Y != 0
7100 // (X - Y) == X --> Y == 0
7101 // (X - Y) != X --> Y != 0
7102 // (X ^ Y) == X --> Y == 0
7103 // (X ^ Y) != X --> Y != 0
7104 Register Dst = MI.getOperand(i: 0).getReg();
7105 CmpInst::Predicate Pred;
7106 Register X, Y, OpLHS, OpRHS;
7107 bool MatchedSub = mi_match(
7108 R: Dst, MRI,
7109 P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), R: m_GSub(L: m_Reg(R&: OpLHS), R: m_Reg(R&: Y))));
7110 if (MatchedSub && X != OpLHS)
7111 return false;
7112 if (!MatchedSub) {
7113 if (!mi_match(R: Dst, MRI,
7114 P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X),
7115 R: m_any_of(preds: m_GAdd(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)),
7116 preds: m_GXor(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS))))))
7117 return false;
7118 Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register();
7119 }
7120 MatchInfo = [=](MachineIRBuilder &B) {
7121 auto Zero = B.buildConstant(Res: MRI.getType(Reg: Y), Val: 0);
7122 B.buildICmp(Pred, Res: Dst, Op0: Y, Op1: Zero);
7123 };
7124 return CmpInst::isEquality(pred: Pred) && Y.isValid();
7125}
7126
7127/// Return the minimum useless shift amount that results in complete loss of the
7128/// source value. Return std::nullopt when it cannot determine a value.
7129static std::optional<unsigned>
7130getMinUselessShift(KnownBits ValueKB, unsigned Opcode,
7131 std::optional<int64_t> &Result) {
7132 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR ||
7133 Opcode == TargetOpcode::G_ASHR) &&
7134 "Expect G_SHL, G_LSHR or G_ASHR.");
7135 auto SignificantBits = 0;
7136 switch (Opcode) {
7137 case TargetOpcode::G_SHL:
7138 SignificantBits = ValueKB.countMinTrailingZeros();
7139 Result = 0;
7140 break;
7141 case TargetOpcode::G_LSHR:
7142 Result = 0;
7143 SignificantBits = ValueKB.countMinLeadingZeros();
7144 break;
7145 case TargetOpcode::G_ASHR:
7146 if (ValueKB.isNonNegative()) {
7147 SignificantBits = ValueKB.countMinLeadingZeros();
7148 Result = 0;
7149 } else if (ValueKB.isNegative()) {
7150 SignificantBits = ValueKB.countMinLeadingOnes();
7151 Result = -1;
7152 } else {
7153 // Cannot determine shift result.
7154 Result = std::nullopt;
7155 }
7156 break;
7157 default:
7158 break;
7159 }
7160 return ValueKB.getBitWidth() - SignificantBits;
7161}
7162
7163bool CombinerHelper::matchShiftsTooBig(
7164 MachineInstr &MI, std::optional<int64_t> &MatchInfo) const {
7165 Register ShiftVal = MI.getOperand(i: 1).getReg();
7166 Register ShiftReg = MI.getOperand(i: 2).getReg();
7167 LLT ResTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
7168 auto IsShiftTooBig = [&](const Constant *C) {
7169 auto *CI = dyn_cast<ConstantInt>(Val: C);
7170 if (!CI)
7171 return false;
7172 if (CI->uge(Num: ResTy.getScalarSizeInBits())) {
7173 MatchInfo = std::nullopt;
7174 return true;
7175 }
7176 auto OptMaxUsefulShift = getMinUselessShift(ValueKB: VT->getKnownBits(R: ShiftVal),
7177 Opcode: MI.getOpcode(), Result&: MatchInfo);
7178 return OptMaxUsefulShift && CI->uge(Num: *OptMaxUsefulShift);
7179 };
7180 return matchUnaryPredicate(MRI, Reg: ShiftReg, Match: IsShiftTooBig);
7181}
7182
7183bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) const {
7184 unsigned LHSOpndIdx = 1;
7185 unsigned RHSOpndIdx = 2;
7186 switch (MI.getOpcode()) {
7187 case TargetOpcode::G_UADDO:
7188 case TargetOpcode::G_SADDO:
7189 case TargetOpcode::G_UMULO:
7190 case TargetOpcode::G_SMULO:
7191 LHSOpndIdx = 2;
7192 RHSOpndIdx = 3;
7193 break;
7194 default:
7195 break;
7196 }
7197 Register LHS = MI.getOperand(i: LHSOpndIdx).getReg();
7198 Register RHS = MI.getOperand(i: RHSOpndIdx).getReg();
7199 if (!getIConstantVRegVal(VReg: LHS, MRI)) {
7200 // Skip commuting if LHS is not a constant. But, LHS may be a
7201 // G_CONSTANT_FOLD_BARRIER. If so we commute as long as we don't already
7202 // have a constant on the RHS.
7203 if (MRI.getVRegDef(Reg: LHS)->getOpcode() !=
7204 TargetOpcode::G_CONSTANT_FOLD_BARRIER)
7205 return false;
7206 }
7207 // Commute as long as RHS is not a constant or G_CONSTANT_FOLD_BARRIER.
7208 return MRI.getVRegDef(Reg: RHS)->getOpcode() !=
7209 TargetOpcode::G_CONSTANT_FOLD_BARRIER &&
7210 !getIConstantVRegVal(VReg: RHS, MRI);
7211}
7212
7213bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) const {
7214 Register LHS = MI.getOperand(i: 1).getReg();
7215 Register RHS = MI.getOperand(i: 2).getReg();
7216 std::optional<FPValueAndVReg> ValAndVReg;
7217 if (!mi_match(R: LHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg)))
7218 return false;
7219 return !mi_match(R: RHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg));
7220}
7221
7222void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) const {
7223 Observer.changingInstr(MI);
7224 unsigned LHSOpndIdx = 1;
7225 unsigned RHSOpndIdx = 2;
7226 switch (MI.getOpcode()) {
7227 case TargetOpcode::G_UADDO:
7228 case TargetOpcode::G_SADDO:
7229 case TargetOpcode::G_UMULO:
7230 case TargetOpcode::G_SMULO:
7231 LHSOpndIdx = 2;
7232 RHSOpndIdx = 3;
7233 break;
7234 default:
7235 break;
7236 }
7237 Register LHSReg = MI.getOperand(i: LHSOpndIdx).getReg();
7238 Register RHSReg = MI.getOperand(i: RHSOpndIdx).getReg();
7239 MI.getOperand(i: LHSOpndIdx).setReg(RHSReg);
7240 MI.getOperand(i: RHSOpndIdx).setReg(LHSReg);
7241 Observer.changedInstr(MI);
7242}
7243
7244bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) const {
7245 LLT SrcTy = MRI.getType(Reg: Src);
7246 if (SrcTy.isFixedVector())
7247 return isConstantSplatVector(Src, SplatValue: 1, AllowUndefs);
7248 if (SrcTy.isScalar()) {
7249 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr)
7250 return true;
7251 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7252 return IConstant && IConstant->Value == 1;
7253 }
7254 return false; // scalable vector
7255}
7256
7257bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) const {
7258 LLT SrcTy = MRI.getType(Reg: Src);
7259 if (SrcTy.isFixedVector())
7260 return isConstantSplatVector(Src, SplatValue: 0, AllowUndefs);
7261 if (SrcTy.isScalar()) {
7262 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr)
7263 return true;
7264 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7265 return IConstant && IConstant->Value == 0;
7266 }
7267 return false; // scalable vector
7268}
7269
7270// Ignores COPYs during conformance checks.
7271// FIXME scalable vectors.
7272bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
7273 bool AllowUndefs) const {
7274 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7275 if (!BuildVector)
7276 return false;
7277 unsigned NumSources = BuildVector->getNumSources();
7278
7279 for (unsigned I = 0; I < NumSources; ++I) {
7280 GImplicitDef *ImplicitDef =
7281 getOpcodeDef<GImplicitDef>(Reg: BuildVector->getSourceReg(I), MRI);
7282 if (ImplicitDef && AllowUndefs)
7283 continue;
7284 if (ImplicitDef && !AllowUndefs)
7285 return false;
7286 std::optional<ValueAndVReg> IConstant =
7287 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7288 if (IConstant && IConstant->Value == SplatValue)
7289 continue;
7290 return false;
7291 }
7292 return true;
7293}
7294
7295// Ignores COPYs during lookups.
7296// FIXME scalable vectors
7297std::optional<APInt>
7298CombinerHelper::getConstantOrConstantSplatVector(Register Src) const {
7299 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7300 if (IConstant)
7301 return IConstant->Value;
7302
7303 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7304 if (!BuildVector)
7305 return std::nullopt;
7306 unsigned NumSources = BuildVector->getNumSources();
7307
7308 std::optional<APInt> Value = std::nullopt;
7309 for (unsigned I = 0; I < NumSources; ++I) {
7310 std::optional<ValueAndVReg> IConstant =
7311 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7312 if (!IConstant)
7313 return std::nullopt;
7314 if (!Value)
7315 Value = IConstant->Value;
7316 else if (*Value != IConstant->Value)
7317 return std::nullopt;
7318 }
7319 return Value;
7320}
7321
7322// FIXME G_SPLAT_VECTOR
7323bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const {
7324 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7325 if (IConstant)
7326 return true;
7327
7328 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7329 if (!BuildVector)
7330 return false;
7331
7332 unsigned NumSources = BuildVector->getNumSources();
7333 for (unsigned I = 0; I < NumSources; ++I) {
7334 std::optional<ValueAndVReg> IConstant =
7335 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7336 if (!IConstant)
7337 return false;
7338 }
7339 return true;
7340}
7341
7342// TODO: use knownbits to determine zeros
7343bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
7344 BuildFnTy &MatchInfo) const {
7345 uint32_t Flags = Select->getFlags();
7346 Register Dest = Select->getReg(Idx: 0);
7347 Register Cond = Select->getCondReg();
7348 Register True = Select->getTrueReg();
7349 Register False = Select->getFalseReg();
7350 LLT CondTy = MRI.getType(Reg: Select->getCondReg());
7351 LLT TrueTy = MRI.getType(Reg: Select->getTrueReg());
7352
7353 // We only do this combine for scalar boolean conditions.
7354 if (CondTy != LLT::scalar(SizeInBits: 1))
7355 return false;
7356
7357 if (TrueTy.isPointer())
7358 return false;
7359
7360 // Both are scalars.
7361 std::optional<ValueAndVReg> TrueOpt =
7362 getIConstantVRegValWithLookThrough(VReg: True, MRI);
7363 std::optional<ValueAndVReg> FalseOpt =
7364 getIConstantVRegValWithLookThrough(VReg: False, MRI);
7365
7366 if (!TrueOpt || !FalseOpt)
7367 return false;
7368
7369 APInt TrueValue = TrueOpt->Value;
7370 APInt FalseValue = FalseOpt->Value;
7371
7372 // select Cond, 1, 0 --> zext (Cond)
7373 if (TrueValue.isOne() && FalseValue.isZero()) {
7374 MatchInfo = [=](MachineIRBuilder &B) {
7375 B.setInstrAndDebugLoc(*Select);
7376 B.buildZExtOrTrunc(Res: Dest, Op: Cond);
7377 };
7378 return true;
7379 }
7380
7381 // select Cond, -1, 0 --> sext (Cond)
7382 if (TrueValue.isAllOnes() && FalseValue.isZero()) {
7383 MatchInfo = [=](MachineIRBuilder &B) {
7384 B.setInstrAndDebugLoc(*Select);
7385 B.buildSExtOrTrunc(Res: Dest, Op: Cond);
7386 };
7387 return true;
7388 }
7389
7390 // select Cond, 0, 1 --> zext (!Cond)
7391 if (TrueValue.isZero() && FalseValue.isOne()) {
7392 MatchInfo = [=](MachineIRBuilder &B) {
7393 B.setInstrAndDebugLoc(*Select);
7394 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7395 B.buildNot(Dst: Inner, Src0: Cond);
7396 B.buildZExtOrTrunc(Res: Dest, Op: Inner);
7397 };
7398 return true;
7399 }
7400
7401 // select Cond, 0, -1 --> sext (!Cond)
7402 if (TrueValue.isZero() && FalseValue.isAllOnes()) {
7403 MatchInfo = [=](MachineIRBuilder &B) {
7404 B.setInstrAndDebugLoc(*Select);
7405 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7406 B.buildNot(Dst: Inner, Src0: Cond);
7407 B.buildSExtOrTrunc(Res: Dest, Op: Inner);
7408 };
7409 return true;
7410 }
7411
7412 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
7413 if (TrueValue - 1 == FalseValue) {
7414 MatchInfo = [=](MachineIRBuilder &B) {
7415 B.setInstrAndDebugLoc(*Select);
7416 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7417 B.buildZExtOrTrunc(Res: Inner, Op: Cond);
7418 B.buildAdd(Dst: Dest, Src0: Inner, Src1: False);
7419 };
7420 return true;
7421 }
7422
7423 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
7424 if (TrueValue + 1 == FalseValue) {
7425 MatchInfo = [=](MachineIRBuilder &B) {
7426 B.setInstrAndDebugLoc(*Select);
7427 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7428 B.buildSExtOrTrunc(Res: Inner, Op: Cond);
7429 B.buildAdd(Dst: Dest, Src0: Inner, Src1: False);
7430 };
7431 return true;
7432 }
7433
7434 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
7435 if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
7436 MatchInfo = [=](MachineIRBuilder &B) {
7437 B.setInstrAndDebugLoc(*Select);
7438 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7439 B.buildZExtOrTrunc(Res: Inner, Op: Cond);
7440 // The shift amount must be scalar.
7441 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
7442 auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: TrueValue.exactLogBase2());
7443 B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags);
7444 };
7445 return true;
7446 }
7447
7448 // select Cond, 0, Pow2 --> (zext (!Cond)) << log2(Pow2)
7449 if (FalseValue.isPowerOf2() && TrueValue.isZero()) {
7450 MatchInfo = [=](MachineIRBuilder &B) {
7451 B.setInstrAndDebugLoc(*Select);
7452 Register Not = MRI.createGenericVirtualRegister(Ty: CondTy);
7453 B.buildNot(Dst: Not, Src0: Cond);
7454 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7455 B.buildZExtOrTrunc(Res: Inner, Op: Not);
7456 // The shift amount must be scalar.
7457 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
7458 auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: FalseValue.exactLogBase2());
7459 B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags);
7460 };
7461 return true;
7462 }
7463
7464 // select Cond, -1, C --> or (sext Cond), C
7465 if (TrueValue.isAllOnes()) {
7466 MatchInfo = [=](MachineIRBuilder &B) {
7467 B.setInstrAndDebugLoc(*Select);
7468 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7469 B.buildSExtOrTrunc(Res: Inner, Op: Cond);
7470 B.buildOr(Dst: Dest, Src0: Inner, Src1: False, Flags);
7471 };
7472 return true;
7473 }
7474
7475 // select Cond, C, -1 --> or (sext (not Cond)), C
7476 if (FalseValue.isAllOnes()) {
7477 MatchInfo = [=](MachineIRBuilder &B) {
7478 B.setInstrAndDebugLoc(*Select);
7479 Register Not = MRI.createGenericVirtualRegister(Ty: CondTy);
7480 B.buildNot(Dst: Not, Src0: Cond);
7481 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7482 B.buildSExtOrTrunc(Res: Inner, Op: Not);
7483 B.buildOr(Dst: Dest, Src0: Inner, Src1: True, Flags);
7484 };
7485 return true;
7486 }
7487
7488 return false;
7489}
7490
7491// TODO: use knownbits to determine zeros
7492bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
7493 BuildFnTy &MatchInfo) const {
7494 uint32_t Flags = Select->getFlags();
7495 Register DstReg = Select->getReg(Idx: 0);
7496 Register Cond = Select->getCondReg();
7497 Register True = Select->getTrueReg();
7498 Register False = Select->getFalseReg();
7499 LLT CondTy = MRI.getType(Reg: Select->getCondReg());
7500 LLT TrueTy = MRI.getType(Reg: Select->getTrueReg());
7501
7502 // Boolean or fixed vector of booleans.
7503 if (CondTy.isScalableVector() ||
7504 (CondTy.isFixedVector() &&
7505 CondTy.getElementType().getScalarSizeInBits() != 1) ||
7506 CondTy.getScalarSizeInBits() != 1)
7507 return false;
7508
7509 if (CondTy != TrueTy)
7510 return false;
7511
7512 // select Cond, Cond, F --> or Cond, F
7513 // select Cond, 1, F --> or Cond, F
7514 if ((Cond == True) || isOneOrOneSplat(Src: True, /* AllowUndefs */ true)) {
7515 MatchInfo = [=](MachineIRBuilder &B) {
7516 B.setInstrAndDebugLoc(*Select);
7517 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7518 B.buildZExtOrTrunc(Res: Ext, Op: Cond);
7519 auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False);
7520 B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeFalse, Flags);
7521 };
7522 return true;
7523 }
7524
7525 // select Cond, T, Cond --> and Cond, T
7526 // select Cond, T, 0 --> and Cond, T
7527 if ((Cond == False) || isZeroOrZeroSplat(Src: False, /* AllowUndefs */ true)) {
7528 MatchInfo = [=](MachineIRBuilder &B) {
7529 B.setInstrAndDebugLoc(*Select);
7530 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7531 B.buildZExtOrTrunc(Res: Ext, Op: Cond);
7532 auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True);
7533 B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeTrue);
7534 };
7535 return true;
7536 }
7537
7538 // select Cond, T, 1 --> or (not Cond), T
7539 if (isOneOrOneSplat(Src: False, /* AllowUndefs */ true)) {
7540 MatchInfo = [=](MachineIRBuilder &B) {
7541 B.setInstrAndDebugLoc(*Select);
7542 // First the not.
7543 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7544 B.buildNot(Dst: Inner, Src0: Cond);
7545 // Then an ext to match the destination register.
7546 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7547 B.buildZExtOrTrunc(Res: Ext, Op: Inner);
7548 auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True);
7549 B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeTrue, Flags);
7550 };
7551 return true;
7552 }
7553
7554 // select Cond, 0, F --> and (not Cond), F
7555 if (isZeroOrZeroSplat(Src: True, /* AllowUndefs */ true)) {
7556 MatchInfo = [=](MachineIRBuilder &B) {
7557 B.setInstrAndDebugLoc(*Select);
7558 // First the not.
7559 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7560 B.buildNot(Dst: Inner, Src0: Cond);
7561 // Then an ext to match the destination register.
7562 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7563 B.buildZExtOrTrunc(Res: Ext, Op: Inner);
7564 auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False);
7565 B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeFalse);
7566 };
7567 return true;
7568 }
7569
7570 return false;
7571}
7572
7573bool CombinerHelper::matchSelectIMinMax(const MachineOperand &MO,
7574 BuildFnTy &MatchInfo) const {
7575 GSelect *Select = cast<GSelect>(Val: MRI.getVRegDef(Reg: MO.getReg()));
7576 GICmp *Cmp = cast<GICmp>(Val: MRI.getVRegDef(Reg: Select->getCondReg()));
7577
7578 Register DstReg = Select->getReg(Idx: 0);
7579 Register True = Select->getTrueReg();
7580 Register False = Select->getFalseReg();
7581 LLT DstTy = MRI.getType(Reg: DstReg);
7582
7583 if (DstTy.isPointerOrPointerVector())
7584 return false;
7585
7586 // We want to fold the icmp and replace the select.
7587 if (!MRI.hasOneNonDBGUse(RegNo: Cmp->getReg(Idx: 0)))
7588 return false;
7589
7590 CmpInst::Predicate Pred = Cmp->getCond();
7591 // We need a larger or smaller predicate for
7592 // canonicalization.
7593 if (CmpInst::isEquality(pred: Pred))
7594 return false;
7595
7596 Register CmpLHS = Cmp->getLHSReg();
7597 Register CmpRHS = Cmp->getRHSReg();
7598
7599 // We can swap CmpLHS and CmpRHS for higher hitrate.
7600 if (True == CmpRHS && False == CmpLHS) {
7601 std::swap(a&: CmpLHS, b&: CmpRHS);
7602 Pred = CmpInst::getSwappedPredicate(pred: Pred);
7603 }
7604
7605 // (icmp X, Y) ? X : Y -> integer minmax.
7606 // see matchSelectPattern in ValueTracking.
7607 // Legality between G_SELECT and integer minmax can differ.
7608 if (True != CmpLHS || False != CmpRHS)
7609 return false;
7610
7611 switch (Pred) {
7612 case ICmpInst::ICMP_UGT:
7613 case ICmpInst::ICMP_UGE: {
7614 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMAX, DstTy}))
7615 return false;
7616 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMax(Dst: DstReg, Src0: True, Src1: False); };
7617 return true;
7618 }
7619 case ICmpInst::ICMP_SGT:
7620 case ICmpInst::ICMP_SGE: {
7621 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMAX, DstTy}))
7622 return false;
7623 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMax(Dst: DstReg, Src0: True, Src1: False); };
7624 return true;
7625 }
7626 case ICmpInst::ICMP_ULT:
7627 case ICmpInst::ICMP_ULE: {
7628 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMIN, DstTy}))
7629 return false;
7630 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMin(Dst: DstReg, Src0: True, Src1: False); };
7631 return true;
7632 }
7633 case ICmpInst::ICMP_SLT:
7634 case ICmpInst::ICMP_SLE: {
7635 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMIN, DstTy}))
7636 return false;
7637 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMin(Dst: DstReg, Src0: True, Src1: False); };
7638 return true;
7639 }
7640 default:
7641 return false;
7642 }
7643}
7644
7645// (neg (min/max x, (neg x))) --> (max/min x, (neg x))
7646bool CombinerHelper::matchSimplifyNegMinMax(MachineInstr &MI,
7647 BuildFnTy &MatchInfo) const {
7648 assert(MI.getOpcode() == TargetOpcode::G_SUB);
7649 Register DestReg = MI.getOperand(i: 0).getReg();
7650 LLT DestTy = MRI.getType(Reg: DestReg);
7651
7652 Register X;
7653 Register Sub0;
7654 auto NegPattern = m_all_of(preds: m_Neg(Src: m_DeferredReg(R&: X)), preds: m_Reg(R&: Sub0));
7655 if (mi_match(R: DestReg, MRI,
7656 P: m_Neg(Src: m_OneUse(SP: m_any_of(preds: m_GSMin(L: m_Reg(R&: X), R: NegPattern),
7657 preds: m_GSMax(L: m_Reg(R&: X), R: NegPattern),
7658 preds: m_GUMin(L: m_Reg(R&: X), R: NegPattern),
7659 preds: m_GUMax(L: m_Reg(R&: X), R: NegPattern)))))) {
7660 MachineInstr *MinMaxMI = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
7661 unsigned NewOpc = getInverseGMinMaxOpcode(MinMaxOpc: MinMaxMI->getOpcode());
7662 if (isLegal(Query: {NewOpc, {DestTy}})) {
7663 MatchInfo = [=](MachineIRBuilder &B) {
7664 B.buildInstr(Opc: NewOpc, DstOps: {DestReg}, SrcOps: {X, Sub0});
7665 };
7666 return true;
7667 }
7668 }
7669
7670 return false;
7671}
7672
7673bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7674 GSelect *Select = cast<GSelect>(Val: &MI);
7675
7676 if (tryFoldSelectOfConstants(Select, MatchInfo))
7677 return true;
7678
7679 if (tryFoldBoolSelectToLogic(Select, MatchInfo))
7680 return true;
7681
7682 return false;
7683}
7684
7685/// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
7686/// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
7687/// into a single comparison using range-based reasoning.
7688/// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges.
7689bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(
7690 GLogicalBinOp *Logic, BuildFnTy &MatchInfo) const {
7691 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor");
7692 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
7693 Register DstReg = Logic->getReg(Idx: 0);
7694 Register LHS = Logic->getLHSReg();
7695 Register RHS = Logic->getRHSReg();
7696 unsigned Flags = Logic->getFlags();
7697
7698 // We need an G_ICMP on the LHS register.
7699 GICmp *Cmp1 = getOpcodeDef<GICmp>(Reg: LHS, MRI);
7700 if (!Cmp1)
7701 return false;
7702
7703 // We need an G_ICMP on the RHS register.
7704 GICmp *Cmp2 = getOpcodeDef<GICmp>(Reg: RHS, MRI);
7705 if (!Cmp2)
7706 return false;
7707
7708 // We want to fold the icmps.
7709 if (!MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) ||
7710 !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)))
7711 return false;
7712
7713 APInt C1;
7714 APInt C2;
7715 std::optional<ValueAndVReg> MaybeC1 =
7716 getIConstantVRegValWithLookThrough(VReg: Cmp1->getRHSReg(), MRI);
7717 if (!MaybeC1)
7718 return false;
7719 C1 = MaybeC1->Value;
7720
7721 std::optional<ValueAndVReg> MaybeC2 =
7722 getIConstantVRegValWithLookThrough(VReg: Cmp2->getRHSReg(), MRI);
7723 if (!MaybeC2)
7724 return false;
7725 C2 = MaybeC2->Value;
7726
7727 Register R1 = Cmp1->getLHSReg();
7728 Register R2 = Cmp2->getLHSReg();
7729 CmpInst::Predicate Pred1 = Cmp1->getCond();
7730 CmpInst::Predicate Pred2 = Cmp2->getCond();
7731 LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0));
7732 LLT CmpOperandTy = MRI.getType(Reg: R1);
7733
7734 if (CmpOperandTy.isPointer())
7735 return false;
7736
7737 // We build ands, adds, and constants of type CmpOperandTy.
7738 // They must be legal to build.
7739 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_AND, CmpOperandTy}) ||
7740 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, CmpOperandTy}) ||
7741 !isConstantLegalOrBeforeLegalizer(Ty: CmpOperandTy))
7742 return false;
7743
7744 // Look through add of a constant offset on R1, R2, or both operands. This
7745 // allows us to interpret the R + C' < C'' range idiom into a proper range.
7746 std::optional<APInt> Offset1;
7747 std::optional<APInt> Offset2;
7748 if (R1 != R2) {
7749 if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R1, MRI)) {
7750 std::optional<ValueAndVReg> MaybeOffset1 =
7751 getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI);
7752 if (MaybeOffset1) {
7753 R1 = Add->getLHSReg();
7754 Offset1 = MaybeOffset1->Value;
7755 }
7756 }
7757 if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R2, MRI)) {
7758 std::optional<ValueAndVReg> MaybeOffset2 =
7759 getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI);
7760 if (MaybeOffset2) {
7761 R2 = Add->getLHSReg();
7762 Offset2 = MaybeOffset2->Value;
7763 }
7764 }
7765 }
7766
7767 if (R1 != R2)
7768 return false;
7769
7770 // We calculate the icmp ranges including maybe offsets.
7771 ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
7772 Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred1) : Pred1, Other: C1);
7773 if (Offset1)
7774 CR1 = CR1.subtract(CI: *Offset1);
7775
7776 ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
7777 Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred2) : Pred2, Other: C2);
7778 if (Offset2)
7779 CR2 = CR2.subtract(CI: *Offset2);
7780
7781 bool CreateMask = false;
7782 APInt LowerDiff;
7783 std::optional<ConstantRange> CR = CR1.exactUnionWith(CR: CR2);
7784 if (!CR) {
7785 // We need non-wrapping ranges.
7786 if (CR1.isWrappedSet() || CR2.isWrappedSet())
7787 return false;
7788
7789 // Check whether we have equal-size ranges that only differ by one bit.
7790 // In that case we can apply a mask to map one range onto the other.
7791 LowerDiff = CR1.getLower() ^ CR2.getLower();
7792 APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1);
7793 APInt CR1Size = CR1.getUpper() - CR1.getLower();
7794 if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff ||
7795 CR1Size != CR2.getUpper() - CR2.getLower())
7796 return false;
7797
7798 CR = CR1.getLower().ult(RHS: CR2.getLower()) ? CR1 : CR2;
7799 CreateMask = true;
7800 }
7801
7802 if (IsAnd)
7803 CR = CR->inverse();
7804
7805 CmpInst::Predicate NewPred;
7806 APInt NewC, Offset;
7807 CR->getEquivalentICmp(Pred&: NewPred, RHS&: NewC, Offset);
7808
7809 // We take the result type of one of the original icmps, CmpTy, for
7810 // the to be build icmp. The operand type, CmpOperandTy, is used for
7811 // the other instructions and constants to be build. The types of
7812 // the parameters and output are the same for add and and. CmpTy
7813 // and the type of DstReg might differ. That is why we zext or trunc
7814 // the icmp into the destination register.
7815
7816 MatchInfo = [=](MachineIRBuilder &B) {
7817 if (CreateMask && Offset != 0) {
7818 auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff);
7819 auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask.
7820 auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset);
7821 auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: And, Src1: OffsetC, Flags);
7822 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7823 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon);
7824 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7825 } else if (CreateMask && Offset == 0) {
7826 auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff);
7827 auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask.
7828 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7829 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: And, Op1: NewCon);
7830 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7831 } else if (!CreateMask && Offset != 0) {
7832 auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset);
7833 auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: R1, Src1: OffsetC, Flags);
7834 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7835 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon);
7836 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7837 } else if (!CreateMask && Offset == 0) {
7838 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7839 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: R1, Op1: NewCon);
7840 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7841 } else {
7842 llvm_unreachable("unexpected configuration of CreateMask and Offset");
7843 }
7844 };
7845 return true;
7846}
7847
7848bool CombinerHelper::tryFoldLogicOfFCmps(GLogicalBinOp *Logic,
7849 BuildFnTy &MatchInfo) const {
7850 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpecte xor");
7851 Register DestReg = Logic->getReg(Idx: 0);
7852 Register LHS = Logic->getLHSReg();
7853 Register RHS = Logic->getRHSReg();
7854 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
7855
7856 // We need a compare on the LHS register.
7857 GFCmp *Cmp1 = getOpcodeDef<GFCmp>(Reg: LHS, MRI);
7858 if (!Cmp1)
7859 return false;
7860
7861 // We need a compare on the RHS register.
7862 GFCmp *Cmp2 = getOpcodeDef<GFCmp>(Reg: RHS, MRI);
7863 if (!Cmp2)
7864 return false;
7865
7866 LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0));
7867 LLT CmpOperandTy = MRI.getType(Reg: Cmp1->getLHSReg());
7868
7869 // We build one fcmp, want to fold the fcmps, replace the logic op,
7870 // and the fcmps must have the same shape.
7871 if (!isLegalOrBeforeLegalizer(
7872 Query: {TargetOpcode::G_FCMP, {CmpTy, CmpOperandTy}}) ||
7873 !MRI.hasOneNonDBGUse(RegNo: Logic->getReg(Idx: 0)) ||
7874 !MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) ||
7875 !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)) ||
7876 MRI.getType(Reg: Cmp1->getLHSReg()) != MRI.getType(Reg: Cmp2->getLHSReg()))
7877 return false;
7878
7879 CmpInst::Predicate PredL = Cmp1->getCond();
7880 CmpInst::Predicate PredR = Cmp2->getCond();
7881 Register LHS0 = Cmp1->getLHSReg();
7882 Register LHS1 = Cmp1->getRHSReg();
7883 Register RHS0 = Cmp2->getLHSReg();
7884 Register RHS1 = Cmp2->getRHSReg();
7885
7886 if (LHS0 == RHS1 && LHS1 == RHS0) {
7887 // Swap RHS operands to match LHS.
7888 PredR = CmpInst::getSwappedPredicate(pred: PredR);
7889 std::swap(a&: RHS0, b&: RHS1);
7890 }
7891
7892 if (LHS0 == RHS0 && LHS1 == RHS1) {
7893 // We determine the new predicate.
7894 unsigned CmpCodeL = getFCmpCode(CC: PredL);
7895 unsigned CmpCodeR = getFCmpCode(CC: PredR);
7896 unsigned NewPred = IsAnd ? CmpCodeL & CmpCodeR : CmpCodeL | CmpCodeR;
7897 unsigned Flags = Cmp1->getFlags() | Cmp2->getFlags();
7898 MatchInfo = [=](MachineIRBuilder &B) {
7899 // The fcmp predicates fill the lower part of the enum.
7900 FCmpInst::Predicate Pred = static_cast<FCmpInst::Predicate>(NewPred);
7901 if (Pred == FCmpInst::FCMP_FALSE &&
7902 isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) {
7903 auto False = B.buildConstant(Res: CmpTy, Val: 0);
7904 B.buildZExtOrTrunc(Res: DestReg, Op: False);
7905 } else if (Pred == FCmpInst::FCMP_TRUE &&
7906 isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) {
7907 auto True =
7908 B.buildConstant(Res: CmpTy, Val: getICmpTrueVal(TLI: getTargetLowering(),
7909 IsVector: CmpTy.isVector() /*isVector*/,
7910 IsFP: true /*isFP*/));
7911 B.buildZExtOrTrunc(Res: DestReg, Op: True);
7912 } else { // We take the predicate without predicate optimizations.
7913 auto Cmp = B.buildFCmp(Pred, Res: CmpTy, Op0: LHS0, Op1: LHS1, Flags);
7914 B.buildZExtOrTrunc(Res: DestReg, Op: Cmp);
7915 }
7916 };
7917 return true;
7918 }
7919
7920 return false;
7921}
7922
7923bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7924 GAnd *And = cast<GAnd>(Val: &MI);
7925
7926 if (tryFoldAndOrOrICmpsUsingRanges(Logic: And, MatchInfo))
7927 return true;
7928
7929 if (tryFoldLogicOfFCmps(Logic: And, MatchInfo))
7930 return true;
7931
7932 return false;
7933}
7934
7935bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7936 GOr *Or = cast<GOr>(Val: &MI);
7937
7938 if (tryFoldAndOrOrICmpsUsingRanges(Logic: Or, MatchInfo))
7939 return true;
7940
7941 if (tryFoldLogicOfFCmps(Logic: Or, MatchInfo))
7942 return true;
7943
7944 return false;
7945}
7946
7947bool CombinerHelper::matchAddOverflow(MachineInstr &MI,
7948 BuildFnTy &MatchInfo) const {
7949 GAddCarryOut *Add = cast<GAddCarryOut>(Val: &MI);
7950
7951 // Addo has no flags
7952 Register Dst = Add->getReg(Idx: 0);
7953 Register Carry = Add->getReg(Idx: 1);
7954 Register LHS = Add->getLHSReg();
7955 Register RHS = Add->getRHSReg();
7956 bool IsSigned = Add->isSigned();
7957 LLT DstTy = MRI.getType(Reg: Dst);
7958 LLT CarryTy = MRI.getType(Reg: Carry);
7959
7960 // Fold addo, if the carry is dead -> add, undef.
7961 if (MRI.use_nodbg_empty(RegNo: Carry) &&
7962 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}})) {
7963 MatchInfo = [=](MachineIRBuilder &B) {
7964 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
7965 B.buildUndef(Res: Carry);
7966 };
7967 return true;
7968 }
7969
7970 // Canonicalize constant to RHS.
7971 if (isConstantOrConstantVectorI(Src: LHS) && !isConstantOrConstantVectorI(Src: RHS)) {
7972 if (IsSigned) {
7973 MatchInfo = [=](MachineIRBuilder &B) {
7974 B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS);
7975 };
7976 return true;
7977 }
7978 // !IsSigned
7979 MatchInfo = [=](MachineIRBuilder &B) {
7980 B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS);
7981 };
7982 return true;
7983 }
7984
7985 std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(Src: LHS);
7986 std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(Src: RHS);
7987
7988 // Fold addo(c1, c2) -> c3, carry.
7989 if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(Ty: DstTy) &&
7990 isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) {
7991 bool Overflow;
7992 APInt Result = IsSigned ? MaybeLHS->sadd_ov(RHS: *MaybeRHS, Overflow)
7993 : MaybeLHS->uadd_ov(RHS: *MaybeRHS, Overflow);
7994 MatchInfo = [=](MachineIRBuilder &B) {
7995 B.buildConstant(Res: Dst, Val: Result);
7996 B.buildConstant(Res: Carry, Val: Overflow);
7997 };
7998 return true;
7999 }
8000
8001 // Fold (addo x, 0) -> x, no carry
8002 if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) {
8003 MatchInfo = [=](MachineIRBuilder &B) {
8004 B.buildCopy(Res: Dst, Op: LHS);
8005 B.buildConstant(Res: Carry, Val: 0);
8006 };
8007 return true;
8008 }
8009
8010 // Given 2 constant operands whose sum does not overflow:
8011 // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1
8012 // saddo (X +nsw C0), C1 -> saddo X, C0 + C1
8013 GAdd *AddLHS = getOpcodeDef<GAdd>(Reg: LHS, MRI);
8014 if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)) &&
8015 ((IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoSWrap)) ||
8016 (!IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoUWrap)))) {
8017 std::optional<APInt> MaybeAddRHS =
8018 getConstantOrConstantSplatVector(Src: AddLHS->getRHSReg());
8019 if (MaybeAddRHS) {
8020 bool Overflow;
8021 APInt NewC = IsSigned ? MaybeAddRHS->sadd_ov(RHS: *MaybeRHS, Overflow)
8022 : MaybeAddRHS->uadd_ov(RHS: *MaybeRHS, Overflow);
8023 if (!Overflow && isConstantLegalOrBeforeLegalizer(Ty: DstTy)) {
8024 if (IsSigned) {
8025 MatchInfo = [=](MachineIRBuilder &B) {
8026 auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC);
8027 B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS);
8028 };
8029 return true;
8030 }
8031 // !IsSigned
8032 MatchInfo = [=](MachineIRBuilder &B) {
8033 auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC);
8034 B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS);
8035 };
8036 return true;
8037 }
8038 }
8039 };
8040
8041 // We try to combine addo to non-overflowing add.
8042 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}}) ||
8043 !isConstantLegalOrBeforeLegalizer(Ty: CarryTy))
8044 return false;
8045
8046 // We try to combine uaddo to non-overflowing add.
8047 if (!IsSigned) {
8048 ConstantRange CRLHS =
8049 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS), /*IsSigned=*/false);
8050 ConstantRange CRRHS =
8051 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS), /*IsSigned=*/false);
8052
8053 switch (CRLHS.unsignedAddMayOverflow(Other: CRRHS)) {
8054 case ConstantRange::OverflowResult::MayOverflow:
8055 return false;
8056 case ConstantRange::OverflowResult::NeverOverflows: {
8057 MatchInfo = [=](MachineIRBuilder &B) {
8058 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap);
8059 B.buildConstant(Res: Carry, Val: 0);
8060 };
8061 return true;
8062 }
8063 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8064 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8065 MatchInfo = [=](MachineIRBuilder &B) {
8066 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
8067 B.buildConstant(Res: Carry, Val: 1);
8068 };
8069 return true;
8070 }
8071 }
8072 return false;
8073 }
8074
8075 // We try to combine saddo to non-overflowing add.
8076
8077 // If LHS and RHS each have at least two sign bits, then there is no signed
8078 // overflow.
8079 if (VT->computeNumSignBits(R: RHS) > 1 && VT->computeNumSignBits(R: LHS) > 1) {
8080 MatchInfo = [=](MachineIRBuilder &B) {
8081 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8082 B.buildConstant(Res: Carry, Val: 0);
8083 };
8084 return true;
8085 }
8086
8087 ConstantRange CRLHS =
8088 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS), /*IsSigned=*/true);
8089 ConstantRange CRRHS =
8090 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS), /*IsSigned=*/true);
8091
8092 switch (CRLHS.signedAddMayOverflow(Other: CRRHS)) {
8093 case ConstantRange::OverflowResult::MayOverflow:
8094 return false;
8095 case ConstantRange::OverflowResult::NeverOverflows: {
8096 MatchInfo = [=](MachineIRBuilder &B) {
8097 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8098 B.buildConstant(Res: Carry, Val: 0);
8099 };
8100 return true;
8101 }
8102 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8103 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8104 MatchInfo = [=](MachineIRBuilder &B) {
8105 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
8106 B.buildConstant(Res: Carry, Val: 1);
8107 };
8108 return true;
8109 }
8110 }
8111
8112 return false;
8113}
8114
8115void CombinerHelper::applyBuildFnMO(const MachineOperand &MO,
8116 BuildFnTy &MatchInfo) const {
8117 MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI);
8118 MatchInfo(Builder);
8119 Root->eraseFromParent();
8120}
8121
8122bool CombinerHelper::matchFPowIExpansion(MachineInstr &MI,
8123 int64_t Exponent) const {
8124 bool OptForSize = MI.getMF()->getFunction().hasOptSize();
8125 return getTargetLowering().isBeneficialToExpandPowI(Exponent, OptForSize);
8126}
8127
8128void CombinerHelper::applyExpandFPowI(MachineInstr &MI,
8129 int64_t Exponent) const {
8130 auto [Dst, Base] = MI.getFirst2Regs();
8131 LLT Ty = MRI.getType(Reg: Dst);
8132 int64_t ExpVal = Exponent;
8133
8134 if (ExpVal == 0) {
8135 Builder.buildFConstant(Res: Dst, Val: 1.0);
8136 MI.removeFromParent();
8137 return;
8138 }
8139
8140 if (ExpVal < 0)
8141 ExpVal = -ExpVal;
8142
8143 // We use the simple binary decomposition method from SelectionDAG ExpandPowI
8144 // to generate the multiply sequence. There are more optimal ways to do this
8145 // (for example, powi(x,15) generates one more multiply than it should), but
8146 // this has the benefit of being both really simple and much better than a
8147 // libcall.
8148 std::optional<SrcOp> Res;
8149 SrcOp CurSquare = Base;
8150 while (ExpVal > 0) {
8151 if (ExpVal & 1) {
8152 if (!Res)
8153 Res = CurSquare;
8154 else
8155 Res = Builder.buildFMul(Dst: Ty, Src0: *Res, Src1: CurSquare);
8156 }
8157
8158 CurSquare = Builder.buildFMul(Dst: Ty, Src0: CurSquare, Src1: CurSquare);
8159 ExpVal >>= 1;
8160 }
8161
8162 // If the original exponent was negative, invert the result, producing
8163 // 1/(x*x*x).
8164 if (Exponent < 0)
8165 Res = Builder.buildFDiv(Dst: Ty, Src0: Builder.buildFConstant(Res: Ty, Val: 1.0), Src1: *Res,
8166 Flags: MI.getFlags());
8167
8168 Builder.buildCopy(Res: Dst, Op: *Res);
8169 MI.eraseFromParent();
8170}
8171
8172bool CombinerHelper::matchFoldAPlusC1MinusC2(const MachineInstr &MI,
8173 BuildFnTy &MatchInfo) const {
8174 // fold (A+C1)-C2 -> A+(C1-C2)
8175 const GSub *Sub = cast<GSub>(Val: &MI);
8176 GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: Sub->getLHSReg()));
8177
8178 if (!MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)))
8179 return false;
8180
8181 APInt C2 = getIConstantFromReg(VReg: Sub->getRHSReg(), MRI);
8182 APInt C1 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8183
8184 Register Dst = Sub->getReg(Idx: 0);
8185 LLT DstTy = MRI.getType(Reg: Dst);
8186
8187 MatchInfo = [=](MachineIRBuilder &B) {
8188 auto Const = B.buildConstant(Res: DstTy, Val: C1 - C2);
8189 B.buildAdd(Dst, Src0: Add->getLHSReg(), Src1: Const);
8190 };
8191
8192 return true;
8193}
8194
8195bool CombinerHelper::matchFoldC2MinusAPlusC1(const MachineInstr &MI,
8196 BuildFnTy &MatchInfo) const {
8197 // fold C2-(A+C1) -> (C2-C1)-A
8198 const GSub *Sub = cast<GSub>(Val: &MI);
8199 GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: Sub->getRHSReg()));
8200
8201 if (!MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)))
8202 return false;
8203
8204 APInt C2 = getIConstantFromReg(VReg: Sub->getLHSReg(), MRI);
8205 APInt C1 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8206
8207 Register Dst = Sub->getReg(Idx: 0);
8208 LLT DstTy = MRI.getType(Reg: Dst);
8209
8210 MatchInfo = [=](MachineIRBuilder &B) {
8211 auto Const = B.buildConstant(Res: DstTy, Val: C2 - C1);
8212 B.buildSub(Dst, Src0: Const, Src1: Add->getLHSReg());
8213 };
8214
8215 return true;
8216}
8217
8218bool CombinerHelper::matchFoldAMinusC1MinusC2(const MachineInstr &MI,
8219 BuildFnTy &MatchInfo) const {
8220 // fold (A-C1)-C2 -> A-(C1+C2)
8221 const GSub *Sub1 = cast<GSub>(Val: &MI);
8222 GSub *Sub2 = cast<GSub>(Val: MRI.getVRegDef(Reg: Sub1->getLHSReg()));
8223
8224 if (!MRI.hasOneNonDBGUse(RegNo: Sub2->getReg(Idx: 0)))
8225 return false;
8226
8227 APInt C2 = getIConstantFromReg(VReg: Sub1->getRHSReg(), MRI);
8228 APInt C1 = getIConstantFromReg(VReg: Sub2->getRHSReg(), MRI);
8229
8230 Register Dst = Sub1->getReg(Idx: 0);
8231 LLT DstTy = MRI.getType(Reg: Dst);
8232
8233 MatchInfo = [=](MachineIRBuilder &B) {
8234 auto Const = B.buildConstant(Res: DstTy, Val: C1 + C2);
8235 B.buildSub(Dst, Src0: Sub2->getLHSReg(), Src1: Const);
8236 };
8237
8238 return true;
8239}
8240
8241bool CombinerHelper::matchFoldC1Minus2MinusC2(const MachineInstr &MI,
8242 BuildFnTy &MatchInfo) const {
8243 // fold (C1-A)-C2 -> (C1-C2)-A
8244 const GSub *Sub1 = cast<GSub>(Val: &MI);
8245 GSub *Sub2 = cast<GSub>(Val: MRI.getVRegDef(Reg: Sub1->getLHSReg()));
8246
8247 if (!MRI.hasOneNonDBGUse(RegNo: Sub2->getReg(Idx: 0)))
8248 return false;
8249
8250 APInt C2 = getIConstantFromReg(VReg: Sub1->getRHSReg(), MRI);
8251 APInt C1 = getIConstantFromReg(VReg: Sub2->getLHSReg(), MRI);
8252
8253 Register Dst = Sub1->getReg(Idx: 0);
8254 LLT DstTy = MRI.getType(Reg: Dst);
8255
8256 MatchInfo = [=](MachineIRBuilder &B) {
8257 auto Const = B.buildConstant(Res: DstTy, Val: C1 - C2);
8258 B.buildSub(Dst, Src0: Const, Src1: Sub2->getRHSReg());
8259 };
8260
8261 return true;
8262}
8263
8264bool CombinerHelper::matchFoldAMinusC1PlusC2(const MachineInstr &MI,
8265 BuildFnTy &MatchInfo) const {
8266 // fold ((A-C1)+C2) -> (A+(C2-C1))
8267 const GAdd *Add = cast<GAdd>(Val: &MI);
8268 GSub *Sub = cast<GSub>(Val: MRI.getVRegDef(Reg: Add->getLHSReg()));
8269
8270 if (!MRI.hasOneNonDBGUse(RegNo: Sub->getReg(Idx: 0)))
8271 return false;
8272
8273 APInt C2 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8274 APInt C1 = getIConstantFromReg(VReg: Sub->getRHSReg(), MRI);
8275
8276 Register Dst = Add->getReg(Idx: 0);
8277 LLT DstTy = MRI.getType(Reg: Dst);
8278
8279 MatchInfo = [=](MachineIRBuilder &B) {
8280 auto Const = B.buildConstant(Res: DstTy, Val: C2 - C1);
8281 B.buildAdd(Dst, Src0: Sub->getLHSReg(), Src1: Const);
8282 };
8283
8284 return true;
8285}
8286
8287bool CombinerHelper::matchUnmergeValuesAnyExtBuildVector(
8288 const MachineInstr &MI, BuildFnTy &MatchInfo) const {
8289 const GUnmerge *Unmerge = cast<GUnmerge>(Val: &MI);
8290
8291 if (!MRI.hasOneNonDBGUse(RegNo: Unmerge->getSourceReg()))
8292 return false;
8293
8294 const MachineInstr *Source = MRI.getVRegDef(Reg: Unmerge->getSourceReg());
8295
8296 LLT DstTy = MRI.getType(Reg: Unmerge->getReg(Idx: 0));
8297
8298 // $bv:_(<8 x s8>) = G_BUILD_VECTOR ....
8299 // $any:_(<8 x s16>) = G_ANYEXT $bv
8300 // $uv:_(<4 x s16>), $uv1:_(<4 x s16>) = G_UNMERGE_VALUES $any
8301 //
8302 // ->
8303 //
8304 // $any:_(s16) = G_ANYEXT $bv[0]
8305 // $any1:_(s16) = G_ANYEXT $bv[1]
8306 // $any2:_(s16) = G_ANYEXT $bv[2]
8307 // $any3:_(s16) = G_ANYEXT $bv[3]
8308 // $any4:_(s16) = G_ANYEXT $bv[4]
8309 // $any5:_(s16) = G_ANYEXT $bv[5]
8310 // $any6:_(s16) = G_ANYEXT $bv[6]
8311 // $any7:_(s16) = G_ANYEXT $bv[7]
8312 // $uv:_(<4 x s16>) = G_BUILD_VECTOR $any, $any1, $any2, $any3
8313 // $uv1:_(<4 x s16>) = G_BUILD_VECTOR $any4, $any5, $any6, $any7
8314
8315 // We want to unmerge into vectors.
8316 if (!DstTy.isFixedVector())
8317 return false;
8318
8319 const GAnyExt *Any = dyn_cast<GAnyExt>(Val: Source);
8320 if (!Any)
8321 return false;
8322
8323 const MachineInstr *NextSource = MRI.getVRegDef(Reg: Any->getSrcReg());
8324
8325 if (const GBuildVector *BV = dyn_cast<GBuildVector>(Val: NextSource)) {
8326 // G_UNMERGE_VALUES G_ANYEXT G_BUILD_VECTOR
8327
8328 if (!MRI.hasOneNonDBGUse(RegNo: BV->getReg(Idx: 0)))
8329 return false;
8330
8331 // FIXME: check element types?
8332 if (BV->getNumSources() % Unmerge->getNumDefs() != 0)
8333 return false;
8334
8335 LLT BigBvTy = MRI.getType(Reg: BV->getReg(Idx: 0));
8336 LLT SmallBvTy = DstTy;
8337 LLT SmallBvElemenTy = SmallBvTy.getElementType();
8338
8339 if (!isLegalOrBeforeLegalizer(
8340 Query: {TargetOpcode::G_BUILD_VECTOR, {SmallBvTy, SmallBvElemenTy}}))
8341 return false;
8342
8343 // We check the legality of scalar anyext.
8344 if (!isLegalOrBeforeLegalizer(
8345 Query: {TargetOpcode::G_ANYEXT,
8346 {SmallBvElemenTy, BigBvTy.getElementType()}}))
8347 return false;
8348
8349 MatchInfo = [=](MachineIRBuilder &B) {
8350 // Build into each G_UNMERGE_VALUES def
8351 // a small build vector with anyext from the source build vector.
8352 for (unsigned I = 0; I < Unmerge->getNumDefs(); ++I) {
8353 SmallVector<Register> Ops;
8354 for (unsigned J = 0; J < SmallBvTy.getNumElements(); ++J) {
8355 Register SourceArray =
8356 BV->getSourceReg(I: I * SmallBvTy.getNumElements() + J);
8357 auto AnyExt = B.buildAnyExt(Res: SmallBvElemenTy, Op: SourceArray);
8358 Ops.push_back(Elt: AnyExt.getReg(Idx: 0));
8359 }
8360 B.buildBuildVector(Res: Unmerge->getOperand(i: I).getReg(), Ops);
8361 };
8362 };
8363 return true;
8364 };
8365
8366 return false;
8367}
8368
8369bool CombinerHelper::matchShuffleUndefRHS(MachineInstr &MI,
8370 BuildFnTy &MatchInfo) const {
8371
8372 bool Changed = false;
8373 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
8374 ArrayRef<int> OrigMask = Shuffle.getMask();
8375 SmallVector<int, 16> NewMask;
8376 const LLT SrcTy = MRI.getType(Reg: Shuffle.getSrc1Reg());
8377 const unsigned NumSrcElems = SrcTy.isVector() ? SrcTy.getNumElements() : 1;
8378 const unsigned NumDstElts = OrigMask.size();
8379 for (unsigned i = 0; i != NumDstElts; ++i) {
8380 int Idx = OrigMask[i];
8381 if (Idx >= (int)NumSrcElems) {
8382 Idx = -1;
8383 Changed = true;
8384 }
8385 NewMask.push_back(Elt: Idx);
8386 }
8387
8388 if (!Changed)
8389 return false;
8390
8391 MatchInfo = [&, NewMask = std::move(NewMask)](MachineIRBuilder &B) {
8392 B.buildShuffleVector(Res: MI.getOperand(i: 0), Src1: MI.getOperand(i: 1), Src2: MI.getOperand(i: 2),
8393 Mask: std::move(NewMask));
8394 };
8395
8396 return true;
8397}
8398
8399static void commuteMask(MutableArrayRef<int> Mask, const unsigned NumElems) {
8400 const unsigned MaskSize = Mask.size();
8401 for (unsigned I = 0; I < MaskSize; ++I) {
8402 int Idx = Mask[I];
8403 if (Idx < 0)
8404 continue;
8405
8406 if (Idx < (int)NumElems)
8407 Mask[I] = Idx + NumElems;
8408 else
8409 Mask[I] = Idx - NumElems;
8410 }
8411}
8412
8413bool CombinerHelper::matchShuffleDisjointMask(MachineInstr &MI,
8414 BuildFnTy &MatchInfo) const {
8415
8416 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
8417 // If any of the two inputs is already undef, don't check the mask again to
8418 // prevent infinite loop
8419 if (getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: Shuffle.getSrc1Reg(), MRI))
8420 return false;
8421
8422 if (getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: Shuffle.getSrc2Reg(), MRI))
8423 return false;
8424
8425 const LLT DstTy = MRI.getType(Reg: Shuffle.getReg(Idx: 0));
8426 const LLT Src1Ty = MRI.getType(Reg: Shuffle.getSrc1Reg());
8427 if (!isLegalOrBeforeLegalizer(
8428 Query: {TargetOpcode::G_SHUFFLE_VECTOR, {DstTy, Src1Ty}}))
8429 return false;
8430
8431 ArrayRef<int> Mask = Shuffle.getMask();
8432 const unsigned NumSrcElems = Src1Ty.getNumElements();
8433
8434 bool TouchesSrc1 = false;
8435 bool TouchesSrc2 = false;
8436 const unsigned NumElems = Mask.size();
8437 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
8438 if (Mask[Idx] < 0)
8439 continue;
8440
8441 if (Mask[Idx] < (int)NumSrcElems)
8442 TouchesSrc1 = true;
8443 else
8444 TouchesSrc2 = true;
8445 }
8446
8447 if (TouchesSrc1 == TouchesSrc2)
8448 return false;
8449
8450 Register NewSrc1 = Shuffle.getSrc1Reg();
8451 SmallVector<int, 16> NewMask(Mask);
8452 if (TouchesSrc2) {
8453 NewSrc1 = Shuffle.getSrc2Reg();
8454 commuteMask(Mask: NewMask, NumElems: NumSrcElems);
8455 }
8456
8457 MatchInfo = [=, &Shuffle](MachineIRBuilder &B) {
8458 auto Undef = B.buildUndef(Res: Src1Ty);
8459 B.buildShuffleVector(Res: Shuffle.getReg(Idx: 0), Src1: NewSrc1, Src2: Undef, Mask: NewMask);
8460 };
8461
8462 return true;
8463}
8464
8465bool CombinerHelper::matchSuboCarryOut(const MachineInstr &MI,
8466 BuildFnTy &MatchInfo) const {
8467 const GSubCarryOut *Subo = cast<GSubCarryOut>(Val: &MI);
8468
8469 Register Dst = Subo->getReg(Idx: 0);
8470 Register LHS = Subo->getLHSReg();
8471 Register RHS = Subo->getRHSReg();
8472 Register Carry = Subo->getCarryOutReg();
8473 LLT DstTy = MRI.getType(Reg: Dst);
8474 LLT CarryTy = MRI.getType(Reg: Carry);
8475
8476 // Check legality before known bits.
8477 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy}}) ||
8478 !isConstantLegalOrBeforeLegalizer(Ty: CarryTy))
8479 return false;
8480
8481 ConstantRange KBLHS =
8482 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS),
8483 /* IsSigned=*/Subo->isSigned());
8484 ConstantRange KBRHS =
8485 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS),
8486 /* IsSigned=*/Subo->isSigned());
8487
8488 if (Subo->isSigned()) {
8489 // G_SSUBO
8490 switch (KBLHS.signedSubMayOverflow(Other: KBRHS)) {
8491 case ConstantRange::OverflowResult::MayOverflow:
8492 return false;
8493 case ConstantRange::OverflowResult::NeverOverflows: {
8494 MatchInfo = [=](MachineIRBuilder &B) {
8495 B.buildSub(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8496 B.buildConstant(Res: Carry, Val: 0);
8497 };
8498 return true;
8499 }
8500 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8501 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8502 MatchInfo = [=](MachineIRBuilder &B) {
8503 B.buildSub(Dst, Src0: LHS, Src1: RHS);
8504 B.buildConstant(Res: Carry, Val: getICmpTrueVal(TLI: getTargetLowering(),
8505 /*isVector=*/IsVector: CarryTy.isVector(),
8506 /*isFP=*/IsFP: false));
8507 };
8508 return true;
8509 }
8510 }
8511 return false;
8512 }
8513
8514 // G_USUBO
8515 switch (KBLHS.unsignedSubMayOverflow(Other: KBRHS)) {
8516 case ConstantRange::OverflowResult::MayOverflow:
8517 return false;
8518 case ConstantRange::OverflowResult::NeverOverflows: {
8519 MatchInfo = [=](MachineIRBuilder &B) {
8520 B.buildSub(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap);
8521 B.buildConstant(Res: Carry, Val: 0);
8522 };
8523 return true;
8524 }
8525 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8526 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8527 MatchInfo = [=](MachineIRBuilder &B) {
8528 B.buildSub(Dst, Src0: LHS, Src1: RHS);
8529 B.buildConstant(Res: Carry, Val: getICmpTrueVal(TLI: getTargetLowering(),
8530 /*isVector=*/IsVector: CarryTy.isVector(),
8531 /*isFP=*/IsFP: false));
8532 };
8533 return true;
8534 }
8535 }
8536
8537 return false;
8538}
8539
8540// Fold (ctlz (xor x, (sra x, bitwidth-1))) -> (add (ctls x), 1).
8541// Fold (ctlz (or (shl (xor x, (sra x, bitwidth-1)), 1), 1) -> (ctls x)
8542bool CombinerHelper::matchCtls(MachineInstr &CtlzMI,
8543 BuildFnTy &MatchInfo) const {
8544 assert((CtlzMI.getOpcode() == TargetOpcode::G_CTLZ ||
8545 CtlzMI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) &&
8546 "Expected G_CTLZ variant");
8547
8548 const Register Dst = CtlzMI.getOperand(i: 0).getReg();
8549 Register Src = CtlzMI.getOperand(i: 1).getReg();
8550
8551 LLT Ty = MRI.getType(Reg: Dst);
8552 LLT SrcTy = MRI.getType(Reg: Src);
8553
8554 if (!(Ty.isValid() && Ty.isScalar()))
8555 return false;
8556
8557 if (!LI)
8558 return false;
8559
8560 SmallVector<LLT, 2> QueryTypes = {Ty, SrcTy};
8561 LegalityQuery Query(TargetOpcode::G_CTLS, QueryTypes);
8562
8563 switch (LI->getAction(Query).Action) {
8564 default:
8565 return false;
8566 case LegalizeActions::Legal:
8567 case LegalizeActions::Custom:
8568 case LegalizeActions::WidenScalar:
8569 break;
8570 }
8571
8572 // Src = or(shl(V, 1), 1) -> Src=V; NeedAdd = False
8573 Register V;
8574 bool NeedAdd = true;
8575 if (mi_match(R: Src, MRI,
8576 P: m_OneUse(SP: m_GOr(L: m_OneUse(SP: m_GShl(L: m_Reg(R&: V), R: m_SpecificICst(RequestedValue: 1))),
8577 R: m_SpecificICst(RequestedValue: 1))))) {
8578 NeedAdd = false;
8579 Src = V;
8580 }
8581
8582 unsigned BitWidth = Ty.getScalarSizeInBits();
8583
8584 Register X;
8585 if (!mi_match(R: Src, MRI,
8586 P: m_OneUse(SP: m_GXor(L: m_Reg(R&: X), R: m_OneUse(SP: m_GAShr(
8587 L: m_DeferredReg(R&: X),
8588 R: m_SpecificICst(RequestedValue: BitWidth - 1)))))))
8589 return false;
8590
8591 MatchInfo = [=](MachineIRBuilder &B) {
8592 if (!NeedAdd) {
8593 B.buildCTLS(Dst, Src0: X);
8594 return;
8595 }
8596
8597 auto Ctls = B.buildCTLS(Dst: Ty, Src0: X);
8598 auto One = B.buildConstant(Res: Ty, Val: 1);
8599
8600 B.buildAdd(Dst, Src0: Ctls, Src1: One);
8601 };
8602
8603 return true;
8604}
8605