1//===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
9#include "llvm/ADT/APFloat.h"
10#include "llvm/ADT/STLExtras.h"
11#include "llvm/ADT/SetVector.h"
12#include "llvm/ADT/SmallBitVector.h"
13#include "llvm/Analysis/CmpInstAnalysis.h"
14#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
15#include "llvm/CodeGen/GlobalISel/GISelValueTracking.h"
16#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
17#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
19#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
20#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
21#include "llvm/CodeGen/GlobalISel/Utils.h"
22#include "llvm/CodeGen/LowLevelTypeUtils.h"
23#include "llvm/CodeGen/MachineBasicBlock.h"
24#include "llvm/CodeGen/MachineDominators.h"
25#include "llvm/CodeGen/MachineInstr.h"
26#include "llvm/CodeGen/MachineMemOperand.h"
27#include "llvm/CodeGen/MachineRegisterInfo.h"
28#include "llvm/CodeGen/Register.h"
29#include "llvm/CodeGen/RegisterBankInfo.h"
30#include "llvm/CodeGen/TargetInstrInfo.h"
31#include "llvm/CodeGen/TargetLowering.h"
32#include "llvm/CodeGen/TargetOpcodes.h"
33#include "llvm/IR/ConstantRange.h"
34#include "llvm/IR/DataLayout.h"
35#include "llvm/IR/InstrTypes.h"
36#include "llvm/IR/PatternMatch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/DivisionByConstantInfo.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/MathExtras.h"
41#include "llvm/Target/TargetMachine.h"
42#include <cmath>
43#include <optional>
44#include <tuple>
45
46#define DEBUG_TYPE "gi-combiner"
47
48using namespace llvm;
49using namespace MIPatternMatch;
50
51// Option to allow testing of the combiner while no targets know about indexed
52// addressing.
53static cl::opt<bool>
54 ForceLegalIndexing("force-legal-indexing", cl::Hidden, cl::init(Val: false),
55 cl::desc("Force all indexed operations to be "
56 "legal for the GlobalISel combiner"));
57
58CombinerHelper::CombinerHelper(GISelChangeObserver &Observer,
59 MachineIRBuilder &B, bool IsPreLegalize,
60 GISelValueTracking *VT,
61 MachineDominatorTree *MDT,
62 const LegalizerInfo *LI)
63 : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), VT(VT),
64 MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI),
65 RBI(Builder.getMF().getSubtarget().getRegBankInfo()),
66 TRI(Builder.getMF().getSubtarget().getRegisterInfo()) {
67 (void)this->VT;
68}
69
70const TargetLowering &CombinerHelper::getTargetLowering() const {
71 return *Builder.getMF().getSubtarget().getTargetLowering();
72}
73
74const MachineFunction &CombinerHelper::getMachineFunction() const {
75 return Builder.getMF();
76}
77
78const DataLayout &CombinerHelper::getDataLayout() const {
79 return getMachineFunction().getDataLayout();
80}
81
82LLVMContext &CombinerHelper::getContext() const { return Builder.getContext(); }
83
84/// \returns The little endian in-memory byte position of byte \p I in a
85/// \p ByteWidth bytes wide type.
86///
87/// E.g. Given a 4-byte type x, x[0] -> byte 0
88static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) {
89 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
90 return I;
91}
92
93/// Determines the LogBase2 value for a non-null input value using the
94/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
95static Register buildLogBase2(Register V, MachineIRBuilder &MIB) {
96 auto &MRI = *MIB.getMRI();
97 LLT Ty = MRI.getType(Reg: V);
98 auto Ctlz = MIB.buildCTLZ(Dst: Ty, Src0: V);
99 auto Base = MIB.buildConstant(Res: Ty, Val: Ty.getScalarSizeInBits() - 1);
100 return MIB.buildSub(Dst: Ty, Src0: Base, Src1: Ctlz).getReg(Idx: 0);
101}
102
103/// \returns The big endian in-memory byte position of byte \p I in a
104/// \p ByteWidth bytes wide type.
105///
106/// E.g. Given a 4-byte type x, x[0] -> byte 3
107static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) {
108 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
109 return ByteWidth - I - 1;
110}
111
112/// Given a map from byte offsets in memory to indices in a load/store,
113/// determine if that map corresponds to a little or big endian byte pattern.
114///
115/// \param MemOffset2Idx maps memory offsets to address offsets.
116/// \param LowestIdx is the lowest index in \p MemOffset2Idx.
117///
118/// \returns true if the map corresponds to a big endian byte pattern, false if
119/// it corresponds to a little endian byte pattern, and std::nullopt otherwise.
120///
121/// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns
122/// are as follows:
123///
124/// AddrOffset Little endian Big endian
125/// 0 0 3
126/// 1 1 2
127/// 2 2 1
128/// 3 3 0
129static std::optional<bool>
130isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
131 int64_t LowestIdx) {
132 // Need at least two byte positions to decide on endianness.
133 unsigned Width = MemOffset2Idx.size();
134 if (Width < 2)
135 return std::nullopt;
136 bool BigEndian = true, LittleEndian = true;
137 for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) {
138 auto MemOffsetAndIdx = MemOffset2Idx.find(Val: MemOffset);
139 if (MemOffsetAndIdx == MemOffset2Idx.end())
140 return std::nullopt;
141 const int64_t Idx = MemOffsetAndIdx->second - LowestIdx;
142 assert(Idx >= 0 && "Expected non-negative byte offset?");
143 LittleEndian &= Idx == littleEndianByteAt(ByteWidth: Width, I: MemOffset);
144 BigEndian &= Idx == bigEndianByteAt(ByteWidth: Width, I: MemOffset);
145 if (!BigEndian && !LittleEndian)
146 return std::nullopt;
147 }
148
149 assert((BigEndian != LittleEndian) &&
150 "Pattern cannot be both big and little endian!");
151 return BigEndian;
152}
153
154bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; }
155
156bool CombinerHelper::isLegal(const LegalityQuery &Query) const {
157 assert(LI && "Must have LegalizerInfo to query isLegal!");
158 return LI->getAction(Query).Action == LegalizeActions::Legal;
159}
160
161bool CombinerHelper::isLegalOrBeforeLegalizer(
162 const LegalityQuery &Query) const {
163 return isPreLegalize() || isLegal(Query);
164}
165
166bool CombinerHelper::isLegalOrHasWidenScalar(const LegalityQuery &Query) const {
167 return isLegal(Query) ||
168 LI->getAction(Query).Action == LegalizeActions::WidenScalar;
169}
170
171bool CombinerHelper::isLegalOrHasFewerElements(
172 const LegalityQuery &Query) const {
173 LegalizeAction Action = LI->getAction(Query).Action;
174 return Action == LegalizeActions::Legal ||
175 Action == LegalizeActions::FewerElements;
176}
177
178bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
179 if (!Ty.isVector())
180 return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CONSTANT, {Ty}});
181 // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs.
182 if (isPreLegalize())
183 return true;
184 LLT EltTy = Ty.getElementType();
185 return isLegal(Query: {TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) &&
186 isLegal(Query: {TargetOpcode::G_CONSTANT, {EltTy}});
187}
188
189void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg,
190 Register ToReg) const {
191 Observer.changingAllUsesOfReg(MRI, Reg: FromReg);
192
193 if (MRI.constrainRegAttrs(Reg: ToReg, ConstrainingReg: FromReg))
194 MRI.replaceRegWith(FromReg, ToReg);
195 else
196 Builder.buildCopy(Res: FromReg, Op: ToReg);
197
198 Observer.finishedChangingAllUsesOfReg();
199}
200
201void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI,
202 MachineOperand &FromRegOp,
203 Register ToReg) const {
204 assert(FromRegOp.getParent() && "Expected an operand in an MI");
205 Observer.changingInstr(MI&: *FromRegOp.getParent());
206
207 FromRegOp.setReg(ToReg);
208
209 Observer.changedInstr(MI&: *FromRegOp.getParent());
210}
211
212void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI,
213 unsigned ToOpcode) const {
214 Observer.changingInstr(MI&: FromMI);
215
216 FromMI.setDesc(Builder.getTII().get(Opcode: ToOpcode));
217
218 Observer.changedInstr(MI&: FromMI);
219}
220
221const RegisterBank *CombinerHelper::getRegBank(Register Reg) const {
222 return RBI->getRegBank(Reg, MRI, TRI: *TRI);
223}
224
225void CombinerHelper::setRegBank(Register Reg,
226 const RegisterBank *RegBank) const {
227 if (RegBank)
228 MRI.setRegBank(Reg, RegBank: *RegBank);
229}
230
231bool CombinerHelper::tryCombineCopy(MachineInstr &MI) const {
232 if (matchCombineCopy(MI)) {
233 applyCombineCopy(MI);
234 return true;
235 }
236 return false;
237}
238bool CombinerHelper::matchCombineCopy(MachineInstr &MI) const {
239 if (MI.getOpcode() != TargetOpcode::COPY)
240 return false;
241 Register DstReg = MI.getOperand(i: 0).getReg();
242 Register SrcReg = MI.getOperand(i: 1).getReg();
243 return canReplaceReg(DstReg, SrcReg, MRI);
244}
245void CombinerHelper::applyCombineCopy(MachineInstr &MI) const {
246 Register DstReg = MI.getOperand(i: 0).getReg();
247 Register SrcReg = MI.getOperand(i: 1).getReg();
248 replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg);
249 MI.eraseFromParent();
250}
251
252bool CombinerHelper::matchFreezeOfSingleMaybePoisonOperand(
253 MachineInstr &MI, BuildFnTy &MatchInfo) const {
254 // Ported from InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating.
255 Register DstOp = MI.getOperand(i: 0).getReg();
256 Register OrigOp = MI.getOperand(i: 1).getReg();
257
258 if (!MRI.hasOneNonDBGUse(RegNo: OrigOp))
259 return false;
260
261 MachineInstr *OrigDef = MRI.getUniqueVRegDef(Reg: OrigOp);
262 // Even if only a single operand of the PHI is not guaranteed non-poison,
263 // moving freeze() backwards across a PHI can cause optimization issues for
264 // other users of that operand.
265 //
266 // Moving freeze() from one of the output registers of a G_UNMERGE_VALUES to
267 // the source register is unprofitable because it makes the freeze() more
268 // strict than is necessary (it would affect the whole register instead of
269 // just the subreg being frozen).
270 if (OrigDef->isPHI() || isa<GUnmerge>(Val: OrigDef))
271 return false;
272
273 if (canCreateUndefOrPoison(Reg: OrigOp, MRI,
274 /*ConsiderFlagsAndMetadata=*/false))
275 return false;
276
277 std::optional<MachineOperand> MaybePoisonOperand;
278 for (MachineOperand &Operand : OrigDef->uses()) {
279 if (!Operand.isReg())
280 return false;
281
282 if (isGuaranteedNotToBeUndefOrPoison(Reg: Operand.getReg(), MRI))
283 continue;
284
285 if (!MaybePoisonOperand)
286 MaybePoisonOperand = Operand;
287 else {
288 // We have more than one maybe-poison operand. Moving the freeze is
289 // unsafe.
290 return false;
291 }
292 }
293
294 // Eliminate freeze if all operands are guaranteed non-poison.
295 if (!MaybePoisonOperand) {
296 MatchInfo = [=](MachineIRBuilder &B) {
297 Observer.changingInstr(MI&: *OrigDef);
298 cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags();
299 Observer.changedInstr(MI&: *OrigDef);
300 B.buildCopy(Res: DstOp, Op: OrigOp);
301 };
302 return true;
303 }
304
305 Register MaybePoisonOperandReg = MaybePoisonOperand->getReg();
306 LLT MaybePoisonOperandRegTy = MRI.getType(Reg: MaybePoisonOperandReg);
307
308 MatchInfo = [=](MachineIRBuilder &B) mutable {
309 Observer.changingInstr(MI&: *OrigDef);
310 cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags();
311 Observer.changedInstr(MI&: *OrigDef);
312 B.setInsertPt(MBB&: *OrigDef->getParent(), II: OrigDef->getIterator());
313 auto Freeze = B.buildFreeze(Dst: MaybePoisonOperandRegTy, Src: MaybePoisonOperandReg);
314 replaceRegOpWith(
315 MRI, FromRegOp&: *OrigDef->findRegisterUseOperand(Reg: MaybePoisonOperandReg, TRI),
316 ToReg: Freeze.getReg(Idx: 0));
317 replaceRegWith(MRI, FromReg: DstOp, ToReg: OrigOp);
318 };
319 return true;
320}
321
322bool CombinerHelper::matchCombineConcatVectors(
323 MachineInstr &MI, SmallVector<Register> &Ops) const {
324 assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS &&
325 "Invalid instruction");
326 bool IsUndef = true;
327 MachineInstr *Undef = nullptr;
328
329 // Walk over all the operands of concat vectors and check if they are
330 // build_vector themselves or undef.
331 // Then collect their operands in Ops.
332 for (const MachineOperand &MO : MI.uses()) {
333 Register Reg = MO.getReg();
334 MachineInstr *Def = MRI.getVRegDef(Reg);
335 assert(Def && "Operand not defined");
336 if (!MRI.hasOneNonDBGUse(RegNo: Reg))
337 return false;
338 switch (Def->getOpcode()) {
339 case TargetOpcode::G_BUILD_VECTOR:
340 IsUndef = false;
341 // Remember the operands of the build_vector to fold
342 // them into the yet-to-build flattened concat vectors.
343 for (const MachineOperand &BuildVecMO : Def->uses())
344 Ops.push_back(Elt: BuildVecMO.getReg());
345 break;
346 case TargetOpcode::G_IMPLICIT_DEF: {
347 LLT OpType = MRI.getType(Reg);
348 // Keep one undef value for all the undef operands.
349 if (!Undef) {
350 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
351 Undef = Builder.buildUndef(Res: OpType.getScalarType());
352 }
353 assert(MRI.getType(Undef->getOperand(0).getReg()) ==
354 OpType.getScalarType() &&
355 "All undefs should have the same type");
356 // Break the undef vector in as many scalar elements as needed
357 // for the flattening.
358 for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements();
359 EltIdx != EltEnd; ++EltIdx)
360 Ops.push_back(Elt: Undef->getOperand(i: 0).getReg());
361 break;
362 }
363 default:
364 return false;
365 }
366 }
367
368 // Check if the combine is illegal
369 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
370 if (!isLegalOrBeforeLegalizer(
371 Query: {TargetOpcode::G_BUILD_VECTOR, {DstTy, MRI.getType(Reg: Ops[0])}})) {
372 return false;
373 }
374
375 if (IsUndef)
376 Ops.clear();
377
378 return true;
379}
380void CombinerHelper::applyCombineConcatVectors(
381 MachineInstr &MI, SmallVector<Register> &Ops) const {
382 // We determined that the concat_vectors can be flatten.
383 // Generate the flattened build_vector.
384 Register DstReg = MI.getOperand(i: 0).getReg();
385 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
386 Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg);
387
388 // Note: IsUndef is sort of redundant. We could have determine it by
389 // checking that at all Ops are undef. Alternatively, we could have
390 // generate a build_vector of undefs and rely on another combine to
391 // clean that up. For now, given we already gather this information
392 // in matchCombineConcatVectors, just save compile time and issue the
393 // right thing.
394 if (Ops.empty())
395 Builder.buildUndef(Res: NewDstReg);
396 else
397 Builder.buildBuildVector(Res: NewDstReg, Ops);
398 replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg);
399 MI.eraseFromParent();
400}
401
402void CombinerHelper::applyCombineShuffleToBuildVector(MachineInstr &MI) const {
403 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
404
405 Register SrcVec1 = Shuffle.getSrc1Reg();
406 Register SrcVec2 = Shuffle.getSrc2Reg();
407 LLT EltTy = MRI.getType(Reg: SrcVec1).getElementType();
408 int Width = MRI.getType(Reg: SrcVec1).getNumElements();
409
410 auto Unmerge1 = Builder.buildUnmerge(Res: EltTy, Op: SrcVec1);
411 auto Unmerge2 = Builder.buildUnmerge(Res: EltTy, Op: SrcVec2);
412
413 SmallVector<Register> Extracts;
414 // Select only applicable elements from unmerged values.
415 for (int Val : Shuffle.getMask()) {
416 if (Val == -1)
417 Extracts.push_back(Elt: Builder.buildUndef(Res: EltTy).getReg(Idx: 0));
418 else if (Val < Width)
419 Extracts.push_back(Elt: Unmerge1.getReg(Idx: Val));
420 else
421 Extracts.push_back(Elt: Unmerge2.getReg(Idx: Val - Width));
422 }
423 assert(Extracts.size() > 0 && "Expected at least one element in the shuffle");
424 if (Extracts.size() == 1)
425 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: Extracts[0]);
426 else
427 Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: Extracts);
428 MI.eraseFromParent();
429}
430
431bool CombinerHelper::matchCombineShuffleConcat(
432 MachineInstr &MI, SmallVector<Register> &Ops) const {
433 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
434 auto ConcatMI1 =
435 dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg()));
436 auto ConcatMI2 =
437 dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg()));
438 if (!ConcatMI1 || !ConcatMI2)
439 return false;
440
441 // Check that the sources of the Concat instructions have the same type
442 if (MRI.getType(Reg: ConcatMI1->getSourceReg(I: 0)) !=
443 MRI.getType(Reg: ConcatMI2->getSourceReg(I: 0)))
444 return false;
445
446 LLT ConcatSrcTy = MRI.getType(Reg: ConcatMI1->getReg(Idx: 1));
447 LLT ShuffleSrcTy1 = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
448 unsigned ConcatSrcNumElt = ConcatSrcTy.getNumElements();
449 for (unsigned i = 0; i < Mask.size(); i += ConcatSrcNumElt) {
450 // Check if the index takes a whole source register from G_CONCAT_VECTORS
451 // Assumes that all Sources of G_CONCAT_VECTORS are the same type
452 if (Mask[i] == -1) {
453 for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
454 if (i + j >= Mask.size())
455 return false;
456 if (Mask[i + j] != -1)
457 return false;
458 }
459 if (!isLegalOrBeforeLegalizer(
460 Query: {TargetOpcode::G_IMPLICIT_DEF, {ConcatSrcTy}}))
461 return false;
462 Ops.push_back(Elt: 0);
463 } else if (Mask[i] % ConcatSrcNumElt == 0) {
464 for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
465 if (i + j >= Mask.size())
466 return false;
467 if (Mask[i + j] != Mask[i] + static_cast<int>(j))
468 return false;
469 }
470 // Retrieve the source register from its respective G_CONCAT_VECTORS
471 // instruction
472 if (Mask[i] < ShuffleSrcTy1.getNumElements()) {
473 Ops.push_back(Elt: ConcatMI1->getSourceReg(I: Mask[i] / ConcatSrcNumElt));
474 } else {
475 Ops.push_back(Elt: ConcatMI2->getSourceReg(I: Mask[i] / ConcatSrcNumElt -
476 ConcatMI1->getNumSources()));
477 }
478 } else {
479 return false;
480 }
481 }
482
483 if (!isLegalOrBeforeLegalizer(
484 Query: {TargetOpcode::G_CONCAT_VECTORS,
485 {MRI.getType(Reg: MI.getOperand(i: 0).getReg()), ConcatSrcTy}}))
486 return false;
487
488 return !Ops.empty();
489}
490
491void CombinerHelper::applyCombineShuffleConcat(
492 MachineInstr &MI, SmallVector<Register> &Ops) const {
493 LLT SrcTy;
494 for (Register &Reg : Ops) {
495 if (Reg != 0)
496 SrcTy = MRI.getType(Reg);
497 }
498 assert(SrcTy.isValid() && "Unexpected full undef vector in concat combine");
499
500 Register UndefReg = 0;
501
502 for (Register &Reg : Ops) {
503 if (Reg == 0) {
504 if (UndefReg == 0)
505 UndefReg = Builder.buildUndef(Res: SrcTy).getReg(Idx: 0);
506 Reg = UndefReg;
507 }
508 }
509
510 if (Ops.size() > 1)
511 Builder.buildConcatVectors(Res: MI.getOperand(i: 0).getReg(), Ops);
512 else
513 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: Ops[0]);
514 MI.eraseFromParent();
515}
516
517bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) const {
518 SmallVector<Register, 4> Ops;
519 if (matchCombineShuffleVector(MI, Ops)) {
520 applyCombineShuffleVector(MI, Ops);
521 return true;
522 }
523 return false;
524}
525
526bool CombinerHelper::matchCombineShuffleVector(
527 MachineInstr &MI, SmallVectorImpl<Register> &Ops) const {
528 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
529 "Invalid instruction kind");
530 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
531 Register Src1 = MI.getOperand(i: 1).getReg();
532 LLT SrcType = MRI.getType(Reg: Src1);
533
534 unsigned DstNumElts = DstType.getNumElements();
535 unsigned SrcNumElts = SrcType.getNumElements();
536
537 // If the resulting vector is smaller than the size of the source
538 // vectors being concatenated, we won't be able to replace the
539 // shuffle vector into a concat_vectors.
540 //
541 // Note: We may still be able to produce a concat_vectors fed by
542 // extract_vector_elt and so on. It is less clear that would
543 // be better though, so don't bother for now.
544 //
545 // If the destination is a scalar, the size of the sources doesn't
546 // matter. we will lower the shuffle to a plain copy. This will
547 // work only if the source and destination have the same size. But
548 // that's covered by the next condition.
549 //
550 // TODO: If the size between the source and destination don't match
551 // we could still emit an extract vector element in that case.
552 if (DstNumElts < 2 * SrcNumElts)
553 return false;
554
555 // Check that the shuffle mask can be broken evenly between the
556 // different sources.
557 if (DstNumElts % SrcNumElts != 0)
558 return false;
559
560 // Mask length is a multiple of the source vector length.
561 // Check if the shuffle is some kind of concatenation of the input
562 // vectors.
563 unsigned NumConcat = DstNumElts / SrcNumElts;
564 SmallVector<int, 8> ConcatSrcs(NumConcat, -1);
565 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
566 for (unsigned i = 0; i != DstNumElts; ++i) {
567 int Idx = Mask[i];
568 // Undef value.
569 if (Idx < 0)
570 continue;
571 // Ensure the indices in each SrcType sized piece are sequential and that
572 // the same source is used for the whole piece.
573 if ((Idx % SrcNumElts != (i % SrcNumElts)) ||
574 (ConcatSrcs[i / SrcNumElts] >= 0 &&
575 ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts)))
576 return false;
577 // Remember which source this index came from.
578 ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts;
579 }
580
581 // The shuffle is concatenating multiple vectors together.
582 // Collect the different operands for that.
583 Register UndefReg;
584 Register Src2 = MI.getOperand(i: 2).getReg();
585 for (auto Src : ConcatSrcs) {
586 if (Src < 0) {
587 if (!UndefReg) {
588 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
589 UndefReg = Builder.buildUndef(Res: SrcType).getReg(Idx: 0);
590 }
591 Ops.push_back(Elt: UndefReg);
592 } else if (Src == 0)
593 Ops.push_back(Elt: Src1);
594 else
595 Ops.push_back(Elt: Src2);
596 }
597 return true;
598}
599
600void CombinerHelper::applyCombineShuffleVector(MachineInstr &MI,
601 ArrayRef<Register> Ops) const {
602 Register DstReg = MI.getOperand(i: 0).getReg();
603 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
604 Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg);
605
606 if (Ops.size() == 1)
607 Builder.buildCopy(Res: NewDstReg, Op: Ops[0]);
608 else
609 Builder.buildMergeLikeInstr(Res: NewDstReg, Ops);
610
611 replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg);
612 MI.eraseFromParent();
613}
614
615namespace {
616
617/// Select a preference between two uses. CurrentUse is the current preference
618/// while *ForCandidate is attributes of the candidate under consideration.
619PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI,
620 PreferredTuple &CurrentUse,
621 const LLT TyForCandidate,
622 unsigned OpcodeForCandidate,
623 MachineInstr *MIForCandidate) {
624 if (!CurrentUse.Ty.isValid()) {
625 if (CurrentUse.ExtendOpcode == OpcodeForCandidate ||
626 CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT)
627 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
628 return CurrentUse;
629 }
630
631 // We permit the extend to hoist through basic blocks but this is only
632 // sensible if the target has extending loads. If you end up lowering back
633 // into a load and extend during the legalizer then the end result is
634 // hoisting the extend up to the load.
635
636 // Prefer defined extensions to undefined extensions as these are more
637 // likely to reduce the number of instructions.
638 if (OpcodeForCandidate == TargetOpcode::G_ANYEXT &&
639 CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT)
640 return CurrentUse;
641 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT &&
642 OpcodeForCandidate != TargetOpcode::G_ANYEXT)
643 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
644
645 // Prefer sign extensions to zero extensions as sign-extensions tend to be
646 // more expensive. Don't do this if the load is already a zero-extend load
647 // though, otherwise we'll rewrite a zero-extend load into a sign-extend
648 // later.
649 if (!isa<GZExtLoad>(Val: LoadMI) && CurrentUse.Ty == TyForCandidate) {
650 if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT &&
651 OpcodeForCandidate == TargetOpcode::G_ZEXT)
652 return CurrentUse;
653 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT &&
654 OpcodeForCandidate == TargetOpcode::G_SEXT)
655 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
656 }
657
658 // This is potentially target specific. We've chosen the largest type
659 // because G_TRUNC is usually free. One potential catch with this is that
660 // some targets have a reduced number of larger registers than smaller
661 // registers and this choice potentially increases the live-range for the
662 // larger value.
663 if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) {
664 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
665 }
666 return CurrentUse;
667}
668
669/// Find a suitable place to insert some instructions and insert them. This
670/// function accounts for special cases like inserting before a PHI node.
671/// The current strategy for inserting before PHI's is to duplicate the
672/// instructions for each predecessor. However, while that's ok for G_TRUNC
673/// on most targets since it generally requires no code, other targets/cases may
674/// want to try harder to find a dominating block.
675static void InsertInsnsWithoutSideEffectsBeforeUse(
676 MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO,
677 std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator,
678 MachineOperand &UseMO)>
679 Inserter) {
680 MachineInstr &UseMI = *UseMO.getParent();
681
682 MachineBasicBlock *InsertBB = UseMI.getParent();
683
684 // If the use is a PHI then we want the predecessor block instead.
685 if (UseMI.isPHI()) {
686 MachineOperand *PredBB = std::next(x: &UseMO);
687 InsertBB = PredBB->getMBB();
688 }
689
690 // If the block is the same block as the def then we want to insert just after
691 // the def instead of at the start of the block.
692 if (InsertBB == DefMI.getParent()) {
693 MachineBasicBlock::iterator InsertPt = &DefMI;
694 Inserter(InsertBB, std::next(x: InsertPt), UseMO);
695 return;
696 }
697
698 // Otherwise we want the start of the BB
699 Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO);
700}
701} // end anonymous namespace
702
703bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) const {
704 PreferredTuple Preferred;
705 if (matchCombineExtendingLoads(MI, MatchInfo&: Preferred)) {
706 applyCombineExtendingLoads(MI, MatchInfo&: Preferred);
707 return true;
708 }
709 return false;
710}
711
712static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) {
713 unsigned CandidateLoadOpc;
714 switch (ExtOpc) {
715 case TargetOpcode::G_ANYEXT:
716 CandidateLoadOpc = TargetOpcode::G_LOAD;
717 break;
718 case TargetOpcode::G_SEXT:
719 CandidateLoadOpc = TargetOpcode::G_SEXTLOAD;
720 break;
721 case TargetOpcode::G_ZEXT:
722 CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD;
723 break;
724 default:
725 llvm_unreachable("Unexpected extend opc");
726 }
727 return CandidateLoadOpc;
728}
729
730bool CombinerHelper::matchCombineExtendingLoads(
731 MachineInstr &MI, PreferredTuple &Preferred) const {
732 // We match the loads and follow the uses to the extend instead of matching
733 // the extends and following the def to the load. This is because the load
734 // must remain in the same position for correctness (unless we also add code
735 // to find a safe place to sink it) whereas the extend is freely movable.
736 // It also prevents us from duplicating the load for the volatile case or just
737 // for performance.
738 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: &MI);
739 if (!LoadMI)
740 return false;
741
742 Register LoadReg = LoadMI->getDstReg();
743
744 LLT LoadValueTy = MRI.getType(Reg: LoadReg);
745 if (!LoadValueTy.isScalar())
746 return false;
747
748 // Most architectures are going to legalize <s8 loads into at least a 1 byte
749 // load, and the MMOs can only describe memory accesses in multiples of bytes.
750 // If we try to perform extload combining on those, we can end up with
751 // %a(s8) = extload %ptr (load 1 byte from %ptr)
752 // ... which is an illegal extload instruction.
753 if (LoadValueTy.getSizeInBits() < 8)
754 return false;
755
756 // For non power-of-2 types, they will very likely be legalized into multiple
757 // loads. Don't bother trying to match them into extending loads.
758 if (!llvm::has_single_bit<uint32_t>(Value: LoadValueTy.getSizeInBits()))
759 return false;
760
761 // Find the preferred type aside from the any-extends (unless it's the only
762 // one) and non-extending ops. We'll emit an extending load to that type and
763 // and emit a variant of (extend (trunc X)) for the others according to the
764 // relative type sizes. At the same time, pick an extend to use based on the
765 // extend involved in the chosen type.
766 unsigned PreferredOpcode =
767 isa<GLoad>(Val: &MI)
768 ? TargetOpcode::G_ANYEXT
769 : isa<GSExtLoad>(Val: &MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
770 Preferred = {.Ty: LLT(), .ExtendOpcode: PreferredOpcode, .MI: nullptr};
771 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: LoadReg)) {
772 if (UseMI.getOpcode() == TargetOpcode::G_SEXT ||
773 UseMI.getOpcode() == TargetOpcode::G_ZEXT ||
774 (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) {
775 const auto &MMO = LoadMI->getMMO();
776 // Don't do anything for atomics.
777 if (MMO.isAtomic())
778 continue;
779 // Check for legality.
780 if (!isPreLegalize()) {
781 LegalityQuery::MemDesc MMDesc(MMO);
782 unsigned CandidateLoadOpc = getExtLoadOpcForExtend(ExtOpc: UseMI.getOpcode());
783 LLT UseTy = MRI.getType(Reg: UseMI.getOperand(i: 0).getReg());
784 LLT SrcTy = MRI.getType(Reg: LoadMI->getPointerReg());
785 if (LI->getAction(Query: {CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}})
786 .Action != LegalizeActions::Legal)
787 continue;
788 }
789 Preferred = ChoosePreferredUse(LoadMI&: MI, CurrentUse&: Preferred,
790 TyForCandidate: MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()),
791 OpcodeForCandidate: UseMI.getOpcode(), MIForCandidate: &UseMI);
792 }
793 }
794
795 // There were no extends
796 if (!Preferred.MI)
797 return false;
798 // It should be impossible to chose an extend without selecting a different
799 // type since by definition the result of an extend is larger.
800 assert(Preferred.Ty != LoadValueTy && "Extending to same type?");
801
802 LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI);
803 return true;
804}
805
806void CombinerHelper::applyCombineExtendingLoads(
807 MachineInstr &MI, PreferredTuple &Preferred) const {
808 // Rewrite the load to the chosen extending load.
809 Register ChosenDstReg = Preferred.MI->getOperand(i: 0).getReg();
810
811 // Inserter to insert a truncate back to the original type at a given point
812 // with some basic CSE to limit truncate duplication to one per BB.
813 DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns;
814 auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB,
815 MachineBasicBlock::iterator InsertBefore,
816 MachineOperand &UseMO) {
817 MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(Val: InsertIntoBB);
818 if (PreviouslyEmitted) {
819 Observer.changingInstr(MI&: *UseMO.getParent());
820 UseMO.setReg(PreviouslyEmitted->getOperand(i: 0).getReg());
821 Observer.changedInstr(MI&: *UseMO.getParent());
822 return;
823 }
824
825 Builder.setInsertPt(MBB&: *InsertIntoBB, II: InsertBefore);
826 Register NewDstReg = MRI.cloneVirtualRegister(VReg: MI.getOperand(i: 0).getReg());
827 MachineInstr *NewMI = Builder.buildTrunc(Res: NewDstReg, Op: ChosenDstReg);
828 EmittedInsns[InsertIntoBB] = NewMI;
829 replaceRegOpWith(MRI, FromRegOp&: UseMO, ToReg: NewDstReg);
830 };
831
832 Observer.changingInstr(MI);
833 unsigned LoadOpc = getExtLoadOpcForExtend(ExtOpc: Preferred.ExtendOpcode);
834 MI.setDesc(Builder.getTII().get(Opcode: LoadOpc));
835
836 // Rewrite all the uses to fix up the types.
837 auto &LoadValue = MI.getOperand(i: 0);
838 SmallVector<MachineOperand *, 4> Uses(
839 llvm::make_pointer_range(Range: MRI.use_operands(Reg: LoadValue.getReg())));
840
841 for (auto *UseMO : Uses) {
842 MachineInstr *UseMI = UseMO->getParent();
843
844 // If the extend is compatible with the preferred extend then we should fix
845 // up the type and extend so that it uses the preferred use.
846 if (UseMI->getOpcode() == Preferred.ExtendOpcode ||
847 UseMI->getOpcode() == TargetOpcode::G_ANYEXT) {
848 Register UseDstReg = UseMI->getOperand(i: 0).getReg();
849 MachineOperand &UseSrcMO = UseMI->getOperand(i: 1);
850 const LLT UseDstTy = MRI.getType(Reg: UseDstReg);
851 if (UseDstReg != ChosenDstReg) {
852 if (Preferred.Ty == UseDstTy) {
853 // If the use has the same type as the preferred use, then merge
854 // the vregs and erase the extend. For example:
855 // %1:_(s8) = G_LOAD ...
856 // %2:_(s32) = G_SEXT %1(s8)
857 // %3:_(s32) = G_ANYEXT %1(s8)
858 // ... = ... %3(s32)
859 // rewrites to:
860 // %2:_(s32) = G_SEXTLOAD ...
861 // ... = ... %2(s32)
862 replaceRegWith(MRI, FromReg: UseDstReg, ToReg: ChosenDstReg);
863 Observer.erasingInstr(MI&: *UseMO->getParent());
864 UseMO->getParent()->eraseFromParent();
865 } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) {
866 // If the preferred size is smaller, then keep the extend but extend
867 // from the result of the extending load. For example:
868 // %1:_(s8) = G_LOAD ...
869 // %2:_(s32) = G_SEXT %1(s8)
870 // %3:_(s64) = G_ANYEXT %1(s8)
871 // ... = ... %3(s64)
872 /// rewrites to:
873 // %2:_(s32) = G_SEXTLOAD ...
874 // %3:_(s64) = G_ANYEXT %2:_(s32)
875 // ... = ... %3(s64)
876 replaceRegOpWith(MRI, FromRegOp&: UseSrcMO, ToReg: ChosenDstReg);
877 } else {
878 // If the preferred size is large, then insert a truncate. For
879 // example:
880 // %1:_(s8) = G_LOAD ...
881 // %2:_(s64) = G_SEXT %1(s8)
882 // %3:_(s32) = G_ZEXT %1(s8)
883 // ... = ... %3(s32)
884 /// rewrites to:
885 // %2:_(s64) = G_SEXTLOAD ...
886 // %4:_(s8) = G_TRUNC %2:_(s32)
887 // %3:_(s64) = G_ZEXT %2:_(s8)
888 // ... = ... %3(s64)
889 InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO,
890 Inserter: InsertTruncAt);
891 }
892 continue;
893 }
894 // The use is (one of) the uses of the preferred use we chose earlier.
895 // We're going to update the load to def this value later so just erase
896 // the old extend.
897 Observer.erasingInstr(MI&: *UseMO->getParent());
898 UseMO->getParent()->eraseFromParent();
899 continue;
900 }
901
902 // The use isn't an extend. Truncate back to the type we originally loaded.
903 // This is free on many targets.
904 InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, Inserter: InsertTruncAt);
905 }
906
907 MI.getOperand(i: 0).setReg(ChosenDstReg);
908 Observer.changedInstr(MI);
909}
910
911bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI,
912 BuildFnTy &MatchInfo) const {
913 assert(MI.getOpcode() == TargetOpcode::G_AND);
914
915 // If we have the following code:
916 // %mask = G_CONSTANT 255
917 // %ld = G_LOAD %ptr, (load s16)
918 // %and = G_AND %ld, %mask
919 //
920 // Try to fold it into
921 // %ld = G_ZEXTLOAD %ptr, (load s8)
922
923 Register Dst = MI.getOperand(i: 0).getReg();
924 if (MRI.getType(Reg: Dst).isVector())
925 return false;
926
927 auto MaybeMask =
928 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
929 if (!MaybeMask)
930 return false;
931
932 APInt MaskVal = MaybeMask->Value;
933
934 if (!MaskVal.isMask())
935 return false;
936
937 Register SrcReg = MI.getOperand(i: 1).getReg();
938 // Don't use getOpcodeDef() here since intermediate instructions may have
939 // multiple users.
940 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: MRI.getVRegDef(Reg: SrcReg));
941 if (!LoadMI || !MRI.hasOneNonDBGUse(RegNo: LoadMI->getDstReg()))
942 return false;
943
944 Register LoadReg = LoadMI->getDstReg();
945 LLT RegTy = MRI.getType(Reg: LoadReg);
946 Register PtrReg = LoadMI->getPointerReg();
947 unsigned RegSize = RegTy.getSizeInBits();
948 LocationSize LoadSizeBits = LoadMI->getMemSizeInBits();
949 unsigned MaskSizeBits = MaskVal.countr_one();
950
951 // The mask may not be larger than the in-memory type, as it might cover sign
952 // extended bits
953 if (MaskSizeBits > LoadSizeBits.getValue())
954 return false;
955
956 // If the mask covers the whole destination register, there's nothing to
957 // extend
958 if (MaskSizeBits >= RegSize)
959 return false;
960
961 // Most targets cannot deal with loads of size < 8 and need to re-legalize to
962 // at least byte loads. Avoid creating such loads here
963 if (MaskSizeBits < 8 || !isPowerOf2_32(Value: MaskSizeBits))
964 return false;
965
966 const MachineMemOperand &MMO = LoadMI->getMMO();
967 LegalityQuery::MemDesc MemDesc(MMO);
968
969 // Don't modify the memory access size if this is atomic/volatile, but we can
970 // still adjust the opcode to indicate the high bit behavior.
971 if (LoadMI->isSimple())
972 MemDesc.MemoryTy = LLT::scalar(SizeInBits: MaskSizeBits);
973 else if (LoadSizeBits.getValue() > MaskSizeBits ||
974 LoadSizeBits.getValue() == RegSize)
975 return false;
976
977 // TODO: Could check if it's legal with the reduced or original memory size.
978 if (!isLegalOrBeforeLegalizer(
979 Query: {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(Reg: PtrReg)}, {MemDesc}}))
980 return false;
981
982 MatchInfo = [=](MachineIRBuilder &B) {
983 B.setInstrAndDebugLoc(*LoadMI);
984 auto &MF = B.getMF();
985 auto PtrInfo = MMO.getPointerInfo();
986 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: MemDesc.MemoryTy);
987 B.buildLoadInstr(Opcode: TargetOpcode::G_ZEXTLOAD, Res: Dst, Addr: PtrReg, MMO&: *NewMMO);
988 LoadMI->eraseFromParent();
989 };
990 return true;
991}
992
993bool CombinerHelper::isPredecessor(const MachineInstr &DefMI,
994 const MachineInstr &UseMI) const {
995 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
996 "shouldn't consider debug uses");
997 assert(DefMI.getParent() == UseMI.getParent());
998 if (&DefMI == &UseMI)
999 return true;
1000 const MachineBasicBlock &MBB = *DefMI.getParent();
1001 auto DefOrUse = find_if(Range: MBB, P: [&DefMI, &UseMI](const MachineInstr &MI) {
1002 return &MI == &DefMI || &MI == &UseMI;
1003 });
1004 if (DefOrUse == MBB.end())
1005 llvm_unreachable("Block must contain both DefMI and UseMI!");
1006 return &*DefOrUse == &DefMI;
1007}
1008
1009bool CombinerHelper::dominates(const MachineInstr &DefMI,
1010 const MachineInstr &UseMI) const {
1011 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
1012 "shouldn't consider debug uses");
1013 if (MDT)
1014 return MDT->dominates(A: &DefMI, B: &UseMI);
1015 else if (DefMI.getParent() != UseMI.getParent())
1016 return false;
1017
1018 return isPredecessor(DefMI, UseMI);
1019}
1020
1021bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) const {
1022 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1023 Register SrcReg = MI.getOperand(i: 1).getReg();
1024 Register LoadUser = SrcReg;
1025
1026 if (MRI.getType(Reg: SrcReg).isVector())
1027 return false;
1028
1029 Register TruncSrc;
1030 if (mi_match(R: SrcReg, MRI, P: m_GTrunc(Src: m_Reg(R&: TruncSrc))))
1031 LoadUser = TruncSrc;
1032
1033 uint64_t SizeInBits = MI.getOperand(i: 2).getImm();
1034 // If the source is a G_SEXTLOAD from the same bit width, then we don't
1035 // need any extend at all, just a truncate.
1036 if (auto *LoadMI = getOpcodeDef<GSExtLoad>(Reg: LoadUser, MRI)) {
1037 // If truncating more than the original extended value, abort.
1038 auto LoadSizeBits = LoadMI->getMemSizeInBits();
1039 if (TruncSrc &&
1040 MRI.getType(Reg: TruncSrc).getSizeInBits() < LoadSizeBits.getValue())
1041 return false;
1042 if (LoadSizeBits == SizeInBits)
1043 return true;
1044 }
1045 return false;
1046}
1047
1048void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) const {
1049 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1050 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: MI.getOperand(i: 1).getReg());
1051 MI.eraseFromParent();
1052}
1053
1054bool CombinerHelper::matchSextInRegOfLoad(
1055 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const {
1056 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1057
1058 Register DstReg = MI.getOperand(i: 0).getReg();
1059 LLT RegTy = MRI.getType(Reg: DstReg);
1060
1061 // Only supports scalars for now.
1062 if (RegTy.isVector())
1063 return false;
1064
1065 Register SrcReg = MI.getOperand(i: 1).getReg();
1066 auto *LoadDef = getOpcodeDef<GLoad>(Reg: SrcReg, MRI);
1067 if (!LoadDef || !MRI.hasOneNonDBGUse(RegNo: SrcReg))
1068 return false;
1069
1070 uint64_t MemBits = LoadDef->getMemSizeInBits().getValue();
1071
1072 // If the sign extend extends from a narrower width than the load's width,
1073 // then we can narrow the load width when we combine to a G_SEXTLOAD.
1074 // Avoid widening the load at all.
1075 unsigned NewSizeBits = std::min(a: (uint64_t)MI.getOperand(i: 2).getImm(), b: MemBits);
1076
1077 // Don't generate G_SEXTLOADs with a < 1 byte width.
1078 if (NewSizeBits < 8)
1079 return false;
1080 // Don't bother creating a non-power-2 sextload, it will likely be broken up
1081 // anyway for most targets.
1082 if (!isPowerOf2_32(Value: NewSizeBits))
1083 return false;
1084
1085 const MachineMemOperand &MMO = LoadDef->getMMO();
1086 LegalityQuery::MemDesc MMDesc(MMO);
1087
1088 // Don't modify the memory access size if this is atomic/volatile, but we can
1089 // still adjust the opcode to indicate the high bit behavior.
1090 if (LoadDef->isSimple())
1091 MMDesc.MemoryTy = LLT::scalar(SizeInBits: NewSizeBits);
1092 else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits())
1093 return false;
1094
1095 // TODO: Could check if it's legal with the reduced or original memory size.
1096 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXTLOAD,
1097 {MRI.getType(Reg: LoadDef->getDstReg()),
1098 MRI.getType(Reg: LoadDef->getPointerReg())},
1099 {MMDesc}}))
1100 return false;
1101
1102 MatchInfo = std::make_tuple(args: LoadDef->getDstReg(), args&: NewSizeBits);
1103 return true;
1104}
1105
1106void CombinerHelper::applySextInRegOfLoad(
1107 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const {
1108 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1109 Register LoadReg;
1110 unsigned ScalarSizeBits;
1111 std::tie(args&: LoadReg, args&: ScalarSizeBits) = MatchInfo;
1112 GLoad *LoadDef = cast<GLoad>(Val: MRI.getVRegDef(Reg: LoadReg));
1113
1114 // If we have the following:
1115 // %ld = G_LOAD %ptr, (load 2)
1116 // %ext = G_SEXT_INREG %ld, 8
1117 // ==>
1118 // %ld = G_SEXTLOAD %ptr (load 1)
1119
1120 auto &MMO = LoadDef->getMMO();
1121 Builder.setInstrAndDebugLoc(*LoadDef);
1122 auto &MF = Builder.getMF();
1123 auto PtrInfo = MMO.getPointerInfo();
1124 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: ScalarSizeBits / 8);
1125 Builder.buildLoadInstr(Opcode: TargetOpcode::G_SEXTLOAD, Res: MI.getOperand(i: 0).getReg(),
1126 Addr: LoadDef->getPointerReg(), MMO&: *NewMMO);
1127 MI.eraseFromParent();
1128
1129 // Not all loads can be deleted, so make sure the old one is removed.
1130 LoadDef->eraseFromParent();
1131}
1132
1133/// Return true if 'MI' is a load or a store that may be fold it's address
1134/// operand into the load / store addressing mode.
1135static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI,
1136 MachineRegisterInfo &MRI) {
1137 TargetLowering::AddrMode AM;
1138 auto *MF = MI->getMF();
1139 auto *Addr = getOpcodeDef<GPtrAdd>(Reg: MI->getPointerReg(), MRI);
1140 if (!Addr)
1141 return false;
1142
1143 AM.HasBaseReg = true;
1144 if (auto CstOff = getIConstantVRegVal(VReg: Addr->getOffsetReg(), MRI))
1145 AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm]
1146 else
1147 AM.Scale = 1; // [reg +/- reg]
1148
1149 return TLI.isLegalAddressingMode(
1150 DL: MF->getDataLayout(), AM,
1151 Ty: getTypeForLLT(Ty: MI->getMMO().getMemoryType(),
1152 C&: MF->getFunction().getContext()),
1153 AddrSpace: MI->getMMO().getAddrSpace());
1154}
1155
1156static unsigned getIndexedOpc(unsigned LdStOpc) {
1157 switch (LdStOpc) {
1158 case TargetOpcode::G_LOAD:
1159 return TargetOpcode::G_INDEXED_LOAD;
1160 case TargetOpcode::G_STORE:
1161 return TargetOpcode::G_INDEXED_STORE;
1162 case TargetOpcode::G_ZEXTLOAD:
1163 return TargetOpcode::G_INDEXED_ZEXTLOAD;
1164 case TargetOpcode::G_SEXTLOAD:
1165 return TargetOpcode::G_INDEXED_SEXTLOAD;
1166 default:
1167 llvm_unreachable("Unexpected opcode");
1168 }
1169}
1170
1171bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
1172 // Check for legality.
1173 LLT PtrTy = MRI.getType(Reg: LdSt.getPointerReg());
1174 LLT Ty = MRI.getType(Reg: LdSt.getReg(Idx: 0));
1175 LLT MemTy = LdSt.getMMO().getMemoryType();
1176 SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
1177 {{MemTy, MemTy.getSizeInBits().getKnownMinValue(),
1178 AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic}});
1179 unsigned IndexedOpc = getIndexedOpc(LdStOpc: LdSt.getOpcode());
1180 SmallVector<LLT> OpTys;
1181 if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)
1182 OpTys = {PtrTy, Ty, Ty};
1183 else
1184 OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD
1185
1186 LegalityQuery Q(IndexedOpc, OpTys, MemDescrs);
1187 return isLegal(Query: Q);
1188}
1189
1190static cl::opt<unsigned> PostIndexUseThreshold(
1191 "post-index-use-threshold", cl::Hidden, cl::init(Val: 32),
1192 cl::desc("Number of uses of a base pointer to check before it is no longer "
1193 "considered for post-indexing."));
1194
1195bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr,
1196 Register &Base, Register &Offset,
1197 bool &RematOffset) const {
1198 // We're looking for the following pattern, for either load or store:
1199 // %baseptr:_(p0) = ...
1200 // G_STORE %val(s64), %baseptr(p0)
1201 // %offset:_(s64) = G_CONSTANT i64 -256
1202 // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64)
1203 const auto &TLI = getTargetLowering();
1204
1205 Register Ptr = LdSt.getPointerReg();
1206 // If the store is the only use, don't bother.
1207 if (MRI.hasOneNonDBGUse(RegNo: Ptr))
1208 return false;
1209
1210 if (!isIndexedLoadStoreLegal(LdSt))
1211 return false;
1212
1213 if (getOpcodeDef(Opcode: TargetOpcode::G_FRAME_INDEX, Reg: Ptr, MRI))
1214 return false;
1215
1216 MachineInstr *StoredValDef = getDefIgnoringCopies(Reg: LdSt.getReg(Idx: 0), MRI);
1217 auto *PtrDef = MRI.getVRegDef(Reg: Ptr);
1218
1219 unsigned NumUsesChecked = 0;
1220 for (auto &Use : MRI.use_nodbg_instructions(Reg: Ptr)) {
1221 if (++NumUsesChecked > PostIndexUseThreshold)
1222 return false; // Try to avoid exploding compile time.
1223
1224 auto *PtrAdd = dyn_cast<GPtrAdd>(Val: &Use);
1225 // The use itself might be dead. This can happen during combines if DCE
1226 // hasn't had a chance to run yet. Don't allow it to form an indexed op.
1227 if (!PtrAdd || MRI.use_nodbg_empty(RegNo: PtrAdd->getReg(Idx: 0)))
1228 continue;
1229
1230 // Check the user of this isn't the store, otherwise we'd be generate a
1231 // indexed store defining its own use.
1232 if (StoredValDef == &Use)
1233 continue;
1234
1235 Offset = PtrAdd->getOffsetReg();
1236 if (!ForceLegalIndexing &&
1237 !TLI.isIndexingLegal(MI&: LdSt, Base: PtrAdd->getBaseReg(), Offset,
1238 /*IsPre*/ false, MRI))
1239 continue;
1240
1241 // Make sure the offset calculation is before the potentially indexed op.
1242 MachineInstr *OffsetDef = MRI.getVRegDef(Reg: Offset);
1243 RematOffset = false;
1244 if (!dominates(DefMI: *OffsetDef, UseMI: LdSt)) {
1245 // If the offset however is just a G_CONSTANT, we can always just
1246 // rematerialize it where we need it.
1247 if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT)
1248 continue;
1249 RematOffset = true;
1250 }
1251
1252 for (auto &BasePtrUse : MRI.use_nodbg_instructions(Reg: PtrAdd->getBaseReg())) {
1253 if (&BasePtrUse == PtrDef)
1254 continue;
1255
1256 // If the user is a later load/store that can be post-indexed, then don't
1257 // combine this one.
1258 auto *BasePtrLdSt = dyn_cast<GLoadStore>(Val: &BasePtrUse);
1259 if (BasePtrLdSt && BasePtrLdSt != &LdSt &&
1260 dominates(DefMI: LdSt, UseMI: *BasePtrLdSt) &&
1261 isIndexedLoadStoreLegal(LdSt&: *BasePtrLdSt))
1262 return false;
1263
1264 // Now we're looking for the key G_PTR_ADD instruction, which contains
1265 // the offset add that we want to fold.
1266 if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(Val: &BasePtrUse)) {
1267 Register PtrAddDefReg = BasePtrUseDef->getReg(Idx: 0);
1268 for (auto &BaseUseUse : MRI.use_nodbg_instructions(Reg: PtrAddDefReg)) {
1269 // If the use is in a different block, then we may produce worse code
1270 // due to the extra register pressure.
1271 if (BaseUseUse.getParent() != LdSt.getParent())
1272 return false;
1273
1274 if (auto *UseUseLdSt = dyn_cast<GLoadStore>(Val: &BaseUseUse))
1275 if (canFoldInAddressingMode(MI: UseUseLdSt, TLI, MRI))
1276 return false;
1277 }
1278 if (!dominates(DefMI: LdSt, UseMI: BasePtrUse))
1279 return false; // All use must be dominated by the load/store.
1280 }
1281 }
1282
1283 Addr = PtrAdd->getReg(Idx: 0);
1284 Base = PtrAdd->getBaseReg();
1285 return true;
1286 }
1287
1288 return false;
1289}
1290
1291bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr,
1292 Register &Base,
1293 Register &Offset) const {
1294 auto &MF = *LdSt.getParent()->getParent();
1295 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1296
1297 Addr = LdSt.getPointerReg();
1298 if (!mi_match(R: Addr, MRI, P: m_GPtrAdd(L: m_Reg(R&: Base), R: m_Reg(R&: Offset))) ||
1299 MRI.hasOneNonDBGUse(RegNo: Addr))
1300 return false;
1301
1302 if (!ForceLegalIndexing &&
1303 !TLI.isIndexingLegal(MI&: LdSt, Base, Offset, /*IsPre*/ true, MRI))
1304 return false;
1305
1306 if (!isIndexedLoadStoreLegal(LdSt))
1307 return false;
1308
1309 MachineInstr *BaseDef = getDefIgnoringCopies(Reg: Base, MRI);
1310 if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX)
1311 return false;
1312
1313 if (auto *St = dyn_cast<GStore>(Val: &LdSt)) {
1314 // Would require a copy.
1315 if (Base == St->getValueReg())
1316 return false;
1317
1318 // We're expecting one use of Addr in MI, but it could also be the
1319 // value stored, which isn't actually dominated by the instruction.
1320 if (St->getValueReg() == Addr)
1321 return false;
1322 }
1323
1324 // Avoid increasing cross-block register pressure.
1325 for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr))
1326 if (AddrUse.getParent() != LdSt.getParent())
1327 return false;
1328
1329 // FIXME: check whether all uses of the base pointer are constant PtrAdds.
1330 // That might allow us to end base's liveness here by adjusting the constant.
1331 bool RealUse = false;
1332 for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) {
1333 if (!dominates(DefMI: LdSt, UseMI: AddrUse))
1334 return false; // All use must be dominated by the load/store.
1335
1336 // If Ptr may be folded in addressing mode of other use, then it's
1337 // not profitable to do this transformation.
1338 if (auto *UseLdSt = dyn_cast<GLoadStore>(Val: &AddrUse)) {
1339 if (!canFoldInAddressingMode(MI: UseLdSt, TLI, MRI))
1340 RealUse = true;
1341 } else {
1342 RealUse = true;
1343 }
1344 }
1345 return RealUse;
1346}
1347
1348bool CombinerHelper::matchCombineExtractedVectorLoad(
1349 MachineInstr &MI, BuildFnTy &MatchInfo) const {
1350 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
1351
1352 // Check if there is a load that defines the vector being extracted from.
1353 auto *LoadMI = getOpcodeDef<GLoad>(Reg: MI.getOperand(i: 1).getReg(), MRI);
1354 if (!LoadMI)
1355 return false;
1356
1357 Register Vector = MI.getOperand(i: 1).getReg();
1358 LLT VecEltTy = MRI.getType(Reg: Vector).getElementType();
1359
1360 assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy);
1361
1362 // Checking whether we should reduce the load width.
1363 if (!MRI.hasOneNonDBGUse(RegNo: Vector))
1364 return false;
1365
1366 // Check if the defining load is simple.
1367 if (!LoadMI->isSimple())
1368 return false;
1369
1370 // If the vector element type is not a multiple of a byte then we are unable
1371 // to correctly compute an address to load only the extracted element as a
1372 // scalar.
1373 if (!VecEltTy.isByteSized())
1374 return false;
1375
1376 // Check for load fold barriers between the extraction and the load.
1377 if (MI.getParent() != LoadMI->getParent())
1378 return false;
1379 const unsigned MaxIter = 20;
1380 unsigned Iter = 0;
1381 for (auto II = LoadMI->getIterator(), IE = MI.getIterator(); II != IE; ++II) {
1382 if (II->isLoadFoldBarrier())
1383 return false;
1384 if (Iter++ == MaxIter)
1385 return false;
1386 }
1387
1388 // Check if the new load that we are going to create is legal
1389 // if we are in the post-legalization phase.
1390 MachineMemOperand MMO = LoadMI->getMMO();
1391 Align Alignment = MMO.getAlign();
1392 MachinePointerInfo PtrInfo;
1393 uint64_t Offset;
1394
1395 // Finding the appropriate PtrInfo if offset is a known constant.
1396 // This is required to create the memory operand for the narrowed load.
1397 // This machine memory operand object helps us infer about legality
1398 // before we proceed to combine the instruction.
1399 if (auto CVal = getIConstantVRegVal(VReg: Vector, MRI)) {
1400 int Elt = CVal->getZExtValue();
1401 // FIXME: should be (ABI size)*Elt.
1402 Offset = VecEltTy.getSizeInBits() * Elt / 8;
1403 PtrInfo = MMO.getPointerInfo().getWithOffset(O: Offset);
1404 } else {
1405 // Discard the pointer info except the address space because the memory
1406 // operand can't represent this new access since the offset is variable.
1407 Offset = VecEltTy.getSizeInBits() / 8;
1408 PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace());
1409 }
1410
1411 Alignment = commonAlignment(A: Alignment, Offset);
1412
1413 Register VecPtr = LoadMI->getPointerReg();
1414 LLT PtrTy = MRI.getType(Reg: VecPtr);
1415
1416 MachineFunction &MF = *MI.getMF();
1417 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: VecEltTy);
1418
1419 LegalityQuery::MemDesc MMDesc(*NewMMO);
1420
1421 if (!isLegalOrBeforeLegalizer(
1422 Query: {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}}))
1423 return false;
1424
1425 // Load must be allowed and fast on the target.
1426 LLVMContext &C = MF.getFunction().getContext();
1427 auto &DL = MF.getDataLayout();
1428 unsigned Fast = 0;
1429 if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty: VecEltTy, MMO: *NewMMO,
1430 Fast: &Fast) ||
1431 !Fast)
1432 return false;
1433
1434 Register Result = MI.getOperand(i: 0).getReg();
1435 Register Index = MI.getOperand(i: 2).getReg();
1436
1437 MatchInfo = [=](MachineIRBuilder &B) {
1438 GISelObserverWrapper DummyObserver;
1439 LegalizerHelper Helper(B.getMF(), DummyObserver, B);
1440 //// Get pointer to the vector element.
1441 Register finalPtr = Helper.getVectorElementPointer(
1442 VecPtr: LoadMI->getPointerReg(), VecTy: MRI.getType(Reg: LoadMI->getOperand(i: 0).getReg()),
1443 Index);
1444 // New G_LOAD instruction.
1445 B.buildLoad(Res: Result, Addr: finalPtr, PtrInfo, Alignment);
1446 // Remove original GLOAD instruction.
1447 LoadMI->eraseFromParent();
1448 };
1449
1450 return true;
1451}
1452
1453bool CombinerHelper::matchCombineIndexedLoadStore(
1454 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const {
1455 auto &LdSt = cast<GLoadStore>(Val&: MI);
1456
1457 if (LdSt.isAtomic())
1458 return false;
1459
1460 MatchInfo.IsPre = findPreIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base,
1461 Offset&: MatchInfo.Offset);
1462 if (!MatchInfo.IsPre &&
1463 !findPostIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base,
1464 Offset&: MatchInfo.Offset, RematOffset&: MatchInfo.RematOffset))
1465 return false;
1466
1467 return true;
1468}
1469
1470void CombinerHelper::applyCombineIndexedLoadStore(
1471 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const {
1472 MachineInstr &AddrDef = *MRI.getUniqueVRegDef(Reg: MatchInfo.Addr);
1473 unsigned Opcode = MI.getOpcode();
1474 bool IsStore = Opcode == TargetOpcode::G_STORE;
1475 unsigned NewOpcode = getIndexedOpc(LdStOpc: Opcode);
1476
1477 // If the offset constant didn't happen to dominate the load/store, we can
1478 // just clone it as needed.
1479 if (MatchInfo.RematOffset) {
1480 auto *OldCst = MRI.getVRegDef(Reg: MatchInfo.Offset);
1481 auto NewCst = Builder.buildConstant(Res: MRI.getType(Reg: MatchInfo.Offset),
1482 Val: *OldCst->getOperand(i: 1).getCImm());
1483 MatchInfo.Offset = NewCst.getReg(Idx: 0);
1484 }
1485
1486 auto MIB = Builder.buildInstr(Opcode: NewOpcode);
1487 if (IsStore) {
1488 MIB.addDef(RegNo: MatchInfo.Addr);
1489 MIB.addUse(RegNo: MI.getOperand(i: 0).getReg());
1490 } else {
1491 MIB.addDef(RegNo: MI.getOperand(i: 0).getReg());
1492 MIB.addDef(RegNo: MatchInfo.Addr);
1493 }
1494
1495 MIB.addUse(RegNo: MatchInfo.Base);
1496 MIB.addUse(RegNo: MatchInfo.Offset);
1497 MIB.addImm(Val: MatchInfo.IsPre);
1498 MIB->cloneMemRefs(MF&: *MI.getMF(), MI);
1499 MI.eraseFromParent();
1500 AddrDef.eraseFromParent();
1501
1502 LLVM_DEBUG(dbgs() << " Combinined to indexed operation");
1503}
1504
1505bool CombinerHelper::matchCombineDivRem(MachineInstr &MI,
1506 MachineInstr *&OtherMI) const {
1507 unsigned Opcode = MI.getOpcode();
1508 bool IsDiv, IsSigned;
1509
1510 switch (Opcode) {
1511 default:
1512 llvm_unreachable("Unexpected opcode!");
1513 case TargetOpcode::G_SDIV:
1514 case TargetOpcode::G_UDIV: {
1515 IsDiv = true;
1516 IsSigned = Opcode == TargetOpcode::G_SDIV;
1517 break;
1518 }
1519 case TargetOpcode::G_SREM:
1520 case TargetOpcode::G_UREM: {
1521 IsDiv = false;
1522 IsSigned = Opcode == TargetOpcode::G_SREM;
1523 break;
1524 }
1525 }
1526
1527 Register Src1 = MI.getOperand(i: 1).getReg();
1528 unsigned DivOpcode, RemOpcode, DivremOpcode;
1529 if (IsSigned) {
1530 DivOpcode = TargetOpcode::G_SDIV;
1531 RemOpcode = TargetOpcode::G_SREM;
1532 DivremOpcode = TargetOpcode::G_SDIVREM;
1533 } else {
1534 DivOpcode = TargetOpcode::G_UDIV;
1535 RemOpcode = TargetOpcode::G_UREM;
1536 DivremOpcode = TargetOpcode::G_UDIVREM;
1537 }
1538
1539 if (!isLegalOrBeforeLegalizer(Query: {DivremOpcode, {MRI.getType(Reg: Src1)}}))
1540 return false;
1541
1542 // Combine:
1543 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1544 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1545 // into:
1546 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1547
1548 // Combine:
1549 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1550 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1551 // into:
1552 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1553
1554 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: Src1)) {
1555 if (MI.getParent() == UseMI.getParent() &&
1556 ((IsDiv && UseMI.getOpcode() == RemOpcode) ||
1557 (!IsDiv && UseMI.getOpcode() == DivOpcode)) &&
1558 matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: UseMI.getOperand(i: 2)) &&
1559 matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: UseMI.getOperand(i: 1))) {
1560 OtherMI = &UseMI;
1561 return true;
1562 }
1563 }
1564
1565 return false;
1566}
1567
1568void CombinerHelper::applyCombineDivRem(MachineInstr &MI,
1569 MachineInstr *&OtherMI) const {
1570 unsigned Opcode = MI.getOpcode();
1571 assert(OtherMI && "OtherMI shouldn't be empty.");
1572
1573 Register DestDivReg, DestRemReg;
1574 if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) {
1575 DestDivReg = MI.getOperand(i: 0).getReg();
1576 DestRemReg = OtherMI->getOperand(i: 0).getReg();
1577 } else {
1578 DestDivReg = OtherMI->getOperand(i: 0).getReg();
1579 DestRemReg = MI.getOperand(i: 0).getReg();
1580 }
1581
1582 bool IsSigned =
1583 Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM;
1584
1585 // Check which instruction is first in the block so we don't break def-use
1586 // deps by "moving" the instruction incorrectly. Also keep track of which
1587 // instruction is first so we pick it's operands, avoiding use-before-def
1588 // bugs.
1589 MachineInstr *FirstInst = dominates(DefMI: MI, UseMI: *OtherMI) ? &MI : OtherMI;
1590 Builder.setInstrAndDebugLoc(*FirstInst);
1591
1592 Builder.buildInstr(Opc: IsSigned ? TargetOpcode::G_SDIVREM
1593 : TargetOpcode::G_UDIVREM,
1594 DstOps: {DestDivReg, DestRemReg},
1595 SrcOps: { FirstInst->getOperand(i: 1), FirstInst->getOperand(i: 2) });
1596 MI.eraseFromParent();
1597 OtherMI->eraseFromParent();
1598}
1599
1600bool CombinerHelper::matchOptBrCondByInvertingCond(
1601 MachineInstr &MI, MachineInstr *&BrCond) const {
1602 assert(MI.getOpcode() == TargetOpcode::G_BR);
1603
1604 // Try to match the following:
1605 // bb1:
1606 // G_BRCOND %c1, %bb2
1607 // G_BR %bb3
1608 // bb2:
1609 // ...
1610 // bb3:
1611
1612 // The above pattern does not have a fall through to the successor bb2, always
1613 // resulting in a branch no matter which path is taken. Here we try to find
1614 // and replace that pattern with conditional branch to bb3 and otherwise
1615 // fallthrough to bb2. This is generally better for branch predictors.
1616
1617 MachineBasicBlock *MBB = MI.getParent();
1618 MachineBasicBlock::iterator BrIt(MI);
1619 if (BrIt == MBB->begin())
1620 return false;
1621 assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator");
1622
1623 BrCond = &*std::prev(x: BrIt);
1624 if (BrCond->getOpcode() != TargetOpcode::G_BRCOND)
1625 return false;
1626
1627 // Check that the next block is the conditional branch target. Also make sure
1628 // that it isn't the same as the G_BR's target (otherwise, this will loop.)
1629 MachineBasicBlock *BrCondTarget = BrCond->getOperand(i: 1).getMBB();
1630 return BrCondTarget != MI.getOperand(i: 0).getMBB() &&
1631 MBB->isLayoutSuccessor(MBB: BrCondTarget);
1632}
1633
1634void CombinerHelper::applyOptBrCondByInvertingCond(
1635 MachineInstr &MI, MachineInstr *&BrCond) const {
1636 MachineBasicBlock *BrTarget = MI.getOperand(i: 0).getMBB();
1637 Builder.setInstrAndDebugLoc(*BrCond);
1638 LLT Ty = MRI.getType(Reg: BrCond->getOperand(i: 0).getReg());
1639 // FIXME: Does int/fp matter for this? If so, we might need to restrict
1640 // this to i1 only since we might not know for sure what kind of
1641 // compare generated the condition value.
1642 auto True = Builder.buildConstant(
1643 Res: Ty, Val: getICmpTrueVal(TLI: getTargetLowering(), IsVector: false, IsFP: false));
1644 auto Xor = Builder.buildXor(Dst: Ty, Src0: BrCond->getOperand(i: 0), Src1: True);
1645
1646 auto *FallthroughBB = BrCond->getOperand(i: 1).getMBB();
1647 Observer.changingInstr(MI);
1648 MI.getOperand(i: 0).setMBB(FallthroughBB);
1649 Observer.changedInstr(MI);
1650
1651 // Change the conditional branch to use the inverted condition and
1652 // new target block.
1653 Observer.changingInstr(MI&: *BrCond);
1654 BrCond->getOperand(i: 0).setReg(Xor.getReg(Idx: 0));
1655 BrCond->getOperand(i: 1).setMBB(BrTarget);
1656 Observer.changedInstr(MI&: *BrCond);
1657}
1658
1659bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) const {
1660 MachineIRBuilder HelperBuilder(MI);
1661 GISelObserverWrapper DummyObserver;
1662 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1663 return Helper.lowerMemcpyInline(MI) ==
1664 LegalizerHelper::LegalizeResult::Legalized;
1665}
1666
1667bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI,
1668 unsigned MaxLen) const {
1669 MachineIRBuilder HelperBuilder(MI);
1670 GISelObserverWrapper DummyObserver;
1671 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1672 return Helper.lowerMemCpyFamily(MI, MaxLen) ==
1673 LegalizerHelper::LegalizeResult::Legalized;
1674}
1675
1676static APFloat constantFoldFpUnary(const MachineInstr &MI,
1677 const MachineRegisterInfo &MRI,
1678 const APFloat &Val) {
1679 APFloat Result(Val);
1680 switch (MI.getOpcode()) {
1681 default:
1682 llvm_unreachable("Unexpected opcode!");
1683 case TargetOpcode::G_FNEG: {
1684 Result.changeSign();
1685 return Result;
1686 }
1687 case TargetOpcode::G_FABS: {
1688 Result.clearSign();
1689 return Result;
1690 }
1691 case TargetOpcode::G_FCEIL:
1692 Result.roundToIntegral(RM: APFloat::rmTowardPositive);
1693 return Result;
1694 case TargetOpcode::G_FFLOOR:
1695 Result.roundToIntegral(RM: APFloat::rmTowardNegative);
1696 return Result;
1697 case TargetOpcode::G_INTRINSIC_TRUNC:
1698 Result.roundToIntegral(RM: APFloat::rmTowardZero);
1699 return Result;
1700 case TargetOpcode::G_INTRINSIC_ROUND:
1701 Result.roundToIntegral(RM: APFloat::rmNearestTiesToAway);
1702 return Result;
1703 case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
1704 Result.roundToIntegral(RM: APFloat::rmNearestTiesToEven);
1705 return Result;
1706 case TargetOpcode::G_FRINT:
1707 case TargetOpcode::G_FNEARBYINT:
1708 // Use default rounding mode (round to nearest, ties to even)
1709 Result.roundToIntegral(RM: APFloat::rmNearestTiesToEven);
1710 return Result;
1711 case TargetOpcode::G_FPEXT:
1712 case TargetOpcode::G_FPTRUNC: {
1713 bool Unused;
1714 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
1715 Result.convert(ToSemantics: getFltSemanticForLLT(Ty: DstTy), RM: APFloat::rmNearestTiesToEven,
1716 losesInfo: &Unused);
1717 return Result;
1718 }
1719 case TargetOpcode::G_FSQRT: {
1720 bool Unused;
1721 Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven,
1722 losesInfo: &Unused);
1723 Result = APFloat(sqrt(x: Result.convertToDouble()));
1724 break;
1725 }
1726 case TargetOpcode::G_FLOG2: {
1727 bool Unused;
1728 Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven,
1729 losesInfo: &Unused);
1730 Result = APFloat(log2(x: Result.convertToDouble()));
1731 break;
1732 }
1733 }
1734 // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise,
1735 // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and
1736 // `G_FLOG2` reach here.
1737 bool Unused;
1738 Result.convert(ToSemantics: Val.getSemantics(), RM: APFloat::rmNearestTiesToEven, losesInfo: &Unused);
1739 return Result;
1740}
1741
1742void CombinerHelper::applyCombineConstantFoldFpUnary(
1743 MachineInstr &MI, const ConstantFP *Cst) const {
1744 APFloat Folded = constantFoldFpUnary(MI, MRI, Val: Cst->getValue());
1745 const ConstantFP *NewCst = ConstantFP::get(Context&: Builder.getContext(), V: Folded);
1746 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: *NewCst);
1747 MI.eraseFromParent();
1748}
1749
1750bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI,
1751 PtrAddChain &MatchInfo) const {
1752 // We're trying to match the following pattern:
1753 // %t1 = G_PTR_ADD %base, G_CONSTANT imm1
1754 // %root = G_PTR_ADD %t1, G_CONSTANT imm2
1755 // -->
1756 // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2)
1757
1758 if (MI.getOpcode() != TargetOpcode::G_PTR_ADD)
1759 return false;
1760
1761 Register Add2 = MI.getOperand(i: 1).getReg();
1762 Register Imm1 = MI.getOperand(i: 2).getReg();
1763 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI);
1764 if (!MaybeImmVal)
1765 return false;
1766
1767 MachineInstr *Add2Def = MRI.getVRegDef(Reg: Add2);
1768 if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD)
1769 return false;
1770
1771 Register Base = Add2Def->getOperand(i: 1).getReg();
1772 Register Imm2 = Add2Def->getOperand(i: 2).getReg();
1773 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI);
1774 if (!MaybeImm2Val)
1775 return false;
1776
1777 // Check if the new combined immediate forms an illegal addressing mode.
1778 // Do not combine if it was legal before but would get illegal.
1779 // To do so, we need to find a load/store user of the pointer to get
1780 // the access type.
1781 Type *AccessTy = nullptr;
1782 auto &MF = *MI.getMF();
1783 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: MI.getOperand(i: 0).getReg())) {
1784 if (auto *LdSt = dyn_cast<GLoadStore>(Val: &UseMI)) {
1785 AccessTy = getTypeForLLT(Ty: MRI.getType(Reg: LdSt->getReg(Idx: 0)),
1786 C&: MF.getFunction().getContext());
1787 break;
1788 }
1789 }
1790 TargetLoweringBase::AddrMode AMNew;
1791 APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value;
1792 AMNew.BaseOffs = CombinedImm.getSExtValue();
1793 if (AccessTy) {
1794 AMNew.HasBaseReg = true;
1795 TargetLoweringBase::AddrMode AMOld;
1796 AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue();
1797 AMOld.HasBaseReg = true;
1798 unsigned AS = MRI.getType(Reg: Add2).getAddressSpace();
1799 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1800 if (TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMOld, Ty: AccessTy, AddrSpace: AS) &&
1801 !TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMNew, Ty: AccessTy, AddrSpace: AS))
1802 return false;
1803 }
1804
1805 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
1806 // inbounds, reaching the same result in one G_PTR_ADD is also inbounds.
1807 // The nusw constraints are satisfied because imm1+imm2 cannot exceed the
1808 // largest signed integer that fits into the index type, which is the maximum
1809 // size of allocated objects according to the IR Language Reference.
1810 unsigned PtrAddFlags = MI.getFlags();
1811 unsigned LHSPtrAddFlags = Add2Def->getFlags();
1812 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
1813 bool IsInBounds =
1814 PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::InBounds;
1815 unsigned Flags = 0;
1816 if (IsNoUWrap)
1817 Flags |= MachineInstr::MIFlag::NoUWrap;
1818 if (IsInBounds) {
1819 Flags |= MachineInstr::MIFlag::InBounds;
1820 Flags |= MachineInstr::MIFlag::NoUSWrap;
1821 }
1822
1823 // Pass the combined immediate to the apply function.
1824 MatchInfo.Imm = AMNew.BaseOffs;
1825 MatchInfo.Base = Base;
1826 MatchInfo.Bank = getRegBank(Reg: Imm2);
1827 MatchInfo.Flags = Flags;
1828 return true;
1829}
1830
1831void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI,
1832 PtrAddChain &MatchInfo) const {
1833 assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD");
1834 MachineIRBuilder MIB(MI);
1835 LLT OffsetTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
1836 auto NewOffset = MIB.buildConstant(Res: OffsetTy, Val: MatchInfo.Imm);
1837 setRegBank(Reg: NewOffset.getReg(Idx: 0), RegBank: MatchInfo.Bank);
1838 Observer.changingInstr(MI);
1839 MI.getOperand(i: 1).setReg(MatchInfo.Base);
1840 MI.getOperand(i: 2).setReg(NewOffset.getReg(Idx: 0));
1841 MI.setFlags(MatchInfo.Flags);
1842 Observer.changedInstr(MI);
1843}
1844
1845bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI,
1846 RegisterImmPair &MatchInfo) const {
1847 // We're trying to match the following pattern with any of
1848 // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions:
1849 // %t1 = SHIFT %base, G_CONSTANT imm1
1850 // %root = SHIFT %t1, G_CONSTANT imm2
1851 // -->
1852 // %root = SHIFT %base, G_CONSTANT (imm1 + imm2)
1853
1854 unsigned Opcode = MI.getOpcode();
1855 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1856 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1857 Opcode == TargetOpcode::G_USHLSAT) &&
1858 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1859
1860 Register Shl2 = MI.getOperand(i: 1).getReg();
1861 Register Imm1 = MI.getOperand(i: 2).getReg();
1862 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI);
1863 if (!MaybeImmVal)
1864 return false;
1865
1866 MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Reg: Shl2);
1867 if (Shl2Def->getOpcode() != Opcode)
1868 return false;
1869
1870 Register Base = Shl2Def->getOperand(i: 1).getReg();
1871 Register Imm2 = Shl2Def->getOperand(i: 2).getReg();
1872 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI);
1873 if (!MaybeImm2Val)
1874 return false;
1875
1876 // Pass the combined immediate to the apply function.
1877 MatchInfo.Imm =
1878 (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue();
1879 MatchInfo.Reg = Base;
1880
1881 // There is no simple replacement for a saturating unsigned left shift that
1882 // exceeds the scalar size.
1883 if (Opcode == TargetOpcode::G_USHLSAT &&
1884 MatchInfo.Imm >= MRI.getType(Reg: Shl2).getScalarSizeInBits())
1885 return false;
1886
1887 return true;
1888}
1889
1890void CombinerHelper::applyShiftImmedChain(MachineInstr &MI,
1891 RegisterImmPair &MatchInfo) const {
1892 unsigned Opcode = MI.getOpcode();
1893 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1894 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1895 Opcode == TargetOpcode::G_USHLSAT) &&
1896 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1897
1898 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
1899 unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits();
1900 auto Imm = MatchInfo.Imm;
1901
1902 if (Imm >= ScalarSizeInBits) {
1903 // Any logical shift that exceeds scalar size will produce zero.
1904 if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) {
1905 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: 0);
1906 MI.eraseFromParent();
1907 return;
1908 }
1909 // Arithmetic shift and saturating signed left shift have no effect beyond
1910 // scalar size.
1911 Imm = ScalarSizeInBits - 1;
1912 }
1913
1914 LLT ImmTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
1915 Register NewImm = Builder.buildConstant(Res: ImmTy, Val: Imm).getReg(Idx: 0);
1916 Observer.changingInstr(MI);
1917 MI.getOperand(i: 1).setReg(MatchInfo.Reg);
1918 MI.getOperand(i: 2).setReg(NewImm);
1919 Observer.changedInstr(MI);
1920}
1921
1922bool CombinerHelper::matchShiftOfShiftedLogic(
1923 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const {
1924 // We're trying to match the following pattern with any of
1925 // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination
1926 // with any of G_AND/G_OR/G_XOR logic instructions.
1927 // %t1 = SHIFT %X, G_CONSTANT C0
1928 // %t2 = LOGIC %t1, %Y
1929 // %root = SHIFT %t2, G_CONSTANT C1
1930 // -->
1931 // %t3 = SHIFT %X, G_CONSTANT (C0+C1)
1932 // %t4 = SHIFT %Y, G_CONSTANT C1
1933 // %root = LOGIC %t3, %t4
1934 unsigned ShiftOpcode = MI.getOpcode();
1935 assert((ShiftOpcode == TargetOpcode::G_SHL ||
1936 ShiftOpcode == TargetOpcode::G_ASHR ||
1937 ShiftOpcode == TargetOpcode::G_LSHR ||
1938 ShiftOpcode == TargetOpcode::G_USHLSAT ||
1939 ShiftOpcode == TargetOpcode::G_SSHLSAT) &&
1940 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1941
1942 // Match a one-use bitwise logic op.
1943 Register LogicDest = MI.getOperand(i: 1).getReg();
1944 if (!MRI.hasOneNonDBGUse(RegNo: LogicDest))
1945 return false;
1946
1947 MachineInstr *LogicMI = MRI.getUniqueVRegDef(Reg: LogicDest);
1948 unsigned LogicOpcode = LogicMI->getOpcode();
1949 if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR &&
1950 LogicOpcode != TargetOpcode::G_XOR)
1951 return false;
1952
1953 // Find a matching one-use shift by constant.
1954 const Register C1 = MI.getOperand(i: 2).getReg();
1955 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: C1, MRI);
1956 if (!MaybeImmVal || MaybeImmVal->Value == 0)
1957 return false;
1958
1959 const uint64_t C1Val = MaybeImmVal->Value.getZExtValue();
1960
1961 auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) {
1962 // Shift should match previous one and should be a one-use.
1963 if (MI->getOpcode() != ShiftOpcode ||
1964 !MRI.hasOneNonDBGUse(RegNo: MI->getOperand(i: 0).getReg()))
1965 return false;
1966
1967 // Must be a constant.
1968 auto MaybeImmVal =
1969 getIConstantVRegValWithLookThrough(VReg: MI->getOperand(i: 2).getReg(), MRI);
1970 if (!MaybeImmVal)
1971 return false;
1972
1973 ShiftVal = MaybeImmVal->Value.getSExtValue();
1974 return true;
1975 };
1976
1977 // Logic ops are commutative, so check each operand for a match.
1978 Register LogicMIReg1 = LogicMI->getOperand(i: 1).getReg();
1979 MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(Reg: LogicMIReg1);
1980 Register LogicMIReg2 = LogicMI->getOperand(i: 2).getReg();
1981 MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(Reg: LogicMIReg2);
1982 uint64_t C0Val;
1983
1984 if (matchFirstShift(LogicMIOp1, C0Val)) {
1985 MatchInfo.LogicNonShiftReg = LogicMIReg2;
1986 MatchInfo.Shift2 = LogicMIOp1;
1987 } else if (matchFirstShift(LogicMIOp2, C0Val)) {
1988 MatchInfo.LogicNonShiftReg = LogicMIReg1;
1989 MatchInfo.Shift2 = LogicMIOp2;
1990 } else
1991 return false;
1992
1993 MatchInfo.ValSum = C0Val + C1Val;
1994
1995 // The fold is not valid if the sum of the shift values exceeds bitwidth.
1996 if (MatchInfo.ValSum >= MRI.getType(Reg: LogicDest).getScalarSizeInBits())
1997 return false;
1998
1999 MatchInfo.Logic = LogicMI;
2000 return true;
2001}
2002
2003void CombinerHelper::applyShiftOfShiftedLogic(
2004 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const {
2005 unsigned Opcode = MI.getOpcode();
2006 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
2007 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT ||
2008 Opcode == TargetOpcode::G_SSHLSAT) &&
2009 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
2010
2011 LLT ShlType = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
2012 LLT DestType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2013
2014 Register Const = Builder.buildConstant(Res: ShlType, Val: MatchInfo.ValSum).getReg(Idx: 0);
2015
2016 Register Shift1Base = MatchInfo.Shift2->getOperand(i: 1).getReg();
2017 Register Shift1 =
2018 Builder.buildInstr(Opc: Opcode, DstOps: {DestType}, SrcOps: {Shift1Base, Const}).getReg(Idx: 0);
2019
2020 // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same
2021 // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when
2022 // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we
2023 // remove old shift1. And it will cause crash later. So erase it earlier to
2024 // avoid the crash.
2025 MatchInfo.Shift2->eraseFromParent();
2026
2027 Register Shift2Const = MI.getOperand(i: 2).getReg();
2028 Register Shift2 = Builder
2029 .buildInstr(Opc: Opcode, DstOps: {DestType},
2030 SrcOps: {MatchInfo.LogicNonShiftReg, Shift2Const})
2031 .getReg(Idx: 0);
2032
2033 Register Dest = MI.getOperand(i: 0).getReg();
2034 Builder.buildInstr(Opc: MatchInfo.Logic->getOpcode(), DstOps: {Dest}, SrcOps: {Shift1, Shift2});
2035
2036 // This was one use so it's safe to remove it.
2037 MatchInfo.Logic->eraseFromParent();
2038
2039 MI.eraseFromParent();
2040}
2041
2042bool CombinerHelper::matchCommuteShift(MachineInstr &MI,
2043 BuildFnTy &MatchInfo) const {
2044 assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL");
2045 // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
2046 // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
2047 auto &Shl = cast<GenericMachineInstr>(Val&: MI);
2048 Register DstReg = Shl.getReg(Idx: 0);
2049 Register SrcReg = Shl.getReg(Idx: 1);
2050 Register ShiftReg = Shl.getReg(Idx: 2);
2051 Register X, C1;
2052
2053 if (!getTargetLowering().isDesirableToCommuteWithShift(MI, IsAfterLegal: !isPreLegalize()))
2054 return false;
2055
2056 if (!mi_match(R: SrcReg, MRI,
2057 P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: C1)),
2058 preds: m_GOr(L: m_Reg(R&: X), R: m_Reg(R&: C1))))))
2059 return false;
2060
2061 APInt C1Val, C2Val;
2062 if (!mi_match(R: C1, MRI, P: m_ICstOrSplat(Cst&: C1Val)) ||
2063 !mi_match(R: ShiftReg, MRI, P: m_ICstOrSplat(Cst&: C2Val)))
2064 return false;
2065
2066 auto *SrcDef = MRI.getVRegDef(Reg: SrcReg);
2067 assert((SrcDef->getOpcode() == TargetOpcode::G_ADD ||
2068 SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op");
2069 LLT SrcTy = MRI.getType(Reg: SrcReg);
2070 MatchInfo = [=](MachineIRBuilder &B) {
2071 auto S1 = B.buildShl(Dst: SrcTy, Src0: X, Src1: ShiftReg);
2072 auto S2 = B.buildShl(Dst: SrcTy, Src0: C1, Src1: ShiftReg);
2073 B.buildInstr(Opc: SrcDef->getOpcode(), DstOps: {DstReg}, SrcOps: {S1, S2});
2074 };
2075 return true;
2076}
2077
2078bool CombinerHelper::matchLshrOfTruncOfLshr(MachineInstr &MI,
2079 LshrOfTruncOfLshr &MatchInfo,
2080 MachineInstr &ShiftMI) const {
2081 assert(MI.getOpcode() == TargetOpcode::G_LSHR && "Expected a G_LSHR");
2082
2083 Register N0 = MI.getOperand(i: 1).getReg();
2084 Register N1 = MI.getOperand(i: 2).getReg();
2085 unsigned OpSizeInBits = MRI.getType(Reg: N0).getScalarSizeInBits();
2086
2087 APInt N1C, N001C;
2088 if (!mi_match(R: N1, MRI, P: m_ICstOrSplat(Cst&: N1C)))
2089 return false;
2090 auto N001 = ShiftMI.getOperand(i: 2).getReg();
2091 if (!mi_match(R: N001, MRI, P: m_ICstOrSplat(Cst&: N001C)))
2092 return false;
2093
2094 if (N001C.getBitWidth() > N1C.getBitWidth())
2095 N1C = N1C.zext(width: N001C.getBitWidth());
2096 else
2097 N001C = N001C.zext(width: N1C.getBitWidth());
2098
2099 Register InnerShift = ShiftMI.getOperand(i: 0).getReg();
2100 LLT InnerShiftTy = MRI.getType(Reg: InnerShift);
2101 uint64_t InnerShiftSize = InnerShiftTy.getScalarSizeInBits();
2102 if ((N1C + N001C).ult(RHS: InnerShiftSize)) {
2103 MatchInfo.Src = ShiftMI.getOperand(i: 1).getReg();
2104 MatchInfo.ShiftAmt = N1C + N001C;
2105 MatchInfo.ShiftAmtTy = MRI.getType(Reg: N001);
2106 MatchInfo.InnerShiftTy = InnerShiftTy;
2107
2108 if ((N001C + OpSizeInBits) == InnerShiftSize)
2109 return true;
2110 if (MRI.hasOneUse(RegNo: N0) && MRI.hasOneUse(RegNo: InnerShift)) {
2111 MatchInfo.Mask = true;
2112 MatchInfo.MaskVal = APInt(N1C.getBitWidth(), OpSizeInBits) - N1C;
2113 return true;
2114 }
2115 }
2116 return false;
2117}
2118
2119void CombinerHelper::applyLshrOfTruncOfLshr(
2120 MachineInstr &MI, LshrOfTruncOfLshr &MatchInfo) const {
2121 assert(MI.getOpcode() == TargetOpcode::G_LSHR && "Expected a G_LSHR");
2122
2123 Register Dst = MI.getOperand(i: 0).getReg();
2124 auto ShiftAmt =
2125 Builder.buildConstant(Res: MatchInfo.ShiftAmtTy, Val: MatchInfo.ShiftAmt);
2126 auto Shift =
2127 Builder.buildLShr(Dst: MatchInfo.InnerShiftTy, Src0: MatchInfo.Src, Src1: ShiftAmt);
2128 if (MatchInfo.Mask == true) {
2129 APInt MaskVal =
2130 APInt::getLowBitsSet(numBits: MatchInfo.InnerShiftTy.getScalarSizeInBits(),
2131 loBitsSet: MatchInfo.MaskVal.getZExtValue());
2132 auto Mask = Builder.buildConstant(Res: MatchInfo.InnerShiftTy, Val: MaskVal);
2133 auto And = Builder.buildAnd(Dst: MatchInfo.InnerShiftTy, Src0: Shift, Src1: Mask);
2134 Builder.buildTrunc(Res: Dst, Op: And);
2135 } else
2136 Builder.buildTrunc(Res: Dst, Op: Shift);
2137 MI.eraseFromParent();
2138}
2139
2140bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI,
2141 unsigned &ShiftVal) const {
2142 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2143 auto MaybeImmVal =
2144 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
2145 if (!MaybeImmVal)
2146 return false;
2147
2148 ShiftVal = MaybeImmVal->Value.exactLogBase2();
2149 return (static_cast<int32_t>(ShiftVal) != -1);
2150}
2151
2152void CombinerHelper::applyCombineMulToShl(MachineInstr &MI,
2153 unsigned &ShiftVal) const {
2154 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2155 MachineIRBuilder MIB(MI);
2156 LLT ShiftTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2157 auto ShiftCst = MIB.buildConstant(Res: ShiftTy, Val: ShiftVal);
2158 Observer.changingInstr(MI);
2159 MI.setDesc(MIB.getTII().get(Opcode: TargetOpcode::G_SHL));
2160 MI.getOperand(i: 2).setReg(ShiftCst.getReg(Idx: 0));
2161 if (ShiftVal == ShiftTy.getScalarSizeInBits() - 1)
2162 MI.clearFlag(Flag: MachineInstr::MIFlag::NoSWrap);
2163 Observer.changedInstr(MI);
2164}
2165
2166bool CombinerHelper::matchCombineSubToAdd(MachineInstr &MI,
2167 BuildFnTy &MatchInfo) const {
2168 GSub &Sub = cast<GSub>(Val&: MI);
2169
2170 LLT Ty = MRI.getType(Reg: Sub.getReg(Idx: 0));
2171
2172 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {Ty}}))
2173 return false;
2174
2175 if (!isConstantLegalOrBeforeLegalizer(Ty))
2176 return false;
2177
2178 APInt Imm = getIConstantFromReg(VReg: Sub.getRHSReg(), MRI);
2179
2180 MatchInfo = [=, &MI](MachineIRBuilder &B) {
2181 auto NegCst = B.buildConstant(Res: Ty, Val: -Imm);
2182 Observer.changingInstr(MI);
2183 MI.setDesc(B.getTII().get(Opcode: TargetOpcode::G_ADD));
2184 MI.getOperand(i: 2).setReg(NegCst.getReg(Idx: 0));
2185 MI.clearFlag(Flag: MachineInstr::MIFlag::NoUWrap);
2186 if (Imm.isMinSignedValue())
2187 MI.clearFlags(flags: MachineInstr::MIFlag::NoSWrap);
2188 Observer.changedInstr(MI);
2189 };
2190 return true;
2191}
2192
2193// shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source
2194bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI,
2195 RegisterImmPair &MatchData) const {
2196 assert(MI.getOpcode() == TargetOpcode::G_SHL && VT);
2197 if (!getTargetLowering().isDesirableToPullExtFromShl(MI))
2198 return false;
2199
2200 Register LHS = MI.getOperand(i: 1).getReg();
2201
2202 Register ExtSrc;
2203 if (!mi_match(R: LHS, MRI, P: m_GAnyExt(Src: m_Reg(R&: ExtSrc))) &&
2204 !mi_match(R: LHS, MRI, P: m_GZExt(Src: m_Reg(R&: ExtSrc))) &&
2205 !mi_match(R: LHS, MRI, P: m_GSExt(Src: m_Reg(R&: ExtSrc))))
2206 return false;
2207
2208 Register RHS = MI.getOperand(i: 2).getReg();
2209 MachineInstr *MIShiftAmt = MRI.getVRegDef(Reg: RHS);
2210 auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(MI&: *MIShiftAmt, MRI);
2211 if (!MaybeShiftAmtVal)
2212 return false;
2213
2214 if (LI) {
2215 LLT SrcTy = MRI.getType(Reg: ExtSrc);
2216
2217 // We only really care about the legality with the shifted value. We can
2218 // pick any type the constant shift amount, so ask the target what to
2219 // use. Otherwise we would have to guess and hope it is reported as legal.
2220 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: SrcTy);
2221 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}}))
2222 return false;
2223 }
2224
2225 int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue();
2226 MatchData.Reg = ExtSrc;
2227 MatchData.Imm = ShiftAmt;
2228
2229 unsigned MinLeadingZeros = VT->getKnownZeroes(R: ExtSrc).countl_one();
2230 unsigned SrcTySize = MRI.getType(Reg: ExtSrc).getScalarSizeInBits();
2231 return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize;
2232}
2233
2234void CombinerHelper::applyCombineShlOfExtend(
2235 MachineInstr &MI, const RegisterImmPair &MatchData) const {
2236 Register ExtSrcReg = MatchData.Reg;
2237 int64_t ShiftAmtVal = MatchData.Imm;
2238
2239 LLT ExtSrcTy = MRI.getType(Reg: ExtSrcReg);
2240 auto ShiftAmt = Builder.buildConstant(Res: ExtSrcTy, Val: ShiftAmtVal);
2241 auto NarrowShift =
2242 Builder.buildShl(Dst: ExtSrcTy, Src0: ExtSrcReg, Src1: ShiftAmt, Flags: MI.getFlags());
2243 Builder.buildZExt(Res: MI.getOperand(i: 0), Op: NarrowShift);
2244 MI.eraseFromParent();
2245}
2246
2247bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI,
2248 Register &MatchInfo) const {
2249 GMerge &Merge = cast<GMerge>(Val&: MI);
2250 SmallVector<Register, 16> MergedValues;
2251 for (unsigned I = 0; I < Merge.getNumSources(); ++I)
2252 MergedValues.emplace_back(Args: Merge.getSourceReg(I));
2253
2254 auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: MergedValues[0], MRI);
2255 if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources())
2256 return false;
2257
2258 for (unsigned I = 0; I < MergedValues.size(); ++I)
2259 if (MergedValues[I] != Unmerge->getReg(Idx: I))
2260 return false;
2261
2262 MatchInfo = Unmerge->getSourceReg();
2263 return true;
2264}
2265
2266static Register peekThroughBitcast(Register Reg,
2267 const MachineRegisterInfo &MRI) {
2268 while (mi_match(R: Reg, MRI, P: m_GBitcast(Src: m_Reg(R&: Reg))))
2269 ;
2270
2271 return Reg;
2272}
2273
2274bool CombinerHelper::matchCombineUnmergeMergeToPlainValues(
2275 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const {
2276 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2277 "Expected an unmerge");
2278 auto &Unmerge = cast<GUnmerge>(Val&: MI);
2279 Register SrcReg = peekThroughBitcast(Reg: Unmerge.getSourceReg(), MRI);
2280
2281 auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(Reg: SrcReg, MRI);
2282 if (!SrcInstr)
2283 return false;
2284
2285 // Check the source type of the merge.
2286 LLT SrcMergeTy = MRI.getType(Reg: SrcInstr->getSourceReg(I: 0));
2287 LLT Dst0Ty = MRI.getType(Reg: Unmerge.getReg(Idx: 0));
2288 bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits();
2289 if (SrcMergeTy != Dst0Ty && !SameSize)
2290 return false;
2291 // They are the same now (modulo a bitcast).
2292 // We can collect all the src registers.
2293 for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx)
2294 Operands.push_back(Elt: SrcInstr->getSourceReg(I: Idx));
2295 return true;
2296}
2297
2298void CombinerHelper::applyCombineUnmergeMergeToPlainValues(
2299 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const {
2300 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2301 "Expected an unmerge");
2302 assert((MI.getNumOperands() - 1 == Operands.size()) &&
2303 "Not enough operands to replace all defs");
2304 unsigned NumElems = MI.getNumOperands() - 1;
2305
2306 LLT SrcTy = MRI.getType(Reg: Operands[0]);
2307 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2308 bool CanReuseInputDirectly = DstTy == SrcTy;
2309 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2310 Register DstReg = MI.getOperand(i: Idx).getReg();
2311 Register SrcReg = Operands[Idx];
2312
2313 // This combine may run after RegBankSelect, so we need to be aware of
2314 // register banks.
2315 const auto &DstCB = MRI.getRegClassOrRegBank(Reg: DstReg);
2316 if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(Reg: SrcReg)) {
2317 SrcReg = Builder.buildCopy(Res: MRI.getType(Reg: SrcReg), Op: SrcReg).getReg(Idx: 0);
2318 MRI.setRegClassOrRegBank(Reg: SrcReg, RCOrRB: DstCB);
2319 }
2320
2321 if (CanReuseInputDirectly)
2322 replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg);
2323 else
2324 Builder.buildCast(Dst: DstReg, Src: SrcReg);
2325 }
2326 MI.eraseFromParent();
2327}
2328
2329bool CombinerHelper::matchCombineUnmergeConstant(
2330 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const {
2331 unsigned SrcIdx = MI.getNumOperands() - 1;
2332 Register SrcReg = MI.getOperand(i: SrcIdx).getReg();
2333 MachineInstr *SrcInstr = MRI.getVRegDef(Reg: SrcReg);
2334 if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT &&
2335 SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT)
2336 return false;
2337 // Break down the big constant in smaller ones.
2338 const MachineOperand &CstVal = SrcInstr->getOperand(i: 1);
2339 APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT
2340 ? CstVal.getCImm()->getValue()
2341 : CstVal.getFPImm()->getValueAPF().bitcastToAPInt();
2342
2343 LLT Dst0Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2344 unsigned ShiftAmt = Dst0Ty.getSizeInBits();
2345 // Unmerge a constant.
2346 for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) {
2347 Csts.emplace_back(Args: Val.trunc(width: ShiftAmt));
2348 Val = Val.lshr(shiftAmt: ShiftAmt);
2349 }
2350
2351 return true;
2352}
2353
2354void CombinerHelper::applyCombineUnmergeConstant(
2355 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const {
2356 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2357 "Expected an unmerge");
2358 assert((MI.getNumOperands() - 1 == Csts.size()) &&
2359 "Not enough operands to replace all defs");
2360 unsigned NumElems = MI.getNumOperands() - 1;
2361 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2362 Register DstReg = MI.getOperand(i: Idx).getReg();
2363 Builder.buildConstant(Res: DstReg, Val: Csts[Idx]);
2364 }
2365
2366 MI.eraseFromParent();
2367}
2368
2369bool CombinerHelper::matchCombineUnmergeUndef(
2370 MachineInstr &MI,
2371 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
2372 unsigned SrcIdx = MI.getNumOperands() - 1;
2373 Register SrcReg = MI.getOperand(i: SrcIdx).getReg();
2374 MatchInfo = [&MI](MachineIRBuilder &B) {
2375 unsigned NumElems = MI.getNumOperands() - 1;
2376 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2377 Register DstReg = MI.getOperand(i: Idx).getReg();
2378 B.buildUndef(Res: DstReg);
2379 }
2380 };
2381 return isa<GImplicitDef>(Val: MRI.getVRegDef(Reg: SrcReg));
2382}
2383
2384bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc(
2385 MachineInstr &MI) const {
2386 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2387 "Expected an unmerge");
2388 if (MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector() ||
2389 MRI.getType(Reg: MI.getOperand(i: MI.getNumDefs()).getReg()).isVector())
2390 return false;
2391 // Check that all the lanes are dead except the first one.
2392 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2393 if (!MRI.use_nodbg_empty(RegNo: MI.getOperand(i: Idx).getReg()))
2394 return false;
2395 }
2396 return true;
2397}
2398
2399void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc(
2400 MachineInstr &MI) const {
2401 Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg();
2402 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2403 Builder.buildTrunc(Res: Dst0Reg, Op: SrcReg);
2404 MI.eraseFromParent();
2405}
2406
2407bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) const {
2408 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2409 "Expected an unmerge");
2410 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2411 LLT Dst0Ty = MRI.getType(Reg: Dst0Reg);
2412 // G_ZEXT on vector applies to each lane, so it will
2413 // affect all destinations. Therefore we won't be able
2414 // to simplify the unmerge to just the first definition.
2415 if (Dst0Ty.isVector())
2416 return false;
2417 Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg();
2418 LLT SrcTy = MRI.getType(Reg: SrcReg);
2419 if (SrcTy.isVector())
2420 return false;
2421
2422 Register ZExtSrcReg;
2423 if (!mi_match(R: SrcReg, MRI, P: m_GZExt(Src: m_Reg(R&: ZExtSrcReg))))
2424 return false;
2425
2426 // Finally we can replace the first definition with
2427 // a zext of the source if the definition is big enough to hold
2428 // all of ZExtSrc bits.
2429 LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg);
2430 return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits();
2431}
2432
2433void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) const {
2434 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2435 "Expected an unmerge");
2436
2437 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2438
2439 MachineInstr *ZExtInstr =
2440 MRI.getVRegDef(Reg: MI.getOperand(i: MI.getNumDefs()).getReg());
2441 assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT &&
2442 "Expecting a G_ZEXT");
2443
2444 Register ZExtSrcReg = ZExtInstr->getOperand(i: 1).getReg();
2445 LLT Dst0Ty = MRI.getType(Reg: Dst0Reg);
2446 LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg);
2447
2448 if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) {
2449 Builder.buildZExt(Res: Dst0Reg, Op: ZExtSrcReg);
2450 } else {
2451 assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() &&
2452 "ZExt src doesn't fit in destination");
2453 replaceRegWith(MRI, FromReg: Dst0Reg, ToReg: ZExtSrcReg);
2454 }
2455
2456 Register ZeroReg;
2457 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2458 if (!ZeroReg)
2459 ZeroReg = Builder.buildConstant(Res: Dst0Ty, Val: 0).getReg(Idx: 0);
2460 replaceRegWith(MRI, FromReg: MI.getOperand(i: Idx).getReg(), ToReg: ZeroReg);
2461 }
2462 MI.eraseFromParent();
2463}
2464
2465bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI,
2466 unsigned TargetShiftSize,
2467 unsigned &ShiftVal) const {
2468 assert((MI.getOpcode() == TargetOpcode::G_SHL ||
2469 MI.getOpcode() == TargetOpcode::G_LSHR ||
2470 MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift");
2471
2472 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2473 if (Ty.isVector()) // TODO:
2474 return false;
2475
2476 // Don't narrow further than the requested size.
2477 unsigned Size = Ty.getSizeInBits();
2478 if (Size <= TargetShiftSize)
2479 return false;
2480
2481 auto MaybeImmVal =
2482 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
2483 if (!MaybeImmVal)
2484 return false;
2485
2486 ShiftVal = MaybeImmVal->Value.getSExtValue();
2487 return ShiftVal >= Size / 2 && ShiftVal < Size;
2488}
2489
2490void CombinerHelper::applyCombineShiftToUnmerge(
2491 MachineInstr &MI, const unsigned &ShiftVal) const {
2492 Register DstReg = MI.getOperand(i: 0).getReg();
2493 Register SrcReg = MI.getOperand(i: 1).getReg();
2494 LLT Ty = MRI.getType(Reg: SrcReg);
2495 unsigned Size = Ty.getSizeInBits();
2496 unsigned HalfSize = Size / 2;
2497 assert(ShiftVal >= HalfSize);
2498
2499 LLT HalfTy = LLT::scalar(SizeInBits: HalfSize);
2500
2501 auto Unmerge = Builder.buildUnmerge(Res: HalfTy, Op: SrcReg);
2502 unsigned NarrowShiftAmt = ShiftVal - HalfSize;
2503
2504 if (MI.getOpcode() == TargetOpcode::G_LSHR) {
2505 Register Narrowed = Unmerge.getReg(Idx: 1);
2506
2507 // dst = G_LSHR s64:x, C for C >= 32
2508 // =>
2509 // lo, hi = G_UNMERGE_VALUES x
2510 // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0
2511
2512 if (NarrowShiftAmt != 0) {
2513 Narrowed = Builder.buildLShr(Dst: HalfTy, Src0: Narrowed,
2514 Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0);
2515 }
2516
2517 auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0);
2518 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Narrowed, Zero});
2519 } else if (MI.getOpcode() == TargetOpcode::G_SHL) {
2520 Register Narrowed = Unmerge.getReg(Idx: 0);
2521 // dst = G_SHL s64:x, C for C >= 32
2522 // =>
2523 // lo, hi = G_UNMERGE_VALUES x
2524 // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32)
2525 if (NarrowShiftAmt != 0) {
2526 Narrowed = Builder.buildShl(Dst: HalfTy, Src0: Narrowed,
2527 Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0);
2528 }
2529
2530 auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0);
2531 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Zero, Narrowed});
2532 } else {
2533 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
2534 auto Hi = Builder.buildAShr(
2535 Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1),
2536 Src1: Builder.buildConstant(Res: HalfTy, Val: HalfSize - 1));
2537
2538 if (ShiftVal == HalfSize) {
2539 // (G_ASHR i64:x, 32) ->
2540 // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31)
2541 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Unmerge.getReg(Idx: 1), Hi});
2542 } else if (ShiftVal == Size - 1) {
2543 // Don't need a second shift.
2544 // (G_ASHR i64:x, 63) ->
2545 // %narrowed = (G_ASHR hi_32(x), 31)
2546 // G_MERGE_VALUES %narrowed, %narrowed
2547 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Hi, Hi});
2548 } else {
2549 auto Lo = Builder.buildAShr(
2550 Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1),
2551 Src1: Builder.buildConstant(Res: HalfTy, Val: ShiftVal - HalfSize));
2552
2553 // (G_ASHR i64:x, C) ->, for C >= 32
2554 // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31)
2555 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Lo, Hi});
2556 }
2557 }
2558
2559 MI.eraseFromParent();
2560}
2561
2562bool CombinerHelper::tryCombineShiftToUnmerge(
2563 MachineInstr &MI, unsigned TargetShiftAmount) const {
2564 unsigned ShiftAmt;
2565 if (matchCombineShiftToUnmerge(MI, TargetShiftSize: TargetShiftAmount, ShiftVal&: ShiftAmt)) {
2566 applyCombineShiftToUnmerge(MI, ShiftVal: ShiftAmt);
2567 return true;
2568 }
2569
2570 return false;
2571}
2572
2573bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI,
2574 Register &Reg) const {
2575 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2576 Register DstReg = MI.getOperand(i: 0).getReg();
2577 LLT DstTy = MRI.getType(Reg: DstReg);
2578 Register SrcReg = MI.getOperand(i: 1).getReg();
2579 return mi_match(R: SrcReg, MRI,
2580 P: m_GPtrToInt(Src: m_all_of(preds: m_SpecificType(Ty: DstTy), preds: m_Reg(R&: Reg))));
2581}
2582
2583void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI,
2584 Register &Reg) const {
2585 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2586 Register DstReg = MI.getOperand(i: 0).getReg();
2587 Builder.buildCopy(Res: DstReg, Op: Reg);
2588 MI.eraseFromParent();
2589}
2590
2591void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI,
2592 Register &Reg) const {
2593 assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT");
2594 Register DstReg = MI.getOperand(i: 0).getReg();
2595 Builder.buildZExtOrTrunc(Res: DstReg, Op: Reg);
2596 MI.eraseFromParent();
2597}
2598
2599bool CombinerHelper::matchCombineAddP2IToPtrAdd(
2600 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const {
2601 assert(MI.getOpcode() == TargetOpcode::G_ADD);
2602 Register LHS = MI.getOperand(i: 1).getReg();
2603 Register RHS = MI.getOperand(i: 2).getReg();
2604 LLT IntTy = MRI.getType(Reg: LHS);
2605
2606 // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the
2607 // instruction.
2608 PtrReg.second = false;
2609 for (Register SrcReg : {LHS, RHS}) {
2610 if (mi_match(R: SrcReg, MRI, P: m_GPtrToInt(Src: m_Reg(R&: PtrReg.first)))) {
2611 // Don't handle cases where the integer is implicitly converted to the
2612 // pointer width.
2613 LLT PtrTy = MRI.getType(Reg: PtrReg.first);
2614 if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits())
2615 return true;
2616 }
2617
2618 PtrReg.second = true;
2619 }
2620
2621 return false;
2622}
2623
2624void CombinerHelper::applyCombineAddP2IToPtrAdd(
2625 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const {
2626 Register Dst = MI.getOperand(i: 0).getReg();
2627 Register LHS = MI.getOperand(i: 1).getReg();
2628 Register RHS = MI.getOperand(i: 2).getReg();
2629
2630 const bool DoCommute = PtrReg.second;
2631 if (DoCommute)
2632 std::swap(a&: LHS, b&: RHS);
2633 LHS = PtrReg.first;
2634
2635 LLT PtrTy = MRI.getType(Reg: LHS);
2636
2637 auto PtrAdd = Builder.buildPtrAdd(Res: PtrTy, Op0: LHS, Op1: RHS);
2638 Builder.buildPtrToInt(Dst, Src: PtrAdd);
2639 MI.eraseFromParent();
2640}
2641
2642bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI,
2643 APInt &NewCst) const {
2644 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
2645 Register LHS = PtrAdd.getBaseReg();
2646 Register RHS = PtrAdd.getOffsetReg();
2647 MachineRegisterInfo &MRI = Builder.getMF().getRegInfo();
2648
2649 if (auto RHSCst = getIConstantVRegVal(VReg: RHS, MRI)) {
2650 APInt Cst;
2651 if (mi_match(R: LHS, MRI, P: m_GIntToPtr(Src: m_ICst(Cst)))) {
2652 auto DstTy = MRI.getType(Reg: PtrAdd.getReg(Idx: 0));
2653 // G_INTTOPTR uses zero-extension
2654 NewCst = Cst.zextOrTrunc(width: DstTy.getSizeInBits());
2655 NewCst += RHSCst->sextOrTrunc(width: DstTy.getSizeInBits());
2656 return true;
2657 }
2658 }
2659
2660 return false;
2661}
2662
2663void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI,
2664 APInt &NewCst) const {
2665 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
2666 Register Dst = PtrAdd.getReg(Idx: 0);
2667
2668 Builder.buildConstant(Res: Dst, Val: NewCst);
2669 PtrAdd.eraseFromParent();
2670}
2671
2672bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI,
2673 Register &Reg) const {
2674 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT");
2675 Register DstReg = MI.getOperand(i: 0).getReg();
2676 Register SrcReg = MI.getOperand(i: 1).getReg();
2677 Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI);
2678 if (OriginalSrcReg.isValid())
2679 SrcReg = OriginalSrcReg;
2680 LLT DstTy = MRI.getType(Reg: DstReg);
2681 return mi_match(R: SrcReg, MRI,
2682 P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))) &&
2683 canReplaceReg(DstReg, SrcReg: Reg, MRI);
2684}
2685
2686bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI,
2687 Register &Reg) const {
2688 assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT");
2689 Register DstReg = MI.getOperand(i: 0).getReg();
2690 Register SrcReg = MI.getOperand(i: 1).getReg();
2691 LLT DstTy = MRI.getType(Reg: DstReg);
2692 if (mi_match(R: SrcReg, MRI,
2693 P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))) &&
2694 canReplaceReg(DstReg, SrcReg: Reg, MRI)) {
2695 unsigned DstSize = DstTy.getScalarSizeInBits();
2696 unsigned SrcSize = MRI.getType(Reg: SrcReg).getScalarSizeInBits();
2697 return VT->getKnownBits(R: Reg).countMinLeadingZeros() >= DstSize - SrcSize;
2698 }
2699 return false;
2700}
2701
2702static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) {
2703 const unsigned ShiftSize = ShiftTy.getScalarSizeInBits();
2704 const unsigned TruncSize = TruncTy.getScalarSizeInBits();
2705
2706 // ShiftTy > 32 > TruncTy -> 32
2707 if (ShiftSize > 32 && TruncSize < 32)
2708 return ShiftTy.changeElementSize(NewEltSize: 32);
2709
2710 // TODO: We could also reduce to 16 bits, but that's more target-dependent.
2711 // Some targets like it, some don't, some only like it under certain
2712 // conditions/processor versions, etc.
2713 // A TL hook might be needed for this.
2714
2715 // Don't combine
2716 return ShiftTy;
2717}
2718
2719bool CombinerHelper::matchCombineTruncOfShift(
2720 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const {
2721 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2722 Register DstReg = MI.getOperand(i: 0).getReg();
2723 Register SrcReg = MI.getOperand(i: 1).getReg();
2724
2725 if (!MRI.hasOneNonDBGUse(RegNo: SrcReg))
2726 return false;
2727
2728 LLT SrcTy = MRI.getType(Reg: SrcReg);
2729 LLT DstTy = MRI.getType(Reg: DstReg);
2730
2731 MachineInstr *SrcMI = getDefIgnoringCopies(Reg: SrcReg, MRI);
2732 const auto &TL = getTargetLowering();
2733
2734 LLT NewShiftTy;
2735 switch (SrcMI->getOpcode()) {
2736 default:
2737 return false;
2738 case TargetOpcode::G_SHL: {
2739 NewShiftTy = DstTy;
2740
2741 // Make sure new shift amount is legal.
2742 KnownBits Known = VT->getKnownBits(R: SrcMI->getOperand(i: 2).getReg());
2743 if (Known.getMaxValue().uge(RHS: NewShiftTy.getScalarSizeInBits()))
2744 return false;
2745 break;
2746 }
2747 case TargetOpcode::G_LSHR:
2748 case TargetOpcode::G_ASHR: {
2749 // For right shifts, we conservatively do not do the transform if the TRUNC
2750 // has any STORE users. The reason is that if we change the type of the
2751 // shift, we may break the truncstore combine.
2752 //
2753 // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)).
2754 for (auto &User : MRI.use_instructions(Reg: DstReg))
2755 if (User.getOpcode() == TargetOpcode::G_STORE)
2756 return false;
2757
2758 NewShiftTy = getMidVTForTruncRightShiftCombine(ShiftTy: SrcTy, TruncTy: DstTy);
2759 if (NewShiftTy == SrcTy)
2760 return false;
2761
2762 // Make sure we won't lose information by truncating the high bits.
2763 KnownBits Known = VT->getKnownBits(R: SrcMI->getOperand(i: 2).getReg());
2764 if (Known.getMaxValue().ugt(RHS: NewShiftTy.getScalarSizeInBits() -
2765 DstTy.getScalarSizeInBits()))
2766 return false;
2767 break;
2768 }
2769 }
2770
2771 if (!isLegalOrBeforeLegalizer(
2772 Query: {SrcMI->getOpcode(),
2773 {NewShiftTy, TL.getPreferredShiftAmountTy(ShiftValueTy: NewShiftTy)}}))
2774 return false;
2775
2776 MatchInfo = std::make_pair(x&: SrcMI, y&: NewShiftTy);
2777 return true;
2778}
2779
2780void CombinerHelper::applyCombineTruncOfShift(
2781 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const {
2782 MachineInstr *ShiftMI = MatchInfo.first;
2783 LLT NewShiftTy = MatchInfo.second;
2784
2785 Register Dst = MI.getOperand(i: 0).getReg();
2786 LLT DstTy = MRI.getType(Reg: Dst);
2787
2788 Register ShiftAmt = ShiftMI->getOperand(i: 2).getReg();
2789 Register ShiftSrc = ShiftMI->getOperand(i: 1).getReg();
2790 ShiftSrc = Builder.buildTrunc(Res: NewShiftTy, Op: ShiftSrc).getReg(Idx: 0);
2791
2792 Register NewShift =
2793 Builder
2794 .buildInstr(Opc: ShiftMI->getOpcode(), DstOps: {NewShiftTy}, SrcOps: {ShiftSrc, ShiftAmt})
2795 .getReg(Idx: 0);
2796
2797 if (NewShiftTy == DstTy)
2798 replaceRegWith(MRI, FromReg: Dst, ToReg: NewShift);
2799 else
2800 Builder.buildTrunc(Res: Dst, Op: NewShift);
2801
2802 eraseInst(MI);
2803}
2804
2805bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) const {
2806 return any_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) {
2807 return MO.isReg() &&
2808 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
2809 });
2810}
2811
2812bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) const {
2813 return all_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) {
2814 return !MO.isReg() ||
2815 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
2816 });
2817}
2818
2819bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) const {
2820 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
2821 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
2822 return all_of(Range&: Mask, P: [](int Elt) { return Elt < 0; });
2823}
2824
2825bool CombinerHelper::matchUndefStore(MachineInstr &MI) const {
2826 assert(MI.getOpcode() == TargetOpcode::G_STORE);
2827 return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 0).getReg(),
2828 MRI);
2829}
2830
2831bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) const {
2832 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2833 return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 1).getReg(),
2834 MRI);
2835}
2836
2837bool CombinerHelper::matchInsertExtractVecEltOutOfBounds(
2838 MachineInstr &MI) const {
2839 assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT ||
2840 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) &&
2841 "Expected an insert/extract element op");
2842 LLT VecTy = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
2843 if (VecTy.isScalableVector())
2844 return false;
2845
2846 unsigned IdxIdx =
2847 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3;
2848 auto Idx = getIConstantVRegVal(VReg: MI.getOperand(i: IdxIdx).getReg(), MRI);
2849 if (!Idx)
2850 return false;
2851 return Idx->getZExtValue() >= VecTy.getNumElements();
2852}
2853
2854bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI,
2855 unsigned &OpIdx) const {
2856 GSelect &SelMI = cast<GSelect>(Val&: MI);
2857 auto Cst =
2858 isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: SelMI.getCondReg()), MRI);
2859 if (!Cst)
2860 return false;
2861 OpIdx = Cst->isZero() ? 3 : 2;
2862 return true;
2863}
2864
2865void CombinerHelper::eraseInst(MachineInstr &MI) const { MI.eraseFromParent(); }
2866
2867bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1,
2868 const MachineOperand &MOP2) const {
2869 if (!MOP1.isReg() || !MOP2.isReg())
2870 return false;
2871 auto InstAndDef1 = getDefSrcRegIgnoringCopies(Reg: MOP1.getReg(), MRI);
2872 if (!InstAndDef1)
2873 return false;
2874 auto InstAndDef2 = getDefSrcRegIgnoringCopies(Reg: MOP2.getReg(), MRI);
2875 if (!InstAndDef2)
2876 return false;
2877 MachineInstr *I1 = InstAndDef1->MI;
2878 MachineInstr *I2 = InstAndDef2->MI;
2879
2880 // Handle a case like this:
2881 //
2882 // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>)
2883 //
2884 // Even though %0 and %1 are produced by the same instruction they are not
2885 // the same values.
2886 if (I1 == I2)
2887 return MOP1.getReg() == MOP2.getReg();
2888
2889 // If we have an instruction which loads or stores, we can't guarantee that
2890 // it is identical.
2891 //
2892 // For example, we may have
2893 //
2894 // %x1 = G_LOAD %addr (load N from @somewhere)
2895 // ...
2896 // call @foo
2897 // ...
2898 // %x2 = G_LOAD %addr (load N from @somewhere)
2899 // ...
2900 // %or = G_OR %x1, %x2
2901 //
2902 // It's possible that @foo will modify whatever lives at the address we're
2903 // loading from. To be safe, let's just assume that all loads and stores
2904 // are different (unless we have something which is guaranteed to not
2905 // change.)
2906 if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad())
2907 return false;
2908
2909 // If both instructions are loads or stores, they are equal only if both
2910 // are dereferenceable invariant loads with the same number of bits.
2911 if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) {
2912 GLoadStore *LS1 = dyn_cast<GLoadStore>(Val: I1);
2913 GLoadStore *LS2 = dyn_cast<GLoadStore>(Val: I2);
2914 if (!LS1 || !LS2)
2915 return false;
2916
2917 if (!I2->isDereferenceableInvariantLoad() ||
2918 (LS1->getMemSizeInBits() != LS2->getMemSizeInBits()))
2919 return false;
2920 }
2921
2922 // Check for physical registers on the instructions first to avoid cases
2923 // like this:
2924 //
2925 // %a = COPY $physreg
2926 // ...
2927 // SOMETHING implicit-def $physreg
2928 // ...
2929 // %b = COPY $physreg
2930 //
2931 // These copies are not equivalent.
2932 if (any_of(Range: I1->uses(), P: [](const MachineOperand &MO) {
2933 return MO.isReg() && MO.getReg().isPhysical();
2934 })) {
2935 // Check if we have a case like this:
2936 //
2937 // %a = COPY $physreg
2938 // %b = COPY %a
2939 //
2940 // In this case, I1 and I2 will both be equal to %a = COPY $physreg.
2941 // From that, we know that they must have the same value, since they must
2942 // have come from the same COPY.
2943 return I1->isIdenticalTo(Other: *I2);
2944 }
2945
2946 // We don't have any physical registers, so we don't necessarily need the
2947 // same vreg defs.
2948 //
2949 // On the off-chance that there's some target instruction feeding into the
2950 // instruction, let's use produceSameValue instead of isIdenticalTo.
2951 if (Builder.getTII().produceSameValue(MI0: *I1, MI1: *I2, MRI: &MRI)) {
2952 // Handle instructions with multiple defs that produce same values. Values
2953 // are same for operands with same index.
2954 // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2955 // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2956 // I1 and I2 are different instructions but produce same values,
2957 // %1 and %6 are same, %1 and %7 are not the same value.
2958 return I1->findRegisterDefOperandIdx(Reg: InstAndDef1->Reg, /*TRI=*/nullptr) ==
2959 I2->findRegisterDefOperandIdx(Reg: InstAndDef2->Reg, /*TRI=*/nullptr);
2960 }
2961 return false;
2962}
2963
2964bool CombinerHelper::matchConstantOp(const MachineOperand &MOP,
2965 int64_t C) const {
2966 if (!MOP.isReg())
2967 return false;
2968 auto *MI = MRI.getVRegDef(Reg: MOP.getReg());
2969 auto MaybeCst = isConstantOrConstantSplatVector(MI&: *MI, MRI);
2970 return MaybeCst && MaybeCst->getBitWidth() <= 64 &&
2971 MaybeCst->getSExtValue() == C;
2972}
2973
2974bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP,
2975 double C) const {
2976 if (!MOP.isReg())
2977 return false;
2978 std::optional<FPValueAndVReg> MaybeCst;
2979 if (!mi_match(R: MOP.getReg(), MRI, P: m_GFCstOrSplat(FPValReg&: MaybeCst)))
2980 return false;
2981
2982 return MaybeCst->Value.isExactlyValue(V: C);
2983}
2984
2985void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI,
2986 unsigned OpIdx) const {
2987 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2988 Register OldReg = MI.getOperand(i: 0).getReg();
2989 Register Replacement = MI.getOperand(i: OpIdx).getReg();
2990 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2991 replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement);
2992 MI.eraseFromParent();
2993}
2994
2995void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI,
2996 Register Replacement) const {
2997 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2998 Register OldReg = MI.getOperand(i: 0).getReg();
2999 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
3000 replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement);
3001 MI.eraseFromParent();
3002}
3003
3004bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI,
3005 unsigned ConstIdx) const {
3006 Register ConstReg = MI.getOperand(i: ConstIdx).getReg();
3007 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3008
3009 // Get the shift amount
3010 auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI);
3011 if (!VRegAndVal)
3012 return false;
3013
3014 // Return true of shift amount >= Bitwidth
3015 return (VRegAndVal->Value.uge(RHS: DstTy.getSizeInBits()));
3016}
3017
3018void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) const {
3019 assert((MI.getOpcode() == TargetOpcode::G_FSHL ||
3020 MI.getOpcode() == TargetOpcode::G_FSHR) &&
3021 "This is not a funnel shift operation");
3022
3023 Register ConstReg = MI.getOperand(i: 3).getReg();
3024 LLT ConstTy = MRI.getType(Reg: ConstReg);
3025 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3026
3027 auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI);
3028 assert((VRegAndVal) && "Value is not a constant");
3029
3030 // Calculate the new Shift Amount = Old Shift Amount % BitWidth
3031 APInt NewConst = VRegAndVal->Value.urem(
3032 RHS: APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits()));
3033
3034 auto NewConstInstr = Builder.buildConstant(Res: ConstTy, Val: NewConst.getZExtValue());
3035 Builder.buildInstr(
3036 Opc: MI.getOpcode(), DstOps: {MI.getOperand(i: 0)},
3037 SrcOps: {MI.getOperand(i: 1), MI.getOperand(i: 2), NewConstInstr.getReg(Idx: 0)});
3038
3039 MI.eraseFromParent();
3040}
3041
3042bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) const {
3043 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
3044 // Match (cond ? x : x)
3045 return matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: MI.getOperand(i: 3)) &&
3046 canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 2).getReg(),
3047 MRI);
3048}
3049
3050bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) const {
3051 return matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: MI.getOperand(i: 2)) &&
3052 canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 1).getReg(),
3053 MRI);
3054}
3055
3056bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI,
3057 unsigned OpIdx) const {
3058 MachineOperand &MO = MI.getOperand(i: OpIdx);
3059 return MO.isReg() &&
3060 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
3061}
3062
3063bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI,
3064 unsigned OpIdx) const {
3065 MachineOperand &MO = MI.getOperand(i: OpIdx);
3066 return isKnownToBeAPowerOfTwo(Val: MO.getReg(), MRI, ValueTracking: VT);
3067}
3068
3069void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI,
3070 double C) const {
3071 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3072 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: C);
3073 MI.eraseFromParent();
3074}
3075
3076void CombinerHelper::replaceInstWithConstant(MachineInstr &MI,
3077 int64_t C) const {
3078 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3079 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C);
3080 MI.eraseFromParent();
3081}
3082
3083void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) const {
3084 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3085 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C);
3086 MI.eraseFromParent();
3087}
3088
3089void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI,
3090 ConstantFP *CFP) const {
3091 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3092 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: CFP->getValueAPF());
3093 MI.eraseFromParent();
3094}
3095
3096void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) const {
3097 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3098 Builder.buildUndef(Res: MI.getOperand(i: 0));
3099 MI.eraseFromParent();
3100}
3101
3102bool CombinerHelper::matchSimplifyAddToSub(
3103 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const {
3104 Register LHS = MI.getOperand(i: 1).getReg();
3105 Register RHS = MI.getOperand(i: 2).getReg();
3106 Register &NewLHS = std::get<0>(t&: MatchInfo);
3107 Register &NewRHS = std::get<1>(t&: MatchInfo);
3108
3109 // Helper lambda to check for opportunities for
3110 // ((0-A) + B) -> B - A
3111 // (A + (0-B)) -> A - B
3112 auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) {
3113 if (!mi_match(R: MaybeSub, MRI, P: m_Neg(Src: m_Reg(R&: NewRHS))))
3114 return false;
3115 NewLHS = MaybeNewLHS;
3116 return true;
3117 };
3118
3119 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
3120}
3121
3122bool CombinerHelper::matchCombineInsertVecElts(
3123 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const {
3124 assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT &&
3125 "Invalid opcode");
3126 Register DstReg = MI.getOperand(i: 0).getReg();
3127 LLT DstTy = MRI.getType(Reg: DstReg);
3128 assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?");
3129
3130 if (DstTy.isScalableVector())
3131 return false;
3132
3133 unsigned NumElts = DstTy.getNumElements();
3134 // If this MI is part of a sequence of insert_vec_elts, then
3135 // don't do the combine in the middle of the sequence.
3136 if (MRI.hasOneUse(RegNo: DstReg) && MRI.use_instr_begin(RegNo: DstReg)->getOpcode() ==
3137 TargetOpcode::G_INSERT_VECTOR_ELT)
3138 return false;
3139 MachineInstr *CurrInst = &MI;
3140 MachineInstr *TmpInst;
3141 int64_t IntImm;
3142 Register TmpReg;
3143 MatchInfo.resize(N: NumElts);
3144 while (mi_match(
3145 R: CurrInst->getOperand(i: 0).getReg(), MRI,
3146 P: m_GInsertVecElt(Src0: m_MInstr(MI&: TmpInst), Src1: m_Reg(R&: TmpReg), Src2: m_ICst(Cst&: IntImm)))) {
3147 if (IntImm >= NumElts || IntImm < 0)
3148 return false;
3149 if (!MatchInfo[IntImm])
3150 MatchInfo[IntImm] = TmpReg;
3151 CurrInst = TmpInst;
3152 }
3153 // Variable index.
3154 if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
3155 return false;
3156 if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
3157 for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) {
3158 if (!MatchInfo[I - 1].isValid())
3159 MatchInfo[I - 1] = TmpInst->getOperand(i: I).getReg();
3160 }
3161 return true;
3162 }
3163 // If we didn't end in a G_IMPLICIT_DEF and the source is not fully
3164 // overwritten, bail out.
3165 return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF ||
3166 all_of(Range&: MatchInfo, P: [](Register Reg) { return !!Reg; });
3167}
3168
3169void CombinerHelper::applyCombineInsertVecElts(
3170 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const {
3171 Register UndefReg;
3172 auto GetUndef = [&]() {
3173 if (UndefReg)
3174 return UndefReg;
3175 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3176 UndefReg = Builder.buildUndef(Res: DstTy.getScalarType()).getReg(Idx: 0);
3177 return UndefReg;
3178 };
3179 for (Register &Reg : MatchInfo) {
3180 if (!Reg)
3181 Reg = GetUndef();
3182 }
3183 Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: MatchInfo);
3184 MI.eraseFromParent();
3185}
3186
3187void CombinerHelper::applySimplifyAddToSub(
3188 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const {
3189 Register SubLHS, SubRHS;
3190 std::tie(args&: SubLHS, args&: SubRHS) = MatchInfo;
3191 Builder.buildSub(Dst: MI.getOperand(i: 0).getReg(), Src0: SubLHS, Src1: SubRHS);
3192 MI.eraseFromParent();
3193}
3194
3195bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands(
3196 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const {
3197 // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ...
3198 //
3199 // Creates the new hand + logic instruction (but does not insert them.)
3200 //
3201 // On success, MatchInfo is populated with the new instructions. These are
3202 // inserted in applyHoistLogicOpWithSameOpcodeHands.
3203 unsigned LogicOpcode = MI.getOpcode();
3204 assert(LogicOpcode == TargetOpcode::G_AND ||
3205 LogicOpcode == TargetOpcode::G_OR ||
3206 LogicOpcode == TargetOpcode::G_XOR);
3207 MachineIRBuilder MIB(MI);
3208 Register Dst = MI.getOperand(i: 0).getReg();
3209 Register LHSReg = MI.getOperand(i: 1).getReg();
3210 Register RHSReg = MI.getOperand(i: 2).getReg();
3211
3212 // Don't recompute anything.
3213 if (!MRI.hasOneNonDBGUse(RegNo: LHSReg) || !MRI.hasOneNonDBGUse(RegNo: RHSReg))
3214 return false;
3215
3216 // Make sure we have (hand x, ...), (hand y, ...)
3217 MachineInstr *LeftHandInst = getDefIgnoringCopies(Reg: LHSReg, MRI);
3218 MachineInstr *RightHandInst = getDefIgnoringCopies(Reg: RHSReg, MRI);
3219 if (!LeftHandInst || !RightHandInst)
3220 return false;
3221 unsigned HandOpcode = LeftHandInst->getOpcode();
3222 if (HandOpcode != RightHandInst->getOpcode())
3223 return false;
3224 if (LeftHandInst->getNumOperands() < 2 ||
3225 !LeftHandInst->getOperand(i: 1).isReg() ||
3226 RightHandInst->getNumOperands() < 2 ||
3227 !RightHandInst->getOperand(i: 1).isReg())
3228 return false;
3229
3230 // Make sure the types match up, and if we're doing this post-legalization,
3231 // we end up with legal types.
3232 Register X = LeftHandInst->getOperand(i: 1).getReg();
3233 Register Y = RightHandInst->getOperand(i: 1).getReg();
3234 LLT XTy = MRI.getType(Reg: X);
3235 LLT YTy = MRI.getType(Reg: Y);
3236 if (!XTy.isValid() || XTy != YTy)
3237 return false;
3238
3239 // Optional extra source register.
3240 Register ExtraHandOpSrcReg;
3241 switch (HandOpcode) {
3242 default:
3243 return false;
3244 case TargetOpcode::G_ANYEXT:
3245 case TargetOpcode::G_SEXT:
3246 case TargetOpcode::G_ZEXT: {
3247 // Match: logic (ext X), (ext Y) --> ext (logic X, Y)
3248 break;
3249 }
3250 case TargetOpcode::G_TRUNC: {
3251 // Match: logic (trunc X), (trunc Y) -> trunc (logic X, Y)
3252 const MachineFunction *MF = MI.getMF();
3253 LLVMContext &Ctx = MF->getFunction().getContext();
3254
3255 LLT DstTy = MRI.getType(Reg: Dst);
3256 const TargetLowering &TLI = getTargetLowering();
3257
3258 // Be extra careful sinking truncate. If it's free, there's no benefit in
3259 // widening a binop.
3260 if (TLI.isZExtFree(FromTy: DstTy, ToTy: XTy, Ctx) && TLI.isTruncateFree(FromTy: XTy, ToTy: DstTy, Ctx))
3261 return false;
3262 break;
3263 }
3264 case TargetOpcode::G_AND:
3265 case TargetOpcode::G_ASHR:
3266 case TargetOpcode::G_LSHR:
3267 case TargetOpcode::G_SHL: {
3268 // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z
3269 MachineOperand &ZOp = LeftHandInst->getOperand(i: 2);
3270 if (!matchEqualDefs(MOP1: ZOp, MOP2: RightHandInst->getOperand(i: 2)))
3271 return false;
3272 ExtraHandOpSrcReg = ZOp.getReg();
3273 break;
3274 }
3275 }
3276
3277 if (!isLegalOrBeforeLegalizer(Query: {LogicOpcode, {XTy, YTy}}))
3278 return false;
3279
3280 // Record the steps to build the new instructions.
3281 //
3282 // Steps to build (logic x, y)
3283 auto NewLogicDst = MRI.createGenericVirtualRegister(Ty: XTy);
3284 OperandBuildSteps LogicBuildSteps = {
3285 [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: NewLogicDst); },
3286 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: X); },
3287 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: Y); }};
3288 InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps);
3289
3290 // Steps to build hand (logic x, y), ...z
3291 OperandBuildSteps HandBuildSteps = {
3292 [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: Dst); },
3293 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: NewLogicDst); }};
3294 if (ExtraHandOpSrcReg.isValid())
3295 HandBuildSteps.push_back(
3296 Elt: [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: ExtraHandOpSrcReg); });
3297 InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps);
3298
3299 MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps});
3300 return true;
3301}
3302
3303void CombinerHelper::applyBuildInstructionSteps(
3304 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const {
3305 assert(MatchInfo.InstrsToBuild.size() &&
3306 "Expected at least one instr to build?");
3307 for (auto &InstrToBuild : MatchInfo.InstrsToBuild) {
3308 assert(InstrToBuild.Opcode && "Expected a valid opcode?");
3309 assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?");
3310 MachineInstrBuilder Instr = Builder.buildInstr(Opcode: InstrToBuild.Opcode);
3311 for (auto &OperandFn : InstrToBuild.OperandFns)
3312 OperandFn(Instr);
3313 }
3314 MI.eraseFromParent();
3315}
3316
3317bool CombinerHelper::matchAshrShlToSextInreg(
3318 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const {
3319 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3320 int64_t ShlCst, AshrCst;
3321 Register Src;
3322 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
3323 P: m_GAShr(L: m_GShl(L: m_Reg(R&: Src), R: m_ICstOrSplat(Cst&: ShlCst)),
3324 R: m_ICstOrSplat(Cst&: AshrCst))))
3325 return false;
3326 if (ShlCst != AshrCst)
3327 return false;
3328 if (!isLegalOrBeforeLegalizer(
3329 Query: {TargetOpcode::G_SEXT_INREG, {MRI.getType(Reg: Src)}}))
3330 return false;
3331 MatchInfo = std::make_tuple(args&: Src, args&: ShlCst);
3332 return true;
3333}
3334
3335void CombinerHelper::applyAshShlToSextInreg(
3336 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const {
3337 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3338 Register Src;
3339 int64_t ShiftAmt;
3340 std::tie(args&: Src, args&: ShiftAmt) = MatchInfo;
3341 unsigned Size = MRI.getType(Reg: Src).getScalarSizeInBits();
3342 Builder.buildSExtInReg(Res: MI.getOperand(i: 0).getReg(), Op: Src, ImmOp: Size - ShiftAmt);
3343 MI.eraseFromParent();
3344}
3345
3346/// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0
3347bool CombinerHelper::matchOverlappingAnd(
3348 MachineInstr &MI,
3349 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
3350 assert(MI.getOpcode() == TargetOpcode::G_AND);
3351
3352 Register Dst = MI.getOperand(i: 0).getReg();
3353 LLT Ty = MRI.getType(Reg: Dst);
3354
3355 Register R;
3356 int64_t C1;
3357 int64_t C2;
3358 if (!mi_match(
3359 R: Dst, MRI,
3360 P: m_GAnd(L: m_GAnd(L: m_Reg(R), R: m_ICst(Cst&: C1)), R: m_ICst(Cst&: C2))))
3361 return false;
3362
3363 MatchInfo = [=](MachineIRBuilder &B) {
3364 if (C1 & C2) {
3365 B.buildAnd(Dst, Src0: R, Src1: B.buildConstant(Res: Ty, Val: C1 & C2));
3366 return;
3367 }
3368 auto Zero = B.buildConstant(Res: Ty, Val: 0);
3369 replaceRegWith(MRI, FromReg: Dst, ToReg: Zero->getOperand(i: 0).getReg());
3370 };
3371 return true;
3372}
3373
3374bool CombinerHelper::matchRedundantAnd(MachineInstr &MI,
3375 Register &Replacement) const {
3376 // Given
3377 //
3378 // %y:_(sN) = G_SOMETHING
3379 // %x:_(sN) = G_SOMETHING
3380 // %res:_(sN) = G_AND %x, %y
3381 //
3382 // Eliminate the G_AND when it is known that x & y == x or x & y == y.
3383 //
3384 // Patterns like this can appear as a result of legalization. E.g.
3385 //
3386 // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y
3387 // %one:_(s32) = G_CONSTANT i32 1
3388 // %and:_(s32) = G_AND %cmp, %one
3389 //
3390 // In this case, G_ICMP only produces a single bit, so x & 1 == x.
3391 assert(MI.getOpcode() == TargetOpcode::G_AND);
3392 if (!VT)
3393 return false;
3394
3395 Register AndDst = MI.getOperand(i: 0).getReg();
3396 Register LHS = MI.getOperand(i: 1).getReg();
3397 Register RHS = MI.getOperand(i: 2).getReg();
3398
3399 // Check the RHS (maybe a constant) first, and if we have no KnownBits there,
3400 // we can't do anything. If we do, then it depends on whether we have
3401 // KnownBits on the LHS.
3402 KnownBits RHSBits = VT->getKnownBits(R: RHS);
3403 if (RHSBits.isUnknown())
3404 return false;
3405
3406 KnownBits LHSBits = VT->getKnownBits(R: LHS);
3407
3408 // Check that x & Mask == x.
3409 // x & 1 == x, always
3410 // x & 0 == x, only if x is also 0
3411 // Meaning Mask has no effect if every bit is either one in Mask or zero in x.
3412 //
3413 // Check if we can replace AndDst with the LHS of the G_AND
3414 if (canReplaceReg(DstReg: AndDst, SrcReg: LHS, MRI) &&
3415 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3416 Replacement = LHS;
3417 return true;
3418 }
3419
3420 // Check if we can replace AndDst with the RHS of the G_AND
3421 if (canReplaceReg(DstReg: AndDst, SrcReg: RHS, MRI) &&
3422 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3423 Replacement = RHS;
3424 return true;
3425 }
3426
3427 return false;
3428}
3429
3430bool CombinerHelper::matchRedundantOr(MachineInstr &MI,
3431 Register &Replacement) const {
3432 // Given
3433 //
3434 // %y:_(sN) = G_SOMETHING
3435 // %x:_(sN) = G_SOMETHING
3436 // %res:_(sN) = G_OR %x, %y
3437 //
3438 // Eliminate the G_OR when it is known that x | y == x or x | y == y.
3439 assert(MI.getOpcode() == TargetOpcode::G_OR);
3440 if (!VT)
3441 return false;
3442
3443 Register OrDst = MI.getOperand(i: 0).getReg();
3444 Register LHS = MI.getOperand(i: 1).getReg();
3445 Register RHS = MI.getOperand(i: 2).getReg();
3446
3447 KnownBits LHSBits = VT->getKnownBits(R: LHS);
3448 KnownBits RHSBits = VT->getKnownBits(R: RHS);
3449
3450 // Check that x | Mask == x.
3451 // x | 0 == x, always
3452 // x | 1 == x, only if x is also 1
3453 // Meaning Mask has no effect if every bit is either zero in Mask or one in x.
3454 //
3455 // Check if we can replace OrDst with the LHS of the G_OR
3456 if (canReplaceReg(DstReg: OrDst, SrcReg: LHS, MRI) &&
3457 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3458 Replacement = LHS;
3459 return true;
3460 }
3461
3462 // Check if we can replace OrDst with the RHS of the G_OR
3463 if (canReplaceReg(DstReg: OrDst, SrcReg: RHS, MRI) &&
3464 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3465 Replacement = RHS;
3466 return true;
3467 }
3468
3469 return false;
3470}
3471
3472bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) const {
3473 // If the input is already sign extended, just drop the extension.
3474 Register Src = MI.getOperand(i: 1).getReg();
3475 unsigned ExtBits = MI.getOperand(i: 2).getImm();
3476 unsigned TypeSize = MRI.getType(Reg: Src).getScalarSizeInBits();
3477 return VT->computeNumSignBits(R: Src) >= (TypeSize - ExtBits + 1);
3478}
3479
3480static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits,
3481 int64_t Cst, bool IsVector, bool IsFP) {
3482 // For i1, Cst will always be -1 regardless of boolean contents.
3483 return (ScalarSizeBits == 1 && Cst == -1) ||
3484 isConstTrueVal(TLI, Val: Cst, IsVector, IsFP);
3485}
3486
3487// This pattern aims to match the following shape to avoid extra mov
3488// instructions
3489// G_BUILD_VECTOR(
3490// G_UNMERGE_VALUES(src, 0)
3491// G_UNMERGE_VALUES(src, 1)
3492// G_IMPLICIT_DEF
3493// G_IMPLICIT_DEF
3494// )
3495// ->
3496// G_CONCAT_VECTORS(
3497// src,
3498// undef
3499// )
3500bool CombinerHelper::matchCombineBuildUnmerge(MachineInstr &MI,
3501 MachineRegisterInfo &MRI,
3502 Register &UnmergeSrc) const {
3503 auto &BV = cast<GBuildVector>(Val&: MI);
3504
3505 unsigned BuildUseCount = BV.getNumSources();
3506 if (BuildUseCount % 2 != 0)
3507 return false;
3508
3509 unsigned NumUnmerge = BuildUseCount / 2;
3510
3511 auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: BV.getSourceReg(I: 0), MRI);
3512
3513 // Check the first operand is an unmerge and has the correct number of
3514 // operands
3515 if (!Unmerge || Unmerge->getNumDefs() != NumUnmerge)
3516 return false;
3517
3518 UnmergeSrc = Unmerge->getSourceReg();
3519
3520 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3521 LLT UnmergeSrcTy = MRI.getType(Reg: UnmergeSrc);
3522
3523 if (!UnmergeSrcTy.isVector())
3524 return false;
3525
3526 // Ensure we only generate legal instructions post-legalizer
3527 if (!IsPreLegalize &&
3528 !isLegal(Query: {TargetOpcode::G_CONCAT_VECTORS, {DstTy, UnmergeSrcTy}}))
3529 return false;
3530
3531 // Check that all of the operands before the midpoint come from the same
3532 // unmerge and are in the same order as they are used in the build_vector
3533 for (unsigned I = 0; I < NumUnmerge; ++I) {
3534 auto MaybeUnmergeReg = BV.getSourceReg(I);
3535 auto *LoopUnmerge = getOpcodeDef<GUnmerge>(Reg: MaybeUnmergeReg, MRI);
3536
3537 if (!LoopUnmerge || LoopUnmerge != Unmerge)
3538 return false;
3539
3540 if (LoopUnmerge->getOperand(i: I).getReg() != MaybeUnmergeReg)
3541 return false;
3542 }
3543
3544 // Check that all of the unmerged values are used
3545 if (Unmerge->getNumDefs() != NumUnmerge)
3546 return false;
3547
3548 // Check that all of the operands after the mid point are undefs.
3549 for (unsigned I = NumUnmerge; I < BuildUseCount; ++I) {
3550 auto *Undef = getDefIgnoringCopies(Reg: BV.getSourceReg(I), MRI);
3551
3552 if (Undef->getOpcode() != TargetOpcode::G_IMPLICIT_DEF)
3553 return false;
3554 }
3555
3556 return true;
3557}
3558
3559void CombinerHelper::applyCombineBuildUnmerge(MachineInstr &MI,
3560 MachineRegisterInfo &MRI,
3561 MachineIRBuilder &B,
3562 Register &UnmergeSrc) const {
3563 assert(UnmergeSrc && "Expected there to be one matching G_UNMERGE_VALUES");
3564 B.setInstrAndDebugLoc(MI);
3565
3566 Register UndefVec = B.buildUndef(Res: MRI.getType(Reg: UnmergeSrc)).getReg(Idx: 0);
3567 B.buildConcatVectors(Res: MI.getOperand(i: 0), Ops: {UnmergeSrc, UndefVec});
3568
3569 MI.eraseFromParent();
3570}
3571
3572// This combine tries to reduce the number of scalarised G_TRUNC instructions by
3573// using vector truncates instead
3574//
3575// EXAMPLE:
3576// %a(i32), %b(i32) = G_UNMERGE_VALUES %src(<2 x i32>)
3577// %T_a(i16) = G_TRUNC %a(i32)
3578// %T_b(i16) = G_TRUNC %b(i32)
3579// %Undef(i16) = G_IMPLICIT_DEF(i16)
3580// %dst(v4i16) = G_BUILD_VECTORS %T_a(i16), %T_b(i16), %Undef(i16), %Undef(i16)
3581//
3582// ===>
3583// %Undef(<2 x i32>) = G_IMPLICIT_DEF(<2 x i32>)
3584// %Mid(<4 x s32>) = G_CONCAT_VECTORS %src(<2 x i32>), %Undef(<2 x i32>)
3585// %dst(<4 x s16>) = G_TRUNC %Mid(<4 x s32>)
3586//
3587// Only matches sources made up of G_TRUNCs followed by G_IMPLICIT_DEFs
3588bool CombinerHelper::matchUseVectorTruncate(MachineInstr &MI,
3589 Register &MatchInfo) const {
3590 auto BuildMI = cast<GBuildVector>(Val: &MI);
3591 unsigned NumOperands = BuildMI->getNumSources();
3592 LLT DstTy = MRI.getType(Reg: BuildMI->getReg(Idx: 0));
3593
3594 // Check the G_BUILD_VECTOR sources
3595 unsigned I;
3596 MachineInstr *UnmergeMI = nullptr;
3597
3598 // Check all source TRUNCs come from the same UNMERGE instruction
3599 // and that the element order matches (BUILD_VECTOR position I
3600 // corresponds to UNMERGE result I)
3601 for (I = 0; I < NumOperands; ++I) {
3602 auto SrcMI = MRI.getVRegDef(Reg: BuildMI->getSourceReg(I));
3603 auto SrcMIOpc = SrcMI->getOpcode();
3604
3605 // Check if the G_TRUNC instructions all come from the same MI
3606 if (SrcMIOpc == TargetOpcode::G_TRUNC) {
3607 Register TruncSrcReg = SrcMI->getOperand(i: 1).getReg();
3608 if (!UnmergeMI) {
3609 UnmergeMI = MRI.getVRegDef(Reg: TruncSrcReg);
3610 if (UnmergeMI->getOpcode() != TargetOpcode::G_UNMERGE_VALUES)
3611 return false;
3612 } else {
3613 auto UnmergeSrcMI = MRI.getVRegDef(Reg: TruncSrcReg);
3614 if (UnmergeMI != UnmergeSrcMI)
3615 return false;
3616 }
3617 // Verify element ordering: BUILD_VECTOR position I must use
3618 // UNMERGE result I, otherwise the fold would lose element reordering
3619 if (UnmergeMI->getOperand(i: I).getReg() != TruncSrcReg)
3620 return false;
3621 } else {
3622 break;
3623 }
3624 }
3625 if (I < 2)
3626 return false;
3627
3628 // Check the remaining source elements are only G_IMPLICIT_DEF
3629 for (; I < NumOperands; ++I) {
3630 auto SrcMI = MRI.getVRegDef(Reg: BuildMI->getSourceReg(I));
3631 auto SrcMIOpc = SrcMI->getOpcode();
3632
3633 if (SrcMIOpc != TargetOpcode::G_IMPLICIT_DEF)
3634 return false;
3635 }
3636
3637 // Check the size of unmerge source
3638 MatchInfo = cast<GUnmerge>(Val: UnmergeMI)->getSourceReg();
3639 LLT UnmergeSrcTy = MRI.getType(Reg: MatchInfo);
3640 if (!DstTy.getElementCount().isKnownMultipleOf(RHS: UnmergeSrcTy.getNumElements()))
3641 return false;
3642
3643 // Check the unmerge source and destination element types match
3644 LLT UnmergeSrcEltTy = UnmergeSrcTy.getElementType();
3645 Register UnmergeDstReg = UnmergeMI->getOperand(i: 0).getReg();
3646 LLT UnmergeDstEltTy = MRI.getType(Reg: UnmergeDstReg);
3647 if (UnmergeSrcEltTy != UnmergeDstEltTy)
3648 return false;
3649
3650 // Only generate legal instructions post-legalizer
3651 if (!IsPreLegalize) {
3652 LLT MidTy = DstTy.changeElementType(NewEltTy: UnmergeSrcTy.getScalarType());
3653
3654 if (DstTy.getElementCount() != UnmergeSrcTy.getElementCount() &&
3655 !isLegal(Query: {TargetOpcode::G_CONCAT_VECTORS, {MidTy, UnmergeSrcTy}}))
3656 return false;
3657
3658 if (!isLegal(Query: {TargetOpcode::G_TRUNC, {DstTy, MidTy}}))
3659 return false;
3660 }
3661
3662 return true;
3663}
3664
3665void CombinerHelper::applyUseVectorTruncate(MachineInstr &MI,
3666 Register &MatchInfo) const {
3667 Register MidReg;
3668 auto BuildMI = cast<GBuildVector>(Val: &MI);
3669 Register DstReg = BuildMI->getReg(Idx: 0);
3670 LLT DstTy = MRI.getType(Reg: DstReg);
3671 LLT UnmergeSrcTy = MRI.getType(Reg: MatchInfo);
3672 unsigned DstTyNumElt = DstTy.getNumElements();
3673 unsigned UnmergeSrcTyNumElt = UnmergeSrcTy.getNumElements();
3674
3675 // No need to pad vector if only G_TRUNC is needed
3676 if (DstTyNumElt / UnmergeSrcTyNumElt == 1) {
3677 MidReg = MatchInfo;
3678 } else {
3679 Register UndefReg = Builder.buildUndef(Res: UnmergeSrcTy).getReg(Idx: 0);
3680 SmallVector<Register> ConcatRegs = {MatchInfo};
3681 for (unsigned I = 1; I < DstTyNumElt / UnmergeSrcTyNumElt; ++I)
3682 ConcatRegs.push_back(Elt: UndefReg);
3683
3684 auto MidTy = DstTy.changeElementType(NewEltTy: UnmergeSrcTy.getScalarType());
3685 MidReg = Builder.buildConcatVectors(Res: MidTy, Ops: ConcatRegs).getReg(Idx: 0);
3686 }
3687
3688 Builder.buildTrunc(Res: DstReg, Op: MidReg);
3689 MI.eraseFromParent();
3690}
3691
3692bool CombinerHelper::matchNotCmp(
3693 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const {
3694 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3695 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3696 const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering();
3697 Register XorSrc;
3698 Register CstReg;
3699 // We match xor(src, true) here.
3700 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
3701 P: m_GXor(L: m_Reg(R&: XorSrc), R: m_Reg(R&: CstReg))))
3702 return false;
3703
3704 if (!MRI.hasOneNonDBGUse(RegNo: XorSrc))
3705 return false;
3706
3707 // Check that XorSrc is the root of a tree of comparisons combined with ANDs
3708 // and ORs. The suffix of RegsToNegate starting from index I is used a work
3709 // list of tree nodes to visit.
3710 RegsToNegate.push_back(Elt: XorSrc);
3711 // Remember whether the comparisons are all integer or all floating point.
3712 bool IsInt = false;
3713 bool IsFP = false;
3714 for (unsigned I = 0; I < RegsToNegate.size(); ++I) {
3715 Register Reg = RegsToNegate[I];
3716 if (!MRI.hasOneNonDBGUse(RegNo: Reg))
3717 return false;
3718 MachineInstr *Def = MRI.getVRegDef(Reg);
3719 switch (Def->getOpcode()) {
3720 default:
3721 // Don't match if the tree contains anything other than ANDs, ORs and
3722 // comparisons.
3723 return false;
3724 case TargetOpcode::G_ICMP:
3725 if (IsFP)
3726 return false;
3727 IsInt = true;
3728 // When we apply the combine we will invert the predicate.
3729 break;
3730 case TargetOpcode::G_FCMP:
3731 if (IsInt)
3732 return false;
3733 IsFP = true;
3734 // When we apply the combine we will invert the predicate.
3735 break;
3736 case TargetOpcode::G_AND:
3737 case TargetOpcode::G_OR:
3738 // Implement De Morgan's laws:
3739 // ~(x & y) -> ~x | ~y
3740 // ~(x | y) -> ~x & ~y
3741 // When we apply the combine we will change the opcode and recursively
3742 // negate the operands.
3743 RegsToNegate.push_back(Elt: Def->getOperand(i: 1).getReg());
3744 RegsToNegate.push_back(Elt: Def->getOperand(i: 2).getReg());
3745 break;
3746 }
3747 }
3748
3749 // Now we know whether the comparisons are integer or floating point, check
3750 // the constant in the xor.
3751 int64_t Cst;
3752 if (Ty.isVector()) {
3753 MachineInstr *CstDef = MRI.getVRegDef(Reg: CstReg);
3754 auto MaybeCst = getIConstantSplatSExtVal(MI: *CstDef, MRI);
3755 if (!MaybeCst)
3756 return false;
3757 if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getScalarSizeInBits(), Cst: *MaybeCst, IsVector: true, IsFP))
3758 return false;
3759 } else {
3760 if (!mi_match(R: CstReg, MRI, P: m_ICst(Cst)))
3761 return false;
3762 if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getSizeInBits(), Cst, IsVector: false, IsFP))
3763 return false;
3764 }
3765
3766 return true;
3767}
3768
3769void CombinerHelper::applyNotCmp(
3770 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const {
3771 for (Register Reg : RegsToNegate) {
3772 MachineInstr *Def = MRI.getVRegDef(Reg);
3773 Observer.changingInstr(MI&: *Def);
3774 // For each comparison, invert the opcode. For each AND and OR, change the
3775 // opcode.
3776 switch (Def->getOpcode()) {
3777 default:
3778 llvm_unreachable("Unexpected opcode");
3779 case TargetOpcode::G_ICMP:
3780 case TargetOpcode::G_FCMP: {
3781 MachineOperand &PredOp = Def->getOperand(i: 1);
3782 CmpInst::Predicate NewP = CmpInst::getInversePredicate(
3783 pred: (CmpInst::Predicate)PredOp.getPredicate());
3784 PredOp.setPredicate(NewP);
3785 break;
3786 }
3787 case TargetOpcode::G_AND:
3788 Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_OR));
3789 break;
3790 case TargetOpcode::G_OR:
3791 Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND));
3792 break;
3793 }
3794 Observer.changedInstr(MI&: *Def);
3795 }
3796
3797 replaceRegWith(MRI, FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 1).getReg());
3798 MI.eraseFromParent();
3799}
3800
3801bool CombinerHelper::matchXorOfAndWithSameReg(
3802 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const {
3803 // Match (xor (and x, y), y) (or any of its commuted cases)
3804 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3805 Register &X = MatchInfo.first;
3806 Register &Y = MatchInfo.second;
3807 Register AndReg = MI.getOperand(i: 1).getReg();
3808 Register SharedReg = MI.getOperand(i: 2).getReg();
3809
3810 // Find a G_AND on either side of the G_XOR.
3811 // Look for one of
3812 //
3813 // (xor (and x, y), SharedReg)
3814 // (xor SharedReg, (and x, y))
3815 if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) {
3816 std::swap(a&: AndReg, b&: SharedReg);
3817 if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y))))
3818 return false;
3819 }
3820
3821 // Only do this if we'll eliminate the G_AND.
3822 if (!MRI.hasOneNonDBGUse(RegNo: AndReg))
3823 return false;
3824
3825 // We can combine if SharedReg is the same as either the LHS or RHS of the
3826 // G_AND.
3827 if (Y != SharedReg)
3828 std::swap(a&: X, b&: Y);
3829 return Y == SharedReg;
3830}
3831
3832void CombinerHelper::applyXorOfAndWithSameReg(
3833 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const {
3834 // Fold (xor (and x, y), y) -> (and (not x), y)
3835 Register X, Y;
3836 std::tie(args&: X, args&: Y) = MatchInfo;
3837 auto Not = Builder.buildNot(Dst: MRI.getType(Reg: X), Src0: X);
3838 Observer.changingInstr(MI);
3839 MI.setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND));
3840 MI.getOperand(i: 1).setReg(Not->getOperand(i: 0).getReg());
3841 MI.getOperand(i: 2).setReg(Y);
3842 Observer.changedInstr(MI);
3843}
3844
3845bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) const {
3846 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
3847 Register DstReg = PtrAdd.getReg(Idx: 0);
3848 LLT Ty = MRI.getType(Reg: DstReg);
3849 const DataLayout &DL = Builder.getMF().getDataLayout();
3850
3851 if (DL.isNonIntegralAddressSpace(AddrSpace: Ty.getScalarType().getAddressSpace()))
3852 return false;
3853
3854 if (Ty.isPointer()) {
3855 auto ConstVal = getIConstantVRegVal(VReg: PtrAdd.getBaseReg(), MRI);
3856 return ConstVal && *ConstVal == 0;
3857 }
3858
3859 assert(Ty.isVector() && "Expecting a vector type");
3860 const MachineInstr *VecMI = MRI.getVRegDef(Reg: PtrAdd.getBaseReg());
3861 return isBuildVectorAllZeros(MI: *VecMI, MRI);
3862}
3863
3864void CombinerHelper::applyPtrAddZero(MachineInstr &MI) const {
3865 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
3866 Builder.buildIntToPtr(Dst: PtrAdd.getReg(Idx: 0), Src: PtrAdd.getOffsetReg());
3867 PtrAdd.eraseFromParent();
3868}
3869
3870/// The second source operand is known to be a power of 2.
3871void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) const {
3872 Register DstReg = MI.getOperand(i: 0).getReg();
3873 Register Src0 = MI.getOperand(i: 1).getReg();
3874 Register Pow2Src1 = MI.getOperand(i: 2).getReg();
3875 LLT Ty = MRI.getType(Reg: DstReg);
3876
3877 // Fold (urem x, pow2) -> (and x, pow2-1)
3878 auto NegOne = Builder.buildConstant(Res: Ty, Val: -1);
3879 auto Add = Builder.buildAdd(Dst: Ty, Src0: Pow2Src1, Src1: NegOne);
3880 Builder.buildAnd(Dst: DstReg, Src0, Src1: Add);
3881 MI.eraseFromParent();
3882}
3883
3884bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI,
3885 unsigned &SelectOpNo) const {
3886 Register LHS = MI.getOperand(i: 1).getReg();
3887 Register RHS = MI.getOperand(i: 2).getReg();
3888
3889 Register OtherOperandReg = RHS;
3890 SelectOpNo = 1;
3891 MachineInstr *Select = MRI.getVRegDef(Reg: LHS);
3892
3893 // Don't do this unless the old select is going away. We want to eliminate the
3894 // binary operator, not replace a binop with a select.
3895 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3896 !MRI.hasOneNonDBGUse(RegNo: LHS)) {
3897 OtherOperandReg = LHS;
3898 SelectOpNo = 2;
3899 Select = MRI.getVRegDef(Reg: RHS);
3900 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3901 !MRI.hasOneNonDBGUse(RegNo: RHS))
3902 return false;
3903 }
3904
3905 MachineInstr *SelectLHS = MRI.getVRegDef(Reg: Select->getOperand(i: 2).getReg());
3906 MachineInstr *SelectRHS = MRI.getVRegDef(Reg: Select->getOperand(i: 3).getReg());
3907
3908 if (!isConstantOrConstantVector(MI: *SelectLHS, MRI,
3909 /*AllowFP*/ true,
3910 /*AllowOpaqueConstants*/ false))
3911 return false;
3912 if (!isConstantOrConstantVector(MI: *SelectRHS, MRI,
3913 /*AllowFP*/ true,
3914 /*AllowOpaqueConstants*/ false))
3915 return false;
3916
3917 unsigned BinOpcode = MI.getOpcode();
3918
3919 // We know that one of the operands is a select of constants. Now verify that
3920 // the other binary operator operand is either a constant, or we can handle a
3921 // variable.
3922 bool CanFoldNonConst =
3923 (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) &&
3924 (isNullOrNullSplat(MI: *SelectLHS, MRI) ||
3925 isAllOnesOrAllOnesSplat(MI: *SelectLHS, MRI)) &&
3926 (isNullOrNullSplat(MI: *SelectRHS, MRI) ||
3927 isAllOnesOrAllOnesSplat(MI: *SelectRHS, MRI));
3928 if (CanFoldNonConst)
3929 return true;
3930
3931 return isConstantOrConstantVector(MI: *MRI.getVRegDef(Reg: OtherOperandReg), MRI,
3932 /*AllowFP*/ true,
3933 /*AllowOpaqueConstants*/ false);
3934}
3935
3936/// \p SelectOperand is the operand in binary operator \p MI that is the select
3937/// to fold.
3938void CombinerHelper::applyFoldBinOpIntoSelect(
3939 MachineInstr &MI, const unsigned &SelectOperand) const {
3940 Register Dst = MI.getOperand(i: 0).getReg();
3941 Register LHS = MI.getOperand(i: 1).getReg();
3942 Register RHS = MI.getOperand(i: 2).getReg();
3943 MachineInstr *Select = MRI.getVRegDef(Reg: MI.getOperand(i: SelectOperand).getReg());
3944
3945 Register SelectCond = Select->getOperand(i: 1).getReg();
3946 Register SelectTrue = Select->getOperand(i: 2).getReg();
3947 Register SelectFalse = Select->getOperand(i: 3).getReg();
3948
3949 LLT Ty = MRI.getType(Reg: Dst);
3950 unsigned BinOpcode = MI.getOpcode();
3951
3952 Register FoldTrue, FoldFalse;
3953
3954 // We have a select-of-constants followed by a binary operator with a
3955 // constant. Eliminate the binop by pulling the constant math into the select.
3956 // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
3957 if (SelectOperand == 1) {
3958 // TODO: SelectionDAG verifies this actually constant folds before
3959 // committing to the combine.
3960
3961 FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectTrue, RHS}).getReg(Idx: 0);
3962 FoldFalse =
3963 Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectFalse, RHS}).getReg(Idx: 0);
3964 } else {
3965 FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectTrue}).getReg(Idx: 0);
3966 FoldFalse =
3967 Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectFalse}).getReg(Idx: 0);
3968 }
3969
3970 Builder.buildSelect(Res: Dst, Tst: SelectCond, Op0: FoldTrue, Op1: FoldFalse, Flags: MI.getFlags());
3971 MI.eraseFromParent();
3972}
3973
3974std::optional<SmallVector<Register, 8>>
3975CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const {
3976 assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!");
3977 // We want to detect if Root is part of a tree which represents a bunch
3978 // of loads being merged into a larger load. We'll try to recognize patterns
3979 // like, for example:
3980 //
3981 // Reg Reg
3982 // \ /
3983 // OR_1 Reg
3984 // \ /
3985 // OR_2
3986 // \ Reg
3987 // .. /
3988 // Root
3989 //
3990 // Reg Reg Reg Reg
3991 // \ / \ /
3992 // OR_1 OR_2
3993 // \ /
3994 // \ /
3995 // ...
3996 // Root
3997 //
3998 // Each "Reg" may have been produced by a load + some arithmetic. This
3999 // function will save each of them.
4000 SmallVector<Register, 8> RegsToVisit;
4001 SmallVector<const MachineInstr *, 7> Ors = {Root};
4002
4003 // In the "worst" case, we're dealing with a load for each byte. So, there
4004 // are at most #bytes - 1 ORs.
4005 const unsigned MaxIter =
4006 MRI.getType(Reg: Root->getOperand(i: 0).getReg()).getSizeInBytes() - 1;
4007 for (unsigned Iter = 0; Iter < MaxIter; ++Iter) {
4008 if (Ors.empty())
4009 break;
4010 const MachineInstr *Curr = Ors.pop_back_val();
4011 Register OrLHS = Curr->getOperand(i: 1).getReg();
4012 Register OrRHS = Curr->getOperand(i: 2).getReg();
4013
4014 // In the combine, we want to elimate the entire tree.
4015 if (!MRI.hasOneNonDBGUse(RegNo: OrLHS) || !MRI.hasOneNonDBGUse(RegNo: OrRHS))
4016 return std::nullopt;
4017
4018 // If it's a G_OR, save it and continue to walk. If it's not, then it's
4019 // something that may be a load + arithmetic.
4020 if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrLHS, MRI))
4021 Ors.push_back(Elt: Or);
4022 else
4023 RegsToVisit.push_back(Elt: OrLHS);
4024 if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrRHS, MRI))
4025 Ors.push_back(Elt: Or);
4026 else
4027 RegsToVisit.push_back(Elt: OrRHS);
4028 }
4029
4030 // We're going to try and merge each register into a wider power-of-2 type,
4031 // so we ought to have an even number of registers.
4032 if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0)
4033 return std::nullopt;
4034 return RegsToVisit;
4035}
4036
4037/// Helper function for findLoadOffsetsForLoadOrCombine.
4038///
4039/// Check if \p Reg is the result of loading a \p MemSizeInBits wide value,
4040/// and then moving that value into a specific byte offset.
4041///
4042/// e.g. x[i] << 24
4043///
4044/// \returns The load instruction and the byte offset it is moved into.
4045static std::optional<std::pair<GZExtLoad *, int64_t>>
4046matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits,
4047 const MachineRegisterInfo &MRI) {
4048 assert(MRI.hasOneNonDBGUse(Reg) &&
4049 "Expected Reg to only have one non-debug use?");
4050 Register MaybeLoad;
4051 int64_t Shift;
4052 if (!mi_match(R: Reg, MRI,
4053 P: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: MaybeLoad), R: m_ICst(Cst&: Shift))))) {
4054 Shift = 0;
4055 MaybeLoad = Reg;
4056 }
4057
4058 if (Shift % MemSizeInBits != 0)
4059 return std::nullopt;
4060
4061 // TODO: Handle other types of loads.
4062 auto *Load = getOpcodeDef<GZExtLoad>(Reg: MaybeLoad, MRI);
4063 if (!Load)
4064 return std::nullopt;
4065
4066 if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits)
4067 return std::nullopt;
4068
4069 return std::make_pair(x&: Load, y: Shift / MemSizeInBits);
4070}
4071
4072std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>>
4073CombinerHelper::findLoadOffsetsForLoadOrCombine(
4074 SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
4075 const SmallVector<Register, 8> &RegsToVisit,
4076 const unsigned MemSizeInBits) const {
4077
4078 // Each load found for the pattern. There should be one for each RegsToVisit.
4079 SmallSetVector<const MachineInstr *, 8> Loads;
4080
4081 // The lowest index used in any load. (The lowest "i" for each x[i].)
4082 int64_t LowestIdx = INT64_MAX;
4083
4084 // The load which uses the lowest index.
4085 GZExtLoad *LowestIdxLoad = nullptr;
4086
4087 // Keeps track of the load indices we see. We shouldn't see any indices twice.
4088 SmallSet<int64_t, 8> SeenIdx;
4089
4090 // Ensure each load is in the same MBB.
4091 // TODO: Support multiple MachineBasicBlocks.
4092 MachineBasicBlock *MBB = nullptr;
4093 const MachineMemOperand *MMO = nullptr;
4094
4095 // Earliest instruction-order load in the pattern.
4096 GZExtLoad *EarliestLoad = nullptr;
4097
4098 // Latest instruction-order load in the pattern.
4099 GZExtLoad *LatestLoad = nullptr;
4100
4101 // Base pointer which every load should share.
4102 Register BasePtr;
4103
4104 // We want to find a load for each register. Each load should have some
4105 // appropriate bit twiddling arithmetic. During this loop, we will also keep
4106 // track of the load which uses the lowest index. Later, we will check if we
4107 // can use its pointer in the final, combined load.
4108 for (auto Reg : RegsToVisit) {
4109 // Find the load, and find the position that it will end up in (e.g. a
4110 // shifted) value.
4111 auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI);
4112 if (!LoadAndPos)
4113 return std::nullopt;
4114 GZExtLoad *Load;
4115 int64_t DstPos;
4116 std::tie(args&: Load, args&: DstPos) = *LoadAndPos;
4117
4118 // TODO: Handle multiple MachineBasicBlocks. Currently not handled because
4119 // it is difficult to check for stores/calls/etc between loads.
4120 MachineBasicBlock *LoadMBB = Load->getParent();
4121 if (!MBB)
4122 MBB = LoadMBB;
4123 if (LoadMBB != MBB)
4124 return std::nullopt;
4125
4126 // Make sure that the MachineMemOperands of every seen load are compatible.
4127 auto &LoadMMO = Load->getMMO();
4128 if (!MMO)
4129 MMO = &LoadMMO;
4130 if (MMO->getAddrSpace() != LoadMMO.getAddrSpace())
4131 return std::nullopt;
4132
4133 // Find out what the base pointer and index for the load is.
4134 Register LoadPtr;
4135 int64_t Idx;
4136 if (!mi_match(R: Load->getOperand(i: 1).getReg(), MRI,
4137 P: m_GPtrAdd(L: m_Reg(R&: LoadPtr), R: m_ICst(Cst&: Idx)))) {
4138 LoadPtr = Load->getOperand(i: 1).getReg();
4139 Idx = 0;
4140 }
4141
4142 // Don't combine things like a[i], a[i] -> a bigger load.
4143 if (!SeenIdx.insert(V: Idx).second)
4144 return std::nullopt;
4145
4146 // Every load must share the same base pointer; don't combine things like:
4147 //
4148 // a[i], b[i + 1] -> a bigger load.
4149 if (!BasePtr.isValid())
4150 BasePtr = LoadPtr;
4151 if (BasePtr != LoadPtr)
4152 return std::nullopt;
4153
4154 if (Idx < LowestIdx) {
4155 LowestIdx = Idx;
4156 LowestIdxLoad = Load;
4157 }
4158
4159 // Keep track of the byte offset that this load ends up at. If we have seen
4160 // the byte offset, then stop here. We do not want to combine:
4161 //
4162 // a[i] << 16, a[i + k] << 16 -> a bigger load.
4163 if (!MemOffset2Idx.try_emplace(Key: DstPos, Args&: Idx).second)
4164 return std::nullopt;
4165 Loads.insert(X: Load);
4166
4167 // Keep track of the position of the earliest/latest loads in the pattern.
4168 // We will check that there are no load fold barriers between them later
4169 // on.
4170 //
4171 // FIXME: Is there a better way to check for load fold barriers?
4172 if (!EarliestLoad || dominates(DefMI: *Load, UseMI: *EarliestLoad))
4173 EarliestLoad = Load;
4174 if (!LatestLoad || dominates(DefMI: *LatestLoad, UseMI: *Load))
4175 LatestLoad = Load;
4176 }
4177
4178 // We found a load for each register. Let's check if each load satisfies the
4179 // pattern.
4180 assert(Loads.size() == RegsToVisit.size() &&
4181 "Expected to find a load for each register?");
4182 assert(EarliestLoad != LatestLoad && EarliestLoad &&
4183 LatestLoad && "Expected at least two loads?");
4184
4185 // Check if there are any stores, calls, etc. between any of the loads. If
4186 // there are, then we can't safely perform the combine.
4187 //
4188 // MaxIter is chosen based off the (worst case) number of iterations it
4189 // typically takes to succeed in the LLVM test suite plus some padding.
4190 //
4191 // FIXME: Is there a better way to check for load fold barriers?
4192 const unsigned MaxIter = 20;
4193 unsigned Iter = 0;
4194 for (const auto &MI : instructionsWithoutDebug(It: EarliestLoad->getIterator(),
4195 End: LatestLoad->getIterator())) {
4196 if (Loads.count(key: &MI))
4197 continue;
4198 if (MI.isLoadFoldBarrier())
4199 return std::nullopt;
4200 if (Iter++ == MaxIter)
4201 return std::nullopt;
4202 }
4203
4204 return std::make_tuple(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad);
4205}
4206
4207bool CombinerHelper::matchLoadOrCombine(
4208 MachineInstr &MI,
4209 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4210 assert(MI.getOpcode() == TargetOpcode::G_OR);
4211 MachineFunction &MF = *MI.getMF();
4212 // Assuming a little-endian target, transform:
4213 // s8 *a = ...
4214 // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
4215 // =>
4216 // s32 val = *((i32)a)
4217 //
4218 // s8 *a = ...
4219 // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
4220 // =>
4221 // s32 val = BSWAP(*((s32)a))
4222 Register Dst = MI.getOperand(i: 0).getReg();
4223 LLT Ty = MRI.getType(Reg: Dst);
4224 if (Ty.isVector())
4225 return false;
4226
4227 // We need to combine at least two loads into this type. Since the smallest
4228 // possible load is into a byte, we need at least a 16-bit wide type.
4229 const unsigned WideMemSizeInBits = Ty.getSizeInBits();
4230 if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0)
4231 return false;
4232
4233 // Match a collection of non-OR instructions in the pattern.
4234 auto RegsToVisit = findCandidatesForLoadOrCombine(Root: &MI);
4235 if (!RegsToVisit)
4236 return false;
4237
4238 // We have a collection of non-OR instructions. Figure out how wide each of
4239 // the small loads should be based off of the number of potential loads we
4240 // found.
4241 const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size();
4242 if (NarrowMemSizeInBits % 8 != 0)
4243 return false;
4244
4245 // Check if each register feeding into each OR is a load from the same
4246 // base pointer + some arithmetic.
4247 //
4248 // e.g. a[0], a[1] << 8, a[2] << 16, etc.
4249 //
4250 // Also verify that each of these ends up putting a[i] into the same memory
4251 // offset as a load into a wide type would.
4252 SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx;
4253 GZExtLoad *LowestIdxLoad, *LatestLoad;
4254 int64_t LowestIdx;
4255 auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine(
4256 MemOffset2Idx, RegsToVisit: *RegsToVisit, MemSizeInBits: NarrowMemSizeInBits);
4257 if (!MaybeLoadInfo)
4258 return false;
4259 std::tie(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad) = *MaybeLoadInfo;
4260
4261 // We have a bunch of loads being OR'd together. Using the addresses + offsets
4262 // we found before, check if this corresponds to a big or little endian byte
4263 // pattern. If it does, then we can represent it using a load + possibly a
4264 // BSWAP.
4265 bool IsBigEndianTarget = MF.getDataLayout().isBigEndian();
4266 std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx);
4267 if (!IsBigEndian)
4268 return false;
4269 bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian;
4270 if (NeedsBSwap && !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_BSWAP, {Ty}}))
4271 return false;
4272
4273 // Make sure that the load from the lowest index produces offset 0 in the
4274 // final value.
4275 //
4276 // This ensures that we won't combine something like this:
4277 //
4278 // load x[i] -> byte 2
4279 // load x[i+1] -> byte 0 ---> wide_load x[i]
4280 // load x[i+2] -> byte 1
4281 const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits;
4282 const unsigned ZeroByteOffset =
4283 *IsBigEndian
4284 ? bigEndianByteAt(ByteWidth: NumLoadsInTy, I: 0)
4285 : littleEndianByteAt(ByteWidth: NumLoadsInTy, I: 0);
4286 auto ZeroOffsetIdx = MemOffset2Idx.find(Val: ZeroByteOffset);
4287 if (ZeroOffsetIdx == MemOffset2Idx.end() ||
4288 ZeroOffsetIdx->second != LowestIdx)
4289 return false;
4290
4291 // We wil reuse the pointer from the load which ends up at byte offset 0. It
4292 // may not use index 0.
4293 Register Ptr = LowestIdxLoad->getPointerReg();
4294 const MachineMemOperand &MMO = LowestIdxLoad->getMMO();
4295 LegalityQuery::MemDesc MMDesc(MMO);
4296 MMDesc.MemoryTy = Ty;
4297 if (!isLegalOrBeforeLegalizer(
4298 Query: {TargetOpcode::G_LOAD, {Ty, MRI.getType(Reg: Ptr)}, {MMDesc}}))
4299 return false;
4300 auto PtrInfo = MMO.getPointerInfo();
4301 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: WideMemSizeInBits / 8);
4302
4303 // Load must be allowed and fast on the target.
4304 LLVMContext &C = MF.getFunction().getContext();
4305 auto &DL = MF.getDataLayout();
4306 unsigned Fast = 0;
4307 if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty, MMO: *NewMMO, Fast: &Fast) ||
4308 !Fast)
4309 return false;
4310
4311 MatchInfo = [=](MachineIRBuilder &MIB) {
4312 MIB.setInstrAndDebugLoc(*LatestLoad);
4313 Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(VReg: Dst) : Dst;
4314 MIB.buildLoad(Res: LoadDst, Addr: Ptr, MMO&: *NewMMO);
4315 if (NeedsBSwap)
4316 MIB.buildBSwap(Dst, Src0: LoadDst);
4317 };
4318 return true;
4319}
4320
4321bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI,
4322 MachineInstr *&ExtMI) const {
4323 auto &PHI = cast<GPhi>(Val&: MI);
4324 Register DstReg = PHI.getReg(Idx: 0);
4325
4326 // TODO: Extending a vector may be expensive, don't do this until heuristics
4327 // are better.
4328 if (MRI.getType(Reg: DstReg).isVector())
4329 return false;
4330
4331 // Try to match a phi, whose only use is an extend.
4332 if (!MRI.hasOneNonDBGUse(RegNo: DstReg))
4333 return false;
4334 ExtMI = &*MRI.use_instr_nodbg_begin(RegNo: DstReg);
4335 switch (ExtMI->getOpcode()) {
4336 case TargetOpcode::G_ANYEXT:
4337 return true; // G_ANYEXT is usually free.
4338 case TargetOpcode::G_ZEXT:
4339 case TargetOpcode::G_SEXT:
4340 break;
4341 default:
4342 return false;
4343 }
4344
4345 // If the target is likely to fold this extend away, don't propagate.
4346 if (Builder.getTII().isExtendLikelyToBeFolded(ExtMI&: *ExtMI, MRI))
4347 return false;
4348
4349 // We don't want to propagate the extends unless there's a good chance that
4350 // they'll be optimized in some way.
4351 // Collect the unique incoming values.
4352 SmallPtrSet<MachineInstr *, 4> InSrcs;
4353 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
4354 auto *DefMI = getDefIgnoringCopies(Reg: PHI.getIncomingValue(I), MRI);
4355 switch (DefMI->getOpcode()) {
4356 case TargetOpcode::G_LOAD:
4357 case TargetOpcode::G_TRUNC:
4358 case TargetOpcode::G_SEXT:
4359 case TargetOpcode::G_ZEXT:
4360 case TargetOpcode::G_ANYEXT:
4361 case TargetOpcode::G_CONSTANT:
4362 InSrcs.insert(Ptr: DefMI);
4363 // Don't try to propagate if there are too many places to create new
4364 // extends, chances are it'll increase code size.
4365 if (InSrcs.size() > 2)
4366 return false;
4367 break;
4368 default:
4369 return false;
4370 }
4371 }
4372 return true;
4373}
4374
4375void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI,
4376 MachineInstr *&ExtMI) const {
4377 auto &PHI = cast<GPhi>(Val&: MI);
4378 Register DstReg = ExtMI->getOperand(i: 0).getReg();
4379 LLT ExtTy = MRI.getType(Reg: DstReg);
4380
4381 // Propagate the extension into the block of each incoming reg's block.
4382 // Use a SetVector here because PHIs can have duplicate edges, and we want
4383 // deterministic iteration order.
4384 SmallSetVector<MachineInstr *, 8> SrcMIs;
4385 SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap;
4386 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
4387 auto SrcReg = PHI.getIncomingValue(I);
4388 auto *SrcMI = MRI.getVRegDef(Reg: SrcReg);
4389 if (!SrcMIs.insert(X: SrcMI))
4390 continue;
4391
4392 // Build an extend after each src inst.
4393 auto *MBB = SrcMI->getParent();
4394 MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator();
4395 if (InsertPt != MBB->end() && InsertPt->isPHI())
4396 InsertPt = MBB->getFirstNonPHI();
4397
4398 Builder.setInsertPt(MBB&: *SrcMI->getParent(), II: InsertPt);
4399 Builder.setDebugLoc(MI.getDebugLoc());
4400 auto NewExt = Builder.buildExtOrTrunc(ExtOpc: ExtMI->getOpcode(), Res: ExtTy, Op: SrcReg);
4401 OldToNewSrcMap[SrcMI] = NewExt;
4402 }
4403
4404 // Create a new phi with the extended inputs.
4405 Builder.setInstrAndDebugLoc(MI);
4406 auto NewPhi = Builder.buildInstrNoInsert(Opcode: TargetOpcode::G_PHI);
4407 NewPhi.addDef(RegNo: DstReg);
4408 for (const MachineOperand &MO : llvm::drop_begin(RangeOrContainer: MI.operands())) {
4409 if (!MO.isReg()) {
4410 NewPhi.addMBB(MBB: MO.getMBB());
4411 continue;
4412 }
4413 auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(Reg: MO.getReg())];
4414 NewPhi.addUse(RegNo: NewSrc->getOperand(i: 0).getReg());
4415 }
4416 Builder.insertInstr(MIB: NewPhi);
4417 ExtMI->eraseFromParent();
4418}
4419
4420bool CombinerHelper::matchExtractVecEltBuildVec(MachineInstr &MI,
4421 Register &Reg) const {
4422 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
4423 // If we have a constant index, look for a G_BUILD_VECTOR source
4424 // and find the source register that the index maps to.
4425 Register SrcVec = MI.getOperand(i: 1).getReg();
4426 LLT SrcTy = MRI.getType(Reg: SrcVec);
4427 if (SrcTy.isScalableVector())
4428 return false;
4429
4430 auto Cst = getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
4431 if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements())
4432 return false;
4433
4434 unsigned VecIdx = Cst->Value.getZExtValue();
4435
4436 // Check if we have a build_vector or build_vector_trunc with an optional
4437 // trunc in front.
4438 MachineInstr *SrcVecMI = MRI.getVRegDef(Reg: SrcVec);
4439 if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) {
4440 SrcVecMI = MRI.getVRegDef(Reg: SrcVecMI->getOperand(i: 1).getReg());
4441 }
4442
4443 if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR &&
4444 SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC)
4445 return false;
4446
4447 EVT Ty(getMVTForLLT(Ty: SrcTy));
4448 if (!MRI.hasOneNonDBGUse(RegNo: SrcVec) &&
4449 !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty))
4450 return false;
4451
4452 Reg = SrcVecMI->getOperand(i: VecIdx + 1).getReg();
4453 return true;
4454}
4455
4456void CombinerHelper::applyExtractVecEltBuildVec(MachineInstr &MI,
4457 Register &Reg) const {
4458 // Check the type of the register, since it may have come from a
4459 // G_BUILD_VECTOR_TRUNC.
4460 LLT ScalarTy = MRI.getType(Reg);
4461 Register DstReg = MI.getOperand(i: 0).getReg();
4462 LLT DstTy = MRI.getType(Reg: DstReg);
4463
4464 if (ScalarTy != DstTy) {
4465 assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits());
4466 Builder.buildTrunc(Res: DstReg, Op: Reg);
4467 MI.eraseFromParent();
4468 return;
4469 }
4470 replaceSingleDefInstWithReg(MI, Replacement: Reg);
4471}
4472
4473bool CombinerHelper::matchExtractAllEltsFromBuildVector(
4474 MachineInstr &MI,
4475 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const {
4476 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4477 // This combine tries to find build_vector's which have every source element
4478 // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like
4479 // the masked load scalarization is run late in the pipeline. There's already
4480 // a combine for a similar pattern starting from the extract, but that
4481 // doesn't attempt to do it if there are multiple uses of the build_vector,
4482 // which in this case is true. Starting the combine from the build_vector
4483 // feels more natural than trying to find sibling nodes of extracts.
4484 // E.g.
4485 // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4
4486 // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0
4487 // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1
4488 // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2
4489 // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3
4490 // ==>
4491 // replace ext{1,2,3,4} with %s{1,2,3,4}
4492
4493 Register DstReg = MI.getOperand(i: 0).getReg();
4494 LLT DstTy = MRI.getType(Reg: DstReg);
4495 unsigned NumElts = DstTy.getNumElements();
4496
4497 SmallBitVector ExtractedElts(NumElts);
4498 for (MachineInstr &II : MRI.use_nodbg_instructions(Reg: DstReg)) {
4499 if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT)
4500 return false;
4501 auto Cst = getIConstantVRegVal(VReg: II.getOperand(i: 2).getReg(), MRI);
4502 if (!Cst)
4503 return false;
4504 unsigned Idx = Cst->getZExtValue();
4505 if (Idx >= NumElts)
4506 return false; // Out of range.
4507 ExtractedElts.set(Idx);
4508 SrcDstPairs.emplace_back(
4509 Args: std::make_pair(x: MI.getOperand(i: Idx + 1).getReg(), y: &II));
4510 }
4511 // Match if every element was extracted.
4512 return ExtractedElts.all();
4513}
4514
4515void CombinerHelper::applyExtractAllEltsFromBuildVector(
4516 MachineInstr &MI,
4517 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const {
4518 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4519 for (auto &Pair : SrcDstPairs) {
4520 auto *ExtMI = Pair.second;
4521 replaceRegWith(MRI, FromReg: ExtMI->getOperand(i: 0).getReg(), ToReg: Pair.first);
4522 ExtMI->eraseFromParent();
4523 }
4524 MI.eraseFromParent();
4525}
4526
4527void CombinerHelper::applyBuildFn(
4528 MachineInstr &MI,
4529 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4530 applyBuildFnNoErase(MI, MatchInfo);
4531 MI.eraseFromParent();
4532}
4533
4534void CombinerHelper::applyBuildFnNoErase(
4535 MachineInstr &MI,
4536 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4537 MatchInfo(Builder);
4538}
4539
4540bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
4541 bool AllowScalarConstants,
4542 BuildFnTy &MatchInfo) const {
4543 assert(MI.getOpcode() == TargetOpcode::G_OR);
4544
4545 Register Dst = MI.getOperand(i: 0).getReg();
4546 LLT Ty = MRI.getType(Reg: Dst);
4547 unsigned BitWidth = Ty.getScalarSizeInBits();
4548
4549 Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt;
4550 unsigned FshOpc = 0;
4551
4552 // Match (or (shl ...), (lshr ...)).
4553 if (!mi_match(R: Dst, MRI,
4554 // m_GOr() handles the commuted version as well.
4555 P: m_GOr(L: m_GShl(L: m_Reg(R&: ShlSrc), R: m_Reg(R&: ShlAmt)),
4556 R: m_GLShr(L: m_Reg(R&: LShrSrc), R: m_Reg(R&: LShrAmt)))))
4557 return false;
4558
4559 // Given constants C0 and C1 such that C0 + C1 is bit-width:
4560 // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1)
4561 int64_t CstShlAmt = 0, CstLShrAmt;
4562 if (mi_match(R: ShlAmt, MRI, P: m_ICstOrSplat(Cst&: CstShlAmt)) &&
4563 mi_match(R: LShrAmt, MRI, P: m_ICstOrSplat(Cst&: CstLShrAmt)) &&
4564 CstShlAmt + CstLShrAmt == BitWidth) {
4565 FshOpc = TargetOpcode::G_FSHR;
4566 Amt = LShrAmt;
4567 } else if (mi_match(R: LShrAmt, MRI,
4568 P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) &&
4569 ShlAmt == Amt) {
4570 // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt)
4571 FshOpc = TargetOpcode::G_FSHL;
4572 } else if (mi_match(R: ShlAmt, MRI,
4573 P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) &&
4574 LShrAmt == Amt) {
4575 // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt)
4576 FshOpc = TargetOpcode::G_FSHR;
4577 } else {
4578 return false;
4579 }
4580
4581 LLT AmtTy = MRI.getType(Reg: Amt);
4582 if (!isLegalOrBeforeLegalizer(Query: {FshOpc, {Ty, AmtTy}}) &&
4583 (!AllowScalarConstants || CstShlAmt == 0 || !Ty.isScalar()))
4584 return false;
4585
4586 MatchInfo = [=](MachineIRBuilder &B) {
4587 B.buildInstr(Opc: FshOpc, DstOps: {Dst}, SrcOps: {ShlSrc, LShrSrc, Amt});
4588 };
4589 return true;
4590}
4591
4592/// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate.
4593bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) const {
4594 unsigned Opc = MI.getOpcode();
4595 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4596 Register X = MI.getOperand(i: 1).getReg();
4597 Register Y = MI.getOperand(i: 2).getReg();
4598 if (X != Y)
4599 return false;
4600 unsigned RotateOpc =
4601 Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR;
4602 return isLegalOrBeforeLegalizer(Query: {RotateOpc, {MRI.getType(Reg: X), MRI.getType(Reg: Y)}});
4603}
4604
4605void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) const {
4606 unsigned Opc = MI.getOpcode();
4607 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4608 bool IsFSHL = Opc == TargetOpcode::G_FSHL;
4609 Observer.changingInstr(MI);
4610 MI.setDesc(Builder.getTII().get(Opcode: IsFSHL ? TargetOpcode::G_ROTL
4611 : TargetOpcode::G_ROTR));
4612 MI.removeOperand(OpNo: 2);
4613 Observer.changedInstr(MI);
4614}
4615
4616// Fold (rot x, c) -> (rot x, c % BitSize)
4617bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) const {
4618 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4619 MI.getOpcode() == TargetOpcode::G_ROTR);
4620 unsigned Bitsize =
4621 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits();
4622 Register AmtReg = MI.getOperand(i: 2).getReg();
4623 bool OutOfRange = false;
4624 auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) {
4625 if (auto *CI = dyn_cast<ConstantInt>(Val: C))
4626 OutOfRange |= CI->getValue().uge(RHS: Bitsize);
4627 return true;
4628 };
4629 return matchUnaryPredicate(MRI, Reg: AmtReg, Match: MatchOutOfRange) && OutOfRange;
4630}
4631
4632void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) const {
4633 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4634 MI.getOpcode() == TargetOpcode::G_ROTR);
4635 unsigned Bitsize =
4636 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits();
4637 Register Amt = MI.getOperand(i: 2).getReg();
4638 LLT AmtTy = MRI.getType(Reg: Amt);
4639 auto Bits = Builder.buildConstant(Res: AmtTy, Val: Bitsize);
4640 Amt = Builder.buildURem(Dst: AmtTy, Src0: MI.getOperand(i: 2).getReg(), Src1: Bits).getReg(Idx: 0);
4641 Observer.changingInstr(MI);
4642 MI.getOperand(i: 2).setReg(Amt);
4643 Observer.changedInstr(MI);
4644}
4645
4646bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI,
4647 int64_t &MatchInfo) const {
4648 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4649 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
4650
4651 // We want to avoid calling KnownBits on the LHS if possible, as this combine
4652 // has no filter and runs on every G_ICMP instruction. We can avoid calling
4653 // KnownBits on the LHS in two cases:
4654 //
4655 // - The RHS is unknown: Constants are always on RHS. If the RHS is unknown
4656 // we cannot do any transforms so we can safely bail out early.
4657 // - The RHS is zero: we don't need to know the LHS to do unsigned <0 and
4658 // >=0.
4659 auto KnownRHS = VT->getKnownBits(R: MI.getOperand(i: 3).getReg());
4660 if (KnownRHS.isUnknown())
4661 return false;
4662
4663 std::optional<bool> KnownVal;
4664 if (KnownRHS.isZero()) {
4665 // ? uge 0 -> always true
4666 // ? ult 0 -> always false
4667 if (Pred == CmpInst::ICMP_UGE)
4668 KnownVal = true;
4669 else if (Pred == CmpInst::ICMP_ULT)
4670 KnownVal = false;
4671 }
4672
4673 if (!KnownVal) {
4674 auto KnownLHS = VT->getKnownBits(R: MI.getOperand(i: 2).getReg());
4675 KnownVal = ICmpInst::compare(LHS: KnownLHS, RHS: KnownRHS, Pred);
4676 }
4677
4678 if (!KnownVal)
4679 return false;
4680 MatchInfo =
4681 *KnownVal
4682 ? getICmpTrueVal(TLI: getTargetLowering(),
4683 /*IsVector = */
4684 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector(),
4685 /* IsFP = */ false)
4686 : 0;
4687 return true;
4688}
4689
4690bool CombinerHelper::matchICmpToLHSKnownBits(
4691 MachineInstr &MI,
4692 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4693 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4694 // Given:
4695 //
4696 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4697 // %cmp = G_ICMP ne %x, 0
4698 //
4699 // Or:
4700 //
4701 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4702 // %cmp = G_ICMP eq %x, 1
4703 //
4704 // We can replace %cmp with %x assuming true is 1 on the target.
4705 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
4706 if (!CmpInst::isEquality(pred: Pred))
4707 return false;
4708 Register Dst = MI.getOperand(i: 0).getReg();
4709 LLT DstTy = MRI.getType(Reg: Dst);
4710 if (getICmpTrueVal(TLI: getTargetLowering(), IsVector: DstTy.isVector(),
4711 /* IsFP = */ false) != 1)
4712 return false;
4713 int64_t OneOrZero = Pred == CmpInst::ICMP_EQ;
4714 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICst(RequestedValue: OneOrZero)))
4715 return false;
4716 Register LHS = MI.getOperand(i: 2).getReg();
4717 auto KnownLHS = VT->getKnownBits(R: LHS);
4718 if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1)
4719 return false;
4720 // Make sure replacing Dst with the LHS is a legal operation.
4721 LLT LHSTy = MRI.getType(Reg: LHS);
4722 unsigned LHSSize = LHSTy.getSizeInBits();
4723 unsigned DstSize = DstTy.getSizeInBits();
4724 unsigned Op = TargetOpcode::COPY;
4725 if (DstSize != LHSSize)
4726 Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT;
4727 if (!isLegalOrBeforeLegalizer(Query: {Op, {DstTy, LHSTy}}))
4728 return false;
4729 MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Opc: Op, DstOps: {Dst}, SrcOps: {LHS}); };
4730 return true;
4731}
4732
4733// Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0
4734bool CombinerHelper::matchAndOrDisjointMask(
4735 MachineInstr &MI,
4736 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4737 assert(MI.getOpcode() == TargetOpcode::G_AND);
4738
4739 // Ignore vector types to simplify matching the two constants.
4740 // TODO: do this for vectors and scalars via a demanded bits analysis.
4741 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
4742 if (Ty.isVector())
4743 return false;
4744
4745 Register Src;
4746 Register AndMaskReg;
4747 int64_t AndMaskBits;
4748 int64_t OrMaskBits;
4749 if (!mi_match(MI, MRI,
4750 P: m_GAnd(L: m_GOr(L: m_Reg(R&: Src), R: m_ICst(Cst&: OrMaskBits)),
4751 R: m_all_of(preds: m_ICst(Cst&: AndMaskBits), preds: m_Reg(R&: AndMaskReg)))))
4752 return false;
4753
4754 // Check if OrMask could turn on any bits in Src.
4755 if (AndMaskBits & OrMaskBits)
4756 return false;
4757
4758 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4759 Observer.changingInstr(MI);
4760 // Canonicalize the result to have the constant on the RHS.
4761 if (MI.getOperand(i: 1).getReg() == AndMaskReg)
4762 MI.getOperand(i: 2).setReg(AndMaskReg);
4763 MI.getOperand(i: 1).setReg(Src);
4764 Observer.changedInstr(MI);
4765 };
4766 return true;
4767}
4768
4769/// Form a G_SBFX from a G_SEXT_INREG fed by a right shift.
4770bool CombinerHelper::matchBitfieldExtractFromSExtInReg(
4771 MachineInstr &MI,
4772 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4773 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
4774 Register Dst = MI.getOperand(i: 0).getReg();
4775 Register Src = MI.getOperand(i: 1).getReg();
4776 LLT Ty = MRI.getType(Reg: Src);
4777 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4778 if (!LI || !LI->isLegalOrCustom(Query: {TargetOpcode::G_SBFX, {Ty, ExtractTy}}))
4779 return false;
4780 int64_t Width = MI.getOperand(i: 2).getImm();
4781 Register ShiftSrc;
4782 int64_t ShiftImm;
4783 if (!mi_match(
4784 R: Src, MRI,
4785 P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)),
4786 preds: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm))))))
4787 return false;
4788 if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits())
4789 return false;
4790
4791 MatchInfo = [=](MachineIRBuilder &B) {
4792 auto Cst1 = B.buildConstant(Res: ExtractTy, Val: ShiftImm);
4793 auto Cst2 = B.buildConstant(Res: ExtractTy, Val: Width);
4794 B.buildSbfx(Dst, Src: ShiftSrc, LSB: Cst1, Width: Cst2);
4795 };
4796 return true;
4797}
4798
4799/// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants.
4800bool CombinerHelper::matchBitfieldExtractFromAnd(MachineInstr &MI,
4801 BuildFnTy &MatchInfo) const {
4802 GAnd *And = cast<GAnd>(Val: &MI);
4803 Register Dst = And->getReg(Idx: 0);
4804 LLT Ty = MRI.getType(Reg: Dst);
4805 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4806 // Note that isLegalOrBeforeLegalizer is stricter and does not take custom
4807 // into account.
4808 if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4809 return false;
4810
4811 int64_t AndImm, LSBImm;
4812 Register ShiftSrc;
4813 const unsigned Size = Ty.getScalarSizeInBits();
4814 if (!mi_match(R: And->getReg(Idx: 0), MRI,
4815 P: m_GAnd(L: m_OneNonDBGUse(SP: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: LSBImm))),
4816 R: m_ICst(Cst&: AndImm))))
4817 return false;
4818
4819 // The mask is a mask of the low bits iff imm & (imm+1) == 0.
4820 auto MaybeMask = static_cast<uint64_t>(AndImm);
4821 if (MaybeMask & (MaybeMask + 1))
4822 return false;
4823
4824 // LSB must fit within the register.
4825 if (static_cast<uint64_t>(LSBImm) >= Size)
4826 return false;
4827
4828 uint64_t Width = APInt(Size, AndImm).countr_one();
4829 MatchInfo = [=](MachineIRBuilder &B) {
4830 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4831 auto LSBCst = B.buildConstant(Res: ExtractTy, Val: LSBImm);
4832 B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {ShiftSrc, LSBCst, WidthCst});
4833 };
4834 return true;
4835}
4836
4837bool CombinerHelper::matchBitfieldExtractFromShr(
4838 MachineInstr &MI,
4839 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4840 const unsigned Opcode = MI.getOpcode();
4841 assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR);
4842
4843 const Register Dst = MI.getOperand(i: 0).getReg();
4844
4845 const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR
4846 ? TargetOpcode::G_SBFX
4847 : TargetOpcode::G_UBFX;
4848
4849 // Check if the type we would use for the extract is legal
4850 LLT Ty = MRI.getType(Reg: Dst);
4851 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4852 if (!LI || !LI->isLegalOrCustom(Query: {ExtrOpcode, {Ty, ExtractTy}}))
4853 return false;
4854
4855 Register ShlSrc;
4856 int64_t ShrAmt;
4857 int64_t ShlAmt;
4858 const unsigned Size = Ty.getScalarSizeInBits();
4859
4860 // Try to match shr (shl x, c1), c2
4861 if (!mi_match(R: Dst, MRI,
4862 P: m_BinOp(Opcode,
4863 L: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: ShlSrc), R: m_ICst(Cst&: ShlAmt))),
4864 R: m_ICst(Cst&: ShrAmt))))
4865 return false;
4866
4867 // Make sure that the shift sizes can fit a bitfield extract
4868 if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size)
4869 return false;
4870
4871 // Skip this combine if the G_SEXT_INREG combine could handle it
4872 if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt)
4873 return false;
4874
4875 // Calculate start position and width of the extract
4876 const int64_t Pos = ShrAmt - ShlAmt;
4877 const int64_t Width = Size - ShrAmt;
4878
4879 MatchInfo = [=](MachineIRBuilder &B) {
4880 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4881 auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos);
4882 B.buildInstr(Opc: ExtrOpcode, DstOps: {Dst}, SrcOps: {ShlSrc, PosCst, WidthCst});
4883 };
4884 return true;
4885}
4886
4887bool CombinerHelper::matchBitfieldExtractFromShrAnd(
4888 MachineInstr &MI,
4889 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4890 const unsigned Opcode = MI.getOpcode();
4891 assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR);
4892
4893 const Register Dst = MI.getOperand(i: 0).getReg();
4894 LLT Ty = MRI.getType(Reg: Dst);
4895 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4896 if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4897 return false;
4898
4899 // Try to match shr (and x, c1), c2
4900 Register AndSrc;
4901 int64_t ShrAmt;
4902 int64_t SMask;
4903 if (!mi_match(R: Dst, MRI,
4904 P: m_BinOp(Opcode,
4905 L: m_OneNonDBGUse(SP: m_GAnd(L: m_Reg(R&: AndSrc), R: m_ICst(Cst&: SMask))),
4906 R: m_ICst(Cst&: ShrAmt))))
4907 return false;
4908
4909 const unsigned Size = Ty.getScalarSizeInBits();
4910 if (ShrAmt < 0 || ShrAmt >= Size)
4911 return false;
4912
4913 // If the shift subsumes the mask, emit the 0 directly.
4914 if (0 == (SMask >> ShrAmt)) {
4915 MatchInfo = [=](MachineIRBuilder &B) {
4916 B.buildConstant(Res: Dst, Val: 0);
4917 };
4918 return true;
4919 }
4920
4921 // Check that ubfx can do the extraction, with no holes in the mask.
4922 uint64_t UMask = SMask;
4923 UMask |= maskTrailingOnes<uint64_t>(N: ShrAmt);
4924 UMask &= maskTrailingOnes<uint64_t>(N: Size);
4925 if (!isMask_64(Value: UMask))
4926 return false;
4927
4928 // Calculate start position and width of the extract.
4929 const int64_t Pos = ShrAmt;
4930 const int64_t Width = llvm::countr_one(Value: UMask) - ShrAmt;
4931
4932 // It's preferable to keep the shift, rather than form G_SBFX.
4933 // TODO: remove the G_AND via demanded bits analysis.
4934 if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size)
4935 return false;
4936
4937 MatchInfo = [=](MachineIRBuilder &B) {
4938 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4939 auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos);
4940 B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {AndSrc, PosCst, WidthCst});
4941 };
4942 return true;
4943}
4944
4945bool CombinerHelper::reassociationCanBreakAddressingModePattern(
4946 MachineInstr &MI) const {
4947 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
4948
4949 Register Src1Reg = PtrAdd.getBaseReg();
4950 auto *Src1Def = getOpcodeDef<GPtrAdd>(Reg: Src1Reg, MRI);
4951 if (!Src1Def)
4952 return false;
4953
4954 Register Src2Reg = PtrAdd.getOffsetReg();
4955
4956 if (MRI.hasOneNonDBGUse(RegNo: Src1Reg))
4957 return false;
4958
4959 auto C1 = getIConstantVRegVal(VReg: Src1Def->getOffsetReg(), MRI);
4960 if (!C1)
4961 return false;
4962 auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI);
4963 if (!C2)
4964 return false;
4965
4966 const APInt &C1APIntVal = *C1;
4967 const APInt &C2APIntVal = *C2;
4968 const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue();
4969
4970 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: PtrAdd.getReg(Idx: 0))) {
4971 // This combine may end up running before ptrtoint/inttoptr combines
4972 // manage to eliminate redundant conversions, so try to look through them.
4973 MachineInstr *ConvUseMI = &UseMI;
4974 unsigned ConvUseOpc = ConvUseMI->getOpcode();
4975 while (ConvUseOpc == TargetOpcode::G_INTTOPTR ||
4976 ConvUseOpc == TargetOpcode::G_PTRTOINT) {
4977 Register DefReg = ConvUseMI->getOperand(i: 0).getReg();
4978 if (!MRI.hasOneNonDBGUse(RegNo: DefReg))
4979 break;
4980 ConvUseMI = &*MRI.use_instr_nodbg_begin(RegNo: DefReg);
4981 ConvUseOpc = ConvUseMI->getOpcode();
4982 }
4983 auto *LdStMI = dyn_cast<GLoadStore>(Val: ConvUseMI);
4984 if (!LdStMI)
4985 continue;
4986 // Is x[offset2] already not a legal addressing mode? If so then
4987 // reassociating the constants breaks nothing (we test offset2 because
4988 // that's the one we hope to fold into the load or store).
4989 TargetLoweringBase::AddrMode AM;
4990 AM.HasBaseReg = true;
4991 AM.BaseOffs = C2APIntVal.getSExtValue();
4992 unsigned AS = MRI.getType(Reg: LdStMI->getPointerReg()).getAddressSpace();
4993 Type *AccessTy = getTypeForLLT(Ty: LdStMI->getMMO().getMemoryType(),
4994 C&: PtrAdd.getMF()->getFunction().getContext());
4995 const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering();
4996 if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM,
4997 Ty: AccessTy, AddrSpace: AS))
4998 continue;
4999
5000 // Would x[offset1+offset2] still be a legal addressing mode?
5001 AM.BaseOffs = CombinedValue;
5002 if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM,
5003 Ty: AccessTy, AddrSpace: AS))
5004 return true;
5005 }
5006
5007 return false;
5008}
5009
5010bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI,
5011 MachineInstr *RHS,
5012 BuildFnTy &MatchInfo) const {
5013 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
5014 Register Src1Reg = MI.getOperand(i: 1).getReg();
5015 if (RHS->getOpcode() != TargetOpcode::G_ADD)
5016 return false;
5017 auto C2 = getIConstantVRegVal(VReg: RHS->getOperand(i: 2).getReg(), MRI);
5018 if (!C2)
5019 return false;
5020
5021 // If both additions are nuw, the reassociated additions are also nuw.
5022 // If the original G_PTR_ADD is additionally nusw, X and C are both not
5023 // negative, so BASE+X is between BASE and BASE+(X+C). The new G_PTR_ADDs are
5024 // therefore also nusw.
5025 // If the original G_PTR_ADD is additionally inbounds (which implies nusw),
5026 // the new G_PTR_ADDs are then also inbounds.
5027 unsigned PtrAddFlags = MI.getFlags();
5028 unsigned AddFlags = RHS->getFlags();
5029 bool IsNoUWrap = PtrAddFlags & AddFlags & MachineInstr::MIFlag::NoUWrap;
5030 bool IsNoUSWrap = IsNoUWrap && (PtrAddFlags & MachineInstr::MIFlag::NoUSWrap);
5031 bool IsInBounds = IsNoUWrap && (PtrAddFlags & MachineInstr::MIFlag::InBounds);
5032 unsigned Flags = 0;
5033 if (IsNoUWrap)
5034 Flags |= MachineInstr::MIFlag::NoUWrap;
5035 if (IsNoUSWrap)
5036 Flags |= MachineInstr::MIFlag::NoUSWrap;
5037 if (IsInBounds)
5038 Flags |= MachineInstr::MIFlag::InBounds;
5039
5040 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5041 LLT PtrTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5042
5043 auto NewBase =
5044 Builder.buildPtrAdd(Res: PtrTy, Op0: Src1Reg, Op1: RHS->getOperand(i: 1).getReg(), Flags);
5045 Observer.changingInstr(MI);
5046 MI.getOperand(i: 1).setReg(NewBase.getReg(Idx: 0));
5047 MI.getOperand(i: 2).setReg(RHS->getOperand(i: 2).getReg());
5048 MI.setFlags(Flags);
5049 Observer.changedInstr(MI);
5050 };
5051 return !reassociationCanBreakAddressingModePattern(MI);
5052}
5053
5054bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI,
5055 MachineInstr *LHS,
5056 MachineInstr *RHS,
5057 BuildFnTy &MatchInfo) const {
5058 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C)
5059 // if and only if (G_PTR_ADD X, C) has one use.
5060 Register LHSBase;
5061 std::optional<ValueAndVReg> LHSCstOff;
5062 if (!mi_match(R: MI.getBaseReg(), MRI,
5063 P: m_OneNonDBGUse(SP: m_GPtrAdd(L: m_Reg(R&: LHSBase), R: m_GCst(ValReg&: LHSCstOff)))))
5064 return false;
5065
5066 auto *LHSPtrAdd = cast<GPtrAdd>(Val: LHS);
5067
5068 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
5069 // nuw and inbounds (which implies nusw), the offsets are both non-negative,
5070 // so the new G_PTR_ADDs are also inbounds.
5071 unsigned PtrAddFlags = MI.getFlags();
5072 unsigned LHSPtrAddFlags = LHSPtrAdd->getFlags();
5073 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
5074 bool IsNoUSWrap = IsNoUWrap && (PtrAddFlags & LHSPtrAddFlags &
5075 MachineInstr::MIFlag::NoUSWrap);
5076 bool IsInBounds = IsNoUWrap && (PtrAddFlags & LHSPtrAddFlags &
5077 MachineInstr::MIFlag::InBounds);
5078 unsigned Flags = 0;
5079 if (IsNoUWrap)
5080 Flags |= MachineInstr::MIFlag::NoUWrap;
5081 if (IsNoUSWrap)
5082 Flags |= MachineInstr::MIFlag::NoUSWrap;
5083 if (IsInBounds)
5084 Flags |= MachineInstr::MIFlag::InBounds;
5085
5086 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5087 // When we change LHSPtrAdd's offset register we might cause it to use a reg
5088 // before its def. Sink the instruction so the outer PTR_ADD to ensure this
5089 // doesn't happen.
5090 LHSPtrAdd->moveBefore(MovePos: &MI);
5091 Register RHSReg = MI.getOffsetReg();
5092 // set VReg will cause type mismatch if it comes from extend/trunc
5093 auto NewCst = B.buildConstant(Res: MRI.getType(Reg: RHSReg), Val: LHSCstOff->Value);
5094 Observer.changingInstr(MI);
5095 MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0));
5096 MI.setFlags(Flags);
5097 Observer.changedInstr(MI);
5098 Observer.changingInstr(MI&: *LHSPtrAdd);
5099 LHSPtrAdd->getOperand(i: 2).setReg(RHSReg);
5100 LHSPtrAdd->setFlags(Flags);
5101 Observer.changedInstr(MI&: *LHSPtrAdd);
5102 };
5103 return !reassociationCanBreakAddressingModePattern(MI);
5104}
5105
5106bool CombinerHelper::matchReassocFoldConstantsInSubTree(
5107 GPtrAdd &MI, MachineInstr *LHS, MachineInstr *RHS,
5108 BuildFnTy &MatchInfo) const {
5109 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
5110 auto *LHSPtrAdd = dyn_cast<GPtrAdd>(Val: LHS);
5111 if (!LHSPtrAdd)
5112 return false;
5113
5114 Register Src2Reg = MI.getOperand(i: 2).getReg();
5115 Register LHSSrc1 = LHSPtrAdd->getBaseReg();
5116 Register LHSSrc2 = LHSPtrAdd->getOffsetReg();
5117 auto C1 = getIConstantVRegVal(VReg: LHSSrc2, MRI);
5118 if (!C1)
5119 return false;
5120 auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI);
5121 if (!C2)
5122 return false;
5123
5124 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
5125 // inbounds, reaching the same result in one G_PTR_ADD is also inbounds.
5126 // The nusw constraints are satisfied because imm1+imm2 cannot exceed the
5127 // largest signed integer that fits into the index type, which is the maximum
5128 // size of allocated objects according to the IR Language Reference.
5129 unsigned PtrAddFlags = MI.getFlags();
5130 unsigned LHSPtrAddFlags = LHSPtrAdd->getFlags();
5131 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
5132 bool IsInBounds =
5133 PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::InBounds;
5134 unsigned Flags = 0;
5135 if (IsNoUWrap)
5136 Flags |= MachineInstr::MIFlag::NoUWrap;
5137 if (IsInBounds) {
5138 Flags |= MachineInstr::MIFlag::InBounds;
5139 Flags |= MachineInstr::MIFlag::NoUSWrap;
5140 }
5141
5142 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5143 auto NewCst = B.buildConstant(Res: MRI.getType(Reg: Src2Reg), Val: *C1 + *C2);
5144 Observer.changingInstr(MI);
5145 MI.getOperand(i: 1).setReg(LHSSrc1);
5146 MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0));
5147 MI.setFlags(Flags);
5148 Observer.changedInstr(MI);
5149 };
5150 return !reassociationCanBreakAddressingModePattern(MI);
5151}
5152
5153bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI,
5154 BuildFnTy &MatchInfo) const {
5155 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
5156 // We're trying to match a few pointer computation patterns here for
5157 // re-association opportunities.
5158 // 1) Isolating a constant operand to be on the RHS, e.g.:
5159 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
5160 //
5161 // 2) Folding two constants in each sub-tree as long as such folding
5162 // doesn't break a legal addressing mode.
5163 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
5164 //
5165 // 3) Move a constant from the LHS of an inner op to the RHS of the outer.
5166 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C)
5167 // iif (G_PTR_ADD X, C) has one use.
5168 MachineInstr *LHS = MRI.getVRegDef(Reg: PtrAdd.getBaseReg());
5169 MachineInstr *RHS = MRI.getVRegDef(Reg: PtrAdd.getOffsetReg());
5170
5171 // Try to match example 2.
5172 if (matchReassocFoldConstantsInSubTree(MI&: PtrAdd, LHS, RHS, MatchInfo))
5173 return true;
5174
5175 // Try to match example 3.
5176 if (matchReassocConstantInnerLHS(MI&: PtrAdd, LHS, RHS, MatchInfo))
5177 return true;
5178
5179 // Try to match example 1.
5180 if (matchReassocConstantInnerRHS(MI&: PtrAdd, RHS, MatchInfo))
5181 return true;
5182
5183 return false;
5184}
5185bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg,
5186 Register OpLHS, Register OpRHS,
5187 BuildFnTy &MatchInfo) const {
5188 LLT OpRHSTy = MRI.getType(Reg: OpRHS);
5189 MachineInstr *OpLHSDef = MRI.getVRegDef(Reg: OpLHS);
5190
5191 if (OpLHSDef->getOpcode() != Opc)
5192 return false;
5193
5194 MachineInstr *OpRHSDef = MRI.getVRegDef(Reg: OpRHS);
5195 Register OpLHSLHS = OpLHSDef->getOperand(i: 1).getReg();
5196 Register OpLHSRHS = OpLHSDef->getOperand(i: 2).getReg();
5197
5198 // If the inner op is (X op C), pull the constant out so it can be folded with
5199 // other constants in the expression tree. Folding is not guaranteed so we
5200 // might have (C1 op C2). In that case do not pull a constant out because it
5201 // won't help and can lead to infinite loops.
5202 if (isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSRHS), MRI) &&
5203 !isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSLHS), MRI)) {
5204 if (isConstantOrConstantSplatVector(MI&: *OpRHSDef, MRI)) {
5205 // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2))
5206 MatchInfo = [=](MachineIRBuilder &B) {
5207 auto NewCst = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSRHS, OpRHS});
5208 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {OpLHSLHS, NewCst});
5209 };
5210 return true;
5211 }
5212 if (getTargetLowering().isReassocProfitable(MRI, N0: OpLHS, N1: OpRHS)) {
5213 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
5214 // iff (op x, c1) has one use
5215 MatchInfo = [=](MachineIRBuilder &B) {
5216 auto NewLHSLHS = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSLHS, OpRHS});
5217 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {NewLHSLHS, OpLHSRHS});
5218 };
5219 return true;
5220 }
5221 }
5222
5223 return false;
5224}
5225
5226bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI,
5227 BuildFnTy &MatchInfo) const {
5228 // We don't check if the reassociation will break a legal addressing mode
5229 // here since pointer arithmetic is handled by G_PTR_ADD.
5230 unsigned Opc = MI.getOpcode();
5231 Register DstReg = MI.getOperand(i: 0).getReg();
5232 Register LHSReg = MI.getOperand(i: 1).getReg();
5233 Register RHSReg = MI.getOperand(i: 2).getReg();
5234
5235 if (tryReassocBinOp(Opc, DstReg, OpLHS: LHSReg, OpRHS: RHSReg, MatchInfo))
5236 return true;
5237 if (tryReassocBinOp(Opc, DstReg, OpLHS: RHSReg, OpRHS: LHSReg, MatchInfo))
5238 return true;
5239 return false;
5240}
5241
5242bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI,
5243 APInt &MatchInfo) const {
5244 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5245 Register SrcOp = MI.getOperand(i: 1).getReg();
5246
5247 if (auto MaybeCst = ConstantFoldCastOp(Opcode: MI.getOpcode(), DstTy, Op0: SrcOp, MRI)) {
5248 MatchInfo = *MaybeCst;
5249 return true;
5250 }
5251
5252 return false;
5253}
5254
5255bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI,
5256 APInt &MatchInfo) const {
5257 Register Op1 = MI.getOperand(i: 1).getReg();
5258 Register Op2 = MI.getOperand(i: 2).getReg();
5259 auto MaybeCst = ConstantFoldBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI);
5260 if (!MaybeCst)
5261 return false;
5262 MatchInfo = *MaybeCst;
5263 return true;
5264}
5265
5266bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI,
5267 ConstantFP *&MatchInfo) const {
5268 Register Op1 = MI.getOperand(i: 1).getReg();
5269 Register Op2 = MI.getOperand(i: 2).getReg();
5270 auto MaybeCst = ConstantFoldFPBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI);
5271 if (!MaybeCst)
5272 return false;
5273 MatchInfo =
5274 ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: *MaybeCst);
5275 return true;
5276}
5277
5278bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI,
5279 ConstantFP *&MatchInfo) const {
5280 assert(MI.getOpcode() == TargetOpcode::G_FMA ||
5281 MI.getOpcode() == TargetOpcode::G_FMAD);
5282 auto [_, Op1, Op2, Op3] = MI.getFirst4Regs();
5283
5284 const ConstantFP *Op3Cst = getConstantFPVRegVal(VReg: Op3, MRI);
5285 if (!Op3Cst)
5286 return false;
5287
5288 const ConstantFP *Op2Cst = getConstantFPVRegVal(VReg: Op2, MRI);
5289 if (!Op2Cst)
5290 return false;
5291
5292 const ConstantFP *Op1Cst = getConstantFPVRegVal(VReg: Op1, MRI);
5293 if (!Op1Cst)
5294 return false;
5295
5296 APFloat Op1F = Op1Cst->getValueAPF();
5297 Op1F.fusedMultiplyAdd(Multiplicand: Op2Cst->getValueAPF(), Addend: Op3Cst->getValueAPF(),
5298 RM: APFloat::rmNearestTiesToEven);
5299 MatchInfo = ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: Op1F);
5300 return true;
5301}
5302
5303bool CombinerHelper::matchNarrowBinopFeedingAnd(
5304 MachineInstr &MI,
5305 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
5306 // Look for a binop feeding into an AND with a mask:
5307 //
5308 // %add = G_ADD %lhs, %rhs
5309 // %and = G_AND %add, 000...11111111
5310 //
5311 // Check if it's possible to perform the binop at a narrower width and zext
5312 // back to the original width like so:
5313 //
5314 // %narrow_lhs = G_TRUNC %lhs
5315 // %narrow_rhs = G_TRUNC %rhs
5316 // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs
5317 // %new_add = G_ZEXT %narrow_add
5318 // %and = G_AND %new_add, 000...11111111
5319 //
5320 // This can allow later combines to eliminate the G_AND if it turns out
5321 // that the mask is irrelevant.
5322 assert(MI.getOpcode() == TargetOpcode::G_AND);
5323 Register Dst = MI.getOperand(i: 0).getReg();
5324 Register AndLHS = MI.getOperand(i: 1).getReg();
5325 Register AndRHS = MI.getOperand(i: 2).getReg();
5326 LLT WideTy = MRI.getType(Reg: Dst);
5327
5328 // If the potential binop has more than one use, then it's possible that one
5329 // of those uses will need its full width.
5330 if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(RegNo: AndLHS))
5331 return false;
5332
5333 // Check if the LHS feeding the AND is impacted by the high bits that we're
5334 // masking out.
5335 //
5336 // e.g. for 64-bit x, y:
5337 //
5338 // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535
5339 MachineInstr *LHSInst = getDefIgnoringCopies(Reg: AndLHS, MRI);
5340 if (!LHSInst)
5341 return false;
5342 unsigned LHSOpc = LHSInst->getOpcode();
5343 switch (LHSOpc) {
5344 default:
5345 return false;
5346 case TargetOpcode::G_ADD:
5347 case TargetOpcode::G_SUB:
5348 case TargetOpcode::G_MUL:
5349 case TargetOpcode::G_AND:
5350 case TargetOpcode::G_OR:
5351 case TargetOpcode::G_XOR:
5352 break;
5353 }
5354
5355 // Find the mask on the RHS.
5356 auto Cst = getIConstantVRegValWithLookThrough(VReg: AndRHS, MRI);
5357 if (!Cst)
5358 return false;
5359 auto Mask = Cst->Value;
5360 if (!Mask.isMask())
5361 return false;
5362
5363 // No point in combining if there's nothing to truncate.
5364 unsigned NarrowWidth = Mask.countr_one();
5365 if (NarrowWidth == WideTy.getSizeInBits())
5366 return false;
5367 LLT NarrowTy = LLT::scalar(SizeInBits: NarrowWidth);
5368
5369 // Check if adding the zext + truncates could be harmful.
5370 auto &MF = *MI.getMF();
5371 const auto &TLI = getTargetLowering();
5372 LLVMContext &Ctx = MF.getFunction().getContext();
5373 if (!TLI.isTruncateFree(FromTy: WideTy, ToTy: NarrowTy, Ctx) ||
5374 !TLI.isZExtFree(FromTy: NarrowTy, ToTy: WideTy, Ctx))
5375 return false;
5376 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) ||
5377 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {WideTy, NarrowTy}}))
5378 return false;
5379 Register BinOpLHS = LHSInst->getOperand(i: 1).getReg();
5380 Register BinOpRHS = LHSInst->getOperand(i: 2).getReg();
5381 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5382 auto NarrowLHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpLHS);
5383 auto NarrowRHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpRHS);
5384 auto NarrowBinOp =
5385 Builder.buildInstr(Opc: LHSOpc, DstOps: {NarrowTy}, SrcOps: {NarrowLHS, NarrowRHS});
5386 auto Ext = Builder.buildZExt(Res: WideTy, Op: NarrowBinOp);
5387 Observer.changingInstr(MI);
5388 MI.getOperand(i: 1).setReg(Ext.getReg(Idx: 0));
5389 Observer.changedInstr(MI);
5390 };
5391 return true;
5392}
5393
5394bool CombinerHelper::matchMulOBy2(MachineInstr &MI,
5395 BuildFnTy &MatchInfo) const {
5396 unsigned Opc = MI.getOpcode();
5397 assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO);
5398
5399 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 2)))
5400 return false;
5401
5402 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5403 Observer.changingInstr(MI);
5404 unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO
5405 : TargetOpcode::G_SADDO;
5406 MI.setDesc(Builder.getTII().get(Opcode: NewOpc));
5407 MI.getOperand(i: 3).setReg(MI.getOperand(i: 2).getReg());
5408 Observer.changedInstr(MI);
5409 };
5410 return true;
5411}
5412
5413bool CombinerHelper::matchMulOBy0(MachineInstr &MI,
5414 BuildFnTy &MatchInfo) const {
5415 // (G_*MULO x, 0) -> 0 + no carry out
5416 assert(MI.getOpcode() == TargetOpcode::G_UMULO ||
5417 MI.getOpcode() == TargetOpcode::G_SMULO);
5418 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0)))
5419 return false;
5420 Register Dst = MI.getOperand(i: 0).getReg();
5421 Register Carry = MI.getOperand(i: 1).getReg();
5422 if (!isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Dst)) ||
5423 !isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Carry)))
5424 return false;
5425 MatchInfo = [=](MachineIRBuilder &B) {
5426 B.buildConstant(Res: Dst, Val: 0);
5427 B.buildConstant(Res: Carry, Val: 0);
5428 };
5429 return true;
5430}
5431
5432bool CombinerHelper::matchAddEToAddO(MachineInstr &MI,
5433 BuildFnTy &MatchInfo) const {
5434 // (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
5435 // (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
5436 assert(MI.getOpcode() == TargetOpcode::G_UADDE ||
5437 MI.getOpcode() == TargetOpcode::G_SADDE ||
5438 MI.getOpcode() == TargetOpcode::G_USUBE ||
5439 MI.getOpcode() == TargetOpcode::G_SSUBE);
5440 if (!mi_match(R: MI.getOperand(i: 4).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0)))
5441 return false;
5442 MatchInfo = [&](MachineIRBuilder &B) {
5443 unsigned NewOpcode;
5444 switch (MI.getOpcode()) {
5445 case TargetOpcode::G_UADDE:
5446 NewOpcode = TargetOpcode::G_UADDO;
5447 break;
5448 case TargetOpcode::G_SADDE:
5449 NewOpcode = TargetOpcode::G_SADDO;
5450 break;
5451 case TargetOpcode::G_USUBE:
5452 NewOpcode = TargetOpcode::G_USUBO;
5453 break;
5454 case TargetOpcode::G_SSUBE:
5455 NewOpcode = TargetOpcode::G_SSUBO;
5456 break;
5457 }
5458 Observer.changingInstr(MI);
5459 MI.setDesc(B.getTII().get(Opcode: NewOpcode));
5460 MI.removeOperand(OpNo: 4);
5461 Observer.changedInstr(MI);
5462 };
5463 return true;
5464}
5465
5466bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI,
5467 BuildFnTy &MatchInfo) const {
5468 assert(MI.getOpcode() == TargetOpcode::G_SUB);
5469 Register Dst = MI.getOperand(i: 0).getReg();
5470 // (x + y) - z -> x (if y == z)
5471 // (x + y) - z -> y (if x == z)
5472 Register X, Y, Z;
5473 if (mi_match(R: Dst, MRI, P: m_GSub(L: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: Y)), R: m_Reg(R&: Z)))) {
5474 Register ReplaceReg;
5475 int64_t CstX, CstY;
5476 if (Y == Z || (mi_match(R: Y, MRI, P: m_ICstOrSplat(Cst&: CstY)) &&
5477 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstY))))
5478 ReplaceReg = X;
5479 else if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5480 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5481 ReplaceReg = Y;
5482 if (ReplaceReg) {
5483 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: ReplaceReg); };
5484 return true;
5485 }
5486 }
5487
5488 // x - (y + z) -> 0 - y (if x == z)
5489 // x - (y + z) -> 0 - z (if x == y)
5490 if (mi_match(R: Dst, MRI, P: m_GSub(L: m_Reg(R&: X), R: m_GAdd(L: m_Reg(R&: Y), R: m_Reg(R&: Z))))) {
5491 Register ReplaceReg;
5492 int64_t CstX;
5493 if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5494 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5495 ReplaceReg = Y;
5496 else if (X == Y || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5497 mi_match(R: Y, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5498 ReplaceReg = Z;
5499 if (ReplaceReg) {
5500 MatchInfo = [=](MachineIRBuilder &B) {
5501 auto Zero = B.buildConstant(Res: MRI.getType(Reg: Dst), Val: 0);
5502 B.buildSub(Dst, Src0: Zero, Src1: ReplaceReg);
5503 };
5504 return true;
5505 }
5506 }
5507 return false;
5508}
5509
5510MachineInstr *CombinerHelper::buildUDivOrURemUsingMul(MachineInstr &MI) const {
5511 unsigned Opcode = MI.getOpcode();
5512 assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM);
5513 auto &UDivorRem = cast<GenericMachineInstr>(Val&: MI);
5514 Register Dst = UDivorRem.getReg(Idx: 0);
5515 Register LHS = UDivorRem.getReg(Idx: 1);
5516 Register RHS = UDivorRem.getReg(Idx: 2);
5517 LLT Ty = MRI.getType(Reg: Dst);
5518 LLT ScalarTy = Ty.getScalarType();
5519 const unsigned EltBits = ScalarTy.getScalarSizeInBits();
5520 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5521 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5522
5523 auto &MIB = Builder;
5524
5525 bool UseSRL = false;
5526 SmallVector<Register, 16> Shifts, Factors;
5527 auto *RHSDefInstr = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI));
5528 bool IsSplat = getIConstantSplatVal(MI: *RHSDefInstr, MRI).has_value();
5529
5530 auto BuildExactUDIVPattern = [&](const Constant *C) {
5531 // Don't recompute inverses for each splat element.
5532 if (IsSplat && !Factors.empty()) {
5533 Shifts.push_back(Elt: Shifts[0]);
5534 Factors.push_back(Elt: Factors[0]);
5535 return true;
5536 }
5537
5538 auto *CI = cast<ConstantInt>(Val: C);
5539 APInt Divisor = CI->getValue();
5540 unsigned Shift = Divisor.countr_zero();
5541 if (Shift) {
5542 Divisor.lshrInPlace(ShiftAmt: Shift);
5543 UseSRL = true;
5544 }
5545
5546 // Calculate the multiplicative inverse modulo BW.
5547 APInt Factor = Divisor.multiplicativeInverse();
5548 Shifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0));
5549 Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0));
5550 return true;
5551 };
5552
5553 if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5554 // Collect all magic values from the build vector.
5555 if (!matchUnaryPredicate(MRI, Reg: RHS, Match: BuildExactUDIVPattern))
5556 llvm_unreachable("Expected unary predicate match to succeed");
5557
5558 Register Shift, Factor;
5559 if (Ty.isVector()) {
5560 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0);
5561 Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0);
5562 } else {
5563 Shift = Shifts[0];
5564 Factor = Factors[0];
5565 }
5566
5567 Register Res = LHS;
5568
5569 if (UseSRL)
5570 Res = MIB.buildLShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0);
5571
5572 return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor);
5573 }
5574
5575 unsigned KnownLeadingZeros =
5576 VT ? VT->getKnownBits(R: LHS).countMinLeadingZeros() : 0;
5577
5578 bool UseNPQ = false;
5579 SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
5580 auto BuildUDIVPattern = [&](const Constant *C) {
5581 auto *CI = cast<ConstantInt>(Val: C);
5582 const APInt &Divisor = CI->getValue();
5583
5584 bool SelNPQ = false;
5585 APInt Magic(Divisor.getBitWidth(), 0);
5586 unsigned PreShift = 0, PostShift = 0;
5587
5588 // Magic algorithm doesn't work for division by 1. We need to emit a select
5589 // at the end.
5590 // TODO: Use undef values for divisor of 1.
5591 if (!Divisor.isOne()) {
5592
5593 // UnsignedDivisionByConstantInfo doesn't work correctly if leading zeros
5594 // in the dividend exceeds the leading zeros for the divisor.
5595 UnsignedDivisionByConstantInfo magics =
5596 UnsignedDivisionByConstantInfo::get(
5597 D: Divisor, LeadingZeros: std::min(a: KnownLeadingZeros, b: Divisor.countl_zero()));
5598
5599 Magic = std::move(magics.Magic);
5600
5601 assert(magics.PreShift < Divisor.getBitWidth() &&
5602 "We shouldn't generate an undefined shift!");
5603 assert(magics.PostShift < Divisor.getBitWidth() &&
5604 "We shouldn't generate an undefined shift!");
5605 assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift");
5606 PreShift = magics.PreShift;
5607 PostShift = magics.PostShift;
5608 SelNPQ = magics.IsAdd;
5609 }
5610
5611 PreShifts.push_back(
5612 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PreShift).getReg(Idx: 0));
5613 MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magic).getReg(Idx: 0));
5614 NPQFactors.push_back(
5615 Elt: MIB.buildConstant(Res: ScalarTy,
5616 Val: SelNPQ ? APInt::getOneBitSet(numBits: EltBits, BitNo: EltBits - 1)
5617 : APInt::getZero(numBits: EltBits))
5618 .getReg(Idx: 0));
5619 PostShifts.push_back(
5620 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PostShift).getReg(Idx: 0));
5621 UseNPQ |= SelNPQ;
5622 return true;
5623 };
5624
5625 // Collect the shifts/magic values from each element.
5626 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildUDIVPattern);
5627 (void)Matched;
5628 assert(Matched && "Expected unary predicate match to succeed");
5629
5630 Register PreShift, PostShift, MagicFactor, NPQFactor;
5631 auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI);
5632 if (RHSDef) {
5633 PreShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PreShifts).getReg(Idx: 0);
5634 MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0);
5635 NPQFactor = MIB.buildBuildVector(Res: Ty, Ops: NPQFactors).getReg(Idx: 0);
5636 PostShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PostShifts).getReg(Idx: 0);
5637 } else {
5638 assert(MRI.getType(RHS).isScalar() &&
5639 "Non-build_vector operation should have been a scalar");
5640 PreShift = PreShifts[0];
5641 MagicFactor = MagicFactors[0];
5642 PostShift = PostShifts[0];
5643 }
5644
5645 Register Q = LHS;
5646 Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PreShift).getReg(Idx: 0);
5647
5648 // Multiply the numerator (operand 0) by the magic value.
5649 Q = MIB.buildUMulH(Dst: Ty, Src0: Q, Src1: MagicFactor).getReg(Idx: 0);
5650
5651 if (UseNPQ) {
5652 Register NPQ = MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Q).getReg(Idx: 0);
5653
5654 // For vectors we might have a mix of non-NPQ/NPQ paths, so use
5655 // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero.
5656 if (Ty.isVector())
5657 NPQ = MIB.buildUMulH(Dst: Ty, Src0: NPQ, Src1: NPQFactor).getReg(Idx: 0);
5658 else
5659 NPQ = MIB.buildLShr(Dst: Ty, Src0: NPQ, Src1: MIB.buildConstant(Res: ShiftAmtTy, Val: 1)).getReg(Idx: 0);
5660
5661 Q = MIB.buildAdd(Dst: Ty, Src0: NPQ, Src1: Q).getReg(Idx: 0);
5662 }
5663
5664 Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PostShift).getReg(Idx: 0);
5665 auto One = MIB.buildConstant(Res: Ty, Val: 1);
5666 auto IsOne = MIB.buildICmp(
5667 Pred: CmpInst::Predicate::ICMP_EQ,
5668 Res: Ty.isScalar() ? LLT::scalar(SizeInBits: 1) : Ty.changeElementSize(NewEltSize: 1), Op0: RHS, Op1: One);
5669 auto ret = MIB.buildSelect(Res: Ty, Tst: IsOne, Op0: LHS, Op1: Q);
5670
5671 if (Opcode == TargetOpcode::G_UREM) {
5672 auto Prod = MIB.buildMul(Dst: Ty, Src0: ret, Src1: RHS);
5673 return MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Prod);
5674 }
5675 return ret;
5676}
5677
5678bool CombinerHelper::matchUDivOrURemByConst(MachineInstr &MI) const {
5679 unsigned Opcode = MI.getOpcode();
5680 assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM);
5681 Register Dst = MI.getOperand(i: 0).getReg();
5682 Register RHS = MI.getOperand(i: 2).getReg();
5683 LLT DstTy = MRI.getType(Reg: Dst);
5684
5685 auto &MF = *MI.getMF();
5686 AttributeList Attr = MF.getFunction().getAttributes();
5687 const auto &TLI = getTargetLowering();
5688 LLVMContext &Ctx = MF.getFunction().getContext();
5689 if (DstTy.getScalarSizeInBits() == 1 ||
5690 TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, Ctx), Attr))
5691 return false;
5692
5693 // Don't do this for minsize because the instruction sequence is usually
5694 // larger.
5695 if (MF.getFunction().hasMinSize())
5696 return false;
5697
5698 if (Opcode == TargetOpcode::G_UDIV &&
5699 MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5700 return matchUnaryPredicate(
5701 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5702 }
5703
5704 auto *RHSDef = MRI.getVRegDef(Reg: RHS);
5705 if (!isConstantOrConstantVector(MI&: *RHSDef, MRI))
5706 return false;
5707
5708 // Don't do this if the types are not going to be legal.
5709 if (LI) {
5710 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}}))
5711 return false;
5712 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMULH, {DstTy}}))
5713 return false;
5714 if (!isLegalOrBeforeLegalizer(
5715 Query: {TargetOpcode::G_ICMP,
5716 {DstTy.isVector() ? DstTy.changeElementSize(NewEltSize: 1) : LLT::scalar(SizeInBits: 1),
5717 DstTy}}))
5718 return false;
5719 if (Opcode == TargetOpcode::G_UREM &&
5720 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy, DstTy}}))
5721 return false;
5722 }
5723
5724 return matchUnaryPredicate(
5725 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5726}
5727
5728void CombinerHelper::applyUDivOrURemByConst(MachineInstr &MI) const {
5729 auto *NewMI = buildUDivOrURemUsingMul(MI);
5730 replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg());
5731}
5732
5733bool CombinerHelper::matchSDivOrSRemByConst(MachineInstr &MI) const {
5734 unsigned Opcode = MI.getOpcode();
5735 assert(Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM);
5736 Register Dst = MI.getOperand(i: 0).getReg();
5737 Register RHS = MI.getOperand(i: 2).getReg();
5738 LLT DstTy = MRI.getType(Reg: Dst);
5739 auto SizeInBits = DstTy.getScalarSizeInBits();
5740 LLT WideTy = DstTy.changeElementSize(NewEltSize: SizeInBits * 2);
5741
5742 auto &MF = *MI.getMF();
5743 AttributeList Attr = MF.getFunction().getAttributes();
5744 const auto &TLI = getTargetLowering();
5745 LLVMContext &Ctx = MF.getFunction().getContext();
5746 if (DstTy.getScalarSizeInBits() < 3 ||
5747 TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, Ctx), Attr))
5748 return false;
5749
5750 // Don't do this for minsize because the instruction sequence is usually
5751 // larger.
5752 if (MF.getFunction().hasMinSize())
5753 return false;
5754
5755 // If the sdiv has an 'exact' flag we can use a simpler lowering.
5756 if (Opcode == TargetOpcode::G_SDIV &&
5757 MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5758 return matchUnaryPredicate(
5759 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5760 }
5761
5762 auto *RHSDef = MRI.getVRegDef(Reg: RHS);
5763 if (!isConstantOrConstantVector(MI&: *RHSDef, MRI))
5764 return false;
5765
5766 // Don't do this if the types are not going to be legal.
5767 if (LI) {
5768 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}}))
5769 return false;
5770 if (!isLegal(Query: {TargetOpcode::G_SMULH, {DstTy}}) &&
5771 !isLegalOrHasWidenScalar(Query: {TargetOpcode::G_MUL, {WideTy, WideTy}}))
5772 return false;
5773 if (Opcode == TargetOpcode::G_SREM &&
5774 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy, DstTy}}))
5775 return false;
5776 }
5777
5778 return matchUnaryPredicate(
5779 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5780}
5781
5782void CombinerHelper::applySDivOrSRemByConst(MachineInstr &MI) const {
5783 auto *NewMI = buildSDivOrSRemUsingMul(MI);
5784 replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg());
5785}
5786
5787MachineInstr *CombinerHelper::buildSDivOrSRemUsingMul(MachineInstr &MI) const {
5788 unsigned Opcode = MI.getOpcode();
5789 assert(MI.getOpcode() == TargetOpcode::G_SDIV ||
5790 Opcode == TargetOpcode::G_SREM);
5791 auto &SDivorRem = cast<GenericMachineInstr>(Val&: MI);
5792 Register Dst = SDivorRem.getReg(Idx: 0);
5793 Register LHS = SDivorRem.getReg(Idx: 1);
5794 Register RHS = SDivorRem.getReg(Idx: 2);
5795 LLT Ty = MRI.getType(Reg: Dst);
5796 LLT ScalarTy = Ty.getScalarType();
5797 const unsigned EltBits = ScalarTy.getScalarSizeInBits();
5798 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5799 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5800 auto &MIB = Builder;
5801
5802 bool UseSRA = false;
5803 SmallVector<Register, 16> ExactShifts, ExactFactors;
5804
5805 auto *RHSDefInstr = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI));
5806 bool IsSplat = getIConstantSplatVal(MI: *RHSDefInstr, MRI).has_value();
5807
5808 auto BuildExactSDIVPattern = [&](const Constant *C) {
5809 // Don't recompute inverses for each splat element.
5810 if (IsSplat && !ExactFactors.empty()) {
5811 ExactShifts.push_back(Elt: ExactShifts[0]);
5812 ExactFactors.push_back(Elt: ExactFactors[0]);
5813 return true;
5814 }
5815
5816 auto *CI = cast<ConstantInt>(Val: C);
5817 APInt Divisor = CI->getValue();
5818 unsigned Shift = Divisor.countr_zero();
5819 if (Shift) {
5820 Divisor.ashrInPlace(ShiftAmt: Shift);
5821 UseSRA = true;
5822 }
5823
5824 // Calculate the multiplicative inverse modulo BW.
5825 // 2^W requires W + 1 bits, so we have to extend and then truncate.
5826 APInt Factor = Divisor.multiplicativeInverse();
5827 ExactShifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0));
5828 ExactFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0));
5829 return true;
5830 };
5831
5832 if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5833 // Collect all magic values from the build vector.
5834 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildExactSDIVPattern);
5835 (void)Matched;
5836 assert(Matched && "Expected unary predicate match to succeed");
5837
5838 Register Shift, Factor;
5839 if (Ty.isVector()) {
5840 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: ExactShifts).getReg(Idx: 0);
5841 Factor = MIB.buildBuildVector(Res: Ty, Ops: ExactFactors).getReg(Idx: 0);
5842 } else {
5843 Shift = ExactShifts[0];
5844 Factor = ExactFactors[0];
5845 }
5846
5847 Register Res = LHS;
5848
5849 if (UseSRA)
5850 Res = MIB.buildAShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0);
5851
5852 return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor);
5853 }
5854
5855 SmallVector<Register, 16> MagicFactors, Factors, Shifts, ShiftMasks;
5856
5857 auto BuildSDIVPattern = [&](const Constant *C) {
5858 auto *CI = cast<ConstantInt>(Val: C);
5859 const APInt &Divisor = CI->getValue();
5860
5861 SignedDivisionByConstantInfo Magics =
5862 SignedDivisionByConstantInfo::get(D: Divisor);
5863 int NumeratorFactor = 0;
5864 int ShiftMask = -1;
5865
5866 if (Divisor.isOne() || Divisor.isAllOnes()) {
5867 // If d is +1/-1, we just multiply the numerator by +1/-1.
5868 NumeratorFactor = Divisor.getSExtValue();
5869 Magics.Magic = 0;
5870 Magics.ShiftAmount = 0;
5871 ShiftMask = 0;
5872 } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) {
5873 // If d > 0 and m < 0, add the numerator.
5874 NumeratorFactor = 1;
5875 } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) {
5876 // If d < 0 and m > 0, subtract the numerator.
5877 NumeratorFactor = -1;
5878 }
5879
5880 MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magics.Magic).getReg(Idx: 0));
5881 Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: NumeratorFactor).getReg(Idx: 0));
5882 Shifts.push_back(
5883 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Magics.ShiftAmount).getReg(Idx: 0));
5884 ShiftMasks.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: ShiftMask).getReg(Idx: 0));
5885
5886 return true;
5887 };
5888
5889 // Collect the shifts/magic values from each element.
5890 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildSDIVPattern);
5891 (void)Matched;
5892 assert(Matched && "Expected unary predicate match to succeed");
5893
5894 Register MagicFactor, Factor, Shift, ShiftMask;
5895 auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI);
5896 if (RHSDef) {
5897 MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0);
5898 Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0);
5899 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0);
5900 ShiftMask = MIB.buildBuildVector(Res: Ty, Ops: ShiftMasks).getReg(Idx: 0);
5901 } else {
5902 assert(MRI.getType(RHS).isScalar() &&
5903 "Non-build_vector operation should have been a scalar");
5904 MagicFactor = MagicFactors[0];
5905 Factor = Factors[0];
5906 Shift = Shifts[0];
5907 ShiftMask = ShiftMasks[0];
5908 }
5909
5910 Register Q = LHS;
5911 Q = MIB.buildSMulH(Dst: Ty, Src0: LHS, Src1: MagicFactor).getReg(Idx: 0);
5912
5913 // (Optionally) Add/subtract the numerator using Factor.
5914 Factor = MIB.buildMul(Dst: Ty, Src0: LHS, Src1: Factor).getReg(Idx: 0);
5915 Q = MIB.buildAdd(Dst: Ty, Src0: Q, Src1: Factor).getReg(Idx: 0);
5916
5917 // Shift right algebraic by shift value.
5918 Q = MIB.buildAShr(Dst: Ty, Src0: Q, Src1: Shift).getReg(Idx: 0);
5919
5920 // Extract the sign bit, mask it and add it to the quotient.
5921 auto SignShift = MIB.buildConstant(Res: ShiftAmtTy, Val: EltBits - 1);
5922 auto T = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: SignShift);
5923 T = MIB.buildAnd(Dst: Ty, Src0: T, Src1: ShiftMask);
5924 auto ret = MIB.buildAdd(Dst: Ty, Src0: Q, Src1: T);
5925
5926 if (Opcode == TargetOpcode::G_SREM) {
5927 auto Prod = MIB.buildMul(Dst: Ty, Src0: ret, Src1: RHS);
5928 return MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Prod);
5929 }
5930 return ret;
5931}
5932
5933bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const {
5934 assert((MI.getOpcode() == TargetOpcode::G_SDIV ||
5935 MI.getOpcode() == TargetOpcode::G_UDIV) &&
5936 "Expected SDIV or UDIV");
5937 auto &Div = cast<GenericMachineInstr>(Val&: MI);
5938 Register RHS = Div.getReg(Idx: 2);
5939 auto MatchPow2 = [&](const Constant *C) {
5940 auto *CI = dyn_cast<ConstantInt>(Val: C);
5941 return CI && (CI->getValue().isPowerOf2() ||
5942 (IsSigned && CI->getValue().isNegatedPowerOf2()));
5943 };
5944 return matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2, /*AllowUndefs=*/false);
5945}
5946
5947void CombinerHelper::applySDivByPow2(MachineInstr &MI) const {
5948 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5949 auto &SDiv = cast<GenericMachineInstr>(Val&: MI);
5950 Register Dst = SDiv.getReg(Idx: 0);
5951 Register LHS = SDiv.getReg(Idx: 1);
5952 Register RHS = SDiv.getReg(Idx: 2);
5953 LLT Ty = MRI.getType(Reg: Dst);
5954 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5955 LLT CCVT =
5956 Ty.isVector() ? LLT::vector(EC: Ty.getElementCount(), ScalarSizeInBits: 1) : LLT::scalar(SizeInBits: 1);
5957
5958 // Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2,
5959 // to the following version:
5960 //
5961 // %c1 = G_CTTZ %rhs
5962 // %inexact = G_SUB $bitwidth, %c1
5963 // %sign = %G_ASHR %lhs, $(bitwidth - 1)
5964 // %lshr = G_LSHR %sign, %inexact
5965 // %add = G_ADD %lhs, %lshr
5966 // %ashr = G_ASHR %add, %c1
5967 // %ashr = G_SELECT, %isoneorallones, %lhs, %ashr
5968 // %zero = G_CONSTANT $0
5969 // %neg = G_NEG %ashr
5970 // %isneg = G_ICMP SLT %rhs, %zero
5971 // %res = G_SELECT %isneg, %neg, %ashr
5972
5973 unsigned BitWidth = Ty.getScalarSizeInBits();
5974 auto Zero = Builder.buildConstant(Res: Ty, Val: 0);
5975
5976 auto Bits = Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth);
5977 auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS);
5978 auto Inexact = Builder.buildSub(Dst: ShiftAmtTy, Src0: Bits, Src1: C1);
5979 // Splat the sign bit into the register
5980 auto Sign = Builder.buildAShr(
5981 Dst: Ty, Src0: LHS, Src1: Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth - 1));
5982
5983 // Add (LHS < 0) ? abs2 - 1 : 0;
5984 auto LSrl = Builder.buildLShr(Dst: Ty, Src0: Sign, Src1: Inexact);
5985 auto Add = Builder.buildAdd(Dst: Ty, Src0: LHS, Src1: LSrl);
5986 auto AShr = Builder.buildAShr(Dst: Ty, Src0: Add, Src1: C1);
5987
5988 // Special case: (sdiv X, 1) -> X
5989 // Special Case: (sdiv X, -1) -> 0-X
5990 auto One = Builder.buildConstant(Res: Ty, Val: 1);
5991 auto MinusOne = Builder.buildConstant(Res: Ty, Val: -1);
5992 auto IsOne = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: One);
5993 auto IsMinusOne =
5994 Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: MinusOne);
5995 auto IsOneOrMinusOne = Builder.buildOr(Dst: CCVT, Src0: IsOne, Src1: IsMinusOne);
5996 AShr = Builder.buildSelect(Res: Ty, Tst: IsOneOrMinusOne, Op0: LHS, Op1: AShr);
5997
5998 // If divided by a positive value, we're done. Otherwise, the result must be
5999 // negated.
6000 auto Neg = Builder.buildNeg(Dst: Ty, Src0: AShr);
6001 auto IsNeg = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_SLT, Res: CCVT, Op0: RHS, Op1: Zero);
6002 Builder.buildSelect(Res: MI.getOperand(i: 0).getReg(), Tst: IsNeg, Op0: Neg, Op1: AShr);
6003 MI.eraseFromParent();
6004}
6005
6006void CombinerHelper::applyUDivByPow2(MachineInstr &MI) const {
6007 assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV");
6008 auto &UDiv = cast<GenericMachineInstr>(Val&: MI);
6009 Register Dst = UDiv.getReg(Idx: 0);
6010 Register LHS = UDiv.getReg(Idx: 1);
6011 Register RHS = UDiv.getReg(Idx: 2);
6012 LLT Ty = MRI.getType(Reg: Dst);
6013 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6014
6015 auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS);
6016 Builder.buildLShr(Dst: MI.getOperand(i: 0).getReg(), Src0: LHS, Src1: C1);
6017 MI.eraseFromParent();
6018}
6019
6020bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) const {
6021 assert(MI.getOpcode() == TargetOpcode::G_UMULH);
6022 Register RHS = MI.getOperand(i: 2).getReg();
6023 Register Dst = MI.getOperand(i: 0).getReg();
6024 LLT Ty = MRI.getType(Reg: Dst);
6025 LLT RHSTy = MRI.getType(Reg: RHS);
6026 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6027 auto MatchPow2ExceptOne = [&](const Constant *C) {
6028 if (auto *CI = dyn_cast<ConstantInt>(Val: C))
6029 return CI->getValue().isPowerOf2() && !CI->getValue().isOne();
6030 return false;
6031 };
6032 if (!matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2ExceptOne, AllowUndefs: false))
6033 return false;
6034 // We need to check both G_LSHR and G_CTLZ because the combine uses G_CTLZ to
6035 // get log base 2, and it is not always legal for on a target.
6036 return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}}) &&
6037 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CTLZ, {RHSTy, RHSTy}});
6038}
6039
6040void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) const {
6041 Register LHS = MI.getOperand(i: 1).getReg();
6042 Register RHS = MI.getOperand(i: 2).getReg();
6043 Register Dst = MI.getOperand(i: 0).getReg();
6044 LLT Ty = MRI.getType(Reg: Dst);
6045 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6046 unsigned NumEltBits = Ty.getScalarSizeInBits();
6047
6048 auto LogBase2 = buildLogBase2(V: RHS, MIB&: Builder);
6049 auto ShiftAmt =
6050 Builder.buildSub(Dst: Ty, Src0: Builder.buildConstant(Res: Ty, Val: NumEltBits), Src1: LogBase2);
6051 auto Trunc = Builder.buildZExtOrTrunc(Res: ShiftAmtTy, Op: ShiftAmt);
6052 Builder.buildLShr(Dst, Src0: LHS, Src1: Trunc);
6053 MI.eraseFromParent();
6054}
6055
6056bool CombinerHelper::matchTruncSSatS(MachineInstr &MI,
6057 Register &MatchInfo) const {
6058 Register Dst = MI.getOperand(i: 0).getReg();
6059 Register Src = MI.getOperand(i: 1).getReg();
6060 LLT DstTy = MRI.getType(Reg: Dst);
6061 LLT SrcTy = MRI.getType(Reg: Src);
6062 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6063 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6064 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6065
6066 if (!LI || !isLegalOrHasFewerElements(
6067 Query: {TargetOpcode::G_TRUNC_SSAT_S, {DstTy, SrcTy}}))
6068 return false;
6069
6070 APInt SignedMax = APInt::getSignedMaxValue(numBits: NumDstBits).sext(width: NumSrcBits);
6071 APInt SignedMin = APInt::getSignedMinValue(numBits: NumDstBits).sext(width: NumSrcBits);
6072 return mi_match(R: Src, MRI,
6073 P: m_GSMin(L: m_GSMax(L: m_Reg(R&: MatchInfo),
6074 R: m_SpecificICstOrSplat(RequestedValue: SignedMin)),
6075 R: m_SpecificICstOrSplat(RequestedValue: SignedMax))) ||
6076 mi_match(R: Src, MRI,
6077 P: m_GSMax(L: m_GSMin(L: m_Reg(R&: MatchInfo),
6078 R: m_SpecificICstOrSplat(RequestedValue: SignedMax)),
6079 R: m_SpecificICstOrSplat(RequestedValue: SignedMin)));
6080}
6081
6082void CombinerHelper::applyTruncSSatS(MachineInstr &MI,
6083 Register &MatchInfo) const {
6084 Register Dst = MI.getOperand(i: 0).getReg();
6085 Builder.buildTruncSSatS(Res: Dst, Op: MatchInfo);
6086 MI.eraseFromParent();
6087}
6088
6089bool CombinerHelper::matchTruncSSatU(MachineInstr &MI,
6090 Register &MatchInfo) const {
6091 Register Dst = MI.getOperand(i: 0).getReg();
6092 Register Src = MI.getOperand(i: 1).getReg();
6093 LLT DstTy = MRI.getType(Reg: Dst);
6094 LLT SrcTy = MRI.getType(Reg: Src);
6095 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6096 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6097 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6098
6099 if (!LI || !isLegalOrHasFewerElements(
6100 Query: {TargetOpcode::G_TRUNC_SSAT_U, {DstTy, SrcTy}}))
6101 return false;
6102 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
6103 return mi_match(R: Src, MRI,
6104 P: m_GSMin(L: m_GSMax(L: m_Reg(R&: MatchInfo), R: m_SpecificICstOrSplat(RequestedValue: 0)),
6105 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax))) ||
6106 mi_match(R: Src, MRI,
6107 P: m_GSMax(L: m_GSMin(L: m_Reg(R&: MatchInfo),
6108 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)),
6109 R: m_SpecificICstOrSplat(RequestedValue: 0))) ||
6110 mi_match(R: Src, MRI,
6111 P: m_GUMin(L: m_GSMax(L: m_Reg(R&: MatchInfo), R: m_SpecificICstOrSplat(RequestedValue: 0)),
6112 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)));
6113}
6114
6115void CombinerHelper::applyTruncSSatU(MachineInstr &MI,
6116 Register &MatchInfo) const {
6117 Register Dst = MI.getOperand(i: 0).getReg();
6118 Builder.buildTruncSSatU(Res: Dst, Op: MatchInfo);
6119 MI.eraseFromParent();
6120}
6121
6122bool CombinerHelper::matchTruncUSatU(MachineInstr &MI,
6123 MachineInstr &MinMI) const {
6124 Register Min = MinMI.getOperand(i: 2).getReg();
6125 Register Val = MinMI.getOperand(i: 1).getReg();
6126 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6127 LLT SrcTy = MRI.getType(Reg: Val);
6128 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6129 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6130 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6131
6132 if (!LI || !isLegalOrHasFewerElements(
6133 Query: {TargetOpcode::G_TRUNC_SSAT_U, {DstTy, SrcTy}}))
6134 return false;
6135 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
6136 return mi_match(R: Min, MRI, P: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)) &&
6137 !mi_match(R: Val, MRI, P: m_GSMax(L: m_Reg(), R: m_Reg()));
6138}
6139
6140bool CombinerHelper::matchTruncUSatUToFPTOUISat(MachineInstr &MI,
6141 MachineInstr &SrcMI) const {
6142 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6143 LLT SrcTy = MRI.getType(Reg: SrcMI.getOperand(i: 1).getReg());
6144
6145 return LI &&
6146 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FPTOUI_SAT, {DstTy, SrcTy}});
6147}
6148
6149bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI,
6150 BuildFnTy &MatchInfo) const {
6151 unsigned Opc = MI.getOpcode();
6152 assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB ||
6153 Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
6154 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA);
6155
6156 Register Dst = MI.getOperand(i: 0).getReg();
6157 Register X = MI.getOperand(i: 1).getReg();
6158 Register Y = MI.getOperand(i: 2).getReg();
6159 LLT Type = MRI.getType(Reg: Dst);
6160
6161 // fold (fadd x, fneg(y)) -> (fsub x, y)
6162 // fold (fadd fneg(y), x) -> (fsub x, y)
6163 // G_ADD is commutative so both cases are checked by m_GFAdd
6164 if (mi_match(R: Dst, MRI, P: m_GFAdd(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) &&
6165 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FSUB, {Type}})) {
6166 Opc = TargetOpcode::G_FSUB;
6167 }
6168 /// fold (fsub x, fneg(y)) -> (fadd x, y)
6169 else if (mi_match(R: Dst, MRI, P: m_GFSub(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) &&
6170 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FADD, {Type}})) {
6171 Opc = TargetOpcode::G_FADD;
6172 }
6173 // fold (fmul fneg(x), fneg(y)) -> (fmul x, y)
6174 // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y)
6175 // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z)
6176 // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z)
6177 else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
6178 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) &&
6179 mi_match(R: X, MRI, P: m_GFNeg(Src: m_Reg(R&: X))) &&
6180 mi_match(R: Y, MRI, P: m_GFNeg(Src: m_Reg(R&: Y)))) {
6181 // no opcode change
6182 } else
6183 return false;
6184
6185 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6186 Observer.changingInstr(MI);
6187 MI.setDesc(B.getTII().get(Opcode: Opc));
6188 MI.getOperand(i: 1).setReg(X);
6189 MI.getOperand(i: 2).setReg(Y);
6190 Observer.changedInstr(MI);
6191 };
6192 return true;
6193}
6194
6195bool CombinerHelper::matchFsubToFneg(MachineInstr &MI,
6196 Register &MatchInfo) const {
6197 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6198
6199 Register LHS = MI.getOperand(i: 1).getReg();
6200 MatchInfo = MI.getOperand(i: 2).getReg();
6201 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6202
6203 const auto LHSCst = Ty.isVector()
6204 ? getFConstantSplat(VReg: LHS, MRI, /* allowUndef */ AllowUndef: true)
6205 : getFConstantVRegValWithLookThrough(VReg: LHS, MRI);
6206 if (!LHSCst)
6207 return false;
6208
6209 // -0.0 is always allowed
6210 if (LHSCst->Value.isNegZero())
6211 return true;
6212
6213 // +0.0 is only allowed if nsz is set.
6214 if (LHSCst->Value.isPosZero())
6215 return MI.getFlag(Flag: MachineInstr::FmNsz);
6216
6217 return false;
6218}
6219
6220void CombinerHelper::applyFsubToFneg(MachineInstr &MI,
6221 Register &MatchInfo) const {
6222 Register Dst = MI.getOperand(i: 0).getReg();
6223 Builder.buildFNeg(
6224 Dst, Src0: Builder.buildFCanonicalize(Dst: MRI.getType(Reg: Dst), Src0: MatchInfo).getReg(Idx: 0));
6225 eraseInst(MI);
6226}
6227
6228/// Checks if \p MI is TargetOpcode::G_FMUL and contractable either
6229/// due to global flags or MachineInstr flags.
6230static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) {
6231 if (MI.getOpcode() != TargetOpcode::G_FMUL)
6232 return false;
6233 return AllowFusionGlobally || MI.getFlag(Flag: MachineInstr::MIFlag::FmContract);
6234}
6235
6236static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1,
6237 const MachineRegisterInfo &MRI) {
6238 return std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI0.getOperand(i: 0).getReg()),
6239 last: MRI.use_instr_nodbg_end()) >
6240 std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI1.getOperand(i: 0).getReg()),
6241 last: MRI.use_instr_nodbg_end());
6242}
6243
6244bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI,
6245 bool &AllowFusionGlobally,
6246 bool &HasFMAD, bool &Aggressive,
6247 bool CanReassociate) const {
6248
6249 auto *MF = MI.getMF();
6250 const auto &TLI = *MF->getSubtarget().getTargetLowering();
6251 const TargetOptions &Options = MF->getTarget().Options;
6252 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6253
6254 if (CanReassociate && !MI.getFlag(Flag: MachineInstr::MIFlag::FmReassoc))
6255 return false;
6256
6257 // Floating-point multiply-add with intermediate rounding.
6258 HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, Ty: DstType));
6259 // Floating-point multiply-add without intermediate rounding.
6260 bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(MF: *MF, DstType) &&
6261 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FMA, {DstType}});
6262 // No valid opcode, do not combine.
6263 if (!HasFMAD && !HasFMA)
6264 return false;
6265
6266 AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD;
6267 // If the addition is not contractable, do not combine.
6268 if (!AllowFusionGlobally && !MI.getFlag(Flag: MachineInstr::MIFlag::FmContract))
6269 return false;
6270
6271 Aggressive = TLI.enableAggressiveFMAFusion(Ty: DstType);
6272 return true;
6273}
6274
6275bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA(
6276 MachineInstr &MI,
6277 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6278 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6279
6280 bool AllowFusionGlobally, HasFMAD, Aggressive;
6281 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6282 return false;
6283
6284 Register Op1 = MI.getOperand(i: 1).getReg();
6285 Register Op2 = MI.getOperand(i: 2).getReg();
6286 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6287 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6288 unsigned PreferredFusedOpcode =
6289 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6290
6291 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6292 // prefer to fold the multiply with fewer uses.
6293 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6294 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6295 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6296 std::swap(a&: LHS, b&: RHS);
6297 }
6298
6299 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
6300 if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6301 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg))) {
6302 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6303 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6304 SrcOps: {LHS.MI->getOperand(i: 1).getReg(),
6305 LHS.MI->getOperand(i: 2).getReg(), RHS.Reg});
6306 };
6307 return true;
6308 }
6309
6310 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
6311 if (isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6312 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg))) {
6313 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6314 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6315 SrcOps: {RHS.MI->getOperand(i: 1).getReg(),
6316 RHS.MI->getOperand(i: 2).getReg(), LHS.Reg});
6317 };
6318 return true;
6319 }
6320
6321 return false;
6322}
6323
6324bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA(
6325 MachineInstr &MI,
6326 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6327 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6328
6329 bool AllowFusionGlobally, HasFMAD, Aggressive;
6330 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6331 return false;
6332
6333 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6334 Register Op1 = MI.getOperand(i: 1).getReg();
6335 Register Op2 = MI.getOperand(i: 2).getReg();
6336 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6337 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6338 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6339
6340 unsigned PreferredFusedOpcode =
6341 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6342
6343 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6344 // prefer to fold the multiply with fewer uses.
6345 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6346 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6347 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6348 std::swap(a&: LHS, b&: RHS);
6349 }
6350
6351 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
6352 MachineInstr *FpExtSrc;
6353 if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) &&
6354 isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) &&
6355 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6356 SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) {
6357 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6358 auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg());
6359 auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg());
6360 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6361 SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), RHS.Reg});
6362 };
6363 return true;
6364 }
6365
6366 // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z)
6367 // Note: Commutes FADD operands.
6368 if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) &&
6369 isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) &&
6370 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6371 SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) {
6372 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6373 auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg());
6374 auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg());
6375 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6376 SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), LHS.Reg});
6377 };
6378 return true;
6379 }
6380
6381 return false;
6382}
6383
6384bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA(
6385 MachineInstr &MI,
6386 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6387 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6388
6389 bool AllowFusionGlobally, HasFMAD, Aggressive;
6390 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, CanReassociate: true))
6391 return false;
6392
6393 Register Op1 = MI.getOperand(i: 1).getReg();
6394 Register Op2 = MI.getOperand(i: 2).getReg();
6395 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6396 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6397 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6398
6399 unsigned PreferredFusedOpcode =
6400 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6401
6402 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6403 // prefer to fold the multiply with fewer uses.
6404 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6405 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6406 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6407 std::swap(a&: LHS, b&: RHS);
6408 }
6409
6410 MachineInstr *FMA = nullptr;
6411 Register Z;
6412 // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z))
6413 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
6414 (MRI.getVRegDef(Reg: LHS.MI->getOperand(i: 3).getReg())->getOpcode() ==
6415 TargetOpcode::G_FMUL) &&
6416 MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 0).getReg()) &&
6417 MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 3).getReg())) {
6418 FMA = LHS.MI;
6419 Z = RHS.Reg;
6420 }
6421 // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z))
6422 else if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
6423 (MRI.getVRegDef(Reg: RHS.MI->getOperand(i: 3).getReg())->getOpcode() ==
6424 TargetOpcode::G_FMUL) &&
6425 MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 0).getReg()) &&
6426 MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 3).getReg())) {
6427 Z = LHS.Reg;
6428 FMA = RHS.MI;
6429 }
6430
6431 if (FMA) {
6432 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMA->getOperand(i: 3).getReg());
6433 Register X = FMA->getOperand(i: 1).getReg();
6434 Register Y = FMA->getOperand(i: 2).getReg();
6435 Register U = FMulMI->getOperand(i: 1).getReg();
6436 Register V = FMulMI->getOperand(i: 2).getReg();
6437
6438 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6439 Register InnerFMA = MRI.createGenericVirtualRegister(Ty: DstTy);
6440 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {InnerFMA}, SrcOps: {U, V, Z});
6441 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6442 SrcOps: {X, Y, InnerFMA});
6443 };
6444 return true;
6445 }
6446
6447 return false;
6448}
6449
6450bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
6451 MachineInstr &MI,
6452 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6453 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6454
6455 bool AllowFusionGlobally, HasFMAD, Aggressive;
6456 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6457 return false;
6458
6459 if (!Aggressive)
6460 return false;
6461
6462 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6463 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6464 Register Op1 = MI.getOperand(i: 1).getReg();
6465 Register Op2 = MI.getOperand(i: 2).getReg();
6466 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6467 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6468
6469 unsigned PreferredFusedOpcode =
6470 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6471
6472 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6473 // prefer to fold the multiply with fewer uses.
6474 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6475 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6476 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6477 std::swap(a&: LHS, b&: RHS);
6478 }
6479
6480 // Builds: (fma x, y, (fma (fpext u), (fpext v), z))
6481 auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X,
6482 Register Y, MachineIRBuilder &B) {
6483 Register FpExtU = B.buildFPExt(Res: DstType, Op: U).getReg(Idx: 0);
6484 Register FpExtV = B.buildFPExt(Res: DstType, Op: V).getReg(Idx: 0);
6485 Register InnerFMA =
6486 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {DstType}, SrcOps: {FpExtU, FpExtV, Z})
6487 .getReg(Idx: 0);
6488 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6489 SrcOps: {X, Y, InnerFMA});
6490 };
6491
6492 MachineInstr *FMulMI, *FMAMI;
6493 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
6494 // -> (fma x, y, (fma (fpext u), (fpext v), z))
6495 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
6496 mi_match(R: LHS.MI->getOperand(i: 3).getReg(), MRI,
6497 P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6498 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6499 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6500 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6501 MatchInfo = [=](MachineIRBuilder &B) {
6502 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6503 FMulMI->getOperand(i: 2).getReg(), RHS.Reg,
6504 LHS.MI->getOperand(i: 1).getReg(),
6505 LHS.MI->getOperand(i: 2).getReg(), B);
6506 };
6507 return true;
6508 }
6509
6510 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
6511 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
6512 // FIXME: This turns two single-precision and one double-precision
6513 // operation into two double-precision operations, which might not be
6514 // interesting for all targets, especially GPUs.
6515 if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) &&
6516 FMAMI->getOpcode() == PreferredFusedOpcode) {
6517 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg());
6518 if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6519 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6520 SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) {
6521 MatchInfo = [=](MachineIRBuilder &B) {
6522 Register X = FMAMI->getOperand(i: 1).getReg();
6523 Register Y = FMAMI->getOperand(i: 2).getReg();
6524 X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0);
6525 Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0);
6526 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6527 FMulMI->getOperand(i: 2).getReg(), RHS.Reg, X, Y, B);
6528 };
6529
6530 return true;
6531 }
6532 }
6533
6534 // fold (fadd z, (fma x, y, (fpext (fmul u, v)))
6535 // -> (fma x, y, (fma (fpext u), (fpext v), z))
6536 if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
6537 mi_match(R: RHS.MI->getOperand(i: 3).getReg(), MRI,
6538 P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6539 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6540 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6541 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6542 MatchInfo = [=](MachineIRBuilder &B) {
6543 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6544 FMulMI->getOperand(i: 2).getReg(), LHS.Reg,
6545 RHS.MI->getOperand(i: 1).getReg(),
6546 RHS.MI->getOperand(i: 2).getReg(), B);
6547 };
6548 return true;
6549 }
6550
6551 // fold (fadd z, (fpext (fma x, y, (fmul u, v)))
6552 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
6553 // FIXME: This turns two single-precision and one double-precision
6554 // operation into two double-precision operations, which might not be
6555 // interesting for all targets, especially GPUs.
6556 if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) &&
6557 FMAMI->getOpcode() == PreferredFusedOpcode) {
6558 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg());
6559 if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6560 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6561 SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) {
6562 MatchInfo = [=](MachineIRBuilder &B) {
6563 Register X = FMAMI->getOperand(i: 1).getReg();
6564 Register Y = FMAMI->getOperand(i: 2).getReg();
6565 X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0);
6566 Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0);
6567 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6568 FMulMI->getOperand(i: 2).getReg(), LHS.Reg, X, Y, B);
6569 };
6570 return true;
6571 }
6572 }
6573
6574 return false;
6575}
6576
6577bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA(
6578 MachineInstr &MI,
6579 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6580 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6581
6582 bool AllowFusionGlobally, HasFMAD, Aggressive;
6583 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6584 return false;
6585
6586 Register Op1 = MI.getOperand(i: 1).getReg();
6587 Register Op2 = MI.getOperand(i: 2).getReg();
6588 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6589 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6590 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6591
6592 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6593 // prefer to fold the multiply with fewer uses.
6594 int FirstMulHasFewerUses = true;
6595 if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6596 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6597 hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6598 FirstMulHasFewerUses = false;
6599
6600 unsigned PreferredFusedOpcode =
6601 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6602
6603 // fold (fsub (fmul x, y), z) -> (fma x, y, -z)
6604 if (FirstMulHasFewerUses &&
6605 (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6606 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg)))) {
6607 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6608 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHS.Reg).getReg(Idx: 0);
6609 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6610 SrcOps: {LHS.MI->getOperand(i: 1).getReg(),
6611 LHS.MI->getOperand(i: 2).getReg(), NegZ});
6612 };
6613 return true;
6614 }
6615 // fold (fsub x, (fmul y, z)) -> (fma -y, z, x)
6616 else if ((isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6617 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg)))) {
6618 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6619 Register NegY =
6620 B.buildFNeg(Dst: DstTy, Src0: RHS.MI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6621 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6622 SrcOps: {NegY, RHS.MI->getOperand(i: 2).getReg(), LHS.Reg});
6623 };
6624 return true;
6625 }
6626
6627 return false;
6628}
6629
6630bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA(
6631 MachineInstr &MI,
6632 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6633 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6634
6635 bool AllowFusionGlobally, HasFMAD, Aggressive;
6636 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6637 return false;
6638
6639 Register LHSReg = MI.getOperand(i: 1).getReg();
6640 Register RHSReg = MI.getOperand(i: 2).getReg();
6641 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6642
6643 unsigned PreferredFusedOpcode =
6644 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6645
6646 MachineInstr *FMulMI;
6647 // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z))
6648 if (mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) &&
6649 (Aggressive || (MRI.hasOneNonDBGUse(RegNo: LHSReg) &&
6650 MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) &&
6651 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) {
6652 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6653 Register NegX =
6654 B.buildFNeg(Dst: DstTy, Src0: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6655 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0);
6656 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6657 SrcOps: {NegX, FMulMI->getOperand(i: 2).getReg(), NegZ});
6658 };
6659 return true;
6660 }
6661
6662 // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x)
6663 if (mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) &&
6664 (Aggressive || (MRI.hasOneNonDBGUse(RegNo: RHSReg) &&
6665 MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) &&
6666 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) {
6667 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6668 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6669 SrcOps: {FMulMI->getOperand(i: 1).getReg(),
6670 FMulMI->getOperand(i: 2).getReg(), LHSReg});
6671 };
6672 return true;
6673 }
6674
6675 return false;
6676}
6677
6678bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA(
6679 MachineInstr &MI,
6680 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6681 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6682
6683 bool AllowFusionGlobally, HasFMAD, Aggressive;
6684 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6685 return false;
6686
6687 Register LHSReg = MI.getOperand(i: 1).getReg();
6688 Register RHSReg = MI.getOperand(i: 2).getReg();
6689 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6690
6691 unsigned PreferredFusedOpcode =
6692 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6693
6694 MachineInstr *FMulMI;
6695 // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
6696 if (mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6697 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6698 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHSReg))) {
6699 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6700 Register FpExtX =
6701 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6702 Register FpExtY =
6703 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0);
6704 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0);
6705 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6706 SrcOps: {FpExtX, FpExtY, NegZ});
6707 };
6708 return true;
6709 }
6710
6711 // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
6712 if (mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6713 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6714 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHSReg))) {
6715 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6716 Register FpExtY =
6717 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6718 Register NegY = B.buildFNeg(Dst: DstTy, Src0: FpExtY).getReg(Idx: 0);
6719 Register FpExtZ =
6720 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0);
6721 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6722 SrcOps: {NegY, FpExtZ, LHSReg});
6723 };
6724 return true;
6725 }
6726
6727 return false;
6728}
6729
6730bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
6731 MachineInstr &MI,
6732 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6733 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6734
6735 bool AllowFusionGlobally, HasFMAD, Aggressive;
6736 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6737 return false;
6738
6739 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6740 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6741 Register LHSReg = MI.getOperand(i: 1).getReg();
6742 Register RHSReg = MI.getOperand(i: 2).getReg();
6743
6744 unsigned PreferredFusedOpcode =
6745 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6746
6747 auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z,
6748 MachineIRBuilder &B) {
6749 Register FpExtX = B.buildFPExt(Res: DstTy, Op: X).getReg(Idx: 0);
6750 Register FpExtY = B.buildFPExt(Res: DstTy, Op: Y).getReg(Idx: 0);
6751 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {Dst}, SrcOps: {FpExtX, FpExtY, Z});
6752 };
6753
6754 MachineInstr *FMulMI;
6755 // fold (fsub (fpext (fneg (fmul x, y))), z) ->
6756 // (fneg (fma (fpext x), (fpext y), z))
6757 // fold (fsub (fneg (fpext (fmul x, y))), z) ->
6758 // (fneg (fma (fpext x), (fpext y), z))
6759 if ((mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) ||
6760 mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) &&
6761 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6762 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy,
6763 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6764 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6765 Register FMAReg = MRI.createGenericVirtualRegister(Ty: DstTy);
6766 buildMatchInfo(FMAReg, FMulMI->getOperand(i: 1).getReg(),
6767 FMulMI->getOperand(i: 2).getReg(), RHSReg, B);
6768 B.buildFNeg(Dst: MI.getOperand(i: 0).getReg(), Src0: FMAReg);
6769 };
6770 return true;
6771 }
6772
6773 // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
6774 // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
6775 if ((mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) ||
6776 mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) &&
6777 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6778 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy,
6779 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6780 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6781 buildMatchInfo(MI.getOperand(i: 0).getReg(), FMulMI->getOperand(i: 1).getReg(),
6782 FMulMI->getOperand(i: 2).getReg(), LHSReg, B);
6783 };
6784 return true;
6785 }
6786
6787 return false;
6788}
6789
6790bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,
6791 unsigned &IdxToPropagate) const {
6792 bool PropagateNaN;
6793 switch (MI.getOpcode()) {
6794 default:
6795 return false;
6796 case TargetOpcode::G_FMINNUM:
6797 case TargetOpcode::G_FMAXNUM:
6798 PropagateNaN = false;
6799 break;
6800 case TargetOpcode::G_FMINIMUM:
6801 case TargetOpcode::G_FMAXIMUM:
6802 PropagateNaN = true;
6803 break;
6804 }
6805
6806 auto MatchNaN = [&](unsigned Idx) {
6807 Register MaybeNaNReg = MI.getOperand(i: Idx).getReg();
6808 const ConstantFP *MaybeCst = getConstantFPVRegVal(VReg: MaybeNaNReg, MRI);
6809 if (!MaybeCst || !MaybeCst->getValueAPF().isNaN())
6810 return false;
6811 IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1);
6812 return true;
6813 };
6814
6815 return MatchNaN(1) || MatchNaN(2);
6816}
6817
6818// Combine multiple FDIVs with the same divisor into multiple FMULs by the
6819// reciprocal.
6820// E.g., (a / Y; b / Y;) -> (recip = 1.0 / Y; a * recip; b * recip)
6821bool CombinerHelper::matchRepeatedFPDivisor(
6822 MachineInstr &MI, SmallVector<MachineInstr *> &MatchInfo) const {
6823 assert(MI.getOpcode() == TargetOpcode::G_FDIV);
6824
6825 Register X = MI.getOperand(i: 1).getReg();
6826 Register Y = MI.getOperand(i: 2).getReg();
6827
6828 if (!MI.getFlag(Flag: MachineInstr::MIFlag::FmArcp))
6829 return false;
6830
6831 // Skip if current node is a reciprocal/fneg-reciprocal.
6832 auto N0CFP = isConstantOrConstantSplatVectorFP(MI&: *MRI.getVRegDef(Reg: X), MRI);
6833 if (N0CFP && (N0CFP->isExactlyValue(V: 1.0) || N0CFP->isExactlyValue(V: -1.0)))
6834 return false;
6835
6836 // Exit early if the target does not want this transform or if there can't
6837 // possibly be enough uses of the divisor to make the transform worthwhile.
6838 unsigned MinUses = getTargetLowering().combineRepeatedFPDivisors();
6839 if (!MinUses)
6840 return false;
6841
6842 // Find all FDIV users of the same divisor. For the moment we limit all
6843 // instructions to a single BB and use the first Instr in MatchInfo as the
6844 // dominating position.
6845 MatchInfo.push_back(Elt: &MI);
6846 for (auto &U : MRI.use_nodbg_instructions(Reg: Y)) {
6847 if (&U == &MI || U.getParent() != MI.getParent())
6848 continue;
6849 if (U.getOpcode() == TargetOpcode::G_FDIV &&
6850 U.getOperand(i: 2).getReg() == Y && U.getOperand(i: 1).getReg() != Y) {
6851 // This division is eligible for optimization only if global unsafe math
6852 // is enabled or if this division allows reciprocal formation.
6853 if (U.getFlag(Flag: MachineInstr::MIFlag::FmArcp)) {
6854 MatchInfo.push_back(Elt: &U);
6855 if (dominates(DefMI: U, UseMI: *MatchInfo[0]))
6856 std::swap(a&: MatchInfo[0], b&: MatchInfo.back());
6857 }
6858 }
6859 }
6860
6861 // Now that we have the actual number of divisor uses, make sure it meets
6862 // the minimum threshold specified by the target.
6863 return MatchInfo.size() >= MinUses;
6864}
6865
6866void CombinerHelper::applyRepeatedFPDivisor(
6867 SmallVector<MachineInstr *> &MatchInfo) const {
6868 // Generate the new div at the position of the first instruction, that we have
6869 // ensured will dominate all other instructions.
6870 Builder.setInsertPt(MBB&: *MatchInfo[0]->getParent(), II: MatchInfo[0]);
6871 LLT Ty = MRI.getType(Reg: MatchInfo[0]->getOperand(i: 0).getReg());
6872 auto Div = Builder.buildFDiv(Dst: Ty, Src0: Builder.buildFConstant(Res: Ty, Val: 1.0),
6873 Src1: MatchInfo[0]->getOperand(i: 2).getReg(),
6874 Flags: MatchInfo[0]->getFlags());
6875
6876 // Replace all found div's with fmul instructions.
6877 for (MachineInstr *MI : MatchInfo) {
6878 Builder.setInsertPt(MBB&: *MI->getParent(), II: MI);
6879 Builder.buildFMul(Dst: MI->getOperand(i: 0).getReg(), Src0: MI->getOperand(i: 1).getReg(),
6880 Src1: Div->getOperand(i: 0).getReg(), Flags: MI->getFlags());
6881 MI->eraseFromParent();
6882 }
6883}
6884
6885bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) const {
6886 assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD");
6887 Register LHS = MI.getOperand(i: 1).getReg();
6888 Register RHS = MI.getOperand(i: 2).getReg();
6889
6890 // Helper lambda to check for opportunities for
6891 // A + (B - A) -> B
6892 // (B - A) + A -> B
6893 auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) {
6894 Register Reg;
6895 return mi_match(R: MaybeSub, MRI, P: m_GSub(L: m_Reg(R&: Src), R: m_Reg(R&: Reg))) &&
6896 Reg == MaybeSameReg;
6897 };
6898 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
6899}
6900
6901bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI,
6902 Register &MatchInfo) const {
6903 // This combine folds the following patterns:
6904 //
6905 // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k))
6906 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k)))
6907 // into
6908 // x
6909 // if
6910 // k == sizeof(VecEltTy)/2
6911 // type(x) == type(dst)
6912 //
6913 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef)
6914 // into
6915 // x
6916 // if
6917 // type(x) == type(dst)
6918
6919 LLT DstVecTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6920 LLT DstEltTy = DstVecTy.getElementType();
6921
6922 Register Lo, Hi;
6923
6924 if (mi_match(
6925 MI, MRI,
6926 P: m_GBuildVector(L: m_GTrunc(Src: m_GBitcast(Src: m_Reg(R&: Lo))), R: m_GImplicitDef()))) {
6927 MatchInfo = Lo;
6928 return MRI.getType(Reg: MatchInfo) == DstVecTy;
6929 }
6930
6931 std::optional<ValueAndVReg> ShiftAmount;
6932 const auto LoPattern = m_GBitcast(Src: m_Reg(R&: Lo));
6933 const auto HiPattern = m_GLShr(L: m_GBitcast(Src: m_Reg(R&: Hi)), R: m_GCst(ValReg&: ShiftAmount));
6934 if (mi_match(
6935 MI, MRI,
6936 P: m_any_of(preds: m_GBuildVectorTrunc(L: LoPattern, R: HiPattern),
6937 preds: m_GBuildVector(L: m_GTrunc(Src: LoPattern), R: m_GTrunc(Src: HiPattern))))) {
6938 if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) {
6939 MatchInfo = Lo;
6940 return MRI.getType(Reg: MatchInfo) == DstVecTy;
6941 }
6942 }
6943
6944 return false;
6945}
6946
6947bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI,
6948 Register &MatchInfo) const {
6949 // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x
6950 // if type(x) == type(G_TRUNC)
6951 if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI,
6952 P: m_GBitcast(Src: m_GBuildVector(L: m_Reg(R&: MatchInfo), R: m_Reg()))))
6953 return false;
6954
6955 return MRI.getType(Reg: MatchInfo) == MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6956}
6957
6958bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI,
6959 Register &MatchInfo) const {
6960 // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with
6961 // y if K == size of vector element type
6962 std::optional<ValueAndVReg> ShiftAmt;
6963 if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI,
6964 P: m_GLShr(L: m_GBitcast(Src: m_GBuildVector(L: m_Reg(), R: m_Reg(R&: MatchInfo))),
6965 R: m_GCst(ValReg&: ShiftAmt))))
6966 return false;
6967
6968 LLT MatchTy = MRI.getType(Reg: MatchInfo);
6969 return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() &&
6970 MatchTy == MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6971}
6972
6973unsigned CombinerHelper::getFPMinMaxOpcForSelect(
6974 CmpInst::Predicate Pred, LLT DstTy,
6975 SelectPatternNaNBehaviour VsNaNRetVal) const {
6976 assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE &&
6977 "Expected a NaN behaviour?");
6978 // Choose an opcode based off of legality or the behaviour when one of the
6979 // LHS/RHS may be NaN.
6980 switch (Pred) {
6981 default:
6982 return 0;
6983 case CmpInst::FCMP_UGT:
6984 case CmpInst::FCMP_UGE:
6985 case CmpInst::FCMP_OGT:
6986 case CmpInst::FCMP_OGE:
6987 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6988 return TargetOpcode::G_FMAXNUM;
6989 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6990 return TargetOpcode::G_FMAXIMUM;
6991 if (isLegal(Query: {TargetOpcode::G_FMAXNUM, {DstTy}}))
6992 return TargetOpcode::G_FMAXNUM;
6993 if (isLegal(Query: {TargetOpcode::G_FMAXIMUM, {DstTy}}))
6994 return TargetOpcode::G_FMAXIMUM;
6995 return 0;
6996 case CmpInst::FCMP_ULT:
6997 case CmpInst::FCMP_ULE:
6998 case CmpInst::FCMP_OLT:
6999 case CmpInst::FCMP_OLE:
7000 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
7001 return TargetOpcode::G_FMINNUM;
7002 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
7003 return TargetOpcode::G_FMINIMUM;
7004 if (isLegal(Query: {TargetOpcode::G_FMINNUM, {DstTy}}))
7005 return TargetOpcode::G_FMINNUM;
7006 if (!isLegal(Query: {TargetOpcode::G_FMINIMUM, {DstTy}}))
7007 return 0;
7008 return TargetOpcode::G_FMINIMUM;
7009 }
7010}
7011
7012CombinerHelper::SelectPatternNaNBehaviour
7013CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS,
7014 bool IsOrderedComparison) const {
7015 bool LHSSafe = isKnownNeverNaN(Val: LHS, MRI);
7016 bool RHSSafe = isKnownNeverNaN(Val: RHS, MRI);
7017 // Completely unsafe.
7018 if (!LHSSafe && !RHSSafe)
7019 return SelectPatternNaNBehaviour::NOT_APPLICABLE;
7020 if (LHSSafe && RHSSafe)
7021 return SelectPatternNaNBehaviour::RETURNS_ANY;
7022 // An ordered comparison will return false when given a NaN, so it
7023 // returns the RHS.
7024 if (IsOrderedComparison)
7025 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN
7026 : SelectPatternNaNBehaviour::RETURNS_OTHER;
7027 // An unordered comparison will return true when given a NaN, so it
7028 // returns the LHS.
7029 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER
7030 : SelectPatternNaNBehaviour::RETURNS_NAN;
7031}
7032
7033bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond,
7034 Register TrueVal, Register FalseVal,
7035 BuildFnTy &MatchInfo) const {
7036 // Match: select (fcmp cond x, y) x, y
7037 // select (fcmp cond x, y) y, x
7038 // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition.
7039 LLT DstTy = MRI.getType(Reg: Dst);
7040 // Bail out early on pointers, since we'll never want to fold to a min/max.
7041 if (DstTy.isPointer())
7042 return false;
7043 // Match a floating point compare with a less-than/greater-than predicate.
7044 // TODO: Allow multiple users of the compare if they are all selects.
7045 CmpInst::Predicate Pred;
7046 Register CmpLHS, CmpRHS;
7047 if (!mi_match(R: Cond, MRI,
7048 P: m_OneNonDBGUse(
7049 SP: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: CmpLHS), R: m_Reg(R&: CmpRHS)))) ||
7050 CmpInst::isEquality(pred: Pred))
7051 return false;
7052 SelectPatternNaNBehaviour ResWithKnownNaNInfo =
7053 computeRetValAgainstNaN(LHS: CmpLHS, RHS: CmpRHS, IsOrderedComparison: CmpInst::isOrdered(predicate: Pred));
7054 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE)
7055 return false;
7056 if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
7057 std::swap(a&: CmpLHS, b&: CmpRHS);
7058 Pred = CmpInst::getSwappedPredicate(pred: Pred);
7059 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN)
7060 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER;
7061 else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER)
7062 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN;
7063 }
7064 if (TrueVal != CmpLHS || FalseVal != CmpRHS)
7065 return false;
7066 // Decide what type of max/min this should be based off of the predicate.
7067 unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, VsNaNRetVal: ResWithKnownNaNInfo);
7068 if (!Opc || !isLegal(Query: {Opc, {DstTy}}))
7069 return false;
7070 // Comparisons between signed zero and zero may have different results...
7071 // unless we have fmaximum/fminimum. In that case, we know -0 < 0.
7072 if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) {
7073 // We don't know if a comparison between two 0s will give us a consistent
7074 // result. Be conservative and only proceed if at least one side is
7075 // non-zero.
7076 auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpLHS, MRI);
7077 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) {
7078 KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpRHS, MRI);
7079 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero())
7080 return false;
7081 }
7082 }
7083 MatchInfo = [=](MachineIRBuilder &B) {
7084 B.buildInstr(Opc, DstOps: {Dst}, SrcOps: {CmpLHS, CmpRHS});
7085 };
7086 return true;
7087}
7088
7089bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI,
7090 BuildFnTy &MatchInfo) const {
7091 // TODO: Handle integer cases.
7092 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
7093 // Condition may be fed by a truncated compare.
7094 Register Cond = MI.getOperand(i: 1).getReg();
7095 Register MaybeTrunc;
7096 if (mi_match(R: Cond, MRI, P: m_OneNonDBGUse(SP: m_GTrunc(Src: m_Reg(R&: MaybeTrunc)))))
7097 Cond = MaybeTrunc;
7098 Register Dst = MI.getOperand(i: 0).getReg();
7099 Register TrueVal = MI.getOperand(i: 2).getReg();
7100 Register FalseVal = MI.getOperand(i: 3).getReg();
7101 return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo);
7102}
7103
7104bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI,
7105 BuildFnTy &MatchInfo) const {
7106 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
7107 // (X + Y) == X --> Y == 0
7108 // (X + Y) != X --> Y != 0
7109 // (X - Y) == X --> Y == 0
7110 // (X - Y) != X --> Y != 0
7111 // (X ^ Y) == X --> Y == 0
7112 // (X ^ Y) != X --> Y != 0
7113 Register Dst = MI.getOperand(i: 0).getReg();
7114 CmpInst::Predicate Pred;
7115 Register X, Y, OpLHS, OpRHS;
7116 bool MatchedSub = mi_match(
7117 R: Dst, MRI,
7118 P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), R: m_GSub(L: m_Reg(R&: OpLHS), R: m_Reg(R&: Y))));
7119 if (MatchedSub && X != OpLHS)
7120 return false;
7121 if (!MatchedSub) {
7122 if (!mi_match(R: Dst, MRI,
7123 P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X),
7124 R: m_any_of(preds: m_GAdd(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)),
7125 preds: m_GXor(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS))))))
7126 return false;
7127 Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register();
7128 }
7129 MatchInfo = [=](MachineIRBuilder &B) {
7130 auto Zero = B.buildConstant(Res: MRI.getType(Reg: Y), Val: 0);
7131 B.buildICmp(Pred, Res: Dst, Op0: Y, Op1: Zero);
7132 };
7133 return CmpInst::isEquality(pred: Pred) && Y.isValid();
7134}
7135
7136/// Return the minimum useless shift amount that results in complete loss of the
7137/// source value. Return std::nullopt when it cannot determine a value.
7138static std::optional<unsigned>
7139getMinUselessShift(KnownBits ValueKB, unsigned Opcode,
7140 std::optional<int64_t> &Result) {
7141 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR ||
7142 Opcode == TargetOpcode::G_ASHR) &&
7143 "Expect G_SHL, G_LSHR or G_ASHR.");
7144 auto SignificantBits = 0;
7145 switch (Opcode) {
7146 case TargetOpcode::G_SHL:
7147 SignificantBits = ValueKB.countMinTrailingZeros();
7148 Result = 0;
7149 break;
7150 case TargetOpcode::G_LSHR:
7151 Result = 0;
7152 SignificantBits = ValueKB.countMinLeadingZeros();
7153 break;
7154 case TargetOpcode::G_ASHR:
7155 if (ValueKB.isNonNegative()) {
7156 SignificantBits = ValueKB.countMinLeadingZeros();
7157 Result = 0;
7158 } else if (ValueKB.isNegative()) {
7159 SignificantBits = ValueKB.countMinLeadingOnes();
7160 Result = -1;
7161 } else {
7162 // Cannot determine shift result.
7163 Result = std::nullopt;
7164 }
7165 break;
7166 default:
7167 break;
7168 }
7169 return ValueKB.getBitWidth() - SignificantBits;
7170}
7171
7172bool CombinerHelper::matchShiftsTooBig(
7173 MachineInstr &MI, std::optional<int64_t> &MatchInfo) const {
7174 Register ShiftVal = MI.getOperand(i: 1).getReg();
7175 Register ShiftReg = MI.getOperand(i: 2).getReg();
7176 LLT ResTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
7177 auto IsShiftTooBig = [&](const Constant *C) {
7178 auto *CI = dyn_cast<ConstantInt>(Val: C);
7179 if (!CI)
7180 return false;
7181 if (CI->uge(Num: ResTy.getScalarSizeInBits())) {
7182 MatchInfo = std::nullopt;
7183 return true;
7184 }
7185 auto OptMaxUsefulShift = getMinUselessShift(ValueKB: VT->getKnownBits(R: ShiftVal),
7186 Opcode: MI.getOpcode(), Result&: MatchInfo);
7187 return OptMaxUsefulShift && CI->uge(Num: *OptMaxUsefulShift);
7188 };
7189 return matchUnaryPredicate(MRI, Reg: ShiftReg, Match: IsShiftTooBig);
7190}
7191
7192bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) const {
7193 unsigned LHSOpndIdx = 1;
7194 unsigned RHSOpndIdx = 2;
7195 switch (MI.getOpcode()) {
7196 case TargetOpcode::G_UADDO:
7197 case TargetOpcode::G_SADDO:
7198 case TargetOpcode::G_UMULO:
7199 case TargetOpcode::G_SMULO:
7200 LHSOpndIdx = 2;
7201 RHSOpndIdx = 3;
7202 break;
7203 default:
7204 break;
7205 }
7206 Register LHS = MI.getOperand(i: LHSOpndIdx).getReg();
7207 Register RHS = MI.getOperand(i: RHSOpndIdx).getReg();
7208 if (!getIConstantVRegVal(VReg: LHS, MRI)) {
7209 // Skip commuting if LHS is not a constant. But, LHS may be a
7210 // G_CONSTANT_FOLD_BARRIER. If so we commute as long as we don't already
7211 // have a constant on the RHS.
7212 if (MRI.getVRegDef(Reg: LHS)->getOpcode() !=
7213 TargetOpcode::G_CONSTANT_FOLD_BARRIER)
7214 return false;
7215 }
7216 // Commute as long as RHS is not a constant or G_CONSTANT_FOLD_BARRIER.
7217 return MRI.getVRegDef(Reg: RHS)->getOpcode() !=
7218 TargetOpcode::G_CONSTANT_FOLD_BARRIER &&
7219 !getIConstantVRegVal(VReg: RHS, MRI);
7220}
7221
7222bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) const {
7223 Register LHS = MI.getOperand(i: 1).getReg();
7224 Register RHS = MI.getOperand(i: 2).getReg();
7225 std::optional<FPValueAndVReg> ValAndVReg;
7226 if (!mi_match(R: LHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg)))
7227 return false;
7228 return !mi_match(R: RHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg));
7229}
7230
7231void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) const {
7232 Observer.changingInstr(MI);
7233 unsigned LHSOpndIdx = 1;
7234 unsigned RHSOpndIdx = 2;
7235 switch (MI.getOpcode()) {
7236 case TargetOpcode::G_UADDO:
7237 case TargetOpcode::G_SADDO:
7238 case TargetOpcode::G_UMULO:
7239 case TargetOpcode::G_SMULO:
7240 LHSOpndIdx = 2;
7241 RHSOpndIdx = 3;
7242 break;
7243 default:
7244 break;
7245 }
7246 Register LHSReg = MI.getOperand(i: LHSOpndIdx).getReg();
7247 Register RHSReg = MI.getOperand(i: RHSOpndIdx).getReg();
7248 MI.getOperand(i: LHSOpndIdx).setReg(RHSReg);
7249 MI.getOperand(i: RHSOpndIdx).setReg(LHSReg);
7250 Observer.changedInstr(MI);
7251}
7252
7253bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) const {
7254 LLT SrcTy = MRI.getType(Reg: Src);
7255 if (SrcTy.isFixedVector())
7256 return isConstantSplatVector(Src, SplatValue: 1, AllowUndefs);
7257 if (SrcTy.isScalar()) {
7258 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr)
7259 return true;
7260 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7261 return IConstant && IConstant->Value == 1;
7262 }
7263 return false; // scalable vector
7264}
7265
7266bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) const {
7267 LLT SrcTy = MRI.getType(Reg: Src);
7268 if (SrcTy.isFixedVector())
7269 return isConstantSplatVector(Src, SplatValue: 0, AllowUndefs);
7270 if (SrcTy.isScalar()) {
7271 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr)
7272 return true;
7273 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7274 return IConstant && IConstant->Value == 0;
7275 }
7276 return false; // scalable vector
7277}
7278
7279// Ignores COPYs during conformance checks.
7280// FIXME scalable vectors.
7281bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
7282 bool AllowUndefs) const {
7283 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7284 if (!BuildVector)
7285 return false;
7286 unsigned NumSources = BuildVector->getNumSources();
7287
7288 for (unsigned I = 0; I < NumSources; ++I) {
7289 GImplicitDef *ImplicitDef =
7290 getOpcodeDef<GImplicitDef>(Reg: BuildVector->getSourceReg(I), MRI);
7291 if (ImplicitDef && AllowUndefs)
7292 continue;
7293 if (ImplicitDef && !AllowUndefs)
7294 return false;
7295 std::optional<ValueAndVReg> IConstant =
7296 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7297 if (IConstant && IConstant->Value == SplatValue)
7298 continue;
7299 return false;
7300 }
7301 return true;
7302}
7303
7304// Ignores COPYs during lookups.
7305// FIXME scalable vectors
7306std::optional<APInt>
7307CombinerHelper::getConstantOrConstantSplatVector(Register Src) const {
7308 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7309 if (IConstant)
7310 return IConstant->Value;
7311
7312 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7313 if (!BuildVector)
7314 return std::nullopt;
7315 unsigned NumSources = BuildVector->getNumSources();
7316
7317 std::optional<APInt> Value = std::nullopt;
7318 for (unsigned I = 0; I < NumSources; ++I) {
7319 std::optional<ValueAndVReg> IConstant =
7320 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7321 if (!IConstant)
7322 return std::nullopt;
7323 if (!Value)
7324 Value = IConstant->Value;
7325 else if (*Value != IConstant->Value)
7326 return std::nullopt;
7327 }
7328 return Value;
7329}
7330
7331// FIXME G_SPLAT_VECTOR
7332bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const {
7333 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7334 if (IConstant)
7335 return true;
7336
7337 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7338 if (!BuildVector)
7339 return false;
7340
7341 unsigned NumSources = BuildVector->getNumSources();
7342 for (unsigned I = 0; I < NumSources; ++I) {
7343 std::optional<ValueAndVReg> IConstant =
7344 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7345 if (!IConstant)
7346 return false;
7347 }
7348 return true;
7349}
7350
7351// TODO: use knownbits to determine zeros
7352bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
7353 BuildFnTy &MatchInfo) const {
7354 uint32_t Flags = Select->getFlags();
7355 Register Dest = Select->getReg(Idx: 0);
7356 Register Cond = Select->getCondReg();
7357 Register True = Select->getTrueReg();
7358 Register False = Select->getFalseReg();
7359 LLT CondTy = MRI.getType(Reg: Select->getCondReg());
7360 LLT TrueTy = MRI.getType(Reg: Select->getTrueReg());
7361
7362 // We only do this combine for scalar boolean conditions.
7363 if (CondTy != LLT::scalar(SizeInBits: 1))
7364 return false;
7365
7366 if (TrueTy.isPointer())
7367 return false;
7368
7369 // Both are scalars.
7370 std::optional<ValueAndVReg> TrueOpt =
7371 getIConstantVRegValWithLookThrough(VReg: True, MRI);
7372 std::optional<ValueAndVReg> FalseOpt =
7373 getIConstantVRegValWithLookThrough(VReg: False, MRI);
7374
7375 if (!TrueOpt || !FalseOpt)
7376 return false;
7377
7378 APInt TrueValue = TrueOpt->Value;
7379 APInt FalseValue = FalseOpt->Value;
7380
7381 // select Cond, 1, 0 --> zext (Cond)
7382 if (TrueValue.isOne() && FalseValue.isZero()) {
7383 MatchInfo = [=](MachineIRBuilder &B) {
7384 B.setInstrAndDebugLoc(*Select);
7385 B.buildZExtOrTrunc(Res: Dest, Op: Cond);
7386 };
7387 return true;
7388 }
7389
7390 // select Cond, -1, 0 --> sext (Cond)
7391 if (TrueValue.isAllOnes() && FalseValue.isZero()) {
7392 MatchInfo = [=](MachineIRBuilder &B) {
7393 B.setInstrAndDebugLoc(*Select);
7394 B.buildSExtOrTrunc(Res: Dest, Op: Cond);
7395 };
7396 return true;
7397 }
7398
7399 // select Cond, 0, 1 --> zext (!Cond)
7400 if (TrueValue.isZero() && FalseValue.isOne()) {
7401 MatchInfo = [=](MachineIRBuilder &B) {
7402 B.setInstrAndDebugLoc(*Select);
7403 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7404 B.buildNot(Dst: Inner, Src0: Cond);
7405 B.buildZExtOrTrunc(Res: Dest, Op: Inner);
7406 };
7407 return true;
7408 }
7409
7410 // select Cond, 0, -1 --> sext (!Cond)
7411 if (TrueValue.isZero() && FalseValue.isAllOnes()) {
7412 MatchInfo = [=](MachineIRBuilder &B) {
7413 B.setInstrAndDebugLoc(*Select);
7414 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7415 B.buildNot(Dst: Inner, Src0: Cond);
7416 B.buildSExtOrTrunc(Res: Dest, Op: Inner);
7417 };
7418 return true;
7419 }
7420
7421 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
7422 if (TrueValue - 1 == FalseValue) {
7423 MatchInfo = [=](MachineIRBuilder &B) {
7424 B.setInstrAndDebugLoc(*Select);
7425 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7426 B.buildZExtOrTrunc(Res: Inner, Op: Cond);
7427 B.buildAdd(Dst: Dest, Src0: Inner, Src1: False);
7428 };
7429 return true;
7430 }
7431
7432 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
7433 if (TrueValue + 1 == FalseValue) {
7434 MatchInfo = [=](MachineIRBuilder &B) {
7435 B.setInstrAndDebugLoc(*Select);
7436 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7437 B.buildSExtOrTrunc(Res: Inner, Op: Cond);
7438 B.buildAdd(Dst: Dest, Src0: Inner, Src1: False);
7439 };
7440 return true;
7441 }
7442
7443 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
7444 if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
7445 MatchInfo = [=](MachineIRBuilder &B) {
7446 B.setInstrAndDebugLoc(*Select);
7447 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7448 B.buildZExtOrTrunc(Res: Inner, Op: Cond);
7449 // The shift amount must be scalar.
7450 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
7451 auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: TrueValue.exactLogBase2());
7452 B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags);
7453 };
7454 return true;
7455 }
7456
7457 // select Cond, 0, Pow2 --> (zext (!Cond)) << log2(Pow2)
7458 if (FalseValue.isPowerOf2() && TrueValue.isZero()) {
7459 MatchInfo = [=](MachineIRBuilder &B) {
7460 B.setInstrAndDebugLoc(*Select);
7461 Register Not = MRI.createGenericVirtualRegister(Ty: CondTy);
7462 B.buildNot(Dst: Not, Src0: Cond);
7463 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7464 B.buildZExtOrTrunc(Res: Inner, Op: Not);
7465 // The shift amount must be scalar.
7466 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
7467 auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: FalseValue.exactLogBase2());
7468 B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags);
7469 };
7470 return true;
7471 }
7472
7473 // select Cond, -1, C --> or (sext Cond), C
7474 if (TrueValue.isAllOnes()) {
7475 MatchInfo = [=](MachineIRBuilder &B) {
7476 B.setInstrAndDebugLoc(*Select);
7477 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7478 B.buildSExtOrTrunc(Res: Inner, Op: Cond);
7479 B.buildOr(Dst: Dest, Src0: Inner, Src1: False, Flags);
7480 };
7481 return true;
7482 }
7483
7484 // select Cond, C, -1 --> or (sext (not Cond)), C
7485 if (FalseValue.isAllOnes()) {
7486 MatchInfo = [=](MachineIRBuilder &B) {
7487 B.setInstrAndDebugLoc(*Select);
7488 Register Not = MRI.createGenericVirtualRegister(Ty: CondTy);
7489 B.buildNot(Dst: Not, Src0: Cond);
7490 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7491 B.buildSExtOrTrunc(Res: Inner, Op: Not);
7492 B.buildOr(Dst: Dest, Src0: Inner, Src1: True, Flags);
7493 };
7494 return true;
7495 }
7496
7497 return false;
7498}
7499
7500// TODO: use knownbits to determine zeros
7501bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
7502 BuildFnTy &MatchInfo) const {
7503 uint32_t Flags = Select->getFlags();
7504 Register DstReg = Select->getReg(Idx: 0);
7505 Register Cond = Select->getCondReg();
7506 Register True = Select->getTrueReg();
7507 Register False = Select->getFalseReg();
7508 LLT CondTy = MRI.getType(Reg: Select->getCondReg());
7509 LLT TrueTy = MRI.getType(Reg: Select->getTrueReg());
7510
7511 // Boolean or fixed vector of booleans.
7512 if (CondTy.isScalableVector() ||
7513 (CondTy.isFixedVector() &&
7514 CondTy.getElementType().getScalarSizeInBits() != 1) ||
7515 CondTy.getScalarSizeInBits() != 1)
7516 return false;
7517
7518 if (CondTy != TrueTy)
7519 return false;
7520
7521 // select Cond, Cond, F --> or Cond, F
7522 // select Cond, 1, F --> or Cond, F
7523 if ((Cond == True) || isOneOrOneSplat(Src: True, /* AllowUndefs */ true)) {
7524 MatchInfo = [=](MachineIRBuilder &B) {
7525 B.setInstrAndDebugLoc(*Select);
7526 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7527 B.buildZExtOrTrunc(Res: Ext, Op: Cond);
7528 auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False);
7529 B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeFalse, Flags);
7530 };
7531 return true;
7532 }
7533
7534 // select Cond, T, Cond --> and Cond, T
7535 // select Cond, T, 0 --> and Cond, T
7536 if ((Cond == False) || isZeroOrZeroSplat(Src: False, /* AllowUndefs */ true)) {
7537 MatchInfo = [=](MachineIRBuilder &B) {
7538 B.setInstrAndDebugLoc(*Select);
7539 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7540 B.buildZExtOrTrunc(Res: Ext, Op: Cond);
7541 auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True);
7542 B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeTrue);
7543 };
7544 return true;
7545 }
7546
7547 // select Cond, T, 1 --> or (not Cond), T
7548 if (isOneOrOneSplat(Src: False, /* AllowUndefs */ true)) {
7549 MatchInfo = [=](MachineIRBuilder &B) {
7550 B.setInstrAndDebugLoc(*Select);
7551 // First the not.
7552 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7553 B.buildNot(Dst: Inner, Src0: Cond);
7554 // Then an ext to match the destination register.
7555 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7556 B.buildZExtOrTrunc(Res: Ext, Op: Inner);
7557 auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True);
7558 B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeTrue, Flags);
7559 };
7560 return true;
7561 }
7562
7563 // select Cond, 0, F --> and (not Cond), F
7564 if (isZeroOrZeroSplat(Src: True, /* AllowUndefs */ true)) {
7565 MatchInfo = [=](MachineIRBuilder &B) {
7566 B.setInstrAndDebugLoc(*Select);
7567 // First the not.
7568 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7569 B.buildNot(Dst: Inner, Src0: Cond);
7570 // Then an ext to match the destination register.
7571 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7572 B.buildZExtOrTrunc(Res: Ext, Op: Inner);
7573 auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False);
7574 B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeFalse);
7575 };
7576 return true;
7577 }
7578
7579 return false;
7580}
7581
7582bool CombinerHelper::matchSelectIMinMax(const MachineOperand &MO,
7583 BuildFnTy &MatchInfo) const {
7584 GSelect *Select = cast<GSelect>(Val: MRI.getVRegDef(Reg: MO.getReg()));
7585 GICmp *Cmp = cast<GICmp>(Val: MRI.getVRegDef(Reg: Select->getCondReg()));
7586
7587 Register DstReg = Select->getReg(Idx: 0);
7588 Register True = Select->getTrueReg();
7589 Register False = Select->getFalseReg();
7590 LLT DstTy = MRI.getType(Reg: DstReg);
7591
7592 if (DstTy.isPointerOrPointerVector())
7593 return false;
7594
7595 // We want to fold the icmp and replace the select.
7596 if (!MRI.hasOneNonDBGUse(RegNo: Cmp->getReg(Idx: 0)))
7597 return false;
7598
7599 CmpInst::Predicate Pred = Cmp->getCond();
7600 // We need a larger or smaller predicate for
7601 // canonicalization.
7602 if (CmpInst::isEquality(pred: Pred))
7603 return false;
7604
7605 Register CmpLHS = Cmp->getLHSReg();
7606 Register CmpRHS = Cmp->getRHSReg();
7607
7608 // We can swap CmpLHS and CmpRHS for higher hitrate.
7609 if (True == CmpRHS && False == CmpLHS) {
7610 std::swap(a&: CmpLHS, b&: CmpRHS);
7611 Pred = CmpInst::getSwappedPredicate(pred: Pred);
7612 }
7613
7614 // (icmp X, Y) ? X : Y -> integer minmax.
7615 // see matchSelectPattern in ValueTracking.
7616 // Legality between G_SELECT and integer minmax can differ.
7617 if (True != CmpLHS || False != CmpRHS)
7618 return false;
7619
7620 switch (Pred) {
7621 case ICmpInst::ICMP_UGT:
7622 case ICmpInst::ICMP_UGE: {
7623 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMAX, DstTy}))
7624 return false;
7625 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMax(Dst: DstReg, Src0: True, Src1: False); };
7626 return true;
7627 }
7628 case ICmpInst::ICMP_SGT:
7629 case ICmpInst::ICMP_SGE: {
7630 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMAX, DstTy}))
7631 return false;
7632 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMax(Dst: DstReg, Src0: True, Src1: False); };
7633 return true;
7634 }
7635 case ICmpInst::ICMP_ULT:
7636 case ICmpInst::ICMP_ULE: {
7637 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMIN, DstTy}))
7638 return false;
7639 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMin(Dst: DstReg, Src0: True, Src1: False); };
7640 return true;
7641 }
7642 case ICmpInst::ICMP_SLT:
7643 case ICmpInst::ICMP_SLE: {
7644 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMIN, DstTy}))
7645 return false;
7646 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMin(Dst: DstReg, Src0: True, Src1: False); };
7647 return true;
7648 }
7649 default:
7650 return false;
7651 }
7652}
7653
7654// (neg (min/max x, (neg x))) --> (max/min x, (neg x))
7655bool CombinerHelper::matchSimplifyNegMinMax(MachineInstr &MI,
7656 BuildFnTy &MatchInfo) const {
7657 assert(MI.getOpcode() == TargetOpcode::G_SUB);
7658 Register DestReg = MI.getOperand(i: 0).getReg();
7659 LLT DestTy = MRI.getType(Reg: DestReg);
7660
7661 Register X;
7662 Register Sub0;
7663 auto NegPattern = m_all_of(preds: m_Neg(Src: m_DeferredReg(R&: X)), preds: m_Reg(R&: Sub0));
7664 if (mi_match(R: DestReg, MRI,
7665 P: m_Neg(Src: m_OneUse(SP: m_any_of(preds: m_GSMin(L: m_Reg(R&: X), R: NegPattern),
7666 preds: m_GSMax(L: m_Reg(R&: X), R: NegPattern),
7667 preds: m_GUMin(L: m_Reg(R&: X), R: NegPattern),
7668 preds: m_GUMax(L: m_Reg(R&: X), R: NegPattern)))))) {
7669 MachineInstr *MinMaxMI = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
7670 unsigned NewOpc = getInverseGMinMaxOpcode(MinMaxOpc: MinMaxMI->getOpcode());
7671 if (isLegal(Query: {NewOpc, {DestTy}})) {
7672 MatchInfo = [=](MachineIRBuilder &B) {
7673 B.buildInstr(Opc: NewOpc, DstOps: {DestReg}, SrcOps: {X, Sub0});
7674 };
7675 return true;
7676 }
7677 }
7678
7679 return false;
7680}
7681
7682bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7683 GSelect *Select = cast<GSelect>(Val: &MI);
7684
7685 if (tryFoldSelectOfConstants(Select, MatchInfo))
7686 return true;
7687
7688 if (tryFoldBoolSelectToLogic(Select, MatchInfo))
7689 return true;
7690
7691 return false;
7692}
7693
7694/// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
7695/// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
7696/// into a single comparison using range-based reasoning.
7697/// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges.
7698bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(
7699 GLogicalBinOp *Logic, BuildFnTy &MatchInfo) const {
7700 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor");
7701 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
7702 Register DstReg = Logic->getReg(Idx: 0);
7703 Register LHS = Logic->getLHSReg();
7704 Register RHS = Logic->getRHSReg();
7705 unsigned Flags = Logic->getFlags();
7706
7707 // We need an G_ICMP on the LHS register.
7708 GICmp *Cmp1 = getOpcodeDef<GICmp>(Reg: LHS, MRI);
7709 if (!Cmp1)
7710 return false;
7711
7712 // We need an G_ICMP on the RHS register.
7713 GICmp *Cmp2 = getOpcodeDef<GICmp>(Reg: RHS, MRI);
7714 if (!Cmp2)
7715 return false;
7716
7717 // We want to fold the icmps.
7718 if (!MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) ||
7719 !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)))
7720 return false;
7721
7722 APInt C1;
7723 APInt C2;
7724 std::optional<ValueAndVReg> MaybeC1 =
7725 getIConstantVRegValWithLookThrough(VReg: Cmp1->getRHSReg(), MRI);
7726 if (!MaybeC1)
7727 return false;
7728 C1 = MaybeC1->Value;
7729
7730 std::optional<ValueAndVReg> MaybeC2 =
7731 getIConstantVRegValWithLookThrough(VReg: Cmp2->getRHSReg(), MRI);
7732 if (!MaybeC2)
7733 return false;
7734 C2 = MaybeC2->Value;
7735
7736 Register R1 = Cmp1->getLHSReg();
7737 Register R2 = Cmp2->getLHSReg();
7738 CmpInst::Predicate Pred1 = Cmp1->getCond();
7739 CmpInst::Predicate Pred2 = Cmp2->getCond();
7740 LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0));
7741 LLT CmpOperandTy = MRI.getType(Reg: R1);
7742
7743 if (CmpOperandTy.isPointer())
7744 return false;
7745
7746 // We build ands, adds, and constants of type CmpOperandTy.
7747 // They must be legal to build.
7748 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_AND, CmpOperandTy}) ||
7749 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, CmpOperandTy}) ||
7750 !isConstantLegalOrBeforeLegalizer(Ty: CmpOperandTy))
7751 return false;
7752
7753 // Look through add of a constant offset on R1, R2, or both operands. This
7754 // allows us to interpret the R + C' < C'' range idiom into a proper range.
7755 std::optional<APInt> Offset1;
7756 std::optional<APInt> Offset2;
7757 if (R1 != R2) {
7758 if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R1, MRI)) {
7759 std::optional<ValueAndVReg> MaybeOffset1 =
7760 getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI);
7761 if (MaybeOffset1) {
7762 R1 = Add->getLHSReg();
7763 Offset1 = MaybeOffset1->Value;
7764 }
7765 }
7766 if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R2, MRI)) {
7767 std::optional<ValueAndVReg> MaybeOffset2 =
7768 getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI);
7769 if (MaybeOffset2) {
7770 R2 = Add->getLHSReg();
7771 Offset2 = MaybeOffset2->Value;
7772 }
7773 }
7774 }
7775
7776 if (R1 != R2)
7777 return false;
7778
7779 // We calculate the icmp ranges including maybe offsets.
7780 ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
7781 Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred1) : Pred1, Other: C1);
7782 if (Offset1)
7783 CR1 = CR1.subtract(CI: *Offset1);
7784
7785 ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
7786 Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred2) : Pred2, Other: C2);
7787 if (Offset2)
7788 CR2 = CR2.subtract(CI: *Offset2);
7789
7790 bool CreateMask = false;
7791 APInt LowerDiff;
7792 std::optional<ConstantRange> CR = CR1.exactUnionWith(CR: CR2);
7793 if (!CR) {
7794 // We need non-wrapping ranges.
7795 if (CR1.isWrappedSet() || CR2.isWrappedSet())
7796 return false;
7797
7798 // Check whether we have equal-size ranges that only differ by one bit.
7799 // In that case we can apply a mask to map one range onto the other.
7800 LowerDiff = CR1.getLower() ^ CR2.getLower();
7801 APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1);
7802 APInt CR1Size = CR1.getUpper() - CR1.getLower();
7803 if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff ||
7804 CR1Size != CR2.getUpper() - CR2.getLower())
7805 return false;
7806
7807 CR = CR1.getLower().ult(RHS: CR2.getLower()) ? CR1 : CR2;
7808 CreateMask = true;
7809 }
7810
7811 if (IsAnd)
7812 CR = CR->inverse();
7813
7814 CmpInst::Predicate NewPred;
7815 APInt NewC, Offset;
7816 CR->getEquivalentICmp(Pred&: NewPred, RHS&: NewC, Offset);
7817
7818 // We take the result type of one of the original icmps, CmpTy, for
7819 // the to be build icmp. The operand type, CmpOperandTy, is used for
7820 // the other instructions and constants to be build. The types of
7821 // the parameters and output are the same for add and and. CmpTy
7822 // and the type of DstReg might differ. That is why we zext or trunc
7823 // the icmp into the destination register.
7824
7825 MatchInfo = [=](MachineIRBuilder &B) {
7826 if (CreateMask && Offset != 0) {
7827 auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff);
7828 auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask.
7829 auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset);
7830 auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: And, Src1: OffsetC, Flags);
7831 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7832 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon);
7833 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7834 } else if (CreateMask && Offset == 0) {
7835 auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff);
7836 auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask.
7837 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7838 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: And, Op1: NewCon);
7839 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7840 } else if (!CreateMask && Offset != 0) {
7841 auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset);
7842 auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: R1, Src1: OffsetC, Flags);
7843 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7844 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon);
7845 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7846 } else if (!CreateMask && Offset == 0) {
7847 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7848 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: R1, Op1: NewCon);
7849 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7850 } else {
7851 llvm_unreachable("unexpected configuration of CreateMask and Offset");
7852 }
7853 };
7854 return true;
7855}
7856
7857bool CombinerHelper::tryFoldLogicOfFCmps(GLogicalBinOp *Logic,
7858 BuildFnTy &MatchInfo) const {
7859 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpecte xor");
7860 Register DestReg = Logic->getReg(Idx: 0);
7861 Register LHS = Logic->getLHSReg();
7862 Register RHS = Logic->getRHSReg();
7863 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
7864
7865 // We need a compare on the LHS register.
7866 GFCmp *Cmp1 = getOpcodeDef<GFCmp>(Reg: LHS, MRI);
7867 if (!Cmp1)
7868 return false;
7869
7870 // We need a compare on the RHS register.
7871 GFCmp *Cmp2 = getOpcodeDef<GFCmp>(Reg: RHS, MRI);
7872 if (!Cmp2)
7873 return false;
7874
7875 LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0));
7876 LLT CmpOperandTy = MRI.getType(Reg: Cmp1->getLHSReg());
7877
7878 // We build one fcmp, want to fold the fcmps, replace the logic op,
7879 // and the fcmps must have the same shape.
7880 if (!isLegalOrBeforeLegalizer(
7881 Query: {TargetOpcode::G_FCMP, {CmpTy, CmpOperandTy}}) ||
7882 !MRI.hasOneNonDBGUse(RegNo: Logic->getReg(Idx: 0)) ||
7883 !MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) ||
7884 !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)) ||
7885 MRI.getType(Reg: Cmp1->getLHSReg()) != MRI.getType(Reg: Cmp2->getLHSReg()))
7886 return false;
7887
7888 CmpInst::Predicate PredL = Cmp1->getCond();
7889 CmpInst::Predicate PredR = Cmp2->getCond();
7890 Register LHS0 = Cmp1->getLHSReg();
7891 Register LHS1 = Cmp1->getRHSReg();
7892 Register RHS0 = Cmp2->getLHSReg();
7893 Register RHS1 = Cmp2->getRHSReg();
7894
7895 if (LHS0 == RHS1 && LHS1 == RHS0) {
7896 // Swap RHS operands to match LHS.
7897 PredR = CmpInst::getSwappedPredicate(pred: PredR);
7898 std::swap(a&: RHS0, b&: RHS1);
7899 }
7900
7901 if (LHS0 == RHS0 && LHS1 == RHS1) {
7902 // We determine the new predicate.
7903 unsigned CmpCodeL = getFCmpCode(CC: PredL);
7904 unsigned CmpCodeR = getFCmpCode(CC: PredR);
7905 unsigned NewPred = IsAnd ? CmpCodeL & CmpCodeR : CmpCodeL | CmpCodeR;
7906 unsigned Flags = Cmp1->getFlags() | Cmp2->getFlags();
7907 MatchInfo = [=](MachineIRBuilder &B) {
7908 // The fcmp predicates fill the lower part of the enum.
7909 FCmpInst::Predicate Pred = static_cast<FCmpInst::Predicate>(NewPred);
7910 if (Pred == FCmpInst::FCMP_FALSE &&
7911 isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) {
7912 auto False = B.buildConstant(Res: CmpTy, Val: 0);
7913 B.buildZExtOrTrunc(Res: DestReg, Op: False);
7914 } else if (Pred == FCmpInst::FCMP_TRUE &&
7915 isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) {
7916 auto True =
7917 B.buildConstant(Res: CmpTy, Val: getICmpTrueVal(TLI: getTargetLowering(),
7918 IsVector: CmpTy.isVector() /*isVector*/,
7919 IsFP: true /*isFP*/));
7920 B.buildZExtOrTrunc(Res: DestReg, Op: True);
7921 } else { // We take the predicate without predicate optimizations.
7922 auto Cmp = B.buildFCmp(Pred, Res: CmpTy, Op0: LHS0, Op1: LHS1, Flags);
7923 B.buildZExtOrTrunc(Res: DestReg, Op: Cmp);
7924 }
7925 };
7926 return true;
7927 }
7928
7929 return false;
7930}
7931
7932bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7933 GAnd *And = cast<GAnd>(Val: &MI);
7934
7935 if (tryFoldAndOrOrICmpsUsingRanges(Logic: And, MatchInfo))
7936 return true;
7937
7938 if (tryFoldLogicOfFCmps(Logic: And, MatchInfo))
7939 return true;
7940
7941 return false;
7942}
7943
7944bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7945 GOr *Or = cast<GOr>(Val: &MI);
7946
7947 if (tryFoldAndOrOrICmpsUsingRanges(Logic: Or, MatchInfo))
7948 return true;
7949
7950 if (tryFoldLogicOfFCmps(Logic: Or, MatchInfo))
7951 return true;
7952
7953 return false;
7954}
7955
7956bool CombinerHelper::matchAddOverflow(MachineInstr &MI,
7957 BuildFnTy &MatchInfo) const {
7958 GAddCarryOut *Add = cast<GAddCarryOut>(Val: &MI);
7959
7960 // Addo has no flags
7961 Register Dst = Add->getReg(Idx: 0);
7962 Register Carry = Add->getReg(Idx: 1);
7963 Register LHS = Add->getLHSReg();
7964 Register RHS = Add->getRHSReg();
7965 bool IsSigned = Add->isSigned();
7966 LLT DstTy = MRI.getType(Reg: Dst);
7967 LLT CarryTy = MRI.getType(Reg: Carry);
7968
7969 // Fold addo, if the carry is dead -> add, undef.
7970 if (MRI.use_nodbg_empty(RegNo: Carry) &&
7971 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}})) {
7972 MatchInfo = [=](MachineIRBuilder &B) {
7973 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
7974 B.buildUndef(Res: Carry);
7975 };
7976 return true;
7977 }
7978
7979 // Canonicalize constant to RHS.
7980 if (isConstantOrConstantVectorI(Src: LHS) && !isConstantOrConstantVectorI(Src: RHS)) {
7981 if (IsSigned) {
7982 MatchInfo = [=](MachineIRBuilder &B) {
7983 B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS);
7984 };
7985 return true;
7986 }
7987 // !IsSigned
7988 MatchInfo = [=](MachineIRBuilder &B) {
7989 B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS);
7990 };
7991 return true;
7992 }
7993
7994 std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(Src: LHS);
7995 std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(Src: RHS);
7996
7997 // Fold addo(c1, c2) -> c3, carry.
7998 if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(Ty: DstTy) &&
7999 isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) {
8000 bool Overflow;
8001 APInt Result = IsSigned ? MaybeLHS->sadd_ov(RHS: *MaybeRHS, Overflow)
8002 : MaybeLHS->uadd_ov(RHS: *MaybeRHS, Overflow);
8003 MatchInfo = [=](MachineIRBuilder &B) {
8004 B.buildConstant(Res: Dst, Val: Result);
8005 B.buildConstant(Res: Carry, Val: Overflow);
8006 };
8007 return true;
8008 }
8009
8010 // Fold (addo x, 0) -> x, no carry
8011 if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) {
8012 MatchInfo = [=](MachineIRBuilder &B) {
8013 B.buildCopy(Res: Dst, Op: LHS);
8014 B.buildConstant(Res: Carry, Val: 0);
8015 };
8016 return true;
8017 }
8018
8019 // Given 2 constant operands whose sum does not overflow:
8020 // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1
8021 // saddo (X +nsw C0), C1 -> saddo X, C0 + C1
8022 GAdd *AddLHS = getOpcodeDef<GAdd>(Reg: LHS, MRI);
8023 if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)) &&
8024 ((IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoSWrap)) ||
8025 (!IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoUWrap)))) {
8026 std::optional<APInt> MaybeAddRHS =
8027 getConstantOrConstantSplatVector(Src: AddLHS->getRHSReg());
8028 if (MaybeAddRHS) {
8029 bool Overflow;
8030 APInt NewC = IsSigned ? MaybeAddRHS->sadd_ov(RHS: *MaybeRHS, Overflow)
8031 : MaybeAddRHS->uadd_ov(RHS: *MaybeRHS, Overflow);
8032 if (!Overflow && isConstantLegalOrBeforeLegalizer(Ty: DstTy)) {
8033 if (IsSigned) {
8034 MatchInfo = [=](MachineIRBuilder &B) {
8035 auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC);
8036 B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS);
8037 };
8038 return true;
8039 }
8040 // !IsSigned
8041 MatchInfo = [=](MachineIRBuilder &B) {
8042 auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC);
8043 B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS);
8044 };
8045 return true;
8046 }
8047 }
8048 };
8049
8050 // We try to combine addo to non-overflowing add.
8051 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}}) ||
8052 !isConstantLegalOrBeforeLegalizer(Ty: CarryTy))
8053 return false;
8054
8055 // We try to combine uaddo to non-overflowing add.
8056 if (!IsSigned) {
8057 ConstantRange CRLHS =
8058 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS), /*IsSigned=*/false);
8059 ConstantRange CRRHS =
8060 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS), /*IsSigned=*/false);
8061
8062 switch (CRLHS.unsignedAddMayOverflow(Other: CRRHS)) {
8063 case ConstantRange::OverflowResult::MayOverflow:
8064 return false;
8065 case ConstantRange::OverflowResult::NeverOverflows: {
8066 MatchInfo = [=](MachineIRBuilder &B) {
8067 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap);
8068 B.buildConstant(Res: Carry, Val: 0);
8069 };
8070 return true;
8071 }
8072 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8073 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8074 MatchInfo = [=](MachineIRBuilder &B) {
8075 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
8076 B.buildConstant(Res: Carry, Val: 1);
8077 };
8078 return true;
8079 }
8080 }
8081 return false;
8082 }
8083
8084 // We try to combine saddo to non-overflowing add.
8085
8086 // If LHS and RHS each have at least two sign bits, then there is no signed
8087 // overflow.
8088 if (VT->computeNumSignBits(R: RHS) > 1 && VT->computeNumSignBits(R: LHS) > 1) {
8089 MatchInfo = [=](MachineIRBuilder &B) {
8090 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8091 B.buildConstant(Res: Carry, Val: 0);
8092 };
8093 return true;
8094 }
8095
8096 ConstantRange CRLHS =
8097 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS), /*IsSigned=*/true);
8098 ConstantRange CRRHS =
8099 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS), /*IsSigned=*/true);
8100
8101 switch (CRLHS.signedAddMayOverflow(Other: CRRHS)) {
8102 case ConstantRange::OverflowResult::MayOverflow:
8103 return false;
8104 case ConstantRange::OverflowResult::NeverOverflows: {
8105 MatchInfo = [=](MachineIRBuilder &B) {
8106 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8107 B.buildConstant(Res: Carry, Val: 0);
8108 };
8109 return true;
8110 }
8111 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8112 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8113 MatchInfo = [=](MachineIRBuilder &B) {
8114 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
8115 B.buildConstant(Res: Carry, Val: 1);
8116 };
8117 return true;
8118 }
8119 }
8120
8121 return false;
8122}
8123
8124void CombinerHelper::applyBuildFnMO(const MachineOperand &MO,
8125 BuildFnTy &MatchInfo) const {
8126 MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI);
8127 MatchInfo(Builder);
8128 Root->eraseFromParent();
8129}
8130
8131bool CombinerHelper::matchFPowIExpansion(MachineInstr &MI,
8132 int64_t Exponent) const {
8133 bool OptForSize = MI.getMF()->getFunction().hasOptSize();
8134 return getTargetLowering().isBeneficialToExpandPowI(Exponent, OptForSize);
8135}
8136
8137void CombinerHelper::applyExpandFPowI(MachineInstr &MI,
8138 int64_t Exponent) const {
8139 auto [Dst, Base] = MI.getFirst2Regs();
8140 LLT Ty = MRI.getType(Reg: Dst);
8141 int64_t ExpVal = Exponent;
8142
8143 if (ExpVal == 0) {
8144 Builder.buildFConstant(Res: Dst, Val: 1.0);
8145 MI.removeFromParent();
8146 return;
8147 }
8148
8149 if (ExpVal < 0)
8150 ExpVal = -ExpVal;
8151
8152 // We use the simple binary decomposition method from SelectionDAG ExpandPowI
8153 // to generate the multiply sequence. There are more optimal ways to do this
8154 // (for example, powi(x,15) generates one more multiply than it should), but
8155 // this has the benefit of being both really simple and much better than a
8156 // libcall.
8157 std::optional<SrcOp> Res;
8158 SrcOp CurSquare = Base;
8159 while (ExpVal > 0) {
8160 if (ExpVal & 1) {
8161 if (!Res)
8162 Res = CurSquare;
8163 else
8164 Res = Builder.buildFMul(Dst: Ty, Src0: *Res, Src1: CurSquare);
8165 }
8166
8167 CurSquare = Builder.buildFMul(Dst: Ty, Src0: CurSquare, Src1: CurSquare);
8168 ExpVal >>= 1;
8169 }
8170
8171 // If the original exponent was negative, invert the result, producing
8172 // 1/(x*x*x).
8173 if (Exponent < 0)
8174 Res = Builder.buildFDiv(Dst: Ty, Src0: Builder.buildFConstant(Res: Ty, Val: 1.0), Src1: *Res,
8175 Flags: MI.getFlags());
8176
8177 Builder.buildCopy(Res: Dst, Op: *Res);
8178 MI.eraseFromParent();
8179}
8180
8181bool CombinerHelper::matchFoldAPlusC1MinusC2(const MachineInstr &MI,
8182 BuildFnTy &MatchInfo) const {
8183 // fold (A+C1)-C2 -> A+(C1-C2)
8184 const GSub *Sub = cast<GSub>(Val: &MI);
8185 GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: Sub->getLHSReg()));
8186
8187 if (!MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)))
8188 return false;
8189
8190 APInt C2 = getIConstantFromReg(VReg: Sub->getRHSReg(), MRI);
8191 APInt C1 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8192
8193 Register Dst = Sub->getReg(Idx: 0);
8194 LLT DstTy = MRI.getType(Reg: Dst);
8195
8196 MatchInfo = [=](MachineIRBuilder &B) {
8197 auto Const = B.buildConstant(Res: DstTy, Val: C1 - C2);
8198 B.buildAdd(Dst, Src0: Add->getLHSReg(), Src1: Const);
8199 };
8200
8201 return true;
8202}
8203
8204bool CombinerHelper::matchFoldC2MinusAPlusC1(const MachineInstr &MI,
8205 BuildFnTy &MatchInfo) const {
8206 // fold C2-(A+C1) -> (C2-C1)-A
8207 const GSub *Sub = cast<GSub>(Val: &MI);
8208 GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: Sub->getRHSReg()));
8209
8210 if (!MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)))
8211 return false;
8212
8213 APInt C2 = getIConstantFromReg(VReg: Sub->getLHSReg(), MRI);
8214 APInt C1 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8215
8216 Register Dst = Sub->getReg(Idx: 0);
8217 LLT DstTy = MRI.getType(Reg: Dst);
8218
8219 MatchInfo = [=](MachineIRBuilder &B) {
8220 auto Const = B.buildConstant(Res: DstTy, Val: C2 - C1);
8221 B.buildSub(Dst, Src0: Const, Src1: Add->getLHSReg());
8222 };
8223
8224 return true;
8225}
8226
8227bool CombinerHelper::matchFoldAMinusC1MinusC2(const MachineInstr &MI,
8228 BuildFnTy &MatchInfo) const {
8229 // fold (A-C1)-C2 -> A-(C1+C2)
8230 const GSub *Sub1 = cast<GSub>(Val: &MI);
8231 GSub *Sub2 = cast<GSub>(Val: MRI.getVRegDef(Reg: Sub1->getLHSReg()));
8232
8233 if (!MRI.hasOneNonDBGUse(RegNo: Sub2->getReg(Idx: 0)))
8234 return false;
8235
8236 APInt C2 = getIConstantFromReg(VReg: Sub1->getRHSReg(), MRI);
8237 APInt C1 = getIConstantFromReg(VReg: Sub2->getRHSReg(), MRI);
8238
8239 Register Dst = Sub1->getReg(Idx: 0);
8240 LLT DstTy = MRI.getType(Reg: Dst);
8241
8242 MatchInfo = [=](MachineIRBuilder &B) {
8243 auto Const = B.buildConstant(Res: DstTy, Val: C1 + C2);
8244 B.buildSub(Dst, Src0: Sub2->getLHSReg(), Src1: Const);
8245 };
8246
8247 return true;
8248}
8249
8250bool CombinerHelper::matchFoldC1Minus2MinusC2(const MachineInstr &MI,
8251 BuildFnTy &MatchInfo) const {
8252 // fold (C1-A)-C2 -> (C1-C2)-A
8253 const GSub *Sub1 = cast<GSub>(Val: &MI);
8254 GSub *Sub2 = cast<GSub>(Val: MRI.getVRegDef(Reg: Sub1->getLHSReg()));
8255
8256 if (!MRI.hasOneNonDBGUse(RegNo: Sub2->getReg(Idx: 0)))
8257 return false;
8258
8259 APInt C2 = getIConstantFromReg(VReg: Sub1->getRHSReg(), MRI);
8260 APInt C1 = getIConstantFromReg(VReg: Sub2->getLHSReg(), MRI);
8261
8262 Register Dst = Sub1->getReg(Idx: 0);
8263 LLT DstTy = MRI.getType(Reg: Dst);
8264
8265 MatchInfo = [=](MachineIRBuilder &B) {
8266 auto Const = B.buildConstant(Res: DstTy, Val: C1 - C2);
8267 B.buildSub(Dst, Src0: Const, Src1: Sub2->getRHSReg());
8268 };
8269
8270 return true;
8271}
8272
8273bool CombinerHelper::matchFoldAMinusC1PlusC2(const MachineInstr &MI,
8274 BuildFnTy &MatchInfo) const {
8275 // fold ((A-C1)+C2) -> (A+(C2-C1))
8276 const GAdd *Add = cast<GAdd>(Val: &MI);
8277 GSub *Sub = cast<GSub>(Val: MRI.getVRegDef(Reg: Add->getLHSReg()));
8278
8279 if (!MRI.hasOneNonDBGUse(RegNo: Sub->getReg(Idx: 0)))
8280 return false;
8281
8282 APInt C2 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8283 APInt C1 = getIConstantFromReg(VReg: Sub->getRHSReg(), MRI);
8284
8285 Register Dst = Add->getReg(Idx: 0);
8286 LLT DstTy = MRI.getType(Reg: Dst);
8287
8288 MatchInfo = [=](MachineIRBuilder &B) {
8289 auto Const = B.buildConstant(Res: DstTy, Val: C2 - C1);
8290 B.buildAdd(Dst, Src0: Sub->getLHSReg(), Src1: Const);
8291 };
8292
8293 return true;
8294}
8295
8296bool CombinerHelper::matchUnmergeValuesAnyExtBuildVector(
8297 const MachineInstr &MI, BuildFnTy &MatchInfo) const {
8298 const GUnmerge *Unmerge = cast<GUnmerge>(Val: &MI);
8299
8300 if (!MRI.hasOneNonDBGUse(RegNo: Unmerge->getSourceReg()))
8301 return false;
8302
8303 const MachineInstr *Source = MRI.getVRegDef(Reg: Unmerge->getSourceReg());
8304
8305 LLT DstTy = MRI.getType(Reg: Unmerge->getReg(Idx: 0));
8306
8307 // $bv:_(<8 x s8>) = G_BUILD_VECTOR ....
8308 // $any:_(<8 x s16>) = G_ANYEXT $bv
8309 // $uv:_(<4 x s16>), $uv1:_(<4 x s16>) = G_UNMERGE_VALUES $any
8310 //
8311 // ->
8312 //
8313 // $any:_(s16) = G_ANYEXT $bv[0]
8314 // $any1:_(s16) = G_ANYEXT $bv[1]
8315 // $any2:_(s16) = G_ANYEXT $bv[2]
8316 // $any3:_(s16) = G_ANYEXT $bv[3]
8317 // $any4:_(s16) = G_ANYEXT $bv[4]
8318 // $any5:_(s16) = G_ANYEXT $bv[5]
8319 // $any6:_(s16) = G_ANYEXT $bv[6]
8320 // $any7:_(s16) = G_ANYEXT $bv[7]
8321 // $uv:_(<4 x s16>) = G_BUILD_VECTOR $any, $any1, $any2, $any3
8322 // $uv1:_(<4 x s16>) = G_BUILD_VECTOR $any4, $any5, $any6, $any7
8323
8324 // We want to unmerge into vectors.
8325 if (!DstTy.isFixedVector())
8326 return false;
8327
8328 const GAnyExt *Any = dyn_cast<GAnyExt>(Val: Source);
8329 if (!Any)
8330 return false;
8331
8332 const MachineInstr *NextSource = MRI.getVRegDef(Reg: Any->getSrcReg());
8333
8334 if (const GBuildVector *BV = dyn_cast<GBuildVector>(Val: NextSource)) {
8335 // G_UNMERGE_VALUES G_ANYEXT G_BUILD_VECTOR
8336
8337 if (!MRI.hasOneNonDBGUse(RegNo: BV->getReg(Idx: 0)))
8338 return false;
8339
8340 // FIXME: check element types?
8341 if (BV->getNumSources() % Unmerge->getNumDefs() != 0)
8342 return false;
8343
8344 LLT BigBvTy = MRI.getType(Reg: BV->getReg(Idx: 0));
8345 LLT SmallBvTy = DstTy;
8346 LLT SmallBvElemenTy = SmallBvTy.getElementType();
8347
8348 if (!isLegalOrBeforeLegalizer(
8349 Query: {TargetOpcode::G_BUILD_VECTOR, {SmallBvTy, SmallBvElemenTy}}))
8350 return false;
8351
8352 // We check the legality of scalar anyext.
8353 if (!isLegalOrBeforeLegalizer(
8354 Query: {TargetOpcode::G_ANYEXT,
8355 {SmallBvElemenTy, BigBvTy.getElementType()}}))
8356 return false;
8357
8358 MatchInfo = [=](MachineIRBuilder &B) {
8359 // Build into each G_UNMERGE_VALUES def
8360 // a small build vector with anyext from the source build vector.
8361 for (unsigned I = 0; I < Unmerge->getNumDefs(); ++I) {
8362 SmallVector<Register> Ops;
8363 for (unsigned J = 0; J < SmallBvTy.getNumElements(); ++J) {
8364 Register SourceArray =
8365 BV->getSourceReg(I: I * SmallBvTy.getNumElements() + J);
8366 auto AnyExt = B.buildAnyExt(Res: SmallBvElemenTy, Op: SourceArray);
8367 Ops.push_back(Elt: AnyExt.getReg(Idx: 0));
8368 }
8369 B.buildBuildVector(Res: Unmerge->getOperand(i: I).getReg(), Ops);
8370 };
8371 };
8372 return true;
8373 };
8374
8375 return false;
8376}
8377
8378bool CombinerHelper::matchShuffleUndefRHS(MachineInstr &MI,
8379 BuildFnTy &MatchInfo) const {
8380
8381 bool Changed = false;
8382 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
8383 ArrayRef<int> OrigMask = Shuffle.getMask();
8384 SmallVector<int, 16> NewMask;
8385 const LLT SrcTy = MRI.getType(Reg: Shuffle.getSrc1Reg());
8386 const unsigned NumSrcElems = SrcTy.isVector() ? SrcTy.getNumElements() : 1;
8387 const unsigned NumDstElts = OrigMask.size();
8388 for (unsigned i = 0; i != NumDstElts; ++i) {
8389 int Idx = OrigMask[i];
8390 if (Idx >= (int)NumSrcElems) {
8391 Idx = -1;
8392 Changed = true;
8393 }
8394 NewMask.push_back(Elt: Idx);
8395 }
8396
8397 if (!Changed)
8398 return false;
8399
8400 MatchInfo = [&, NewMask = std::move(NewMask)](MachineIRBuilder &B) {
8401 B.buildShuffleVector(Res: MI.getOperand(i: 0), Src1: MI.getOperand(i: 1), Src2: MI.getOperand(i: 2),
8402 Mask: std::move(NewMask));
8403 };
8404
8405 return true;
8406}
8407
8408static void commuteMask(MutableArrayRef<int> Mask, const unsigned NumElems) {
8409 const unsigned MaskSize = Mask.size();
8410 for (unsigned I = 0; I < MaskSize; ++I) {
8411 int Idx = Mask[I];
8412 if (Idx < 0)
8413 continue;
8414
8415 if (Idx < (int)NumElems)
8416 Mask[I] = Idx + NumElems;
8417 else
8418 Mask[I] = Idx - NumElems;
8419 }
8420}
8421
8422bool CombinerHelper::matchShuffleDisjointMask(MachineInstr &MI,
8423 BuildFnTy &MatchInfo) const {
8424
8425 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
8426 // If any of the two inputs is already undef, don't check the mask again to
8427 // prevent infinite loop
8428 if (getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: Shuffle.getSrc1Reg(), MRI))
8429 return false;
8430
8431 if (getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: Shuffle.getSrc2Reg(), MRI))
8432 return false;
8433
8434 const LLT DstTy = MRI.getType(Reg: Shuffle.getReg(Idx: 0));
8435 const LLT Src1Ty = MRI.getType(Reg: Shuffle.getSrc1Reg());
8436 if (!isLegalOrBeforeLegalizer(
8437 Query: {TargetOpcode::G_SHUFFLE_VECTOR, {DstTy, Src1Ty}}))
8438 return false;
8439
8440 ArrayRef<int> Mask = Shuffle.getMask();
8441 const unsigned NumSrcElems = Src1Ty.getNumElements();
8442
8443 bool TouchesSrc1 = false;
8444 bool TouchesSrc2 = false;
8445 const unsigned NumElems = Mask.size();
8446 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
8447 if (Mask[Idx] < 0)
8448 continue;
8449
8450 if (Mask[Idx] < (int)NumSrcElems)
8451 TouchesSrc1 = true;
8452 else
8453 TouchesSrc2 = true;
8454 }
8455
8456 if (TouchesSrc1 == TouchesSrc2)
8457 return false;
8458
8459 Register NewSrc1 = Shuffle.getSrc1Reg();
8460 SmallVector<int, 16> NewMask(Mask);
8461 if (TouchesSrc2) {
8462 NewSrc1 = Shuffle.getSrc2Reg();
8463 commuteMask(Mask: NewMask, NumElems: NumSrcElems);
8464 }
8465
8466 MatchInfo = [=, &Shuffle](MachineIRBuilder &B) {
8467 auto Undef = B.buildUndef(Res: Src1Ty);
8468 B.buildShuffleVector(Res: Shuffle.getReg(Idx: 0), Src1: NewSrc1, Src2: Undef, Mask: NewMask);
8469 };
8470
8471 return true;
8472}
8473
8474bool CombinerHelper::matchSuboCarryOut(const MachineInstr &MI,
8475 BuildFnTy &MatchInfo) const {
8476 const GSubCarryOut *Subo = cast<GSubCarryOut>(Val: &MI);
8477
8478 Register Dst = Subo->getReg(Idx: 0);
8479 Register LHS = Subo->getLHSReg();
8480 Register RHS = Subo->getRHSReg();
8481 Register Carry = Subo->getCarryOutReg();
8482 LLT DstTy = MRI.getType(Reg: Dst);
8483 LLT CarryTy = MRI.getType(Reg: Carry);
8484
8485 // Check legality before known bits.
8486 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy}}) ||
8487 !isConstantLegalOrBeforeLegalizer(Ty: CarryTy))
8488 return false;
8489
8490 ConstantRange KBLHS =
8491 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS),
8492 /* IsSigned=*/Subo->isSigned());
8493 ConstantRange KBRHS =
8494 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS),
8495 /* IsSigned=*/Subo->isSigned());
8496
8497 if (Subo->isSigned()) {
8498 // G_SSUBO
8499 switch (KBLHS.signedSubMayOverflow(Other: KBRHS)) {
8500 case ConstantRange::OverflowResult::MayOverflow:
8501 return false;
8502 case ConstantRange::OverflowResult::NeverOverflows: {
8503 MatchInfo = [=](MachineIRBuilder &B) {
8504 B.buildSub(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8505 B.buildConstant(Res: Carry, Val: 0);
8506 };
8507 return true;
8508 }
8509 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8510 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8511 MatchInfo = [=](MachineIRBuilder &B) {
8512 B.buildSub(Dst, Src0: LHS, Src1: RHS);
8513 B.buildConstant(Res: Carry, Val: getICmpTrueVal(TLI: getTargetLowering(),
8514 /*isVector=*/IsVector: CarryTy.isVector(),
8515 /*isFP=*/IsFP: false));
8516 };
8517 return true;
8518 }
8519 }
8520 return false;
8521 }
8522
8523 // G_USUBO
8524 switch (KBLHS.unsignedSubMayOverflow(Other: KBRHS)) {
8525 case ConstantRange::OverflowResult::MayOverflow:
8526 return false;
8527 case ConstantRange::OverflowResult::NeverOverflows: {
8528 MatchInfo = [=](MachineIRBuilder &B) {
8529 B.buildSub(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap);
8530 B.buildConstant(Res: Carry, Val: 0);
8531 };
8532 return true;
8533 }
8534 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8535 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8536 MatchInfo = [=](MachineIRBuilder &B) {
8537 B.buildSub(Dst, Src0: LHS, Src1: RHS);
8538 B.buildConstant(Res: Carry, Val: getICmpTrueVal(TLI: getTargetLowering(),
8539 /*isVector=*/IsVector: CarryTy.isVector(),
8540 /*isFP=*/IsFP: false));
8541 };
8542 return true;
8543 }
8544 }
8545
8546 return false;
8547}
8548
8549// Fold (ctlz (xor x, (sra x, bitwidth-1))) -> (add (ctls x), 1).
8550// Fold (ctlz (or (shl (xor x, (sra x, bitwidth-1)), 1), 1) -> (ctls x)
8551bool CombinerHelper::matchCtls(MachineInstr &CtlzMI,
8552 BuildFnTy &MatchInfo) const {
8553 assert((CtlzMI.getOpcode() == TargetOpcode::G_CTLZ ||
8554 CtlzMI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) &&
8555 "Expected G_CTLZ variant");
8556
8557 const Register Dst = CtlzMI.getOperand(i: 0).getReg();
8558 Register Src = CtlzMI.getOperand(i: 1).getReg();
8559
8560 LLT Ty = MRI.getType(Reg: Dst);
8561 LLT SrcTy = MRI.getType(Reg: Src);
8562
8563 if (!(Ty.isValid() && Ty.isScalar()))
8564 return false;
8565
8566 if (!LI)
8567 return false;
8568
8569 SmallVector<LLT, 2> QueryTypes = {Ty, SrcTy};
8570 LegalityQuery Query(TargetOpcode::G_CTLS, QueryTypes);
8571
8572 switch (LI->getAction(Query).Action) {
8573 default:
8574 return false;
8575 case LegalizeActions::Legal:
8576 case LegalizeActions::Custom:
8577 case LegalizeActions::WidenScalar:
8578 break;
8579 }
8580
8581 // Src = or(shl(V, 1), 1) -> Src=V; NeedAdd = False
8582 Register V;
8583 bool NeedAdd = true;
8584 if (mi_match(R: Src, MRI,
8585 P: m_OneUse(SP: m_GOr(L: m_OneUse(SP: m_GShl(L: m_Reg(R&: V), R: m_SpecificICst(RequestedValue: 1))),
8586 R: m_SpecificICst(RequestedValue: 1))))) {
8587 NeedAdd = false;
8588 Src = V;
8589 }
8590
8591 unsigned BitWidth = Ty.getScalarSizeInBits();
8592
8593 Register X;
8594 if (!mi_match(R: Src, MRI,
8595 P: m_OneUse(SP: m_GXor(L: m_Reg(R&: X), R: m_OneUse(SP: m_GAShr(
8596 L: m_DeferredReg(R&: X),
8597 R: m_SpecificICst(RequestedValue: BitWidth - 1)))))))
8598 return false;
8599
8600 MatchInfo = [=](MachineIRBuilder &B) {
8601 if (!NeedAdd) {
8602 B.buildCTLS(Dst, Src0: X);
8603 return;
8604 }
8605
8606 auto Ctls = B.buildCTLS(Dst: Ty, Src0: X);
8607 auto One = B.buildConstant(Res: Ty, Val: 1);
8608
8609 B.buildAdd(Dst, Src0: Ctls, Src1: One);
8610 };
8611
8612 return true;
8613}
8614