1//===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
9#include "llvm/ADT/APFloat.h"
10#include "llvm/ADT/STLExtras.h"
11#include "llvm/ADT/SetVector.h"
12#include "llvm/ADT/SmallBitVector.h"
13#include "llvm/Analysis/CmpInstAnalysis.h"
14#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
15#include "llvm/CodeGen/GlobalISel/GISelValueTracking.h"
16#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
17#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
19#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
20#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
21#include "llvm/CodeGen/GlobalISel/Utils.h"
22#include "llvm/CodeGen/LowLevelTypeUtils.h"
23#include "llvm/CodeGen/MachineBasicBlock.h"
24#include "llvm/CodeGen/MachineDominators.h"
25#include "llvm/CodeGen/MachineInstr.h"
26#include "llvm/CodeGen/MachineMemOperand.h"
27#include "llvm/CodeGen/MachineRegisterInfo.h"
28#include "llvm/CodeGen/Register.h"
29#include "llvm/CodeGen/RegisterBankInfo.h"
30#include "llvm/CodeGen/TargetInstrInfo.h"
31#include "llvm/CodeGen/TargetLowering.h"
32#include "llvm/CodeGen/TargetOpcodes.h"
33#include "llvm/IR/ConstantRange.h"
34#include "llvm/IR/DataLayout.h"
35#include "llvm/IR/InstrTypes.h"
36#include "llvm/IR/PatternMatch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/DivisionByConstantInfo.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/MathExtras.h"
41#include "llvm/Target/TargetMachine.h"
42#include <cmath>
43#include <optional>
44#include <tuple>
45
46#define DEBUG_TYPE "gi-combiner"
47
48using namespace llvm;
49using namespace MIPatternMatch;
50
51// Option to allow testing of the combiner while no targets know about indexed
52// addressing.
53static cl::opt<bool>
54 ForceLegalIndexing("force-legal-indexing", cl::Hidden, cl::init(Val: false),
55 cl::desc("Force all indexed operations to be "
56 "legal for the GlobalISel combiner"));
57
58CombinerHelper::CombinerHelper(GISelChangeObserver &Observer,
59 MachineIRBuilder &B, bool IsPreLegalize,
60 GISelValueTracking *VT,
61 MachineDominatorTree *MDT,
62 const LegalizerInfo *LI)
63 : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), VT(VT),
64 MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI),
65 RBI(Builder.getMF().getSubtarget().getRegBankInfo()),
66 TRI(Builder.getMF().getSubtarget().getRegisterInfo()) {
67 (void)this->VT;
68}
69
70const TargetLowering &CombinerHelper::getTargetLowering() const {
71 return *Builder.getMF().getSubtarget().getTargetLowering();
72}
73
74const MachineFunction &CombinerHelper::getMachineFunction() const {
75 return Builder.getMF();
76}
77
78const DataLayout &CombinerHelper::getDataLayout() const {
79 return getMachineFunction().getDataLayout();
80}
81
82LLVMContext &CombinerHelper::getContext() const { return Builder.getContext(); }
83
84/// \returns The little endian in-memory byte position of byte \p I in a
85/// \p ByteWidth bytes wide type.
86///
87/// E.g. Given a 4-byte type x, x[0] -> byte 0
88static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) {
89 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
90 return I;
91}
92
93/// Determines the LogBase2 value for a non-null input value using the
94/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
95static Register buildLogBase2(Register V, MachineIRBuilder &MIB) {
96 auto &MRI = *MIB.getMRI();
97 LLT Ty = MRI.getType(Reg: V);
98 auto Ctlz = MIB.buildCTLZ(Dst: Ty, Src0: V);
99 auto Base = MIB.buildConstant(Res: Ty, Val: Ty.getScalarSizeInBits() - 1);
100 return MIB.buildSub(Dst: Ty, Src0: Base, Src1: Ctlz).getReg(Idx: 0);
101}
102
103/// \returns The big endian in-memory byte position of byte \p I in a
104/// \p ByteWidth bytes wide type.
105///
106/// E.g. Given a 4-byte type x, x[0] -> byte 3
107static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) {
108 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
109 return ByteWidth - I - 1;
110}
111
112/// Given a map from byte offsets in memory to indices in a load/store,
113/// determine if that map corresponds to a little or big endian byte pattern.
114///
115/// \param MemOffset2Idx maps memory offsets to address offsets.
116/// \param LowestIdx is the lowest index in \p MemOffset2Idx.
117///
118/// \returns true if the map corresponds to a big endian byte pattern, false if
119/// it corresponds to a little endian byte pattern, and std::nullopt otherwise.
120///
121/// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns
122/// are as follows:
123///
124/// AddrOffset Little endian Big endian
125/// 0 0 3
126/// 1 1 2
127/// 2 2 1
128/// 3 3 0
129static std::optional<bool>
130isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
131 int64_t LowestIdx) {
132 // Need at least two byte positions to decide on endianness.
133 unsigned Width = MemOffset2Idx.size();
134 if (Width < 2)
135 return std::nullopt;
136 bool BigEndian = true, LittleEndian = true;
137 for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) {
138 auto MemOffsetAndIdx = MemOffset2Idx.find(Val: MemOffset);
139 if (MemOffsetAndIdx == MemOffset2Idx.end())
140 return std::nullopt;
141 const int64_t Idx = MemOffsetAndIdx->second - LowestIdx;
142 assert(Idx >= 0 && "Expected non-negative byte offset?");
143 LittleEndian &= Idx == littleEndianByteAt(ByteWidth: Width, I: MemOffset);
144 BigEndian &= Idx == bigEndianByteAt(ByteWidth: Width, I: MemOffset);
145 if (!BigEndian && !LittleEndian)
146 return std::nullopt;
147 }
148
149 assert((BigEndian != LittleEndian) &&
150 "Pattern cannot be both big and little endian!");
151 return BigEndian;
152}
153
154bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; }
155
156bool CombinerHelper::isLegal(const LegalityQuery &Query) const {
157 assert(LI && "Must have LegalizerInfo to query isLegal!");
158 return LI->getAction(Query).Action == LegalizeActions::Legal;
159}
160
161bool CombinerHelper::isLegalOrBeforeLegalizer(
162 const LegalityQuery &Query) const {
163 return isPreLegalize() || isLegal(Query);
164}
165
166bool CombinerHelper::isLegalOrHasWidenScalar(const LegalityQuery &Query) const {
167 return isLegal(Query) ||
168 LI->getAction(Query).Action == LegalizeActions::WidenScalar;
169}
170
171bool CombinerHelper::isLegalOrHasFewerElements(
172 const LegalityQuery &Query) const {
173 LegalizeAction Action = LI->getAction(Query).Action;
174 return Action == LegalizeActions::Legal ||
175 Action == LegalizeActions::FewerElements;
176}
177
178bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
179 if (!Ty.isVector())
180 return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CONSTANT, {Ty}});
181 // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs.
182 if (isPreLegalize())
183 return true;
184 LLT EltTy = Ty.getElementType();
185 return isLegal(Query: {TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) &&
186 isLegal(Query: {TargetOpcode::G_CONSTANT, {EltTy}});
187}
188
189void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg,
190 Register ToReg) const {
191 Observer.changingAllUsesOfReg(MRI, Reg: FromReg);
192
193 if (MRI.constrainRegAttrs(Reg: ToReg, ConstrainingReg: FromReg))
194 MRI.replaceRegWith(FromReg, ToReg);
195 else
196 Builder.buildCopy(Res: FromReg, Op: ToReg);
197
198 Observer.finishedChangingAllUsesOfReg();
199}
200
201void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI,
202 MachineOperand &FromRegOp,
203 Register ToReg) const {
204 assert(FromRegOp.getParent() && "Expected an operand in an MI");
205 Observer.changingInstr(MI&: *FromRegOp.getParent());
206
207 FromRegOp.setReg(ToReg);
208
209 Observer.changedInstr(MI&: *FromRegOp.getParent());
210}
211
212void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI,
213 unsigned ToOpcode) const {
214 Observer.changingInstr(MI&: FromMI);
215
216 FromMI.setDesc(Builder.getTII().get(Opcode: ToOpcode));
217
218 Observer.changedInstr(MI&: FromMI);
219}
220
221const RegisterBank *CombinerHelper::getRegBank(Register Reg) const {
222 return RBI->getRegBank(Reg, MRI, TRI: *TRI);
223}
224
225void CombinerHelper::setRegBank(Register Reg,
226 const RegisterBank *RegBank) const {
227 if (RegBank)
228 MRI.setRegBank(Reg, RegBank: *RegBank);
229}
230
231bool CombinerHelper::tryCombineCopy(MachineInstr &MI) const {
232 if (matchCombineCopy(MI)) {
233 applyCombineCopy(MI);
234 return true;
235 }
236 return false;
237}
238bool CombinerHelper::matchCombineCopy(MachineInstr &MI) const {
239 if (MI.getOpcode() != TargetOpcode::COPY)
240 return false;
241 Register DstReg = MI.getOperand(i: 0).getReg();
242 Register SrcReg = MI.getOperand(i: 1).getReg();
243 return canReplaceReg(DstReg, SrcReg, MRI);
244}
245void CombinerHelper::applyCombineCopy(MachineInstr &MI) const {
246 Register DstReg = MI.getOperand(i: 0).getReg();
247 Register SrcReg = MI.getOperand(i: 1).getReg();
248 replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg);
249 MI.eraseFromParent();
250}
251
252bool CombinerHelper::matchFreezeOfSingleMaybePoisonOperand(
253 MachineInstr &MI, BuildFnTy &MatchInfo) const {
254 // Ported from InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating.
255 Register DstOp = MI.getOperand(i: 0).getReg();
256 Register OrigOp = MI.getOperand(i: 1).getReg();
257
258 if (!MRI.hasOneNonDBGUse(RegNo: OrigOp))
259 return false;
260
261 MachineInstr *OrigDef = MRI.getUniqueVRegDef(Reg: OrigOp);
262 // Even if only a single operand of the PHI is not guaranteed non-poison,
263 // moving freeze() backwards across a PHI can cause optimization issues for
264 // other users of that operand.
265 //
266 // Moving freeze() from one of the output registers of a G_UNMERGE_VALUES to
267 // the source register is unprofitable because it makes the freeze() more
268 // strict than is necessary (it would affect the whole register instead of
269 // just the subreg being frozen).
270 if (OrigDef->isPHI() || isa<GUnmerge>(Val: OrigDef))
271 return false;
272
273 if (canCreateUndefOrPoison(Reg: OrigOp, MRI,
274 /*ConsiderFlagsAndMetadata=*/false))
275 return false;
276
277 std::optional<MachineOperand> MaybePoisonOperand;
278 for (MachineOperand &Operand : OrigDef->uses()) {
279 if (!Operand.isReg())
280 return false;
281
282 if (isGuaranteedNotToBeUndefOrPoison(Reg: Operand.getReg(), MRI))
283 continue;
284
285 if (!MaybePoisonOperand)
286 MaybePoisonOperand = Operand;
287 else {
288 // We have more than one maybe-poison operand. Moving the freeze is
289 // unsafe.
290 return false;
291 }
292 }
293
294 // Eliminate freeze if all operands are guaranteed non-poison.
295 if (!MaybePoisonOperand) {
296 MatchInfo = [=](MachineIRBuilder &B) {
297 Observer.changingInstr(MI&: *OrigDef);
298 cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags();
299 Observer.changedInstr(MI&: *OrigDef);
300 B.buildCopy(Res: DstOp, Op: OrigOp);
301 };
302 return true;
303 }
304
305 Register MaybePoisonOperandReg = MaybePoisonOperand->getReg();
306 LLT MaybePoisonOperandRegTy = MRI.getType(Reg: MaybePoisonOperandReg);
307
308 MatchInfo = [=](MachineIRBuilder &B) mutable {
309 Observer.changingInstr(MI&: *OrigDef);
310 cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags();
311 Observer.changedInstr(MI&: *OrigDef);
312 B.setInsertPt(MBB&: *OrigDef->getParent(), II: OrigDef->getIterator());
313 auto Freeze = B.buildFreeze(Dst: MaybePoisonOperandRegTy, Src: MaybePoisonOperandReg);
314 replaceRegOpWith(
315 MRI, FromRegOp&: *OrigDef->findRegisterUseOperand(Reg: MaybePoisonOperandReg, TRI),
316 ToReg: Freeze.getReg(Idx: 0));
317 replaceRegWith(MRI, FromReg: DstOp, ToReg: OrigOp);
318 };
319 return true;
320}
321
322bool CombinerHelper::matchCombineConcatVectors(
323 MachineInstr &MI, SmallVector<Register> &Ops) const {
324 assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS &&
325 "Invalid instruction");
326 bool IsUndef = true;
327 MachineInstr *Undef = nullptr;
328
329 // Walk over all the operands of concat vectors and check if they are
330 // build_vector themselves or undef.
331 // Then collect their operands in Ops.
332 for (const MachineOperand &MO : MI.uses()) {
333 Register Reg = MO.getReg();
334 MachineInstr *Def = MRI.getVRegDef(Reg);
335 assert(Def && "Operand not defined");
336 if (!MRI.hasOneNonDBGUse(RegNo: Reg))
337 return false;
338 switch (Def->getOpcode()) {
339 case TargetOpcode::G_BUILD_VECTOR:
340 IsUndef = false;
341 // Remember the operands of the build_vector to fold
342 // them into the yet-to-build flattened concat vectors.
343 for (const MachineOperand &BuildVecMO : Def->uses())
344 Ops.push_back(Elt: BuildVecMO.getReg());
345 break;
346 case TargetOpcode::G_IMPLICIT_DEF: {
347 LLT OpType = MRI.getType(Reg);
348 // Keep one undef value for all the undef operands.
349 if (!Undef) {
350 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
351 Undef = Builder.buildUndef(Res: OpType.getScalarType());
352 }
353 assert(MRI.getType(Undef->getOperand(0).getReg()) ==
354 OpType.getScalarType() &&
355 "All undefs should have the same type");
356 // Break the undef vector in as many scalar elements as needed
357 // for the flattening.
358 for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements();
359 EltIdx != EltEnd; ++EltIdx)
360 Ops.push_back(Elt: Undef->getOperand(i: 0).getReg());
361 break;
362 }
363 default:
364 return false;
365 }
366 }
367
368 // Check if the combine is illegal
369 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
370 if (!isLegalOrBeforeLegalizer(
371 Query: {TargetOpcode::G_BUILD_VECTOR, {DstTy, MRI.getType(Reg: Ops[0])}})) {
372 return false;
373 }
374
375 if (IsUndef)
376 Ops.clear();
377
378 return true;
379}
380void CombinerHelper::applyCombineConcatVectors(
381 MachineInstr &MI, SmallVector<Register> &Ops) const {
382 // We determined that the concat_vectors can be flatten.
383 // Generate the flattened build_vector.
384 Register DstReg = MI.getOperand(i: 0).getReg();
385 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
386 Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg);
387
388 // Note: IsUndef is sort of redundant. We could have determine it by
389 // checking that at all Ops are undef. Alternatively, we could have
390 // generate a build_vector of undefs and rely on another combine to
391 // clean that up. For now, given we already gather this information
392 // in matchCombineConcatVectors, just save compile time and issue the
393 // right thing.
394 if (Ops.empty())
395 Builder.buildUndef(Res: NewDstReg);
396 else
397 Builder.buildBuildVector(Res: NewDstReg, Ops);
398 replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg);
399 MI.eraseFromParent();
400}
401
402void CombinerHelper::applyCombineShuffleToBuildVector(MachineInstr &MI) const {
403 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
404
405 Register SrcVec1 = Shuffle.getSrc1Reg();
406 Register SrcVec2 = Shuffle.getSrc2Reg();
407 LLT EltTy = MRI.getType(Reg: SrcVec1).getElementType();
408 int Width = MRI.getType(Reg: SrcVec1).getNumElements();
409
410 auto Unmerge1 = Builder.buildUnmerge(Res: EltTy, Op: SrcVec1);
411 auto Unmerge2 = Builder.buildUnmerge(Res: EltTy, Op: SrcVec2);
412
413 SmallVector<Register> Extracts;
414 // Select only applicable elements from unmerged values.
415 for (int Val : Shuffle.getMask()) {
416 if (Val == -1)
417 Extracts.push_back(Elt: Builder.buildUndef(Res: EltTy).getReg(Idx: 0));
418 else if (Val < Width)
419 Extracts.push_back(Elt: Unmerge1.getReg(Idx: Val));
420 else
421 Extracts.push_back(Elt: Unmerge2.getReg(Idx: Val - Width));
422 }
423 assert(Extracts.size() > 0 && "Expected at least one element in the shuffle");
424 if (Extracts.size() == 1)
425 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: Extracts[0]);
426 else
427 Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: Extracts);
428 MI.eraseFromParent();
429}
430
431bool CombinerHelper::matchCombineShuffleConcat(
432 MachineInstr &MI, SmallVector<Register> &Ops) const {
433 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
434 auto ConcatMI1 =
435 dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg()));
436 auto ConcatMI2 =
437 dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg()));
438 if (!ConcatMI1 || !ConcatMI2)
439 return false;
440
441 // Check that the sources of the Concat instructions have the same type
442 if (MRI.getType(Reg: ConcatMI1->getSourceReg(I: 0)) !=
443 MRI.getType(Reg: ConcatMI2->getSourceReg(I: 0)))
444 return false;
445
446 LLT ConcatSrcTy = MRI.getType(Reg: ConcatMI1->getReg(Idx: 1));
447 LLT ShuffleSrcTy1 = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
448 unsigned ConcatSrcNumElt = ConcatSrcTy.getNumElements();
449 for (unsigned i = 0; i < Mask.size(); i += ConcatSrcNumElt) {
450 // Check if the index takes a whole source register from G_CONCAT_VECTORS
451 // Assumes that all Sources of G_CONCAT_VECTORS are the same type
452 if (Mask[i] == -1) {
453 for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
454 if (i + j >= Mask.size())
455 return false;
456 if (Mask[i + j] != -1)
457 return false;
458 }
459 if (!isLegalOrBeforeLegalizer(
460 Query: {TargetOpcode::G_IMPLICIT_DEF, {ConcatSrcTy}}))
461 return false;
462 Ops.push_back(Elt: 0);
463 } else if (Mask[i] % ConcatSrcNumElt == 0) {
464 for (unsigned j = 1; j < ConcatSrcNumElt; j++) {
465 if (i + j >= Mask.size())
466 return false;
467 if (Mask[i + j] != Mask[i] + static_cast<int>(j))
468 return false;
469 }
470 // Retrieve the source register from its respective G_CONCAT_VECTORS
471 // instruction
472 if (Mask[i] < ShuffleSrcTy1.getNumElements()) {
473 Ops.push_back(Elt: ConcatMI1->getSourceReg(I: Mask[i] / ConcatSrcNumElt));
474 } else {
475 Ops.push_back(Elt: ConcatMI2->getSourceReg(I: Mask[i] / ConcatSrcNumElt -
476 ConcatMI1->getNumSources()));
477 }
478 } else {
479 return false;
480 }
481 }
482
483 if (!isLegalOrBeforeLegalizer(
484 Query: {TargetOpcode::G_CONCAT_VECTORS,
485 {MRI.getType(Reg: MI.getOperand(i: 0).getReg()), ConcatSrcTy}}))
486 return false;
487
488 return !Ops.empty();
489}
490
491void CombinerHelper::applyCombineShuffleConcat(
492 MachineInstr &MI, SmallVector<Register> &Ops) const {
493 LLT SrcTy;
494 for (Register &Reg : Ops) {
495 if (Reg != 0)
496 SrcTy = MRI.getType(Reg);
497 }
498 assert(SrcTy.isValid() && "Unexpected full undef vector in concat combine");
499
500 Register UndefReg = 0;
501
502 for (Register &Reg : Ops) {
503 if (Reg == 0) {
504 if (UndefReg == 0)
505 UndefReg = Builder.buildUndef(Res: SrcTy).getReg(Idx: 0);
506 Reg = UndefReg;
507 }
508 }
509
510 if (Ops.size() > 1)
511 Builder.buildConcatVectors(Res: MI.getOperand(i: 0).getReg(), Ops);
512 else
513 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: Ops[0]);
514 MI.eraseFromParent();
515}
516
517bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) const {
518 SmallVector<Register, 4> Ops;
519 if (matchCombineShuffleVector(MI, Ops)) {
520 applyCombineShuffleVector(MI, Ops);
521 return true;
522 }
523 return false;
524}
525
526bool CombinerHelper::matchCombineShuffleVector(
527 MachineInstr &MI, SmallVectorImpl<Register> &Ops) const {
528 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
529 "Invalid instruction kind");
530 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
531 Register Src1 = MI.getOperand(i: 1).getReg();
532 LLT SrcType = MRI.getType(Reg: Src1);
533
534 unsigned DstNumElts = DstType.getNumElements();
535 unsigned SrcNumElts = SrcType.getNumElements();
536
537 // If the resulting vector is smaller than the size of the source
538 // vectors being concatenated, we won't be able to replace the
539 // shuffle vector into a concat_vectors.
540 //
541 // Note: We may still be able to produce a concat_vectors fed by
542 // extract_vector_elt and so on. It is less clear that would
543 // be better though, so don't bother for now.
544 //
545 // If the destination is a scalar, the size of the sources doesn't
546 // matter. we will lower the shuffle to a plain copy. This will
547 // work only if the source and destination have the same size. But
548 // that's covered by the next condition.
549 //
550 // TODO: If the size between the source and destination don't match
551 // we could still emit an extract vector element in that case.
552 if (DstNumElts < 2 * SrcNumElts)
553 return false;
554
555 // Check that the shuffle mask can be broken evenly between the
556 // different sources.
557 if (DstNumElts % SrcNumElts != 0)
558 return false;
559
560 // Mask length is a multiple of the source vector length.
561 // Check if the shuffle is some kind of concatenation of the input
562 // vectors.
563 unsigned NumConcat = DstNumElts / SrcNumElts;
564 SmallVector<int, 8> ConcatSrcs(NumConcat, -1);
565 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
566 for (unsigned i = 0; i != DstNumElts; ++i) {
567 int Idx = Mask[i];
568 // Undef value.
569 if (Idx < 0)
570 continue;
571 // Ensure the indices in each SrcType sized piece are sequential and that
572 // the same source is used for the whole piece.
573 if ((Idx % SrcNumElts != (i % SrcNumElts)) ||
574 (ConcatSrcs[i / SrcNumElts] >= 0 &&
575 ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts)))
576 return false;
577 // Remember which source this index came from.
578 ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts;
579 }
580
581 // The shuffle is concatenating multiple vectors together.
582 // Collect the different operands for that.
583 Register UndefReg;
584 Register Src2 = MI.getOperand(i: 2).getReg();
585 for (auto Src : ConcatSrcs) {
586 if (Src < 0) {
587 if (!UndefReg) {
588 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
589 UndefReg = Builder.buildUndef(Res: SrcType).getReg(Idx: 0);
590 }
591 Ops.push_back(Elt: UndefReg);
592 } else if (Src == 0)
593 Ops.push_back(Elt: Src1);
594 else
595 Ops.push_back(Elt: Src2);
596 }
597 return true;
598}
599
600void CombinerHelper::applyCombineShuffleVector(MachineInstr &MI,
601 ArrayRef<Register> Ops) const {
602 Register DstReg = MI.getOperand(i: 0).getReg();
603 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
604 Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg);
605
606 if (Ops.size() == 1)
607 Builder.buildCopy(Res: NewDstReg, Op: Ops[0]);
608 else
609 Builder.buildMergeLikeInstr(Res: NewDstReg, Ops);
610
611 replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg);
612 MI.eraseFromParent();
613}
614
615namespace {
616
617/// Select a preference between two uses. CurrentUse is the current preference
618/// while *ForCandidate is attributes of the candidate under consideration.
619PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI,
620 PreferredTuple &CurrentUse,
621 const LLT TyForCandidate,
622 unsigned OpcodeForCandidate,
623 MachineInstr *MIForCandidate) {
624 if (!CurrentUse.Ty.isValid()) {
625 if (CurrentUse.ExtendOpcode == OpcodeForCandidate ||
626 CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT)
627 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
628 return CurrentUse;
629 }
630
631 // We permit the extend to hoist through basic blocks but this is only
632 // sensible if the target has extending loads. If you end up lowering back
633 // into a load and extend during the legalizer then the end result is
634 // hoisting the extend up to the load.
635
636 // Prefer defined extensions to undefined extensions as these are more
637 // likely to reduce the number of instructions.
638 if (OpcodeForCandidate == TargetOpcode::G_ANYEXT &&
639 CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT)
640 return CurrentUse;
641 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT &&
642 OpcodeForCandidate != TargetOpcode::G_ANYEXT)
643 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
644
645 // Prefer sign extensions to zero extensions as sign-extensions tend to be
646 // more expensive. Don't do this if the load is already a zero-extend load
647 // though, otherwise we'll rewrite a zero-extend load into a sign-extend
648 // later.
649 if (!isa<GZExtLoad>(Val: LoadMI) && CurrentUse.Ty == TyForCandidate) {
650 if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT &&
651 OpcodeForCandidate == TargetOpcode::G_ZEXT)
652 return CurrentUse;
653 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT &&
654 OpcodeForCandidate == TargetOpcode::G_SEXT)
655 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
656 }
657
658 // This is potentially target specific. We've chosen the largest type
659 // because G_TRUNC is usually free. One potential catch with this is that
660 // some targets have a reduced number of larger registers than smaller
661 // registers and this choice potentially increases the live-range for the
662 // larger value.
663 if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) {
664 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
665 }
666 return CurrentUse;
667}
668
669/// Find a suitable place to insert some instructions and insert them. This
670/// function accounts for special cases like inserting before a PHI node.
671/// The current strategy for inserting before PHI's is to duplicate the
672/// instructions for each predecessor. However, while that's ok for G_TRUNC
673/// on most targets since it generally requires no code, other targets/cases may
674/// want to try harder to find a dominating block.
675static void InsertInsnsWithoutSideEffectsBeforeUse(
676 MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO,
677 std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator,
678 MachineOperand &UseMO)>
679 Inserter) {
680 MachineInstr &UseMI = *UseMO.getParent();
681
682 MachineBasicBlock *InsertBB = UseMI.getParent();
683
684 // If the use is a PHI then we want the predecessor block instead.
685 if (UseMI.isPHI()) {
686 MachineOperand *PredBB = std::next(x: &UseMO);
687 InsertBB = PredBB->getMBB();
688 }
689
690 // If the block is the same block as the def then we want to insert just after
691 // the def instead of at the start of the block.
692 if (InsertBB == DefMI.getParent()) {
693 MachineBasicBlock::iterator InsertPt = &DefMI;
694 Inserter(InsertBB, std::next(x: InsertPt), UseMO);
695 return;
696 }
697
698 // Otherwise we want the start of the BB
699 Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO);
700}
701} // end anonymous namespace
702
703bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) const {
704 PreferredTuple Preferred;
705 if (matchCombineExtendingLoads(MI, MatchInfo&: Preferred)) {
706 applyCombineExtendingLoads(MI, MatchInfo&: Preferred);
707 return true;
708 }
709 return false;
710}
711
712static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) {
713 unsigned CandidateLoadOpc;
714 switch (ExtOpc) {
715 case TargetOpcode::G_ANYEXT:
716 CandidateLoadOpc = TargetOpcode::G_LOAD;
717 break;
718 case TargetOpcode::G_SEXT:
719 CandidateLoadOpc = TargetOpcode::G_SEXTLOAD;
720 break;
721 case TargetOpcode::G_ZEXT:
722 CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD;
723 break;
724 default:
725 llvm_unreachable("Unexpected extend opc");
726 }
727 return CandidateLoadOpc;
728}
729
730bool CombinerHelper::matchCombineExtendingLoads(
731 MachineInstr &MI, PreferredTuple &Preferred) const {
732 // We match the loads and follow the uses to the extend instead of matching
733 // the extends and following the def to the load. This is because the load
734 // must remain in the same position for correctness (unless we also add code
735 // to find a safe place to sink it) whereas the extend is freely movable.
736 // It also prevents us from duplicating the load for the volatile case or just
737 // for performance.
738 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: &MI);
739 if (!LoadMI)
740 return false;
741
742 Register LoadReg = LoadMI->getDstReg();
743
744 LLT LoadValueTy = MRI.getType(Reg: LoadReg);
745 if (!LoadValueTy.isScalar())
746 return false;
747
748 // Most architectures are going to legalize <s8 loads into at least a 1 byte
749 // load, and the MMOs can only describe memory accesses in multiples of bytes.
750 // If we try to perform extload combining on those, we can end up with
751 // %a(s8) = extload %ptr (load 1 byte from %ptr)
752 // ... which is an illegal extload instruction.
753 if (LoadValueTy.getSizeInBits() < 8)
754 return false;
755
756 // For non power-of-2 types, they will very likely be legalized into multiple
757 // loads. Don't bother trying to match them into extending loads.
758 if (!llvm::has_single_bit<uint32_t>(Value: LoadValueTy.getSizeInBits()))
759 return false;
760
761 // Find the preferred type aside from the any-extends (unless it's the only
762 // one) and non-extending ops. We'll emit an extending load to that type and
763 // and emit a variant of (extend (trunc X)) for the others according to the
764 // relative type sizes. At the same time, pick an extend to use based on the
765 // extend involved in the chosen type.
766 unsigned PreferredOpcode =
767 isa<GLoad>(Val: &MI)
768 ? TargetOpcode::G_ANYEXT
769 : isa<GSExtLoad>(Val: &MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
770 Preferred = {.Ty: LLT(), .ExtendOpcode: PreferredOpcode, .MI: nullptr};
771 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: LoadReg)) {
772 if (UseMI.getOpcode() == TargetOpcode::G_SEXT ||
773 UseMI.getOpcode() == TargetOpcode::G_ZEXT ||
774 (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) {
775 const auto &MMO = LoadMI->getMMO();
776 // Don't do anything for atomics.
777 if (MMO.isAtomic())
778 continue;
779 // Check for legality.
780 if (!isPreLegalize()) {
781 LegalityQuery::MemDesc MMDesc(MMO);
782 unsigned CandidateLoadOpc = getExtLoadOpcForExtend(ExtOpc: UseMI.getOpcode());
783 LLT UseTy = MRI.getType(Reg: UseMI.getOperand(i: 0).getReg());
784 LLT SrcTy = MRI.getType(Reg: LoadMI->getPointerReg());
785 if (LI->getAction(Query: {CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}})
786 .Action != LegalizeActions::Legal)
787 continue;
788 }
789 Preferred = ChoosePreferredUse(LoadMI&: MI, CurrentUse&: Preferred,
790 TyForCandidate: MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()),
791 OpcodeForCandidate: UseMI.getOpcode(), MIForCandidate: &UseMI);
792 }
793 }
794
795 // There were no extends
796 if (!Preferred.MI)
797 return false;
798 // It should be impossible to chose an extend without selecting a different
799 // type since by definition the result of an extend is larger.
800 assert(Preferred.Ty != LoadValueTy && "Extending to same type?");
801
802 LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI);
803 return true;
804}
805
806void CombinerHelper::applyCombineExtendingLoads(
807 MachineInstr &MI, PreferredTuple &Preferred) const {
808 // Rewrite the load to the chosen extending load.
809 Register ChosenDstReg = Preferred.MI->getOperand(i: 0).getReg();
810
811 // Inserter to insert a truncate back to the original type at a given point
812 // with some basic CSE to limit truncate duplication to one per BB.
813 DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns;
814 auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB,
815 MachineBasicBlock::iterator InsertBefore,
816 MachineOperand &UseMO) {
817 MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(Val: InsertIntoBB);
818 if (PreviouslyEmitted) {
819 Observer.changingInstr(MI&: *UseMO.getParent());
820 UseMO.setReg(PreviouslyEmitted->getOperand(i: 0).getReg());
821 Observer.changedInstr(MI&: *UseMO.getParent());
822 return;
823 }
824
825 Builder.setInsertPt(MBB&: *InsertIntoBB, II: InsertBefore);
826 Register NewDstReg = MRI.cloneVirtualRegister(VReg: MI.getOperand(i: 0).getReg());
827 MachineInstr *NewMI = Builder.buildTrunc(Res: NewDstReg, Op: ChosenDstReg);
828 EmittedInsns[InsertIntoBB] = NewMI;
829 replaceRegOpWith(MRI, FromRegOp&: UseMO, ToReg: NewDstReg);
830 };
831
832 Observer.changingInstr(MI);
833 unsigned LoadOpc = getExtLoadOpcForExtend(ExtOpc: Preferred.ExtendOpcode);
834 MI.setDesc(Builder.getTII().get(Opcode: LoadOpc));
835
836 // Rewrite all the uses to fix up the types.
837 auto &LoadValue = MI.getOperand(i: 0);
838 SmallVector<MachineOperand *, 4> Uses(
839 llvm::make_pointer_range(Range: MRI.use_operands(Reg: LoadValue.getReg())));
840
841 for (auto *UseMO : Uses) {
842 MachineInstr *UseMI = UseMO->getParent();
843
844 // If the extend is compatible with the preferred extend then we should fix
845 // up the type and extend so that it uses the preferred use.
846 if (UseMI->getOpcode() == Preferred.ExtendOpcode ||
847 UseMI->getOpcode() == TargetOpcode::G_ANYEXT) {
848 Register UseDstReg = UseMI->getOperand(i: 0).getReg();
849 MachineOperand &UseSrcMO = UseMI->getOperand(i: 1);
850 const LLT UseDstTy = MRI.getType(Reg: UseDstReg);
851 if (UseDstReg != ChosenDstReg) {
852 if (Preferred.Ty == UseDstTy) {
853 // If the use has the same type as the preferred use, then merge
854 // the vregs and erase the extend. For example:
855 // %1:_(s8) = G_LOAD ...
856 // %2:_(s32) = G_SEXT %1(s8)
857 // %3:_(s32) = G_ANYEXT %1(s8)
858 // ... = ... %3(s32)
859 // rewrites to:
860 // %2:_(s32) = G_SEXTLOAD ...
861 // ... = ... %2(s32)
862 replaceRegWith(MRI, FromReg: UseDstReg, ToReg: ChosenDstReg);
863 Observer.erasingInstr(MI&: *UseMO->getParent());
864 UseMO->getParent()->eraseFromParent();
865 } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) {
866 // If the preferred size is smaller, then keep the extend but extend
867 // from the result of the extending load. For example:
868 // %1:_(s8) = G_LOAD ...
869 // %2:_(s32) = G_SEXT %1(s8)
870 // %3:_(s64) = G_ANYEXT %1(s8)
871 // ... = ... %3(s64)
872 /// rewrites to:
873 // %2:_(s32) = G_SEXTLOAD ...
874 // %3:_(s64) = G_ANYEXT %2:_(s32)
875 // ... = ... %3(s64)
876 replaceRegOpWith(MRI, FromRegOp&: UseSrcMO, ToReg: ChosenDstReg);
877 } else {
878 // If the preferred size is large, then insert a truncate. For
879 // example:
880 // %1:_(s8) = G_LOAD ...
881 // %2:_(s64) = G_SEXT %1(s8)
882 // %3:_(s32) = G_ZEXT %1(s8)
883 // ... = ... %3(s32)
884 /// rewrites to:
885 // %2:_(s64) = G_SEXTLOAD ...
886 // %4:_(s8) = G_TRUNC %2:_(s32)
887 // %3:_(s64) = G_ZEXT %2:_(s8)
888 // ... = ... %3(s64)
889 InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO,
890 Inserter: InsertTruncAt);
891 }
892 continue;
893 }
894 // The use is (one of) the uses of the preferred use we chose earlier.
895 // We're going to update the load to def this value later so just erase
896 // the old extend.
897 Observer.erasingInstr(MI&: *UseMO->getParent());
898 UseMO->getParent()->eraseFromParent();
899 continue;
900 }
901
902 // The use isn't an extend. Truncate back to the type we originally loaded.
903 // This is free on many targets.
904 InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, Inserter: InsertTruncAt);
905 }
906
907 MI.getOperand(i: 0).setReg(ChosenDstReg);
908 Observer.changedInstr(MI);
909}
910
911bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI,
912 BuildFnTy &MatchInfo) const {
913 assert(MI.getOpcode() == TargetOpcode::G_AND);
914
915 // If we have the following code:
916 // %mask = G_CONSTANT 255
917 // %ld = G_LOAD %ptr, (load s16)
918 // %and = G_AND %ld, %mask
919 //
920 // Try to fold it into
921 // %ld = G_ZEXTLOAD %ptr, (load s8)
922
923 Register Dst = MI.getOperand(i: 0).getReg();
924 if (MRI.getType(Reg: Dst).isVector())
925 return false;
926
927 auto MaybeMask =
928 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
929 if (!MaybeMask)
930 return false;
931
932 APInt MaskVal = MaybeMask->Value;
933
934 if (!MaskVal.isMask())
935 return false;
936
937 Register SrcReg = MI.getOperand(i: 1).getReg();
938 // Don't use getOpcodeDef() here since intermediate instructions may have
939 // multiple users.
940 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: MRI.getVRegDef(Reg: SrcReg));
941 if (!LoadMI || !MRI.hasOneNonDBGUse(RegNo: LoadMI->getDstReg()))
942 return false;
943
944 Register LoadReg = LoadMI->getDstReg();
945 LLT RegTy = MRI.getType(Reg: LoadReg);
946 Register PtrReg = LoadMI->getPointerReg();
947 unsigned RegSize = RegTy.getSizeInBits();
948 LocationSize LoadSizeBits = LoadMI->getMemSizeInBits();
949 unsigned MaskSizeBits = MaskVal.countr_one();
950
951 // The mask may not be larger than the in-memory type, as it might cover sign
952 // extended bits
953 if (MaskSizeBits > LoadSizeBits.getValue())
954 return false;
955
956 // If the mask covers the whole destination register, there's nothing to
957 // extend
958 if (MaskSizeBits >= RegSize)
959 return false;
960
961 // Most targets cannot deal with loads of size < 8 and need to re-legalize to
962 // at least byte loads. Avoid creating such loads here
963 if (MaskSizeBits < 8 || !isPowerOf2_32(Value: MaskSizeBits))
964 return false;
965
966 const MachineMemOperand &MMO = LoadMI->getMMO();
967 LegalityQuery::MemDesc MemDesc(MMO);
968
969 // Don't modify the memory access size if this is atomic/volatile, but we can
970 // still adjust the opcode to indicate the high bit behavior.
971 if (LoadMI->isSimple())
972 MemDesc.MemoryTy = LLT::scalar(SizeInBits: MaskSizeBits);
973 else if (LoadSizeBits.getValue() > MaskSizeBits ||
974 LoadSizeBits.getValue() == RegSize)
975 return false;
976
977 // TODO: Could check if it's legal with the reduced or original memory size.
978 if (!isLegalOrBeforeLegalizer(
979 Query: {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(Reg: PtrReg)}, {MemDesc}}))
980 return false;
981
982 MatchInfo = [=](MachineIRBuilder &B) {
983 B.setInstrAndDebugLoc(*LoadMI);
984 auto &MF = B.getMF();
985 auto PtrInfo = MMO.getPointerInfo();
986 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: MemDesc.MemoryTy);
987 B.buildLoadInstr(Opcode: TargetOpcode::G_ZEXTLOAD, Res: Dst, Addr: PtrReg, MMO&: *NewMMO);
988 LoadMI->eraseFromParent();
989 };
990 return true;
991}
992
993bool CombinerHelper::isPredecessor(const MachineInstr &DefMI,
994 const MachineInstr &UseMI) const {
995 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
996 "shouldn't consider debug uses");
997 assert(DefMI.getParent() == UseMI.getParent());
998 if (&DefMI == &UseMI)
999 return true;
1000 const MachineBasicBlock &MBB = *DefMI.getParent();
1001 auto DefOrUse = find_if(Range: MBB, P: [&DefMI, &UseMI](const MachineInstr &MI) {
1002 return &MI == &DefMI || &MI == &UseMI;
1003 });
1004 if (DefOrUse == MBB.end())
1005 llvm_unreachable("Block must contain both DefMI and UseMI!");
1006 return &*DefOrUse == &DefMI;
1007}
1008
1009bool CombinerHelper::dominates(const MachineInstr &DefMI,
1010 const MachineInstr &UseMI) const {
1011 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
1012 "shouldn't consider debug uses");
1013 if (MDT)
1014 return MDT->dominates(A: &DefMI, B: &UseMI);
1015 else if (DefMI.getParent() != UseMI.getParent())
1016 return false;
1017
1018 return isPredecessor(DefMI, UseMI);
1019}
1020
1021bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) const {
1022 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1023 Register SrcReg = MI.getOperand(i: 1).getReg();
1024 Register LoadUser = SrcReg;
1025
1026 if (MRI.getType(Reg: SrcReg).isVector())
1027 return false;
1028
1029 Register TruncSrc;
1030 if (mi_match(R: SrcReg, MRI, P: m_GTrunc(Src: m_Reg(R&: TruncSrc))))
1031 LoadUser = TruncSrc;
1032
1033 uint64_t SizeInBits = MI.getOperand(i: 2).getImm();
1034 // If the source is a G_SEXTLOAD from the same bit width, then we don't
1035 // need any extend at all, just a truncate.
1036 if (auto *LoadMI = getOpcodeDef<GSExtLoad>(Reg: LoadUser, MRI)) {
1037 // If truncating more than the original extended value, abort.
1038 auto LoadSizeBits = LoadMI->getMemSizeInBits();
1039 if (TruncSrc &&
1040 MRI.getType(Reg: TruncSrc).getSizeInBits() < LoadSizeBits.getValue())
1041 return false;
1042 if (LoadSizeBits == SizeInBits)
1043 return true;
1044 }
1045 return false;
1046}
1047
1048void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) const {
1049 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1050 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: MI.getOperand(i: 1).getReg());
1051 MI.eraseFromParent();
1052}
1053
1054bool CombinerHelper::matchSextInRegOfLoad(
1055 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const {
1056 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1057
1058 Register DstReg = MI.getOperand(i: 0).getReg();
1059 LLT RegTy = MRI.getType(Reg: DstReg);
1060
1061 // Only supports scalars for now.
1062 if (RegTy.isVector())
1063 return false;
1064
1065 Register SrcReg = MI.getOperand(i: 1).getReg();
1066 auto *LoadDef = getOpcodeDef<GLoad>(Reg: SrcReg, MRI);
1067 if (!LoadDef || !MRI.hasOneNonDBGUse(RegNo: SrcReg))
1068 return false;
1069
1070 uint64_t MemBits = LoadDef->getMemSizeInBits().getValue();
1071
1072 // If the sign extend extends from a narrower width than the load's width,
1073 // then we can narrow the load width when we combine to a G_SEXTLOAD.
1074 // Avoid widening the load at all.
1075 unsigned NewSizeBits = std::min(a: (uint64_t)MI.getOperand(i: 2).getImm(), b: MemBits);
1076
1077 // Don't generate G_SEXTLOADs with a < 1 byte width.
1078 if (NewSizeBits < 8)
1079 return false;
1080 // Don't bother creating a non-power-2 sextload, it will likely be broken up
1081 // anyway for most targets.
1082 if (!isPowerOf2_32(Value: NewSizeBits))
1083 return false;
1084
1085 const MachineMemOperand &MMO = LoadDef->getMMO();
1086 LegalityQuery::MemDesc MMDesc(MMO);
1087
1088 // Don't modify the memory access size if this is atomic/volatile, but we can
1089 // still adjust the opcode to indicate the high bit behavior.
1090 if (LoadDef->isSimple())
1091 MMDesc.MemoryTy = LLT::scalar(SizeInBits: NewSizeBits);
1092 else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits())
1093 return false;
1094
1095 // TODO: Could check if it's legal with the reduced or original memory size.
1096 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXTLOAD,
1097 {MRI.getType(Reg: LoadDef->getDstReg()),
1098 MRI.getType(Reg: LoadDef->getPointerReg())},
1099 {MMDesc}}))
1100 return false;
1101
1102 MatchInfo = std::make_tuple(args: LoadDef->getDstReg(), args&: NewSizeBits);
1103 return true;
1104}
1105
1106void CombinerHelper::applySextInRegOfLoad(
1107 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const {
1108 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1109 Register LoadReg;
1110 unsigned ScalarSizeBits;
1111 std::tie(args&: LoadReg, args&: ScalarSizeBits) = MatchInfo;
1112 GLoad *LoadDef = cast<GLoad>(Val: MRI.getVRegDef(Reg: LoadReg));
1113
1114 // If we have the following:
1115 // %ld = G_LOAD %ptr, (load 2)
1116 // %ext = G_SEXT_INREG %ld, 8
1117 // ==>
1118 // %ld = G_SEXTLOAD %ptr (load 1)
1119
1120 auto &MMO = LoadDef->getMMO();
1121 Builder.setInstrAndDebugLoc(*LoadDef);
1122 auto &MF = Builder.getMF();
1123 auto PtrInfo = MMO.getPointerInfo();
1124 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: ScalarSizeBits / 8);
1125 Builder.buildLoadInstr(Opcode: TargetOpcode::G_SEXTLOAD, Res: MI.getOperand(i: 0).getReg(),
1126 Addr: LoadDef->getPointerReg(), MMO&: *NewMMO);
1127 MI.eraseFromParent();
1128
1129 // Not all loads can be deleted, so make sure the old one is removed.
1130 LoadDef->eraseFromParent();
1131}
1132
1133/// Return true if 'MI' is a load or a store that may be fold it's address
1134/// operand into the load / store addressing mode.
1135static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI,
1136 MachineRegisterInfo &MRI) {
1137 TargetLowering::AddrMode AM;
1138 auto *MF = MI->getMF();
1139 auto *Addr = getOpcodeDef<GPtrAdd>(Reg: MI->getPointerReg(), MRI);
1140 if (!Addr)
1141 return false;
1142
1143 AM.HasBaseReg = true;
1144 if (auto CstOff = getIConstantVRegVal(VReg: Addr->getOffsetReg(), MRI))
1145 AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm]
1146 else
1147 AM.Scale = 1; // [reg +/- reg]
1148
1149 return TLI.isLegalAddressingMode(
1150 DL: MF->getDataLayout(), AM,
1151 Ty: getTypeForLLT(Ty: MI->getMMO().getMemoryType(),
1152 C&: MF->getFunction().getContext()),
1153 AddrSpace: MI->getMMO().getAddrSpace());
1154}
1155
1156static unsigned getIndexedOpc(unsigned LdStOpc) {
1157 switch (LdStOpc) {
1158 case TargetOpcode::G_LOAD:
1159 return TargetOpcode::G_INDEXED_LOAD;
1160 case TargetOpcode::G_STORE:
1161 return TargetOpcode::G_INDEXED_STORE;
1162 case TargetOpcode::G_ZEXTLOAD:
1163 return TargetOpcode::G_INDEXED_ZEXTLOAD;
1164 case TargetOpcode::G_SEXTLOAD:
1165 return TargetOpcode::G_INDEXED_SEXTLOAD;
1166 default:
1167 llvm_unreachable("Unexpected opcode");
1168 }
1169}
1170
1171bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
1172 // Check for legality.
1173 LLT PtrTy = MRI.getType(Reg: LdSt.getPointerReg());
1174 LLT Ty = MRI.getType(Reg: LdSt.getReg(Idx: 0));
1175 LLT MemTy = LdSt.getMMO().getMemoryType();
1176 SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
1177 {{MemTy, MemTy.getSizeInBits().getKnownMinValue(),
1178 AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic}});
1179 unsigned IndexedOpc = getIndexedOpc(LdStOpc: LdSt.getOpcode());
1180 SmallVector<LLT> OpTys;
1181 if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)
1182 OpTys = {PtrTy, Ty, Ty};
1183 else
1184 OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD
1185
1186 LegalityQuery Q(IndexedOpc, OpTys, MemDescrs);
1187 return isLegal(Query: Q);
1188}
1189
1190static cl::opt<unsigned> PostIndexUseThreshold(
1191 "post-index-use-threshold", cl::Hidden, cl::init(Val: 32),
1192 cl::desc("Number of uses of a base pointer to check before it is no longer "
1193 "considered for post-indexing."));
1194
1195bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr,
1196 Register &Base, Register &Offset,
1197 bool &RematOffset) const {
1198 // We're looking for the following pattern, for either load or store:
1199 // %baseptr:_(p0) = ...
1200 // G_STORE %val(s64), %baseptr(p0)
1201 // %offset:_(s64) = G_CONSTANT i64 -256
1202 // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64)
1203 const auto &TLI = getTargetLowering();
1204
1205 Register Ptr = LdSt.getPointerReg();
1206 // If the store is the only use, don't bother.
1207 if (MRI.hasOneNonDBGUse(RegNo: Ptr))
1208 return false;
1209
1210 if (!isIndexedLoadStoreLegal(LdSt))
1211 return false;
1212
1213 if (getOpcodeDef(Opcode: TargetOpcode::G_FRAME_INDEX, Reg: Ptr, MRI))
1214 return false;
1215
1216 MachineInstr *StoredValDef = getDefIgnoringCopies(Reg: LdSt.getReg(Idx: 0), MRI);
1217 auto *PtrDef = MRI.getVRegDef(Reg: Ptr);
1218
1219 unsigned NumUsesChecked = 0;
1220 for (auto &Use : MRI.use_nodbg_instructions(Reg: Ptr)) {
1221 if (++NumUsesChecked > PostIndexUseThreshold)
1222 return false; // Try to avoid exploding compile time.
1223
1224 auto *PtrAdd = dyn_cast<GPtrAdd>(Val: &Use);
1225 // The use itself might be dead. This can happen during combines if DCE
1226 // hasn't had a chance to run yet. Don't allow it to form an indexed op.
1227 if (!PtrAdd || MRI.use_nodbg_empty(RegNo: PtrAdd->getReg(Idx: 0)))
1228 continue;
1229
1230 // Check the user of this isn't the store, otherwise we'd be generate a
1231 // indexed store defining its own use.
1232 if (StoredValDef == &Use)
1233 continue;
1234
1235 Offset = PtrAdd->getOffsetReg();
1236 if (!ForceLegalIndexing &&
1237 !TLI.isIndexingLegal(MI&: LdSt, Base: PtrAdd->getBaseReg(), Offset,
1238 /*IsPre*/ false, MRI))
1239 continue;
1240
1241 // Make sure the offset calculation is before the potentially indexed op.
1242 MachineInstr *OffsetDef = MRI.getVRegDef(Reg: Offset);
1243 RematOffset = false;
1244 if (!dominates(DefMI: *OffsetDef, UseMI: LdSt)) {
1245 // If the offset however is just a G_CONSTANT, we can always just
1246 // rematerialize it where we need it.
1247 if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT)
1248 continue;
1249 RematOffset = true;
1250 }
1251
1252 for (auto &BasePtrUse : MRI.use_nodbg_instructions(Reg: PtrAdd->getBaseReg())) {
1253 if (&BasePtrUse == PtrDef)
1254 continue;
1255
1256 // If the user is a later load/store that can be post-indexed, then don't
1257 // combine this one.
1258 auto *BasePtrLdSt = dyn_cast<GLoadStore>(Val: &BasePtrUse);
1259 if (BasePtrLdSt && BasePtrLdSt != &LdSt &&
1260 dominates(DefMI: LdSt, UseMI: *BasePtrLdSt) &&
1261 isIndexedLoadStoreLegal(LdSt&: *BasePtrLdSt))
1262 return false;
1263
1264 // Now we're looking for the key G_PTR_ADD instruction, which contains
1265 // the offset add that we want to fold.
1266 if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(Val: &BasePtrUse)) {
1267 Register PtrAddDefReg = BasePtrUseDef->getReg(Idx: 0);
1268 for (auto &BaseUseUse : MRI.use_nodbg_instructions(Reg: PtrAddDefReg)) {
1269 // If the use is in a different block, then we may produce worse code
1270 // due to the extra register pressure.
1271 if (BaseUseUse.getParent() != LdSt.getParent())
1272 return false;
1273
1274 if (auto *UseUseLdSt = dyn_cast<GLoadStore>(Val: &BaseUseUse))
1275 if (canFoldInAddressingMode(MI: UseUseLdSt, TLI, MRI))
1276 return false;
1277 }
1278 if (!dominates(DefMI: LdSt, UseMI: BasePtrUse))
1279 return false; // All use must be dominated by the load/store.
1280 }
1281 }
1282
1283 Addr = PtrAdd->getReg(Idx: 0);
1284 Base = PtrAdd->getBaseReg();
1285 return true;
1286 }
1287
1288 return false;
1289}
1290
1291bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr,
1292 Register &Base,
1293 Register &Offset) const {
1294 auto &MF = *LdSt.getParent()->getParent();
1295 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1296
1297 Addr = LdSt.getPointerReg();
1298 if (!mi_match(R: Addr, MRI, P: m_GPtrAdd(L: m_Reg(R&: Base), R: m_Reg(R&: Offset))) ||
1299 MRI.hasOneNonDBGUse(RegNo: Addr))
1300 return false;
1301
1302 if (!ForceLegalIndexing &&
1303 !TLI.isIndexingLegal(MI&: LdSt, Base, Offset, /*IsPre*/ true, MRI))
1304 return false;
1305
1306 if (!isIndexedLoadStoreLegal(LdSt))
1307 return false;
1308
1309 MachineInstr *BaseDef = getDefIgnoringCopies(Reg: Base, MRI);
1310 if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX)
1311 return false;
1312
1313 if (auto *St = dyn_cast<GStore>(Val: &LdSt)) {
1314 // Would require a copy.
1315 if (Base == St->getValueReg())
1316 return false;
1317
1318 // We're expecting one use of Addr in MI, but it could also be the
1319 // value stored, which isn't actually dominated by the instruction.
1320 if (St->getValueReg() == Addr)
1321 return false;
1322 }
1323
1324 // Avoid increasing cross-block register pressure.
1325 for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr))
1326 if (AddrUse.getParent() != LdSt.getParent())
1327 return false;
1328
1329 // FIXME: check whether all uses of the base pointer are constant PtrAdds.
1330 // That might allow us to end base's liveness here by adjusting the constant.
1331 bool RealUse = false;
1332 for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) {
1333 if (!dominates(DefMI: LdSt, UseMI: AddrUse))
1334 return false; // All use must be dominated by the load/store.
1335
1336 // If Ptr may be folded in addressing mode of other use, then it's
1337 // not profitable to do this transformation.
1338 if (auto *UseLdSt = dyn_cast<GLoadStore>(Val: &AddrUse)) {
1339 if (!canFoldInAddressingMode(MI: UseLdSt, TLI, MRI))
1340 RealUse = true;
1341 } else {
1342 RealUse = true;
1343 }
1344 }
1345 return RealUse;
1346}
1347
1348bool CombinerHelper::matchCombineExtractedVectorLoad(
1349 MachineInstr &MI, BuildFnTy &MatchInfo) const {
1350 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
1351
1352 // Check if there is a load that defines the vector being extracted from.
1353 auto *LoadMI = getOpcodeDef<GLoad>(Reg: MI.getOperand(i: 1).getReg(), MRI);
1354 if (!LoadMI)
1355 return false;
1356
1357 Register Vector = MI.getOperand(i: 1).getReg();
1358 LLT VecEltTy = MRI.getType(Reg: Vector).getElementType();
1359
1360 assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy);
1361
1362 // Checking whether we should reduce the load width.
1363 if (!MRI.hasOneNonDBGUse(RegNo: Vector))
1364 return false;
1365
1366 // Check if the defining load is simple.
1367 if (!LoadMI->isSimple())
1368 return false;
1369
1370 // If the vector element type is not a multiple of a byte then we are unable
1371 // to correctly compute an address to load only the extracted element as a
1372 // scalar.
1373 if (!VecEltTy.isByteSized())
1374 return false;
1375
1376 // Check for load fold barriers between the extraction and the load.
1377 if (MI.getParent() != LoadMI->getParent())
1378 return false;
1379 const unsigned MaxIter = 20;
1380 unsigned Iter = 0;
1381 for (auto II = LoadMI->getIterator(), IE = MI.getIterator(); II != IE; ++II) {
1382 if (II->isLoadFoldBarrier())
1383 return false;
1384 if (Iter++ == MaxIter)
1385 return false;
1386 }
1387
1388 // Check if the new load that we are going to create is legal
1389 // if we are in the post-legalization phase.
1390 MachineMemOperand MMO = LoadMI->getMMO();
1391 Align Alignment = MMO.getAlign();
1392 MachinePointerInfo PtrInfo;
1393 uint64_t Offset;
1394
1395 // Finding the appropriate PtrInfo if offset is a known constant.
1396 // This is required to create the memory operand for the narrowed load.
1397 // This machine memory operand object helps us infer about legality
1398 // before we proceed to combine the instruction.
1399 if (auto CVal = getIConstantVRegVal(VReg: Vector, MRI)) {
1400 int Elt = CVal->getZExtValue();
1401 // FIXME: should be (ABI size)*Elt.
1402 Offset = VecEltTy.getSizeInBits() * Elt / 8;
1403 PtrInfo = MMO.getPointerInfo().getWithOffset(O: Offset);
1404 } else {
1405 // Discard the pointer info except the address space because the memory
1406 // operand can't represent this new access since the offset is variable.
1407 Offset = VecEltTy.getSizeInBits() / 8;
1408 PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace());
1409 }
1410
1411 Alignment = commonAlignment(A: Alignment, Offset);
1412
1413 Register VecPtr = LoadMI->getPointerReg();
1414 LLT PtrTy = MRI.getType(Reg: VecPtr);
1415
1416 MachineFunction &MF = *MI.getMF();
1417 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: VecEltTy);
1418
1419 LegalityQuery::MemDesc MMDesc(*NewMMO);
1420
1421 if (!isLegalOrBeforeLegalizer(
1422 Query: {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}}))
1423 return false;
1424
1425 // Load must be allowed and fast on the target.
1426 LLVMContext &C = MF.getFunction().getContext();
1427 auto &DL = MF.getDataLayout();
1428 unsigned Fast = 0;
1429 if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty: VecEltTy, MMO: *NewMMO,
1430 Fast: &Fast) ||
1431 !Fast)
1432 return false;
1433
1434 Register Result = MI.getOperand(i: 0).getReg();
1435 Register Index = MI.getOperand(i: 2).getReg();
1436
1437 MatchInfo = [=](MachineIRBuilder &B) {
1438 GISelObserverWrapper DummyObserver;
1439 LegalizerHelper Helper(B.getMF(), DummyObserver, B);
1440 //// Get pointer to the vector element.
1441 Register finalPtr = Helper.getVectorElementPointer(
1442 VecPtr: LoadMI->getPointerReg(), VecTy: MRI.getType(Reg: LoadMI->getOperand(i: 0).getReg()),
1443 Index);
1444 // New G_LOAD instruction.
1445 B.buildLoad(Res: Result, Addr: finalPtr, PtrInfo, Alignment);
1446 // Remove original GLOAD instruction.
1447 LoadMI->eraseFromParent();
1448 };
1449
1450 return true;
1451}
1452
1453bool CombinerHelper::matchCombineIndexedLoadStore(
1454 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const {
1455 auto &LdSt = cast<GLoadStore>(Val&: MI);
1456
1457 if (LdSt.isAtomic())
1458 return false;
1459
1460 MatchInfo.IsPre = findPreIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base,
1461 Offset&: MatchInfo.Offset);
1462 if (!MatchInfo.IsPre &&
1463 !findPostIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base,
1464 Offset&: MatchInfo.Offset, RematOffset&: MatchInfo.RematOffset))
1465 return false;
1466
1467 return true;
1468}
1469
1470void CombinerHelper::applyCombineIndexedLoadStore(
1471 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const {
1472 MachineInstr &AddrDef = *MRI.getUniqueVRegDef(Reg: MatchInfo.Addr);
1473 unsigned Opcode = MI.getOpcode();
1474 bool IsStore = Opcode == TargetOpcode::G_STORE;
1475 unsigned NewOpcode = getIndexedOpc(LdStOpc: Opcode);
1476
1477 // If the offset constant didn't happen to dominate the load/store, we can
1478 // just clone it as needed.
1479 if (MatchInfo.RematOffset) {
1480 auto *OldCst = MRI.getVRegDef(Reg: MatchInfo.Offset);
1481 auto NewCst = Builder.buildConstant(Res: MRI.getType(Reg: MatchInfo.Offset),
1482 Val: *OldCst->getOperand(i: 1).getCImm());
1483 MatchInfo.Offset = NewCst.getReg(Idx: 0);
1484 }
1485
1486 auto MIB = Builder.buildInstr(Opcode: NewOpcode);
1487 if (IsStore) {
1488 MIB.addDef(RegNo: MatchInfo.Addr);
1489 MIB.addUse(RegNo: MI.getOperand(i: 0).getReg());
1490 } else {
1491 MIB.addDef(RegNo: MI.getOperand(i: 0).getReg());
1492 MIB.addDef(RegNo: MatchInfo.Addr);
1493 }
1494
1495 MIB.addUse(RegNo: MatchInfo.Base);
1496 MIB.addUse(RegNo: MatchInfo.Offset);
1497 MIB.addImm(Val: MatchInfo.IsPre);
1498 MIB->cloneMemRefs(MF&: *MI.getMF(), MI);
1499 MI.eraseFromParent();
1500 AddrDef.eraseFromParent();
1501
1502 LLVM_DEBUG(dbgs() << " Combinined to indexed operation");
1503}
1504
1505bool CombinerHelper::matchCombineDivRem(MachineInstr &MI,
1506 MachineInstr *&OtherMI) const {
1507 unsigned Opcode = MI.getOpcode();
1508 bool IsDiv, IsSigned;
1509
1510 switch (Opcode) {
1511 default:
1512 llvm_unreachable("Unexpected opcode!");
1513 case TargetOpcode::G_SDIV:
1514 case TargetOpcode::G_UDIV: {
1515 IsDiv = true;
1516 IsSigned = Opcode == TargetOpcode::G_SDIV;
1517 break;
1518 }
1519 case TargetOpcode::G_SREM:
1520 case TargetOpcode::G_UREM: {
1521 IsDiv = false;
1522 IsSigned = Opcode == TargetOpcode::G_SREM;
1523 break;
1524 }
1525 }
1526
1527 Register Src1 = MI.getOperand(i: 1).getReg();
1528 unsigned DivOpcode, RemOpcode, DivremOpcode;
1529 if (IsSigned) {
1530 DivOpcode = TargetOpcode::G_SDIV;
1531 RemOpcode = TargetOpcode::G_SREM;
1532 DivremOpcode = TargetOpcode::G_SDIVREM;
1533 } else {
1534 DivOpcode = TargetOpcode::G_UDIV;
1535 RemOpcode = TargetOpcode::G_UREM;
1536 DivremOpcode = TargetOpcode::G_UDIVREM;
1537 }
1538
1539 if (!isLegalOrBeforeLegalizer(Query: {DivremOpcode, {MRI.getType(Reg: Src1)}}))
1540 return false;
1541
1542 // Combine:
1543 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1544 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1545 // into:
1546 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1547
1548 // Combine:
1549 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1550 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1551 // into:
1552 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1553
1554 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: Src1)) {
1555 if (MI.getParent() == UseMI.getParent() &&
1556 ((IsDiv && UseMI.getOpcode() == RemOpcode) ||
1557 (!IsDiv && UseMI.getOpcode() == DivOpcode)) &&
1558 matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: UseMI.getOperand(i: 2)) &&
1559 matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: UseMI.getOperand(i: 1))) {
1560 OtherMI = &UseMI;
1561 return true;
1562 }
1563 }
1564
1565 return false;
1566}
1567
1568void CombinerHelper::applyCombineDivRem(MachineInstr &MI,
1569 MachineInstr *&OtherMI) const {
1570 unsigned Opcode = MI.getOpcode();
1571 assert(OtherMI && "OtherMI shouldn't be empty.");
1572
1573 Register DestDivReg, DestRemReg;
1574 if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) {
1575 DestDivReg = MI.getOperand(i: 0).getReg();
1576 DestRemReg = OtherMI->getOperand(i: 0).getReg();
1577 } else {
1578 DestDivReg = OtherMI->getOperand(i: 0).getReg();
1579 DestRemReg = MI.getOperand(i: 0).getReg();
1580 }
1581
1582 bool IsSigned =
1583 Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM;
1584
1585 // Check which instruction is first in the block so we don't break def-use
1586 // deps by "moving" the instruction incorrectly. Also keep track of which
1587 // instruction is first so we pick it's operands, avoiding use-before-def
1588 // bugs.
1589 MachineInstr *FirstInst = dominates(DefMI: MI, UseMI: *OtherMI) ? &MI : OtherMI;
1590 Builder.setInstrAndDebugLoc(*FirstInst);
1591
1592 Builder.buildInstr(Opc: IsSigned ? TargetOpcode::G_SDIVREM
1593 : TargetOpcode::G_UDIVREM,
1594 DstOps: {DestDivReg, DestRemReg},
1595 SrcOps: { FirstInst->getOperand(i: 1), FirstInst->getOperand(i: 2) });
1596 MI.eraseFromParent();
1597 OtherMI->eraseFromParent();
1598}
1599
1600bool CombinerHelper::matchOptBrCondByInvertingCond(
1601 MachineInstr &MI, MachineInstr *&BrCond) const {
1602 assert(MI.getOpcode() == TargetOpcode::G_BR);
1603
1604 // Try to match the following:
1605 // bb1:
1606 // G_BRCOND %c1, %bb2
1607 // G_BR %bb3
1608 // bb2:
1609 // ...
1610 // bb3:
1611
1612 // The above pattern does not have a fall through to the successor bb2, always
1613 // resulting in a branch no matter which path is taken. Here we try to find
1614 // and replace that pattern with conditional branch to bb3 and otherwise
1615 // fallthrough to bb2. This is generally better for branch predictors.
1616
1617 MachineBasicBlock *MBB = MI.getParent();
1618 MachineBasicBlock::iterator BrIt(MI);
1619 if (BrIt == MBB->begin())
1620 return false;
1621 assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator");
1622
1623 BrCond = &*std::prev(x: BrIt);
1624 if (BrCond->getOpcode() != TargetOpcode::G_BRCOND)
1625 return false;
1626
1627 // Check that the next block is the conditional branch target. Also make sure
1628 // that it isn't the same as the G_BR's target (otherwise, this will loop.)
1629 MachineBasicBlock *BrCondTarget = BrCond->getOperand(i: 1).getMBB();
1630 return BrCondTarget != MI.getOperand(i: 0).getMBB() &&
1631 MBB->isLayoutSuccessor(MBB: BrCondTarget);
1632}
1633
1634void CombinerHelper::applyOptBrCondByInvertingCond(
1635 MachineInstr &MI, MachineInstr *&BrCond) const {
1636 MachineBasicBlock *BrTarget = MI.getOperand(i: 0).getMBB();
1637 Builder.setInstrAndDebugLoc(*BrCond);
1638 LLT Ty = MRI.getType(Reg: BrCond->getOperand(i: 0).getReg());
1639 // FIXME: Does int/fp matter for this? If so, we might need to restrict
1640 // this to i1 only since we might not know for sure what kind of
1641 // compare generated the condition value.
1642 auto True = Builder.buildConstant(
1643 Res: Ty, Val: getICmpTrueVal(TLI: getTargetLowering(), IsVector: false, IsFP: false));
1644 auto Xor = Builder.buildXor(Dst: Ty, Src0: BrCond->getOperand(i: 0), Src1: True);
1645
1646 auto *FallthroughBB = BrCond->getOperand(i: 1).getMBB();
1647 Observer.changingInstr(MI);
1648 MI.getOperand(i: 0).setMBB(FallthroughBB);
1649 Observer.changedInstr(MI);
1650
1651 // Change the conditional branch to use the inverted condition and
1652 // new target block.
1653 Observer.changingInstr(MI&: *BrCond);
1654 BrCond->getOperand(i: 0).setReg(Xor.getReg(Idx: 0));
1655 BrCond->getOperand(i: 1).setMBB(BrTarget);
1656 Observer.changedInstr(MI&: *BrCond);
1657}
1658
1659bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) const {
1660 MachineIRBuilder HelperBuilder(MI);
1661 GISelObserverWrapper DummyObserver;
1662 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1663 return Helper.lowerMemcpyInline(MI) ==
1664 LegalizerHelper::LegalizeResult::Legalized;
1665}
1666
1667bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI,
1668 unsigned MaxLen) const {
1669 MachineIRBuilder HelperBuilder(MI);
1670 GISelObserverWrapper DummyObserver;
1671 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1672 return Helper.lowerMemCpyFamily(MI, MaxLen) ==
1673 LegalizerHelper::LegalizeResult::Legalized;
1674}
1675
1676static APFloat constantFoldFpUnary(const MachineInstr &MI,
1677 const MachineRegisterInfo &MRI,
1678 const APFloat &Val) {
1679 APFloat Result(Val);
1680 switch (MI.getOpcode()) {
1681 default:
1682 llvm_unreachable("Unexpected opcode!");
1683 case TargetOpcode::G_FNEG: {
1684 Result.changeSign();
1685 return Result;
1686 }
1687 case TargetOpcode::G_FABS: {
1688 Result.clearSign();
1689 return Result;
1690 }
1691 case TargetOpcode::G_FCEIL:
1692 Result.roundToIntegral(RM: APFloat::rmTowardPositive);
1693 return Result;
1694 case TargetOpcode::G_FFLOOR:
1695 Result.roundToIntegral(RM: APFloat::rmTowardNegative);
1696 return Result;
1697 case TargetOpcode::G_INTRINSIC_TRUNC:
1698 Result.roundToIntegral(RM: APFloat::rmTowardZero);
1699 return Result;
1700 case TargetOpcode::G_INTRINSIC_ROUND:
1701 Result.roundToIntegral(RM: APFloat::rmNearestTiesToAway);
1702 return Result;
1703 case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
1704 Result.roundToIntegral(RM: APFloat::rmNearestTiesToEven);
1705 return Result;
1706 case TargetOpcode::G_FRINT:
1707 case TargetOpcode::G_FNEARBYINT:
1708 // Use default rounding mode (round to nearest, ties to even)
1709 Result.roundToIntegral(RM: APFloat::rmNearestTiesToEven);
1710 return Result;
1711 case TargetOpcode::G_FPEXT:
1712 case TargetOpcode::G_FPTRUNC: {
1713 bool Unused;
1714 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
1715 Result.convert(ToSemantics: getFltSemanticForLLT(Ty: DstTy), RM: APFloat::rmNearestTiesToEven,
1716 losesInfo: &Unused);
1717 return Result;
1718 }
1719 case TargetOpcode::G_FSQRT: {
1720 bool Unused;
1721 Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven,
1722 losesInfo: &Unused);
1723 Result = APFloat(sqrt(x: Result.convertToDouble()));
1724 break;
1725 }
1726 case TargetOpcode::G_FLOG2: {
1727 bool Unused;
1728 Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven,
1729 losesInfo: &Unused);
1730 Result = APFloat(log2(x: Result.convertToDouble()));
1731 break;
1732 }
1733 }
1734 // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise,
1735 // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and
1736 // `G_FLOG2` reach here.
1737 bool Unused;
1738 Result.convert(ToSemantics: Val.getSemantics(), RM: APFloat::rmNearestTiesToEven, losesInfo: &Unused);
1739 return Result;
1740}
1741
1742void CombinerHelper::applyCombineConstantFoldFpUnary(
1743 MachineInstr &MI, const ConstantFP *Cst) const {
1744 APFloat Folded = constantFoldFpUnary(MI, MRI, Val: Cst->getValue());
1745 const ConstantFP *NewCst = ConstantFP::get(Context&: Builder.getContext(), V: Folded);
1746 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: *NewCst);
1747 MI.eraseFromParent();
1748}
1749
1750bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI,
1751 PtrAddChain &MatchInfo) const {
1752 // We're trying to match the following pattern:
1753 // %t1 = G_PTR_ADD %base, G_CONSTANT imm1
1754 // %root = G_PTR_ADD %t1, G_CONSTANT imm2
1755 // -->
1756 // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2)
1757
1758 if (MI.getOpcode() != TargetOpcode::G_PTR_ADD)
1759 return false;
1760
1761 Register Add2 = MI.getOperand(i: 1).getReg();
1762 Register Imm1 = MI.getOperand(i: 2).getReg();
1763 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI);
1764 if (!MaybeImmVal)
1765 return false;
1766
1767 MachineInstr *Add2Def = MRI.getVRegDef(Reg: Add2);
1768 if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD)
1769 return false;
1770
1771 Register Base = Add2Def->getOperand(i: 1).getReg();
1772 Register Imm2 = Add2Def->getOperand(i: 2).getReg();
1773 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI);
1774 if (!MaybeImm2Val)
1775 return false;
1776
1777 // Check if the new combined immediate forms an illegal addressing mode.
1778 // Do not combine if it was legal before but would get illegal.
1779 // To do so, we need to find a load/store user of the pointer to get
1780 // the access type.
1781 Type *AccessTy = nullptr;
1782 auto &MF = *MI.getMF();
1783 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: MI.getOperand(i: 0).getReg())) {
1784 if (auto *LdSt = dyn_cast<GLoadStore>(Val: &UseMI)) {
1785 AccessTy = getTypeForLLT(Ty: MRI.getType(Reg: LdSt->getReg(Idx: 0)),
1786 C&: MF.getFunction().getContext());
1787 break;
1788 }
1789 }
1790 TargetLoweringBase::AddrMode AMNew;
1791 APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value;
1792 AMNew.BaseOffs = CombinedImm.getSExtValue();
1793 if (AccessTy) {
1794 AMNew.HasBaseReg = true;
1795 TargetLoweringBase::AddrMode AMOld;
1796 AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue();
1797 AMOld.HasBaseReg = true;
1798 unsigned AS = MRI.getType(Reg: Add2).getAddressSpace();
1799 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1800 if (TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMOld, Ty: AccessTy, AddrSpace: AS) &&
1801 !TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMNew, Ty: AccessTy, AddrSpace: AS))
1802 return false;
1803 }
1804
1805 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
1806 // inbounds, reaching the same result in one G_PTR_ADD is also inbounds.
1807 // The nusw constraints are satisfied because imm1+imm2 cannot exceed the
1808 // largest signed integer that fits into the index type, which is the maximum
1809 // size of allocated objects according to the IR Language Reference.
1810 unsigned PtrAddFlags = MI.getFlags();
1811 unsigned LHSPtrAddFlags = Add2Def->getFlags();
1812 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
1813 bool IsInBounds =
1814 PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::InBounds;
1815 unsigned Flags = 0;
1816 if (IsNoUWrap)
1817 Flags |= MachineInstr::MIFlag::NoUWrap;
1818 if (IsInBounds) {
1819 Flags |= MachineInstr::MIFlag::InBounds;
1820 Flags |= MachineInstr::MIFlag::NoUSWrap;
1821 }
1822
1823 // Pass the combined immediate to the apply function.
1824 MatchInfo.Imm = AMNew.BaseOffs;
1825 MatchInfo.Base = Base;
1826 MatchInfo.Bank = getRegBank(Reg: Imm2);
1827 MatchInfo.Flags = Flags;
1828 return true;
1829}
1830
1831void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI,
1832 PtrAddChain &MatchInfo) const {
1833 assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD");
1834 MachineIRBuilder MIB(MI);
1835 LLT OffsetTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
1836 auto NewOffset = MIB.buildConstant(Res: OffsetTy, Val: MatchInfo.Imm);
1837 setRegBank(Reg: NewOffset.getReg(Idx: 0), RegBank: MatchInfo.Bank);
1838 Observer.changingInstr(MI);
1839 MI.getOperand(i: 1).setReg(MatchInfo.Base);
1840 MI.getOperand(i: 2).setReg(NewOffset.getReg(Idx: 0));
1841 MI.setFlags(MatchInfo.Flags);
1842 Observer.changedInstr(MI);
1843}
1844
1845bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI,
1846 RegisterImmPair &MatchInfo) const {
1847 // We're trying to match the following pattern with any of
1848 // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions:
1849 // %t1 = SHIFT %base, G_CONSTANT imm1
1850 // %root = SHIFT %t1, G_CONSTANT imm2
1851 // -->
1852 // %root = SHIFT %base, G_CONSTANT (imm1 + imm2)
1853
1854 unsigned Opcode = MI.getOpcode();
1855 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1856 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1857 Opcode == TargetOpcode::G_USHLSAT) &&
1858 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1859
1860 Register Shl2 = MI.getOperand(i: 1).getReg();
1861 Register Imm1 = MI.getOperand(i: 2).getReg();
1862 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI);
1863 if (!MaybeImmVal)
1864 return false;
1865
1866 MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Reg: Shl2);
1867 if (Shl2Def->getOpcode() != Opcode)
1868 return false;
1869
1870 Register Base = Shl2Def->getOperand(i: 1).getReg();
1871 Register Imm2 = Shl2Def->getOperand(i: 2).getReg();
1872 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI);
1873 if (!MaybeImm2Val)
1874 return false;
1875
1876 // Pass the combined immediate to the apply function.
1877 MatchInfo.Imm =
1878 (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue();
1879 MatchInfo.Reg = Base;
1880
1881 // There is no simple replacement for a saturating unsigned left shift that
1882 // exceeds the scalar size.
1883 if (Opcode == TargetOpcode::G_USHLSAT &&
1884 MatchInfo.Imm >= MRI.getType(Reg: Shl2).getScalarSizeInBits())
1885 return false;
1886
1887 return true;
1888}
1889
1890void CombinerHelper::applyShiftImmedChain(MachineInstr &MI,
1891 RegisterImmPair &MatchInfo) const {
1892 unsigned Opcode = MI.getOpcode();
1893 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1894 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1895 Opcode == TargetOpcode::G_USHLSAT) &&
1896 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1897
1898 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
1899 unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits();
1900 auto Imm = MatchInfo.Imm;
1901
1902 if (Imm >= ScalarSizeInBits) {
1903 // Any logical shift that exceeds scalar size will produce zero.
1904 if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) {
1905 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: 0);
1906 MI.eraseFromParent();
1907 return;
1908 }
1909 // Arithmetic shift and saturating signed left shift have no effect beyond
1910 // scalar size.
1911 Imm = ScalarSizeInBits - 1;
1912 }
1913
1914 LLT ImmTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
1915 Register NewImm = Builder.buildConstant(Res: ImmTy, Val: Imm).getReg(Idx: 0);
1916 Observer.changingInstr(MI);
1917 MI.getOperand(i: 1).setReg(MatchInfo.Reg);
1918 MI.getOperand(i: 2).setReg(NewImm);
1919 Observer.changedInstr(MI);
1920}
1921
1922bool CombinerHelper::matchShiftOfShiftedLogic(
1923 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const {
1924 // We're trying to match the following pattern with any of
1925 // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination
1926 // with any of G_AND/G_OR/G_XOR logic instructions.
1927 // %t1 = SHIFT %X, G_CONSTANT C0
1928 // %t2 = LOGIC %t1, %Y
1929 // %root = SHIFT %t2, G_CONSTANT C1
1930 // -->
1931 // %t3 = SHIFT %X, G_CONSTANT (C0+C1)
1932 // %t4 = SHIFT %Y, G_CONSTANT C1
1933 // %root = LOGIC %t3, %t4
1934 unsigned ShiftOpcode = MI.getOpcode();
1935 assert((ShiftOpcode == TargetOpcode::G_SHL ||
1936 ShiftOpcode == TargetOpcode::G_ASHR ||
1937 ShiftOpcode == TargetOpcode::G_LSHR ||
1938 ShiftOpcode == TargetOpcode::G_USHLSAT ||
1939 ShiftOpcode == TargetOpcode::G_SSHLSAT) &&
1940 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1941
1942 // Match a one-use bitwise logic op.
1943 Register LogicDest = MI.getOperand(i: 1).getReg();
1944 if (!MRI.hasOneNonDBGUse(RegNo: LogicDest))
1945 return false;
1946
1947 MachineInstr *LogicMI = MRI.getUniqueVRegDef(Reg: LogicDest);
1948 unsigned LogicOpcode = LogicMI->getOpcode();
1949 if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR &&
1950 LogicOpcode != TargetOpcode::G_XOR)
1951 return false;
1952
1953 // Find a matching one-use shift by constant.
1954 const Register C1 = MI.getOperand(i: 2).getReg();
1955 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: C1, MRI);
1956 if (!MaybeImmVal || MaybeImmVal->Value == 0)
1957 return false;
1958
1959 const uint64_t C1Val = MaybeImmVal->Value.getZExtValue();
1960
1961 auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) {
1962 // Shift should match previous one and should be a one-use.
1963 if (MI->getOpcode() != ShiftOpcode ||
1964 !MRI.hasOneNonDBGUse(RegNo: MI->getOperand(i: 0).getReg()))
1965 return false;
1966
1967 // Must be a constant.
1968 auto MaybeImmVal =
1969 getIConstantVRegValWithLookThrough(VReg: MI->getOperand(i: 2).getReg(), MRI);
1970 if (!MaybeImmVal)
1971 return false;
1972
1973 ShiftVal = MaybeImmVal->Value.getSExtValue();
1974 return true;
1975 };
1976
1977 // Logic ops are commutative, so check each operand for a match.
1978 Register LogicMIReg1 = LogicMI->getOperand(i: 1).getReg();
1979 MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(Reg: LogicMIReg1);
1980 Register LogicMIReg2 = LogicMI->getOperand(i: 2).getReg();
1981 MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(Reg: LogicMIReg2);
1982 uint64_t C0Val;
1983
1984 if (matchFirstShift(LogicMIOp1, C0Val)) {
1985 MatchInfo.LogicNonShiftReg = LogicMIReg2;
1986 MatchInfo.Shift2 = LogicMIOp1;
1987 } else if (matchFirstShift(LogicMIOp2, C0Val)) {
1988 MatchInfo.LogicNonShiftReg = LogicMIReg1;
1989 MatchInfo.Shift2 = LogicMIOp2;
1990 } else
1991 return false;
1992
1993 MatchInfo.ValSum = C0Val + C1Val;
1994
1995 // The fold is not valid if the sum of the shift values exceeds bitwidth.
1996 if (MatchInfo.ValSum >= MRI.getType(Reg: LogicDest).getScalarSizeInBits())
1997 return false;
1998
1999 MatchInfo.Logic = LogicMI;
2000 return true;
2001}
2002
2003void CombinerHelper::applyShiftOfShiftedLogic(
2004 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const {
2005 unsigned Opcode = MI.getOpcode();
2006 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
2007 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT ||
2008 Opcode == TargetOpcode::G_SSHLSAT) &&
2009 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
2010
2011 LLT ShlType = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
2012 LLT DestType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2013
2014 Register Const = Builder.buildConstant(Res: ShlType, Val: MatchInfo.ValSum).getReg(Idx: 0);
2015
2016 Register Shift1Base = MatchInfo.Shift2->getOperand(i: 1).getReg();
2017 Register Shift1 =
2018 Builder.buildInstr(Opc: Opcode, DstOps: {DestType}, SrcOps: {Shift1Base, Const}).getReg(Idx: 0);
2019
2020 // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same
2021 // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when
2022 // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we
2023 // remove old shift1. And it will cause crash later. So erase it earlier to
2024 // avoid the crash.
2025 MatchInfo.Shift2->eraseFromParent();
2026
2027 Register Shift2Const = MI.getOperand(i: 2).getReg();
2028 Register Shift2 = Builder
2029 .buildInstr(Opc: Opcode, DstOps: {DestType},
2030 SrcOps: {MatchInfo.LogicNonShiftReg, Shift2Const})
2031 .getReg(Idx: 0);
2032
2033 Register Dest = MI.getOperand(i: 0).getReg();
2034 Builder.buildInstr(Opc: MatchInfo.Logic->getOpcode(), DstOps: {Dest}, SrcOps: {Shift1, Shift2});
2035
2036 // This was one use so it's safe to remove it.
2037 MatchInfo.Logic->eraseFromParent();
2038
2039 MI.eraseFromParent();
2040}
2041
2042bool CombinerHelper::matchCommuteShift(MachineInstr &MI,
2043 BuildFnTy &MatchInfo) const {
2044 assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL");
2045 // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
2046 // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
2047 auto &Shl = cast<GenericMachineInstr>(Val&: MI);
2048 Register DstReg = Shl.getReg(Idx: 0);
2049 Register SrcReg = Shl.getReg(Idx: 1);
2050 Register ShiftReg = Shl.getReg(Idx: 2);
2051 Register X, C1;
2052
2053 if (!getTargetLowering().isDesirableToCommuteWithShift(MI, IsAfterLegal: !isPreLegalize()))
2054 return false;
2055
2056 if (!mi_match(R: SrcReg, MRI,
2057 P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: C1)),
2058 preds: m_GOr(L: m_Reg(R&: X), R: m_Reg(R&: C1))))))
2059 return false;
2060
2061 APInt C1Val, C2Val;
2062 if (!mi_match(R: C1, MRI, P: m_ICstOrSplat(Cst&: C1Val)) ||
2063 !mi_match(R: ShiftReg, MRI, P: m_ICstOrSplat(Cst&: C2Val)))
2064 return false;
2065
2066 auto *SrcDef = MRI.getVRegDef(Reg: SrcReg);
2067 assert((SrcDef->getOpcode() == TargetOpcode::G_ADD ||
2068 SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op");
2069 LLT SrcTy = MRI.getType(Reg: SrcReg);
2070 MatchInfo = [=](MachineIRBuilder &B) {
2071 auto S1 = B.buildShl(Dst: SrcTy, Src0: X, Src1: ShiftReg);
2072 auto S2 = B.buildShl(Dst: SrcTy, Src0: C1, Src1: ShiftReg);
2073 B.buildInstr(Opc: SrcDef->getOpcode(), DstOps: {DstReg}, SrcOps: {S1, S2});
2074 };
2075 return true;
2076}
2077
2078bool CombinerHelper::matchLshrOfTruncOfLshr(MachineInstr &MI,
2079 LshrOfTruncOfLshr &MatchInfo,
2080 MachineInstr &ShiftMI) const {
2081 assert(MI.getOpcode() == TargetOpcode::G_LSHR && "Expected a G_LSHR");
2082
2083 Register N0 = MI.getOperand(i: 1).getReg();
2084 Register N1 = MI.getOperand(i: 2).getReg();
2085 unsigned OpSizeInBits = MRI.getType(Reg: N0).getScalarSizeInBits();
2086
2087 APInt N1C, N001C;
2088 if (!mi_match(R: N1, MRI, P: m_ICstOrSplat(Cst&: N1C)))
2089 return false;
2090 auto N001 = ShiftMI.getOperand(i: 2).getReg();
2091 if (!mi_match(R: N001, MRI, P: m_ICstOrSplat(Cst&: N001C)))
2092 return false;
2093
2094 if (N001C.getBitWidth() > N1C.getBitWidth())
2095 N1C = N1C.zext(width: N001C.getBitWidth());
2096 else
2097 N001C = N001C.zext(width: N1C.getBitWidth());
2098
2099 Register InnerShift = ShiftMI.getOperand(i: 0).getReg();
2100 LLT InnerShiftTy = MRI.getType(Reg: InnerShift);
2101 uint64_t InnerShiftSize = InnerShiftTy.getScalarSizeInBits();
2102 if ((N1C + N001C).ult(RHS: InnerShiftSize)) {
2103 MatchInfo.Src = ShiftMI.getOperand(i: 1).getReg();
2104 MatchInfo.ShiftAmt = N1C + N001C;
2105 MatchInfo.ShiftAmtTy = MRI.getType(Reg: N001);
2106 MatchInfo.InnerShiftTy = InnerShiftTy;
2107
2108 if ((N001C + OpSizeInBits) == InnerShiftSize)
2109 return true;
2110 if (MRI.hasOneUse(RegNo: N0) && MRI.hasOneUse(RegNo: InnerShift)) {
2111 MatchInfo.Mask = true;
2112 MatchInfo.MaskVal = APInt(N1C.getBitWidth(), OpSizeInBits) - N1C;
2113 return true;
2114 }
2115 }
2116 return false;
2117}
2118
2119void CombinerHelper::applyLshrOfTruncOfLshr(
2120 MachineInstr &MI, LshrOfTruncOfLshr &MatchInfo) const {
2121 assert(MI.getOpcode() == TargetOpcode::G_LSHR && "Expected a G_LSHR");
2122
2123 Register Dst = MI.getOperand(i: 0).getReg();
2124 auto ShiftAmt =
2125 Builder.buildConstant(Res: MatchInfo.ShiftAmtTy, Val: MatchInfo.ShiftAmt);
2126 auto Shift =
2127 Builder.buildLShr(Dst: MatchInfo.InnerShiftTy, Src0: MatchInfo.Src, Src1: ShiftAmt);
2128 if (MatchInfo.Mask == true) {
2129 APInt MaskVal =
2130 APInt::getLowBitsSet(numBits: MatchInfo.InnerShiftTy.getScalarSizeInBits(),
2131 loBitsSet: MatchInfo.MaskVal.getZExtValue());
2132 auto Mask = Builder.buildConstant(Res: MatchInfo.InnerShiftTy, Val: MaskVal);
2133 auto And = Builder.buildAnd(Dst: MatchInfo.InnerShiftTy, Src0: Shift, Src1: Mask);
2134 Builder.buildTrunc(Res: Dst, Op: And);
2135 } else
2136 Builder.buildTrunc(Res: Dst, Op: Shift);
2137 MI.eraseFromParent();
2138}
2139
2140bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI,
2141 unsigned &ShiftVal) const {
2142 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2143 auto MaybeImmVal =
2144 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
2145 if (!MaybeImmVal)
2146 return false;
2147
2148 ShiftVal = MaybeImmVal->Value.exactLogBase2();
2149 return (static_cast<int32_t>(ShiftVal) != -1);
2150}
2151
2152void CombinerHelper::applyCombineMulToShl(MachineInstr &MI,
2153 unsigned &ShiftVal) const {
2154 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2155 MachineIRBuilder MIB(MI);
2156 LLT ShiftTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2157 auto ShiftCst = MIB.buildConstant(Res: ShiftTy, Val: ShiftVal);
2158 Observer.changingInstr(MI);
2159 MI.setDesc(MIB.getTII().get(Opcode: TargetOpcode::G_SHL));
2160 MI.getOperand(i: 2).setReg(ShiftCst.getReg(Idx: 0));
2161 if (ShiftVal == ShiftTy.getScalarSizeInBits() - 1)
2162 MI.clearFlag(Flag: MachineInstr::MIFlag::NoSWrap);
2163 Observer.changedInstr(MI);
2164}
2165
2166bool CombinerHelper::matchCombineSubToAdd(MachineInstr &MI,
2167 BuildFnTy &MatchInfo) const {
2168 GSub &Sub = cast<GSub>(Val&: MI);
2169
2170 LLT Ty = MRI.getType(Reg: Sub.getReg(Idx: 0));
2171
2172 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {Ty}}))
2173 return false;
2174
2175 if (!isConstantLegalOrBeforeLegalizer(Ty))
2176 return false;
2177
2178 APInt Imm = getIConstantFromReg(VReg: Sub.getRHSReg(), MRI);
2179
2180 MatchInfo = [=, &MI](MachineIRBuilder &B) {
2181 auto NegCst = B.buildConstant(Res: Ty, Val: -Imm);
2182 Observer.changingInstr(MI);
2183 MI.setDesc(B.getTII().get(Opcode: TargetOpcode::G_ADD));
2184 MI.getOperand(i: 2).setReg(NegCst.getReg(Idx: 0));
2185 MI.clearFlag(Flag: MachineInstr::MIFlag::NoUWrap);
2186 if (Imm.isMinSignedValue())
2187 MI.clearFlags(flags: MachineInstr::MIFlag::NoSWrap);
2188 Observer.changedInstr(MI);
2189 };
2190 return true;
2191}
2192
2193// shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source
2194bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI,
2195 RegisterImmPair &MatchData) const {
2196 assert(MI.getOpcode() == TargetOpcode::G_SHL && VT);
2197 if (!getTargetLowering().isDesirableToPullExtFromShl(MI))
2198 return false;
2199
2200 Register LHS = MI.getOperand(i: 1).getReg();
2201
2202 Register ExtSrc;
2203 if (!mi_match(R: LHS, MRI, P: m_GAnyExt(Src: m_Reg(R&: ExtSrc))) &&
2204 !mi_match(R: LHS, MRI, P: m_GZExt(Src: m_Reg(R&: ExtSrc))) &&
2205 !mi_match(R: LHS, MRI, P: m_GSExt(Src: m_Reg(R&: ExtSrc))))
2206 return false;
2207
2208 Register RHS = MI.getOperand(i: 2).getReg();
2209 MachineInstr *MIShiftAmt = MRI.getVRegDef(Reg: RHS);
2210 auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(MI&: *MIShiftAmt, MRI);
2211 if (!MaybeShiftAmtVal)
2212 return false;
2213
2214 if (LI) {
2215 LLT SrcTy = MRI.getType(Reg: ExtSrc);
2216
2217 // We only really care about the legality with the shifted value. We can
2218 // pick any type the constant shift amount, so ask the target what to
2219 // use. Otherwise we would have to guess and hope it is reported as legal.
2220 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: SrcTy);
2221 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}}))
2222 return false;
2223 }
2224
2225 int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue();
2226 MatchData.Reg = ExtSrc;
2227 MatchData.Imm = ShiftAmt;
2228
2229 unsigned MinLeadingZeros = VT->getKnownZeroes(R: ExtSrc).countl_one();
2230 unsigned SrcTySize = MRI.getType(Reg: ExtSrc).getScalarSizeInBits();
2231 return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize;
2232}
2233
2234void CombinerHelper::applyCombineShlOfExtend(
2235 MachineInstr &MI, const RegisterImmPair &MatchData) const {
2236 Register ExtSrcReg = MatchData.Reg;
2237 int64_t ShiftAmtVal = MatchData.Imm;
2238
2239 LLT ExtSrcTy = MRI.getType(Reg: ExtSrcReg);
2240 auto ShiftAmt = Builder.buildConstant(Res: ExtSrcTy, Val: ShiftAmtVal);
2241 auto NarrowShift =
2242 Builder.buildShl(Dst: ExtSrcTy, Src0: ExtSrcReg, Src1: ShiftAmt, Flags: MI.getFlags());
2243 Builder.buildZExt(Res: MI.getOperand(i: 0), Op: NarrowShift);
2244 MI.eraseFromParent();
2245}
2246
2247bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI,
2248 Register &MatchInfo) const {
2249 GMerge &Merge = cast<GMerge>(Val&: MI);
2250 SmallVector<Register, 16> MergedValues;
2251 for (unsigned I = 0; I < Merge.getNumSources(); ++I)
2252 MergedValues.emplace_back(Args: Merge.getSourceReg(I));
2253
2254 auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: MergedValues[0], MRI);
2255 if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources())
2256 return false;
2257
2258 for (unsigned I = 0; I < MergedValues.size(); ++I)
2259 if (MergedValues[I] != Unmerge->getReg(Idx: I))
2260 return false;
2261
2262 MatchInfo = Unmerge->getSourceReg();
2263 return true;
2264}
2265
2266static Register peekThroughBitcast(Register Reg,
2267 const MachineRegisterInfo &MRI) {
2268 while (mi_match(R: Reg, MRI, P: m_GBitcast(Src: m_Reg(R&: Reg))))
2269 ;
2270
2271 return Reg;
2272}
2273
2274bool CombinerHelper::matchCombineUnmergeMergeToPlainValues(
2275 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const {
2276 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2277 "Expected an unmerge");
2278 auto &Unmerge = cast<GUnmerge>(Val&: MI);
2279 Register SrcReg = peekThroughBitcast(Reg: Unmerge.getSourceReg(), MRI);
2280
2281 auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(Reg: SrcReg, MRI);
2282 if (!SrcInstr)
2283 return false;
2284
2285 // Check the source type of the merge.
2286 LLT SrcMergeTy = MRI.getType(Reg: SrcInstr->getSourceReg(I: 0));
2287 LLT Dst0Ty = MRI.getType(Reg: Unmerge.getReg(Idx: 0));
2288 bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits();
2289 if (SrcMergeTy != Dst0Ty && !SameSize)
2290 return false;
2291 // They are the same now (modulo a bitcast).
2292 // We can collect all the src registers.
2293 for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx)
2294 Operands.push_back(Elt: SrcInstr->getSourceReg(I: Idx));
2295 return true;
2296}
2297
2298void CombinerHelper::applyCombineUnmergeMergeToPlainValues(
2299 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const {
2300 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2301 "Expected an unmerge");
2302 assert((MI.getNumOperands() - 1 == Operands.size()) &&
2303 "Not enough operands to replace all defs");
2304 unsigned NumElems = MI.getNumOperands() - 1;
2305
2306 LLT SrcTy = MRI.getType(Reg: Operands[0]);
2307 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2308 bool CanReuseInputDirectly = DstTy == SrcTy;
2309 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2310 Register DstReg = MI.getOperand(i: Idx).getReg();
2311 Register SrcReg = Operands[Idx];
2312
2313 // This combine may run after RegBankSelect, so we need to be aware of
2314 // register banks.
2315 const auto &DstCB = MRI.getRegClassOrRegBank(Reg: DstReg);
2316 if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(Reg: SrcReg)) {
2317 SrcReg = Builder.buildCopy(Res: MRI.getType(Reg: SrcReg), Op: SrcReg).getReg(Idx: 0);
2318 MRI.setRegClassOrRegBank(Reg: SrcReg, RCOrRB: DstCB);
2319 }
2320
2321 if (CanReuseInputDirectly)
2322 replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg);
2323 else
2324 Builder.buildCast(Dst: DstReg, Src: SrcReg);
2325 }
2326 MI.eraseFromParent();
2327}
2328
2329bool CombinerHelper::matchCombineUnmergeConstant(
2330 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const {
2331 unsigned SrcIdx = MI.getNumOperands() - 1;
2332 Register SrcReg = MI.getOperand(i: SrcIdx).getReg();
2333 MachineInstr *SrcInstr = MRI.getVRegDef(Reg: SrcReg);
2334 if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT &&
2335 SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT)
2336 return false;
2337 // Break down the big constant in smaller ones.
2338 const MachineOperand &CstVal = SrcInstr->getOperand(i: 1);
2339 APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT
2340 ? CstVal.getCImm()->getValue()
2341 : CstVal.getFPImm()->getValueAPF().bitcastToAPInt();
2342
2343 LLT Dst0Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2344 unsigned ShiftAmt = Dst0Ty.getSizeInBits();
2345 // Unmerge a constant.
2346 for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) {
2347 Csts.emplace_back(Args: Val.trunc(width: ShiftAmt));
2348 Val = Val.lshr(shiftAmt: ShiftAmt);
2349 }
2350
2351 return true;
2352}
2353
2354void CombinerHelper::applyCombineUnmergeConstant(
2355 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const {
2356 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2357 "Expected an unmerge");
2358 assert((MI.getNumOperands() - 1 == Csts.size()) &&
2359 "Not enough operands to replace all defs");
2360 unsigned NumElems = MI.getNumOperands() - 1;
2361
2362 Register SrcReg = MI.getOperand(i: NumElems).getReg();
2363
2364 if (MRI.getType(Reg: SrcReg).isFloat()) {
2365 APFloat Val(getFltSemanticForLLT(Ty: MRI.getType(Reg: MI.getOperand(i: 0).getReg())));
2366
2367 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2368 Register DstReg = MI.getOperand(i: Idx).getReg();
2369 Val.convertFromAPInt(Input: Csts[Idx], IsSigned: false, RM: detail::rmTowardZero);
2370 Builder.buildFConstant(Res: DstReg, Val);
2371 }
2372 } else {
2373 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2374 Register DstReg = MI.getOperand(i: Idx).getReg();
2375 Builder.buildConstant(Res: DstReg, Val: Csts[Idx]);
2376 }
2377 }
2378
2379 MI.eraseFromParent();
2380}
2381
2382bool CombinerHelper::matchCombineUnmergeUndef(
2383 MachineInstr &MI,
2384 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
2385 unsigned SrcIdx = MI.getNumOperands() - 1;
2386 Register SrcReg = MI.getOperand(i: SrcIdx).getReg();
2387 MatchInfo = [&MI](MachineIRBuilder &B) {
2388 unsigned NumElems = MI.getNumOperands() - 1;
2389 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2390 Register DstReg = MI.getOperand(i: Idx).getReg();
2391 B.buildUndef(Res: DstReg);
2392 }
2393 };
2394 return isa<GImplicitDef>(Val: MRI.getVRegDef(Reg: SrcReg));
2395}
2396
2397bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc(
2398 MachineInstr &MI) const {
2399 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2400 "Expected an unmerge");
2401 if (MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector() ||
2402 MRI.getType(Reg: MI.getOperand(i: MI.getNumDefs()).getReg()).isVector())
2403 return false;
2404 // Check that all the lanes are dead except the first one.
2405 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2406 if (!MRI.use_nodbg_empty(RegNo: MI.getOperand(i: Idx).getReg()))
2407 return false;
2408 }
2409 return true;
2410}
2411
2412void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc(
2413 MachineInstr &MI) const {
2414 Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg();
2415 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2416 Builder.buildTrunc(Res: Dst0Reg, Op: SrcReg);
2417 MI.eraseFromParent();
2418}
2419
2420bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) const {
2421 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2422 "Expected an unmerge");
2423 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2424 LLT Dst0Ty = MRI.getType(Reg: Dst0Reg);
2425 // G_ZEXT on vector applies to each lane, so it will
2426 // affect all destinations. Therefore we won't be able
2427 // to simplify the unmerge to just the first definition.
2428 if (Dst0Ty.isVector())
2429 return false;
2430 Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg();
2431 LLT SrcTy = MRI.getType(Reg: SrcReg);
2432 if (SrcTy.isVector())
2433 return false;
2434
2435 Register ZExtSrcReg;
2436 if (!mi_match(R: SrcReg, MRI, P: m_GZExt(Src: m_Reg(R&: ZExtSrcReg))))
2437 return false;
2438
2439 // Finally we can replace the first definition with
2440 // a zext of the source if the definition is big enough to hold
2441 // all of ZExtSrc bits.
2442 LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg);
2443 return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits();
2444}
2445
2446void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) const {
2447 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2448 "Expected an unmerge");
2449
2450 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2451
2452 MachineInstr *ZExtInstr =
2453 MRI.getVRegDef(Reg: MI.getOperand(i: MI.getNumDefs()).getReg());
2454 assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT &&
2455 "Expecting a G_ZEXT");
2456
2457 Register ZExtSrcReg = ZExtInstr->getOperand(i: 1).getReg();
2458 LLT Dst0Ty = MRI.getType(Reg: Dst0Reg);
2459 LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg);
2460
2461 if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) {
2462 Builder.buildZExt(Res: Dst0Reg, Op: ZExtSrcReg);
2463 } else {
2464 assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() &&
2465 "ZExt src doesn't fit in destination");
2466 replaceRegWith(MRI, FromReg: Dst0Reg, ToReg: ZExtSrcReg);
2467 }
2468
2469 Register ZeroReg;
2470 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2471 if (!ZeroReg)
2472 ZeroReg = Builder.buildConstant(Res: Dst0Ty, Val: 0).getReg(Idx: 0);
2473 replaceRegWith(MRI, FromReg: MI.getOperand(i: Idx).getReg(), ToReg: ZeroReg);
2474 }
2475 MI.eraseFromParent();
2476}
2477
2478bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI,
2479 unsigned TargetShiftSize,
2480 unsigned &ShiftVal) const {
2481 assert((MI.getOpcode() == TargetOpcode::G_SHL ||
2482 MI.getOpcode() == TargetOpcode::G_LSHR ||
2483 MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift");
2484
2485 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2486 if (Ty.isVector()) // TODO:
2487 return false;
2488
2489 // Don't narrow further than the requested size.
2490 unsigned Size = Ty.getSizeInBits();
2491 if (Size <= TargetShiftSize)
2492 return false;
2493
2494 auto MaybeImmVal =
2495 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
2496 if (!MaybeImmVal)
2497 return false;
2498
2499 ShiftVal = MaybeImmVal->Value.getSExtValue();
2500 return ShiftVal >= Size / 2 && ShiftVal < Size;
2501}
2502
2503void CombinerHelper::applyCombineShiftToUnmerge(
2504 MachineInstr &MI, const unsigned &ShiftVal) const {
2505 Register DstReg = MI.getOperand(i: 0).getReg();
2506 Register SrcReg = MI.getOperand(i: 1).getReg();
2507 LLT Ty = MRI.getType(Reg: SrcReg);
2508 unsigned Size = Ty.getSizeInBits();
2509 unsigned HalfSize = Size / 2;
2510 assert(ShiftVal >= HalfSize);
2511
2512 LLT HalfTy = LLT::scalar(SizeInBits: HalfSize);
2513
2514 auto Unmerge = Builder.buildUnmerge(Res: HalfTy, Op: SrcReg);
2515 unsigned NarrowShiftAmt = ShiftVal - HalfSize;
2516
2517 if (MI.getOpcode() == TargetOpcode::G_LSHR) {
2518 Register Narrowed = Unmerge.getReg(Idx: 1);
2519
2520 // dst = G_LSHR s64:x, C for C >= 32
2521 // =>
2522 // lo, hi = G_UNMERGE_VALUES x
2523 // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0
2524
2525 if (NarrowShiftAmt != 0) {
2526 Narrowed = Builder.buildLShr(Dst: HalfTy, Src0: Narrowed,
2527 Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0);
2528 }
2529
2530 auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0);
2531 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Narrowed, Zero});
2532 } else if (MI.getOpcode() == TargetOpcode::G_SHL) {
2533 Register Narrowed = Unmerge.getReg(Idx: 0);
2534 // dst = G_SHL s64:x, C for C >= 32
2535 // =>
2536 // lo, hi = G_UNMERGE_VALUES x
2537 // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32)
2538 if (NarrowShiftAmt != 0) {
2539 Narrowed = Builder.buildShl(Dst: HalfTy, Src0: Narrowed,
2540 Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0);
2541 }
2542
2543 auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0);
2544 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Zero, Narrowed});
2545 } else {
2546 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
2547 auto Hi = Builder.buildAShr(
2548 Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1),
2549 Src1: Builder.buildConstant(Res: HalfTy, Val: HalfSize - 1));
2550
2551 if (ShiftVal == HalfSize) {
2552 // (G_ASHR i64:x, 32) ->
2553 // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31)
2554 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Unmerge.getReg(Idx: 1), Hi});
2555 } else if (ShiftVal == Size - 1) {
2556 // Don't need a second shift.
2557 // (G_ASHR i64:x, 63) ->
2558 // %narrowed = (G_ASHR hi_32(x), 31)
2559 // G_MERGE_VALUES %narrowed, %narrowed
2560 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Hi, Hi});
2561 } else {
2562 auto Lo = Builder.buildAShr(
2563 Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1),
2564 Src1: Builder.buildConstant(Res: HalfTy, Val: ShiftVal - HalfSize));
2565
2566 // (G_ASHR i64:x, C) ->, for C >= 32
2567 // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31)
2568 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Lo, Hi});
2569 }
2570 }
2571
2572 MI.eraseFromParent();
2573}
2574
2575bool CombinerHelper::tryCombineShiftToUnmerge(
2576 MachineInstr &MI, unsigned TargetShiftAmount) const {
2577 unsigned ShiftAmt;
2578 if (matchCombineShiftToUnmerge(MI, TargetShiftSize: TargetShiftAmount, ShiftVal&: ShiftAmt)) {
2579 applyCombineShiftToUnmerge(MI, ShiftVal: ShiftAmt);
2580 return true;
2581 }
2582
2583 return false;
2584}
2585
2586bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI,
2587 Register &Reg) const {
2588 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2589 Register DstReg = MI.getOperand(i: 0).getReg();
2590 LLT DstTy = MRI.getType(Reg: DstReg);
2591 Register SrcReg = MI.getOperand(i: 1).getReg();
2592 return mi_match(R: SrcReg, MRI,
2593 P: m_GPtrToInt(Src: m_all_of(preds: m_SpecificType(Ty: DstTy), preds: m_Reg(R&: Reg))));
2594}
2595
2596void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI,
2597 Register &Reg) const {
2598 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2599 Register DstReg = MI.getOperand(i: 0).getReg();
2600 Builder.buildCopy(Res: DstReg, Op: Reg);
2601 MI.eraseFromParent();
2602}
2603
2604void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI,
2605 Register &Reg) const {
2606 assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT");
2607 Register DstReg = MI.getOperand(i: 0).getReg();
2608 Builder.buildZExtOrTrunc(Res: DstReg, Op: Reg);
2609 MI.eraseFromParent();
2610}
2611
2612bool CombinerHelper::matchCombineAddP2IToPtrAdd(
2613 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const {
2614 assert(MI.getOpcode() == TargetOpcode::G_ADD);
2615 Register LHS = MI.getOperand(i: 1).getReg();
2616 Register RHS = MI.getOperand(i: 2).getReg();
2617 LLT IntTy = MRI.getType(Reg: LHS);
2618
2619 // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the
2620 // instruction.
2621 PtrReg.second = false;
2622 for (Register SrcReg : {LHS, RHS}) {
2623 if (mi_match(R: SrcReg, MRI, P: m_GPtrToInt(Src: m_Reg(R&: PtrReg.first)))) {
2624 // Don't handle cases where the integer is implicitly converted to the
2625 // pointer width.
2626 LLT PtrTy = MRI.getType(Reg: PtrReg.first);
2627 if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits())
2628 return true;
2629 }
2630
2631 PtrReg.second = true;
2632 }
2633
2634 return false;
2635}
2636
2637void CombinerHelper::applyCombineAddP2IToPtrAdd(
2638 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const {
2639 Register Dst = MI.getOperand(i: 0).getReg();
2640 Register LHS = MI.getOperand(i: 1).getReg();
2641 Register RHS = MI.getOperand(i: 2).getReg();
2642
2643 const bool DoCommute = PtrReg.second;
2644 if (DoCommute)
2645 std::swap(a&: LHS, b&: RHS);
2646 LHS = PtrReg.first;
2647
2648 LLT PtrTy = MRI.getType(Reg: LHS);
2649
2650 auto PtrAdd = Builder.buildPtrAdd(Res: PtrTy, Op0: LHS, Op1: RHS);
2651 Builder.buildPtrToInt(Dst, Src: PtrAdd);
2652 MI.eraseFromParent();
2653}
2654
2655bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI,
2656 APInt &NewCst) const {
2657 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
2658 Register LHS = PtrAdd.getBaseReg();
2659 Register RHS = PtrAdd.getOffsetReg();
2660 MachineRegisterInfo &MRI = Builder.getMF().getRegInfo();
2661
2662 if (auto RHSCst = getIConstantVRegVal(VReg: RHS, MRI)) {
2663 APInt Cst;
2664 if (mi_match(R: LHS, MRI, P: m_GIntToPtr(Src: m_ICst(Cst)))) {
2665 auto DstTy = MRI.getType(Reg: PtrAdd.getReg(Idx: 0));
2666 // G_INTTOPTR uses zero-extension
2667 NewCst = Cst.zextOrTrunc(width: DstTy.getSizeInBits());
2668 NewCst += RHSCst->sextOrTrunc(width: DstTy.getSizeInBits());
2669 return true;
2670 }
2671 }
2672
2673 return false;
2674}
2675
2676void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI,
2677 APInt &NewCst) const {
2678 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
2679 Register Dst = PtrAdd.getReg(Idx: 0);
2680
2681 Builder.buildConstant(Res: Dst, Val: NewCst);
2682 PtrAdd.eraseFromParent();
2683}
2684
2685bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI,
2686 Register &Reg) const {
2687 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT");
2688 Register DstReg = MI.getOperand(i: 0).getReg();
2689 Register SrcReg = MI.getOperand(i: 1).getReg();
2690 Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI);
2691 if (OriginalSrcReg.isValid())
2692 SrcReg = OriginalSrcReg;
2693 LLT DstTy = MRI.getType(Reg: DstReg);
2694 return mi_match(R: SrcReg, MRI,
2695 P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))) &&
2696 canReplaceReg(DstReg, SrcReg: Reg, MRI);
2697}
2698
2699bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI,
2700 Register &Reg) const {
2701 assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT");
2702 Register DstReg = MI.getOperand(i: 0).getReg();
2703 Register SrcReg = MI.getOperand(i: 1).getReg();
2704 LLT DstTy = MRI.getType(Reg: DstReg);
2705 if (mi_match(R: SrcReg, MRI,
2706 P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))) &&
2707 canReplaceReg(DstReg, SrcReg: Reg, MRI)) {
2708 unsigned DstSize = DstTy.getScalarSizeInBits();
2709 unsigned SrcSize = MRI.getType(Reg: SrcReg).getScalarSizeInBits();
2710 return VT->getKnownBits(R: Reg).countMinLeadingZeros() >= DstSize - SrcSize;
2711 }
2712 return false;
2713}
2714
2715static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) {
2716 const unsigned ShiftSize = ShiftTy.getScalarSizeInBits();
2717 const unsigned TruncSize = TruncTy.getScalarSizeInBits();
2718
2719 // ShiftTy > 32 > TruncTy -> 32
2720 if (ShiftSize > 32 && TruncSize < 32)
2721 return ShiftTy.changeElementSize(NewEltSize: 32);
2722
2723 // TODO: We could also reduce to 16 bits, but that's more target-dependent.
2724 // Some targets like it, some don't, some only like it under certain
2725 // conditions/processor versions, etc.
2726 // A TL hook might be needed for this.
2727
2728 // Don't combine
2729 return ShiftTy;
2730}
2731
2732bool CombinerHelper::matchCombineTruncOfShift(
2733 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const {
2734 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2735 Register DstReg = MI.getOperand(i: 0).getReg();
2736 Register SrcReg = MI.getOperand(i: 1).getReg();
2737
2738 if (!MRI.hasOneNonDBGUse(RegNo: SrcReg))
2739 return false;
2740
2741 LLT SrcTy = MRI.getType(Reg: SrcReg);
2742 LLT DstTy = MRI.getType(Reg: DstReg);
2743
2744 MachineInstr *SrcMI = getDefIgnoringCopies(Reg: SrcReg, MRI);
2745 const auto &TL = getTargetLowering();
2746
2747 LLT NewShiftTy;
2748 switch (SrcMI->getOpcode()) {
2749 default:
2750 return false;
2751 case TargetOpcode::G_SHL: {
2752 NewShiftTy = DstTy;
2753
2754 // Make sure new shift amount is legal.
2755 KnownBits Known = VT->getKnownBits(R: SrcMI->getOperand(i: 2).getReg());
2756 if (Known.getMaxValue().uge(RHS: NewShiftTy.getScalarSizeInBits()))
2757 return false;
2758 break;
2759 }
2760 case TargetOpcode::G_LSHR:
2761 case TargetOpcode::G_ASHR: {
2762 // For right shifts, we conservatively do not do the transform if the TRUNC
2763 // has any STORE users. The reason is that if we change the type of the
2764 // shift, we may break the truncstore combine.
2765 //
2766 // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)).
2767 for (auto &User : MRI.use_instructions(Reg: DstReg))
2768 if (User.getOpcode() == TargetOpcode::G_STORE)
2769 return false;
2770
2771 NewShiftTy = getMidVTForTruncRightShiftCombine(ShiftTy: SrcTy, TruncTy: DstTy);
2772 if (NewShiftTy == SrcTy)
2773 return false;
2774
2775 // Make sure we won't lose information by truncating the high bits.
2776 KnownBits Known = VT->getKnownBits(R: SrcMI->getOperand(i: 2).getReg());
2777 if (Known.getMaxValue().ugt(RHS: NewShiftTy.getScalarSizeInBits() -
2778 DstTy.getScalarSizeInBits()))
2779 return false;
2780 break;
2781 }
2782 }
2783
2784 if (!isLegalOrBeforeLegalizer(
2785 Query: {SrcMI->getOpcode(),
2786 {NewShiftTy, TL.getPreferredShiftAmountTy(ShiftValueTy: NewShiftTy)}}))
2787 return false;
2788
2789 MatchInfo = std::make_pair(x&: SrcMI, y&: NewShiftTy);
2790 return true;
2791}
2792
2793void CombinerHelper::applyCombineTruncOfShift(
2794 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const {
2795 MachineInstr *ShiftMI = MatchInfo.first;
2796 LLT NewShiftTy = MatchInfo.second;
2797
2798 Register Dst = MI.getOperand(i: 0).getReg();
2799 LLT DstTy = MRI.getType(Reg: Dst);
2800
2801 Register ShiftAmt = ShiftMI->getOperand(i: 2).getReg();
2802 Register ShiftSrc = ShiftMI->getOperand(i: 1).getReg();
2803 ShiftSrc = Builder.buildTrunc(Res: NewShiftTy, Op: ShiftSrc).getReg(Idx: 0);
2804
2805 Register NewShift =
2806 Builder
2807 .buildInstr(Opc: ShiftMI->getOpcode(), DstOps: {NewShiftTy}, SrcOps: {ShiftSrc, ShiftAmt})
2808 .getReg(Idx: 0);
2809
2810 if (NewShiftTy == DstTy)
2811 replaceRegWith(MRI, FromReg: Dst, ToReg: NewShift);
2812 else
2813 Builder.buildTrunc(Res: Dst, Op: NewShift);
2814
2815 eraseInst(MI);
2816}
2817
2818bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) const {
2819 return any_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) {
2820 return MO.isReg() &&
2821 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
2822 });
2823}
2824
2825bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) const {
2826 return all_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) {
2827 return !MO.isReg() ||
2828 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
2829 });
2830}
2831
2832bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) const {
2833 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
2834 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
2835 return all_of(Range&: Mask, P: [](int Elt) { return Elt < 0; });
2836}
2837
2838bool CombinerHelper::matchUndefStore(MachineInstr &MI) const {
2839 assert(MI.getOpcode() == TargetOpcode::G_STORE);
2840 return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 0).getReg(),
2841 MRI);
2842}
2843
2844bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) const {
2845 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2846 return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 1).getReg(),
2847 MRI);
2848}
2849
2850bool CombinerHelper::matchInsertExtractVecEltOutOfBounds(
2851 MachineInstr &MI) const {
2852 assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT ||
2853 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) &&
2854 "Expected an insert/extract element op");
2855 LLT VecTy = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
2856 if (VecTy.isScalableVector())
2857 return false;
2858
2859 unsigned IdxIdx =
2860 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3;
2861 auto Idx = getIConstantVRegVal(VReg: MI.getOperand(i: IdxIdx).getReg(), MRI);
2862 if (!Idx)
2863 return false;
2864 return Idx->getZExtValue() >= VecTy.getNumElements();
2865}
2866
2867bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI,
2868 unsigned &OpIdx) const {
2869 GSelect &SelMI = cast<GSelect>(Val&: MI);
2870 auto Cst =
2871 isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: SelMI.getCondReg()), MRI);
2872 if (!Cst)
2873 return false;
2874 OpIdx = Cst->isZero() ? 3 : 2;
2875 return true;
2876}
2877
2878void CombinerHelper::eraseInst(MachineInstr &MI) const { MI.eraseFromParent(); }
2879
2880bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1,
2881 const MachineOperand &MOP2) const {
2882 if (!MOP1.isReg() || !MOP2.isReg())
2883 return false;
2884 auto InstAndDef1 = getDefSrcRegIgnoringCopies(Reg: MOP1.getReg(), MRI);
2885 if (!InstAndDef1)
2886 return false;
2887 auto InstAndDef2 = getDefSrcRegIgnoringCopies(Reg: MOP2.getReg(), MRI);
2888 if (!InstAndDef2)
2889 return false;
2890 MachineInstr *I1 = InstAndDef1->MI;
2891 MachineInstr *I2 = InstAndDef2->MI;
2892
2893 // Handle a case like this:
2894 //
2895 // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>)
2896 //
2897 // Even though %0 and %1 are produced by the same instruction they are not
2898 // the same values.
2899 if (I1 == I2)
2900 return MOP1.getReg() == MOP2.getReg();
2901
2902 // If we have an instruction which loads or stores, we can't guarantee that
2903 // it is identical.
2904 //
2905 // For example, we may have
2906 //
2907 // %x1 = G_LOAD %addr (load N from @somewhere)
2908 // ...
2909 // call @foo
2910 // ...
2911 // %x2 = G_LOAD %addr (load N from @somewhere)
2912 // ...
2913 // %or = G_OR %x1, %x2
2914 //
2915 // It's possible that @foo will modify whatever lives at the address we're
2916 // loading from. To be safe, let's just assume that all loads and stores
2917 // are different (unless we have something which is guaranteed to not
2918 // change.)
2919 if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad())
2920 return false;
2921
2922 // If both instructions are loads or stores, they are equal only if both
2923 // are dereferenceable invariant loads with the same number of bits.
2924 if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) {
2925 GLoadStore *LS1 = dyn_cast<GLoadStore>(Val: I1);
2926 GLoadStore *LS2 = dyn_cast<GLoadStore>(Val: I2);
2927 if (!LS1 || !LS2)
2928 return false;
2929
2930 if (!I2->isDereferenceableInvariantLoad() ||
2931 (LS1->getMemSizeInBits() != LS2->getMemSizeInBits()))
2932 return false;
2933 }
2934
2935 // Check for physical registers on the instructions first to avoid cases
2936 // like this:
2937 //
2938 // %a = COPY $physreg
2939 // ...
2940 // SOMETHING implicit-def $physreg
2941 // ...
2942 // %b = COPY $physreg
2943 //
2944 // These copies are not equivalent.
2945 if (any_of(Range: I1->uses(), P: [](const MachineOperand &MO) {
2946 return MO.isReg() && MO.getReg().isPhysical();
2947 })) {
2948 // Check if we have a case like this:
2949 //
2950 // %a = COPY $physreg
2951 // %b = COPY %a
2952 //
2953 // In this case, I1 and I2 will both be equal to %a = COPY $physreg.
2954 // From that, we know that they must have the same value, since they must
2955 // have come from the same COPY.
2956 return I1->isIdenticalTo(Other: *I2);
2957 }
2958
2959 // We don't have any physical registers, so we don't necessarily need the
2960 // same vreg defs.
2961 //
2962 // On the off-chance that there's some target instruction feeding into the
2963 // instruction, let's use produceSameValue instead of isIdenticalTo.
2964 if (Builder.getTII().produceSameValue(MI0: *I1, MI1: *I2, MRI: &MRI)) {
2965 // Handle instructions with multiple defs that produce same values. Values
2966 // are same for operands with same index.
2967 // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2968 // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2969 // I1 and I2 are different instructions but produce same values,
2970 // %1 and %6 are same, %1 and %7 are not the same value.
2971 return I1->findRegisterDefOperandIdx(Reg: InstAndDef1->Reg, /*TRI=*/nullptr) ==
2972 I2->findRegisterDefOperandIdx(Reg: InstAndDef2->Reg, /*TRI=*/nullptr);
2973 }
2974 return false;
2975}
2976
2977bool CombinerHelper::matchConstantOp(const MachineOperand &MOP,
2978 int64_t C) const {
2979 if (!MOP.isReg())
2980 return false;
2981 auto *MI = MRI.getVRegDef(Reg: MOP.getReg());
2982 auto MaybeCst = isConstantOrConstantSplatVector(MI&: *MI, MRI);
2983 return MaybeCst && MaybeCst->getBitWidth() <= 64 &&
2984 MaybeCst->getSExtValue() == C;
2985}
2986
2987bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP,
2988 double C) const {
2989 if (!MOP.isReg())
2990 return false;
2991 std::optional<FPValueAndVReg> MaybeCst;
2992 if (!mi_match(R: MOP.getReg(), MRI, P: m_GFCstOrSplat(FPValReg&: MaybeCst)))
2993 return false;
2994
2995 return MaybeCst->Value.isExactlyValue(V: C);
2996}
2997
2998void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI,
2999 unsigned OpIdx) const {
3000 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
3001 Register OldReg = MI.getOperand(i: 0).getReg();
3002 Register Replacement = MI.getOperand(i: OpIdx).getReg();
3003 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
3004 replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement);
3005 MI.eraseFromParent();
3006}
3007
3008void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI,
3009 Register Replacement) const {
3010 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
3011 Register OldReg = MI.getOperand(i: 0).getReg();
3012 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
3013 replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement);
3014 MI.eraseFromParent();
3015}
3016
3017bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI,
3018 unsigned ConstIdx) const {
3019 Register ConstReg = MI.getOperand(i: ConstIdx).getReg();
3020 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3021
3022 // Get the shift amount
3023 auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI);
3024 if (!VRegAndVal)
3025 return false;
3026
3027 // Return true of shift amount >= Bitwidth
3028 return (VRegAndVal->Value.uge(RHS: DstTy.getSizeInBits()));
3029}
3030
3031void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) const {
3032 assert((MI.getOpcode() == TargetOpcode::G_FSHL ||
3033 MI.getOpcode() == TargetOpcode::G_FSHR) &&
3034 "This is not a funnel shift operation");
3035
3036 Register ConstReg = MI.getOperand(i: 3).getReg();
3037 LLT ConstTy = MRI.getType(Reg: ConstReg);
3038 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3039
3040 auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI);
3041 assert((VRegAndVal) && "Value is not a constant");
3042
3043 // Calculate the new Shift Amount = Old Shift Amount % BitWidth
3044 APInt NewConst = VRegAndVal->Value.urem(
3045 RHS: APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits()));
3046
3047 auto NewConstInstr = Builder.buildConstant(Res: ConstTy, Val: NewConst.getZExtValue());
3048 Builder.buildInstr(
3049 Opc: MI.getOpcode(), DstOps: {MI.getOperand(i: 0)},
3050 SrcOps: {MI.getOperand(i: 1), MI.getOperand(i: 2), NewConstInstr.getReg(Idx: 0)});
3051
3052 MI.eraseFromParent();
3053}
3054
3055bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) const {
3056 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
3057 // Match (cond ? x : x)
3058 return matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: MI.getOperand(i: 3)) &&
3059 canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 2).getReg(),
3060 MRI);
3061}
3062
3063bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) const {
3064 return matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: MI.getOperand(i: 2)) &&
3065 canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 1).getReg(),
3066 MRI);
3067}
3068
3069bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI,
3070 unsigned OpIdx) const {
3071 MachineOperand &MO = MI.getOperand(i: OpIdx);
3072 return MO.isReg() &&
3073 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
3074}
3075
3076bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI,
3077 unsigned OpIdx) const {
3078 MachineOperand &MO = MI.getOperand(i: OpIdx);
3079 return isKnownToBeAPowerOfTwo(Val: MO.getReg(), MRI, ValueTracking: VT);
3080}
3081
3082void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI,
3083 double C) const {
3084 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3085 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: C);
3086 MI.eraseFromParent();
3087}
3088
3089void CombinerHelper::replaceInstWithConstant(MachineInstr &MI,
3090 int64_t C) const {
3091 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3092 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C);
3093 MI.eraseFromParent();
3094}
3095
3096void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) const {
3097 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3098 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C);
3099 MI.eraseFromParent();
3100}
3101
3102void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI,
3103 ConstantFP *CFP) const {
3104 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3105 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: CFP->getValueAPF());
3106 MI.eraseFromParent();
3107}
3108
3109void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) const {
3110 assert(MI.getNumDefs() == 1 && "Expected only one def?");
3111 Builder.buildUndef(Res: MI.getOperand(i: 0));
3112 MI.eraseFromParent();
3113}
3114
3115bool CombinerHelper::matchSimplifyAddToSub(
3116 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const {
3117 Register LHS = MI.getOperand(i: 1).getReg();
3118 Register RHS = MI.getOperand(i: 2).getReg();
3119 Register &NewLHS = std::get<0>(t&: MatchInfo);
3120 Register &NewRHS = std::get<1>(t&: MatchInfo);
3121
3122 // Helper lambda to check for opportunities for
3123 // ((0-A) + B) -> B - A
3124 // (A + (0-B)) -> A - B
3125 auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) {
3126 if (!mi_match(R: MaybeSub, MRI, P: m_Neg(Src: m_Reg(R&: NewRHS))))
3127 return false;
3128 NewLHS = MaybeNewLHS;
3129 return true;
3130 };
3131
3132 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
3133}
3134
3135bool CombinerHelper::matchCombineInsertVecElts(
3136 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const {
3137 assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT &&
3138 "Invalid opcode");
3139 Register DstReg = MI.getOperand(i: 0).getReg();
3140 LLT DstTy = MRI.getType(Reg: DstReg);
3141 assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?");
3142
3143 if (DstTy.isScalableVector())
3144 return false;
3145
3146 unsigned NumElts = DstTy.getNumElements();
3147 // If this MI is part of a sequence of insert_vec_elts, then
3148 // don't do the combine in the middle of the sequence.
3149 if (MRI.hasOneUse(RegNo: DstReg) && MRI.use_instr_begin(RegNo: DstReg)->getOpcode() ==
3150 TargetOpcode::G_INSERT_VECTOR_ELT)
3151 return false;
3152 MachineInstr *CurrInst = &MI;
3153 MachineInstr *TmpInst;
3154 int64_t IntImm;
3155 Register TmpReg;
3156 MatchInfo.resize(N: NumElts);
3157 while (mi_match(
3158 R: CurrInst->getOperand(i: 0).getReg(), MRI,
3159 P: m_GInsertVecElt(Src0: m_MInstr(MI&: TmpInst), Src1: m_Reg(R&: TmpReg), Src2: m_ICst(Cst&: IntImm)))) {
3160 if (IntImm >= NumElts || IntImm < 0)
3161 return false;
3162 if (!MatchInfo[IntImm])
3163 MatchInfo[IntImm] = TmpReg;
3164 CurrInst = TmpInst;
3165 }
3166 // Variable index.
3167 if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
3168 return false;
3169 if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
3170 for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) {
3171 if (!MatchInfo[I - 1].isValid())
3172 MatchInfo[I - 1] = TmpInst->getOperand(i: I).getReg();
3173 }
3174 return true;
3175 }
3176 // If we didn't end in a G_IMPLICIT_DEF and the source is not fully
3177 // overwritten, bail out.
3178 return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF ||
3179 all_of(Range&: MatchInfo, P: [](Register Reg) { return !!Reg; });
3180}
3181
3182void CombinerHelper::applyCombineInsertVecElts(
3183 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const {
3184 Register UndefReg;
3185 auto GetUndef = [&]() {
3186 if (UndefReg)
3187 return UndefReg;
3188 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3189 UndefReg = Builder.buildUndef(Res: DstTy.getScalarType()).getReg(Idx: 0);
3190 return UndefReg;
3191 };
3192 for (Register &Reg : MatchInfo) {
3193 if (!Reg)
3194 Reg = GetUndef();
3195 }
3196 Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: MatchInfo);
3197 MI.eraseFromParent();
3198}
3199
3200void CombinerHelper::applySimplifyAddToSub(
3201 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const {
3202 Register SubLHS, SubRHS;
3203 std::tie(args&: SubLHS, args&: SubRHS) = MatchInfo;
3204 Builder.buildSub(Dst: MI.getOperand(i: 0).getReg(), Src0: SubLHS, Src1: SubRHS);
3205 MI.eraseFromParent();
3206}
3207
3208bool CombinerHelper::matchBinopWithNegInner(Register MInner, Register Other,
3209 unsigned RootOpc, Register Dst,
3210 LLT Ty,
3211 BuildFnTy &MatchInfo) const {
3212 /// Helper function for matchBinopWithNeg: tries to match one commuted form
3213 /// of `a bitwiseop (~b +/- c)` -> `a bitwiseop ~(b -/+ c)`.
3214 MachineInstr *InnerDef = MRI.getVRegDef(Reg: MInner);
3215 if (!InnerDef)
3216 return false;
3217
3218 unsigned InnerOpc = InnerDef->getOpcode();
3219 if (InnerOpc != TargetOpcode::G_ADD && InnerOpc != TargetOpcode::G_SUB)
3220 return false;
3221
3222 if (!MRI.hasOneNonDBGUse(RegNo: MInner))
3223 return false;
3224
3225 Register InnerLHS = InnerDef->getOperand(i: 1).getReg();
3226 Register InnerRHS = InnerDef->getOperand(i: 2).getReg();
3227 Register NotSrc;
3228 Register B, C;
3229
3230 // Check if either operand is ~b
3231 auto TryMatch = [&](Register MaybeNot, Register Other) {
3232 if (mi_match(R: MaybeNot, MRI, P: m_Not(Src: m_Reg(R&: NotSrc)))) {
3233 if (!MRI.hasOneNonDBGUse(RegNo: MaybeNot))
3234 return false;
3235 B = NotSrc;
3236 C = Other;
3237 return true;
3238 }
3239 return false;
3240 };
3241
3242 if (!TryMatch(InnerLHS, InnerRHS) && !TryMatch(InnerRHS, InnerLHS))
3243 return false;
3244
3245 // Flip add/sub
3246 unsigned FlippedOpc = (InnerOpc == TargetOpcode::G_ADD) ? TargetOpcode::G_SUB
3247 : TargetOpcode::G_ADD;
3248
3249 Register A = Other;
3250 MatchInfo = [=](MachineIRBuilder &Builder) {
3251 auto NewInner = Builder.buildInstr(Opc: FlippedOpc, DstOps: {Ty}, SrcOps: {B, C});
3252 auto NewNot = Builder.buildNot(Dst: Ty, Src0: NewInner);
3253 Builder.buildInstr(Opc: RootOpc, DstOps: {Dst}, SrcOps: {A, NewNot});
3254 };
3255 return true;
3256}
3257
3258bool CombinerHelper::matchBinopWithNeg(MachineInstr &MI,
3259 BuildFnTy &MatchInfo) const {
3260 // Fold `a bitwiseop (~b +/- c)` -> `a bitwiseop ~(b -/+ c)`
3261 // Root MI is one of G_AND, G_OR, G_XOR.
3262 // We also look for commuted forms of operations. Pattern shouldn't apply
3263 // if there are multiple reasons of inner operations.
3264
3265 unsigned RootOpc = MI.getOpcode();
3266 Register Dst = MI.getOperand(i: 0).getReg();
3267 LLT Ty = MRI.getType(Reg: Dst);
3268
3269 Register LHS = MI.getOperand(i: 1).getReg();
3270 Register RHS = MI.getOperand(i: 2).getReg();
3271 // Check the commuted and uncommuted forms of the operation.
3272 return matchBinopWithNegInner(MInner: LHS, Other: RHS, RootOpc, Dst, Ty, MatchInfo) ||
3273 matchBinopWithNegInner(MInner: RHS, Other: LHS, RootOpc, Dst, Ty, MatchInfo);
3274}
3275
3276bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands(
3277 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const {
3278 // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ...
3279 //
3280 // Creates the new hand + logic instruction (but does not insert them.)
3281 //
3282 // On success, MatchInfo is populated with the new instructions. These are
3283 // inserted in applyHoistLogicOpWithSameOpcodeHands.
3284 unsigned LogicOpcode = MI.getOpcode();
3285 assert(LogicOpcode == TargetOpcode::G_AND ||
3286 LogicOpcode == TargetOpcode::G_OR ||
3287 LogicOpcode == TargetOpcode::G_XOR);
3288 MachineIRBuilder MIB(MI);
3289 Register Dst = MI.getOperand(i: 0).getReg();
3290 Register LHSReg = MI.getOperand(i: 1).getReg();
3291 Register RHSReg = MI.getOperand(i: 2).getReg();
3292
3293 // Don't recompute anything.
3294 if (!MRI.hasOneNonDBGUse(RegNo: LHSReg) || !MRI.hasOneNonDBGUse(RegNo: RHSReg))
3295 return false;
3296
3297 // Make sure we have (hand x, ...), (hand y, ...)
3298 MachineInstr *LeftHandInst = getDefIgnoringCopies(Reg: LHSReg, MRI);
3299 MachineInstr *RightHandInst = getDefIgnoringCopies(Reg: RHSReg, MRI);
3300 if (!LeftHandInst || !RightHandInst)
3301 return false;
3302 unsigned HandOpcode = LeftHandInst->getOpcode();
3303 if (HandOpcode != RightHandInst->getOpcode())
3304 return false;
3305 if (LeftHandInst->getNumOperands() < 2 ||
3306 !LeftHandInst->getOperand(i: 1).isReg() ||
3307 RightHandInst->getNumOperands() < 2 ||
3308 !RightHandInst->getOperand(i: 1).isReg())
3309 return false;
3310
3311 // Make sure the types match up, and if we're doing this post-legalization,
3312 // we end up with legal types.
3313 Register X = LeftHandInst->getOperand(i: 1).getReg();
3314 Register Y = RightHandInst->getOperand(i: 1).getReg();
3315 LLT XTy = MRI.getType(Reg: X);
3316 LLT YTy = MRI.getType(Reg: Y);
3317 if (!XTy.isValid() || XTy != YTy)
3318 return false;
3319
3320 // Optional extra source register.
3321 Register ExtraHandOpSrcReg;
3322 switch (HandOpcode) {
3323 default:
3324 return false;
3325 case TargetOpcode::G_ANYEXT:
3326 case TargetOpcode::G_SEXT:
3327 case TargetOpcode::G_ZEXT: {
3328 // Match: logic (ext X), (ext Y) --> ext (logic X, Y)
3329 break;
3330 }
3331 case TargetOpcode::G_TRUNC: {
3332 // Match: logic (trunc X), (trunc Y) -> trunc (logic X, Y)
3333 const MachineFunction *MF = MI.getMF();
3334 LLVMContext &Ctx = MF->getFunction().getContext();
3335
3336 LLT DstTy = MRI.getType(Reg: Dst);
3337 const TargetLowering &TLI = getTargetLowering();
3338
3339 // Be extra careful sinking truncate. If it's free, there's no benefit in
3340 // widening a binop.
3341 if (TLI.isZExtFree(FromTy: DstTy, ToTy: XTy, Ctx) && TLI.isTruncateFree(FromTy: XTy, ToTy: DstTy, Ctx))
3342 return false;
3343 break;
3344 }
3345 case TargetOpcode::G_AND:
3346 case TargetOpcode::G_ASHR:
3347 case TargetOpcode::G_LSHR:
3348 case TargetOpcode::G_SHL: {
3349 // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z
3350 MachineOperand &ZOp = LeftHandInst->getOperand(i: 2);
3351 if (!matchEqualDefs(MOP1: ZOp, MOP2: RightHandInst->getOperand(i: 2)))
3352 return false;
3353 ExtraHandOpSrcReg = ZOp.getReg();
3354 break;
3355 }
3356 }
3357
3358 if (!isLegalOrBeforeLegalizer(Query: {LogicOpcode, {XTy, YTy}}))
3359 return false;
3360
3361 // Record the steps to build the new instructions.
3362 //
3363 // Steps to build (logic x, y)
3364 auto NewLogicDst = MRI.createGenericVirtualRegister(Ty: XTy);
3365 OperandBuildSteps LogicBuildSteps = {
3366 [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: NewLogicDst); },
3367 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: X); },
3368 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: Y); }};
3369 InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps);
3370
3371 // Steps to build hand (logic x, y), ...z
3372 OperandBuildSteps HandBuildSteps = {
3373 [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: Dst); },
3374 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: NewLogicDst); }};
3375 if (ExtraHandOpSrcReg.isValid())
3376 HandBuildSteps.push_back(
3377 Elt: [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: ExtraHandOpSrcReg); });
3378 InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps);
3379
3380 MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps});
3381 return true;
3382}
3383
3384void CombinerHelper::applyBuildInstructionSteps(
3385 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const {
3386 assert(MatchInfo.InstrsToBuild.size() &&
3387 "Expected at least one instr to build?");
3388 for (auto &InstrToBuild : MatchInfo.InstrsToBuild) {
3389 assert(InstrToBuild.Opcode && "Expected a valid opcode?");
3390 assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?");
3391 MachineInstrBuilder Instr = Builder.buildInstr(Opcode: InstrToBuild.Opcode);
3392 for (auto &OperandFn : InstrToBuild.OperandFns)
3393 OperandFn(Instr);
3394 }
3395 MI.eraseFromParent();
3396}
3397
3398bool CombinerHelper::matchAshrShlToSextInreg(
3399 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const {
3400 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3401 int64_t ShlCst, AshrCst;
3402 Register Src;
3403 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
3404 P: m_GAShr(L: m_GShl(L: m_Reg(R&: Src), R: m_ICstOrSplat(Cst&: ShlCst)),
3405 R: m_ICstOrSplat(Cst&: AshrCst))))
3406 return false;
3407 if (ShlCst != AshrCst)
3408 return false;
3409 if (!isLegalOrBeforeLegalizer(
3410 Query: {TargetOpcode::G_SEXT_INREG, {MRI.getType(Reg: Src)}}))
3411 return false;
3412 MatchInfo = std::make_tuple(args&: Src, args&: ShlCst);
3413 return true;
3414}
3415
3416void CombinerHelper::applyAshShlToSextInreg(
3417 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const {
3418 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3419 Register Src;
3420 int64_t ShiftAmt;
3421 std::tie(args&: Src, args&: ShiftAmt) = MatchInfo;
3422 unsigned Size = MRI.getType(Reg: Src).getScalarSizeInBits();
3423 Builder.buildSExtInReg(Res: MI.getOperand(i: 0).getReg(), Op: Src, ImmOp: Size - ShiftAmt);
3424 MI.eraseFromParent();
3425}
3426
3427/// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0
3428bool CombinerHelper::matchOverlappingAnd(
3429 MachineInstr &MI,
3430 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
3431 assert(MI.getOpcode() == TargetOpcode::G_AND);
3432
3433 Register Dst = MI.getOperand(i: 0).getReg();
3434 LLT Ty = MRI.getType(Reg: Dst);
3435
3436 Register R;
3437 int64_t C1;
3438 int64_t C2;
3439 if (!mi_match(
3440 R: Dst, MRI,
3441 P: m_GAnd(L: m_GAnd(L: m_Reg(R), R: m_ICst(Cst&: C1)), R: m_ICst(Cst&: C2))))
3442 return false;
3443
3444 MatchInfo = [=](MachineIRBuilder &B) {
3445 if (C1 & C2) {
3446 B.buildAnd(Dst, Src0: R, Src1: B.buildConstant(Res: Ty, Val: C1 & C2));
3447 return;
3448 }
3449 auto Zero = B.buildConstant(Res: Ty, Val: 0);
3450 replaceRegWith(MRI, FromReg: Dst, ToReg: Zero->getOperand(i: 0).getReg());
3451 };
3452 return true;
3453}
3454
3455bool CombinerHelper::matchRedundantAnd(MachineInstr &MI,
3456 Register &Replacement) const {
3457 // Given
3458 //
3459 // %y:_(sN) = G_SOMETHING
3460 // %x:_(sN) = G_SOMETHING
3461 // %res:_(sN) = G_AND %x, %y
3462 //
3463 // Eliminate the G_AND when it is known that x & y == x or x & y == y.
3464 //
3465 // Patterns like this can appear as a result of legalization. E.g.
3466 //
3467 // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y
3468 // %one:_(s32) = G_CONSTANT i32 1
3469 // %and:_(s32) = G_AND %cmp, %one
3470 //
3471 // In this case, G_ICMP only produces a single bit, so x & 1 == x.
3472 assert(MI.getOpcode() == TargetOpcode::G_AND);
3473 if (!VT)
3474 return false;
3475
3476 Register AndDst = MI.getOperand(i: 0).getReg();
3477 Register LHS = MI.getOperand(i: 1).getReg();
3478 Register RHS = MI.getOperand(i: 2).getReg();
3479
3480 // Check the RHS (maybe a constant) first, and if we have no KnownBits there,
3481 // we can't do anything. If we do, then it depends on whether we have
3482 // KnownBits on the LHS.
3483 KnownBits RHSBits = VT->getKnownBits(R: RHS);
3484 if (RHSBits.isUnknown())
3485 return false;
3486
3487 KnownBits LHSBits = VT->getKnownBits(R: LHS);
3488
3489 // Check that x & Mask == x.
3490 // x & 1 == x, always
3491 // x & 0 == x, only if x is also 0
3492 // Meaning Mask has no effect if every bit is either one in Mask or zero in x.
3493 //
3494 // Check if we can replace AndDst with the LHS of the G_AND
3495 if (canReplaceReg(DstReg: AndDst, SrcReg: LHS, MRI) &&
3496 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3497 Replacement = LHS;
3498 return true;
3499 }
3500
3501 // Check if we can replace AndDst with the RHS of the G_AND
3502 if (canReplaceReg(DstReg: AndDst, SrcReg: RHS, MRI) &&
3503 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3504 Replacement = RHS;
3505 return true;
3506 }
3507
3508 return false;
3509}
3510
3511bool CombinerHelper::matchRedundantOr(MachineInstr &MI,
3512 Register &Replacement) const {
3513 // Given
3514 //
3515 // %y:_(sN) = G_SOMETHING
3516 // %x:_(sN) = G_SOMETHING
3517 // %res:_(sN) = G_OR %x, %y
3518 //
3519 // Eliminate the G_OR when it is known that x | y == x or x | y == y.
3520 assert(MI.getOpcode() == TargetOpcode::G_OR);
3521 if (!VT)
3522 return false;
3523
3524 Register OrDst = MI.getOperand(i: 0).getReg();
3525 Register LHS = MI.getOperand(i: 1).getReg();
3526 Register RHS = MI.getOperand(i: 2).getReg();
3527
3528 KnownBits LHSBits = VT->getKnownBits(R: LHS);
3529 KnownBits RHSBits = VT->getKnownBits(R: RHS);
3530
3531 // Check that x | Mask == x.
3532 // x | 0 == x, always
3533 // x | 1 == x, only if x is also 1
3534 // Meaning Mask has no effect if every bit is either zero in Mask or one in x.
3535 //
3536 // Check if we can replace OrDst with the LHS of the G_OR
3537 if (canReplaceReg(DstReg: OrDst, SrcReg: LHS, MRI) &&
3538 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3539 Replacement = LHS;
3540 return true;
3541 }
3542
3543 // Check if we can replace OrDst with the RHS of the G_OR
3544 if (canReplaceReg(DstReg: OrDst, SrcReg: RHS, MRI) &&
3545 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3546 Replacement = RHS;
3547 return true;
3548 }
3549
3550 return false;
3551}
3552
3553bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) const {
3554 // If the input is already sign extended, just drop the extension.
3555 Register Src = MI.getOperand(i: 1).getReg();
3556 unsigned ExtBits = MI.getOperand(i: 2).getImm();
3557 unsigned TypeSize = MRI.getType(Reg: Src).getScalarSizeInBits();
3558 return VT->computeNumSignBits(R: Src) >= (TypeSize - ExtBits + 1);
3559}
3560
3561static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits,
3562 int64_t Cst, bool IsVector, bool IsFP) {
3563 // For i1, Cst will always be -1 regardless of boolean contents.
3564 return (ScalarSizeBits == 1 && Cst == -1) ||
3565 isConstTrueVal(TLI, Val: Cst, IsVector, IsFP);
3566}
3567
3568// This pattern aims to match the following shape to avoid extra mov
3569// instructions
3570// G_BUILD_VECTOR(
3571// G_UNMERGE_VALUES(src, 0)
3572// G_UNMERGE_VALUES(src, 1)
3573// G_IMPLICIT_DEF
3574// G_IMPLICIT_DEF
3575// )
3576// ->
3577// G_CONCAT_VECTORS(
3578// src,
3579// undef
3580// )
3581bool CombinerHelper::matchCombineBuildUnmerge(MachineInstr &MI,
3582 MachineRegisterInfo &MRI,
3583 Register &UnmergeSrc) const {
3584 auto &BV = cast<GBuildVector>(Val&: MI);
3585
3586 unsigned BuildUseCount = BV.getNumSources();
3587 if (BuildUseCount % 2 != 0)
3588 return false;
3589
3590 unsigned NumUnmerge = BuildUseCount / 2;
3591
3592 auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: BV.getSourceReg(I: 0), MRI);
3593
3594 // Check the first operand is an unmerge and has the correct number of
3595 // operands
3596 if (!Unmerge || Unmerge->getNumDefs() != NumUnmerge)
3597 return false;
3598
3599 UnmergeSrc = Unmerge->getSourceReg();
3600
3601 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3602 LLT UnmergeSrcTy = MRI.getType(Reg: UnmergeSrc);
3603
3604 if (!UnmergeSrcTy.isVector())
3605 return false;
3606
3607 // Ensure we only generate legal instructions post-legalizer
3608 if (!IsPreLegalize &&
3609 !isLegal(Query: {TargetOpcode::G_CONCAT_VECTORS, {DstTy, UnmergeSrcTy}}))
3610 return false;
3611
3612 // Check that all of the operands before the midpoint come from the same
3613 // unmerge and are in the same order as they are used in the build_vector
3614 for (unsigned I = 0; I < NumUnmerge; ++I) {
3615 auto MaybeUnmergeReg = BV.getSourceReg(I);
3616 auto *LoopUnmerge = getOpcodeDef<GUnmerge>(Reg: MaybeUnmergeReg, MRI);
3617
3618 if (!LoopUnmerge || LoopUnmerge != Unmerge)
3619 return false;
3620
3621 if (LoopUnmerge->getOperand(i: I).getReg() != MaybeUnmergeReg)
3622 return false;
3623 }
3624
3625 // Check that all of the unmerged values are used
3626 if (Unmerge->getNumDefs() != NumUnmerge)
3627 return false;
3628
3629 // Check that all of the operands after the mid point are undefs.
3630 for (unsigned I = NumUnmerge; I < BuildUseCount; ++I) {
3631 auto *Undef = getDefIgnoringCopies(Reg: BV.getSourceReg(I), MRI);
3632
3633 if (Undef->getOpcode() != TargetOpcode::G_IMPLICIT_DEF)
3634 return false;
3635 }
3636
3637 return true;
3638}
3639
3640void CombinerHelper::applyCombineBuildUnmerge(MachineInstr &MI,
3641 MachineRegisterInfo &MRI,
3642 MachineIRBuilder &B,
3643 Register &UnmergeSrc) const {
3644 assert(UnmergeSrc && "Expected there to be one matching G_UNMERGE_VALUES");
3645 B.setInstrAndDebugLoc(MI);
3646
3647 Register UndefVec = B.buildUndef(Res: MRI.getType(Reg: UnmergeSrc)).getReg(Idx: 0);
3648 B.buildConcatVectors(Res: MI.getOperand(i: 0), Ops: {UnmergeSrc, UndefVec});
3649
3650 MI.eraseFromParent();
3651}
3652
3653// This combine tries to reduce the number of scalarised G_TRUNC instructions by
3654// using vector truncates instead
3655//
3656// EXAMPLE:
3657// %a(i32), %b(i32) = G_UNMERGE_VALUES %src(<2 x i32>)
3658// %T_a(i16) = G_TRUNC %a(i32)
3659// %T_b(i16) = G_TRUNC %b(i32)
3660// %Undef(i16) = G_IMPLICIT_DEF(i16)
3661// %dst(v4i16) = G_BUILD_VECTORS %T_a(i16), %T_b(i16), %Undef(i16), %Undef(i16)
3662//
3663// ===>
3664// %Undef(<2 x i32>) = G_IMPLICIT_DEF(<2 x i32>)
3665// %Mid(<4 x s32>) = G_CONCAT_VECTORS %src(<2 x i32>), %Undef(<2 x i32>)
3666// %dst(<4 x s16>) = G_TRUNC %Mid(<4 x s32>)
3667//
3668// Only matches sources made up of G_TRUNCs followed by G_IMPLICIT_DEFs
3669bool CombinerHelper::matchUseVectorTruncate(MachineInstr &MI,
3670 Register &MatchInfo) const {
3671 auto BuildMI = cast<GBuildVector>(Val: &MI);
3672 unsigned NumOperands = BuildMI->getNumSources();
3673 LLT DstTy = MRI.getType(Reg: BuildMI->getReg(Idx: 0));
3674
3675 // Check the G_BUILD_VECTOR sources
3676 unsigned I;
3677 MachineInstr *UnmergeMI = nullptr;
3678
3679 // Check all source TRUNCs come from the same UNMERGE instruction
3680 // and that the element order matches (BUILD_VECTOR position I
3681 // corresponds to UNMERGE result I)
3682 for (I = 0; I < NumOperands; ++I) {
3683 auto SrcMI = MRI.getVRegDef(Reg: BuildMI->getSourceReg(I));
3684 auto SrcMIOpc = SrcMI->getOpcode();
3685
3686 // Check if the G_TRUNC instructions all come from the same MI
3687 if (SrcMIOpc == TargetOpcode::G_TRUNC) {
3688 Register TruncSrcReg = SrcMI->getOperand(i: 1).getReg();
3689 if (!UnmergeMI) {
3690 UnmergeMI = MRI.getVRegDef(Reg: TruncSrcReg);
3691 if (UnmergeMI->getOpcode() != TargetOpcode::G_UNMERGE_VALUES)
3692 return false;
3693 } else {
3694 auto UnmergeSrcMI = MRI.getVRegDef(Reg: TruncSrcReg);
3695 if (UnmergeMI != UnmergeSrcMI)
3696 return false;
3697 }
3698 // Verify element ordering: BUILD_VECTOR position I must use
3699 // UNMERGE result I, otherwise the fold would lose element reordering
3700 if (UnmergeMI->getOperand(i: I).getReg() != TruncSrcReg)
3701 return false;
3702 } else {
3703 break;
3704 }
3705 }
3706 if (I < 2)
3707 return false;
3708
3709 // Check the remaining source elements are only G_IMPLICIT_DEF
3710 for (; I < NumOperands; ++I) {
3711 auto SrcMI = MRI.getVRegDef(Reg: BuildMI->getSourceReg(I));
3712 auto SrcMIOpc = SrcMI->getOpcode();
3713
3714 if (SrcMIOpc != TargetOpcode::G_IMPLICIT_DEF)
3715 return false;
3716 }
3717
3718 // Check the size of unmerge source
3719 MatchInfo = cast<GUnmerge>(Val: UnmergeMI)->getSourceReg();
3720 LLT UnmergeSrcTy = MRI.getType(Reg: MatchInfo);
3721 if (!DstTy.getElementCount().isKnownMultipleOf(RHS: UnmergeSrcTy.getNumElements()))
3722 return false;
3723
3724 // Check the unmerge source and destination element types match
3725 LLT UnmergeSrcEltTy = UnmergeSrcTy.getElementType();
3726 Register UnmergeDstReg = UnmergeMI->getOperand(i: 0).getReg();
3727 LLT UnmergeDstEltTy = MRI.getType(Reg: UnmergeDstReg);
3728 if (UnmergeSrcEltTy != UnmergeDstEltTy)
3729 return false;
3730
3731 // Only generate legal instructions post-legalizer
3732 if (!IsPreLegalize) {
3733 LLT MidTy = DstTy.changeElementType(NewEltTy: UnmergeSrcTy.getScalarType());
3734
3735 if (DstTy.getElementCount() != UnmergeSrcTy.getElementCount() &&
3736 !isLegal(Query: {TargetOpcode::G_CONCAT_VECTORS, {MidTy, UnmergeSrcTy}}))
3737 return false;
3738
3739 if (!isLegal(Query: {TargetOpcode::G_TRUNC, {DstTy, MidTy}}))
3740 return false;
3741 }
3742
3743 return true;
3744}
3745
3746void CombinerHelper::applyUseVectorTruncate(MachineInstr &MI,
3747 Register &MatchInfo) const {
3748 Register MidReg;
3749 auto BuildMI = cast<GBuildVector>(Val: &MI);
3750 Register DstReg = BuildMI->getReg(Idx: 0);
3751 LLT DstTy = MRI.getType(Reg: DstReg);
3752 LLT UnmergeSrcTy = MRI.getType(Reg: MatchInfo);
3753 unsigned DstTyNumElt = DstTy.getNumElements();
3754 unsigned UnmergeSrcTyNumElt = UnmergeSrcTy.getNumElements();
3755
3756 // No need to pad vector if only G_TRUNC is needed
3757 if (DstTyNumElt / UnmergeSrcTyNumElt == 1) {
3758 MidReg = MatchInfo;
3759 } else {
3760 Register UndefReg = Builder.buildUndef(Res: UnmergeSrcTy).getReg(Idx: 0);
3761 SmallVector<Register> ConcatRegs = {MatchInfo};
3762 for (unsigned I = 1; I < DstTyNumElt / UnmergeSrcTyNumElt; ++I)
3763 ConcatRegs.push_back(Elt: UndefReg);
3764
3765 auto MidTy = DstTy.changeElementType(NewEltTy: UnmergeSrcTy.getScalarType());
3766 MidReg = Builder.buildConcatVectors(Res: MidTy, Ops: ConcatRegs).getReg(Idx: 0);
3767 }
3768
3769 Builder.buildTrunc(Res: DstReg, Op: MidReg);
3770 MI.eraseFromParent();
3771}
3772
3773bool CombinerHelper::matchNotCmp(
3774 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const {
3775 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3776 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3777 const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering();
3778 Register XorSrc;
3779 Register CstReg;
3780 // We match xor(src, true) here.
3781 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
3782 P: m_GXor(L: m_Reg(R&: XorSrc), R: m_Reg(R&: CstReg))))
3783 return false;
3784
3785 if (!MRI.hasOneNonDBGUse(RegNo: XorSrc))
3786 return false;
3787
3788 // Check that XorSrc is the root of a tree of comparisons combined with ANDs
3789 // and ORs. The suffix of RegsToNegate starting from index I is used a work
3790 // list of tree nodes to visit.
3791 RegsToNegate.push_back(Elt: XorSrc);
3792 // Remember whether the comparisons are all integer or all floating point.
3793 bool IsInt = false;
3794 bool IsFP = false;
3795 for (unsigned I = 0; I < RegsToNegate.size(); ++I) {
3796 Register Reg = RegsToNegate[I];
3797 if (!MRI.hasOneNonDBGUse(RegNo: Reg))
3798 return false;
3799 MachineInstr *Def = MRI.getVRegDef(Reg);
3800 switch (Def->getOpcode()) {
3801 default:
3802 // Don't match if the tree contains anything other than ANDs, ORs and
3803 // comparisons.
3804 return false;
3805 case TargetOpcode::G_ICMP:
3806 if (IsFP)
3807 return false;
3808 IsInt = true;
3809 // When we apply the combine we will invert the predicate.
3810 break;
3811 case TargetOpcode::G_FCMP:
3812 if (IsInt)
3813 return false;
3814 IsFP = true;
3815 // When we apply the combine we will invert the predicate.
3816 break;
3817 case TargetOpcode::G_AND:
3818 case TargetOpcode::G_OR:
3819 // Implement De Morgan's laws:
3820 // ~(x & y) -> ~x | ~y
3821 // ~(x | y) -> ~x & ~y
3822 // When we apply the combine we will change the opcode and recursively
3823 // negate the operands.
3824 RegsToNegate.push_back(Elt: Def->getOperand(i: 1).getReg());
3825 RegsToNegate.push_back(Elt: Def->getOperand(i: 2).getReg());
3826 break;
3827 }
3828 }
3829
3830 // Now we know whether the comparisons are integer or floating point, check
3831 // the constant in the xor.
3832 int64_t Cst;
3833 if (Ty.isVector()) {
3834 MachineInstr *CstDef = MRI.getVRegDef(Reg: CstReg);
3835 auto MaybeCst = getIConstantSplatSExtVal(MI: *CstDef, MRI);
3836 if (!MaybeCst)
3837 return false;
3838 if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getScalarSizeInBits(), Cst: *MaybeCst, IsVector: true, IsFP))
3839 return false;
3840 } else {
3841 if (!mi_match(R: CstReg, MRI, P: m_ICst(Cst)))
3842 return false;
3843 if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getSizeInBits(), Cst, IsVector: false, IsFP))
3844 return false;
3845 }
3846
3847 return true;
3848}
3849
3850void CombinerHelper::applyNotCmp(
3851 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const {
3852 for (Register Reg : RegsToNegate) {
3853 MachineInstr *Def = MRI.getVRegDef(Reg);
3854 Observer.changingInstr(MI&: *Def);
3855 // For each comparison, invert the opcode. For each AND and OR, change the
3856 // opcode.
3857 switch (Def->getOpcode()) {
3858 default:
3859 llvm_unreachable("Unexpected opcode");
3860 case TargetOpcode::G_ICMP:
3861 case TargetOpcode::G_FCMP: {
3862 MachineOperand &PredOp = Def->getOperand(i: 1);
3863 CmpInst::Predicate NewP = CmpInst::getInversePredicate(
3864 pred: (CmpInst::Predicate)PredOp.getPredicate());
3865 PredOp.setPredicate(NewP);
3866 break;
3867 }
3868 case TargetOpcode::G_AND:
3869 Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_OR));
3870 break;
3871 case TargetOpcode::G_OR:
3872 Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND));
3873 break;
3874 }
3875 Observer.changedInstr(MI&: *Def);
3876 }
3877
3878 replaceRegWith(MRI, FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 1).getReg());
3879 MI.eraseFromParent();
3880}
3881
3882bool CombinerHelper::matchXorOfAndWithSameReg(
3883 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const {
3884 // Match (xor (and x, y), y) (or any of its commuted cases)
3885 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3886 Register &X = MatchInfo.first;
3887 Register &Y = MatchInfo.second;
3888 Register AndReg = MI.getOperand(i: 1).getReg();
3889 Register SharedReg = MI.getOperand(i: 2).getReg();
3890
3891 // Find a G_AND on either side of the G_XOR.
3892 // Look for one of
3893 //
3894 // (xor (and x, y), SharedReg)
3895 // (xor SharedReg, (and x, y))
3896 if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) {
3897 std::swap(a&: AndReg, b&: SharedReg);
3898 if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y))))
3899 return false;
3900 }
3901
3902 // Only do this if we'll eliminate the G_AND.
3903 if (!MRI.hasOneNonDBGUse(RegNo: AndReg))
3904 return false;
3905
3906 // We can combine if SharedReg is the same as either the LHS or RHS of the
3907 // G_AND.
3908 if (Y != SharedReg)
3909 std::swap(a&: X, b&: Y);
3910 return Y == SharedReg;
3911}
3912
3913void CombinerHelper::applyXorOfAndWithSameReg(
3914 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const {
3915 // Fold (xor (and x, y), y) -> (and (not x), y)
3916 Register X, Y;
3917 std::tie(args&: X, args&: Y) = MatchInfo;
3918 auto Not = Builder.buildNot(Dst: MRI.getType(Reg: X), Src0: X);
3919 Observer.changingInstr(MI);
3920 MI.setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND));
3921 MI.getOperand(i: 1).setReg(Not->getOperand(i: 0).getReg());
3922 MI.getOperand(i: 2).setReg(Y);
3923 Observer.changedInstr(MI);
3924}
3925
3926bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) const {
3927 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
3928 Register DstReg = PtrAdd.getReg(Idx: 0);
3929 LLT Ty = MRI.getType(Reg: DstReg);
3930 const DataLayout &DL = Builder.getMF().getDataLayout();
3931
3932 if (DL.isNonIntegralAddressSpace(AddrSpace: Ty.getScalarType().getAddressSpace()))
3933 return false;
3934
3935 if (Ty.isPointer()) {
3936 auto ConstVal = getIConstantVRegVal(VReg: PtrAdd.getBaseReg(), MRI);
3937 return ConstVal && *ConstVal == 0;
3938 }
3939
3940 assert(Ty.isVector() && "Expecting a vector type");
3941 const MachineInstr *VecMI = MRI.getVRegDef(Reg: PtrAdd.getBaseReg());
3942 return isBuildVectorAllZeros(MI: *VecMI, MRI);
3943}
3944
3945void CombinerHelper::applyPtrAddZero(MachineInstr &MI) const {
3946 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
3947 Builder.buildIntToPtr(Dst: PtrAdd.getReg(Idx: 0), Src: PtrAdd.getOffsetReg());
3948 PtrAdd.eraseFromParent();
3949}
3950
3951/// The second source operand is known to be a power of 2.
3952void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) const {
3953 Register DstReg = MI.getOperand(i: 0).getReg();
3954 Register Src0 = MI.getOperand(i: 1).getReg();
3955 Register Pow2Src1 = MI.getOperand(i: 2).getReg();
3956 LLT Ty = MRI.getType(Reg: DstReg);
3957
3958 // Fold (urem x, pow2) -> (and x, pow2-1)
3959 auto NegOne = Builder.buildConstant(Res: Ty, Val: -1);
3960 auto Add = Builder.buildAdd(Dst: Ty, Src0: Pow2Src1, Src1: NegOne);
3961 Builder.buildAnd(Dst: DstReg, Src0, Src1: Add);
3962 MI.eraseFromParent();
3963}
3964
3965bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI,
3966 unsigned &SelectOpNo) const {
3967 Register LHS = MI.getOperand(i: 1).getReg();
3968 Register RHS = MI.getOperand(i: 2).getReg();
3969
3970 Register OtherOperandReg = RHS;
3971 SelectOpNo = 1;
3972 MachineInstr *Select = MRI.getVRegDef(Reg: LHS);
3973
3974 // Don't do this unless the old select is going away. We want to eliminate the
3975 // binary operator, not replace a binop with a select.
3976 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3977 !MRI.hasOneNonDBGUse(RegNo: LHS)) {
3978 OtherOperandReg = LHS;
3979 SelectOpNo = 2;
3980 Select = MRI.getVRegDef(Reg: RHS);
3981 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3982 !MRI.hasOneNonDBGUse(RegNo: RHS))
3983 return false;
3984 }
3985
3986 MachineInstr *SelectLHS = MRI.getVRegDef(Reg: Select->getOperand(i: 2).getReg());
3987 MachineInstr *SelectRHS = MRI.getVRegDef(Reg: Select->getOperand(i: 3).getReg());
3988
3989 if (!isConstantOrConstantVector(MI: *SelectLHS, MRI,
3990 /*AllowFP*/ true,
3991 /*AllowOpaqueConstants*/ false))
3992 return false;
3993 if (!isConstantOrConstantVector(MI: *SelectRHS, MRI,
3994 /*AllowFP*/ true,
3995 /*AllowOpaqueConstants*/ false))
3996 return false;
3997
3998 unsigned BinOpcode = MI.getOpcode();
3999
4000 // We know that one of the operands is a select of constants. Now verify that
4001 // the other binary operator operand is either a constant, or we can handle a
4002 // variable.
4003 bool CanFoldNonConst =
4004 (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) &&
4005 (isNullOrNullSplat(MI: *SelectLHS, MRI) ||
4006 isAllOnesOrAllOnesSplat(MI: *SelectLHS, MRI)) &&
4007 (isNullOrNullSplat(MI: *SelectRHS, MRI) ||
4008 isAllOnesOrAllOnesSplat(MI: *SelectRHS, MRI));
4009 if (CanFoldNonConst)
4010 return true;
4011
4012 return isConstantOrConstantVector(MI: *MRI.getVRegDef(Reg: OtherOperandReg), MRI,
4013 /*AllowFP*/ true,
4014 /*AllowOpaqueConstants*/ false);
4015}
4016
4017/// \p SelectOperand is the operand in binary operator \p MI that is the select
4018/// to fold.
4019void CombinerHelper::applyFoldBinOpIntoSelect(
4020 MachineInstr &MI, const unsigned &SelectOperand) const {
4021 Register Dst = MI.getOperand(i: 0).getReg();
4022 Register LHS = MI.getOperand(i: 1).getReg();
4023 Register RHS = MI.getOperand(i: 2).getReg();
4024 MachineInstr *Select = MRI.getVRegDef(Reg: MI.getOperand(i: SelectOperand).getReg());
4025
4026 Register SelectCond = Select->getOperand(i: 1).getReg();
4027 Register SelectTrue = Select->getOperand(i: 2).getReg();
4028 Register SelectFalse = Select->getOperand(i: 3).getReg();
4029
4030 LLT Ty = MRI.getType(Reg: Dst);
4031 unsigned BinOpcode = MI.getOpcode();
4032
4033 Register FoldTrue, FoldFalse;
4034
4035 // We have a select-of-constants followed by a binary operator with a
4036 // constant. Eliminate the binop by pulling the constant math into the select.
4037 // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
4038 if (SelectOperand == 1) {
4039 // TODO: SelectionDAG verifies this actually constant folds before
4040 // committing to the combine.
4041
4042 FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectTrue, RHS}).getReg(Idx: 0);
4043 FoldFalse =
4044 Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectFalse, RHS}).getReg(Idx: 0);
4045 } else {
4046 FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectTrue}).getReg(Idx: 0);
4047 FoldFalse =
4048 Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectFalse}).getReg(Idx: 0);
4049 }
4050
4051 Builder.buildSelect(Res: Dst, Tst: SelectCond, Op0: FoldTrue, Op1: FoldFalse, Flags: MI.getFlags());
4052 MI.eraseFromParent();
4053}
4054
4055std::optional<SmallVector<Register, 8>>
4056CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const {
4057 assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!");
4058 // We want to detect if Root is part of a tree which represents a bunch
4059 // of loads being merged into a larger load. We'll try to recognize patterns
4060 // like, for example:
4061 //
4062 // Reg Reg
4063 // \ /
4064 // OR_1 Reg
4065 // \ /
4066 // OR_2
4067 // \ Reg
4068 // .. /
4069 // Root
4070 //
4071 // Reg Reg Reg Reg
4072 // \ / \ /
4073 // OR_1 OR_2
4074 // \ /
4075 // \ /
4076 // ...
4077 // Root
4078 //
4079 // Each "Reg" may have been produced by a load + some arithmetic. This
4080 // function will save each of them.
4081 SmallVector<Register, 8> RegsToVisit;
4082 SmallVector<const MachineInstr *, 7> Ors = {Root};
4083
4084 // In the "worst" case, we're dealing with a load for each byte. So, there
4085 // are at most #bytes - 1 ORs.
4086 const unsigned MaxIter =
4087 MRI.getType(Reg: Root->getOperand(i: 0).getReg()).getSizeInBytes() - 1;
4088 for (unsigned Iter = 0; Iter < MaxIter; ++Iter) {
4089 if (Ors.empty())
4090 break;
4091 const MachineInstr *Curr = Ors.pop_back_val();
4092 Register OrLHS = Curr->getOperand(i: 1).getReg();
4093 Register OrRHS = Curr->getOperand(i: 2).getReg();
4094
4095 // In the combine, we want to elimate the entire tree.
4096 if (!MRI.hasOneNonDBGUse(RegNo: OrLHS) || !MRI.hasOneNonDBGUse(RegNo: OrRHS))
4097 return std::nullopt;
4098
4099 // If it's a G_OR, save it and continue to walk. If it's not, then it's
4100 // something that may be a load + arithmetic.
4101 if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrLHS, MRI))
4102 Ors.push_back(Elt: Or);
4103 else
4104 RegsToVisit.push_back(Elt: OrLHS);
4105 if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrRHS, MRI))
4106 Ors.push_back(Elt: Or);
4107 else
4108 RegsToVisit.push_back(Elt: OrRHS);
4109 }
4110
4111 // We're going to try and merge each register into a wider power-of-2 type,
4112 // so we ought to have an even number of registers.
4113 if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0)
4114 return std::nullopt;
4115 return RegsToVisit;
4116}
4117
4118/// Helper function for findLoadOffsetsForLoadOrCombine.
4119///
4120/// Check if \p Reg is the result of loading a \p MemSizeInBits wide value,
4121/// and then moving that value into a specific byte offset.
4122///
4123/// e.g. x[i] << 24
4124///
4125/// \returns The load instruction and the byte offset it is moved into.
4126static std::optional<std::pair<GZExtLoad *, int64_t>>
4127matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits,
4128 const MachineRegisterInfo &MRI) {
4129 assert(MRI.hasOneNonDBGUse(Reg) &&
4130 "Expected Reg to only have one non-debug use?");
4131 Register MaybeLoad;
4132 int64_t Shift;
4133 if (!mi_match(R: Reg, MRI,
4134 P: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: MaybeLoad), R: m_ICst(Cst&: Shift))))) {
4135 Shift = 0;
4136 MaybeLoad = Reg;
4137 }
4138
4139 if (Shift % MemSizeInBits != 0)
4140 return std::nullopt;
4141
4142 // TODO: Handle other types of loads.
4143 auto *Load = getOpcodeDef<GZExtLoad>(Reg: MaybeLoad, MRI);
4144 if (!Load)
4145 return std::nullopt;
4146
4147 if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits)
4148 return std::nullopt;
4149
4150 return std::make_pair(x&: Load, y: Shift / MemSizeInBits);
4151}
4152
4153std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>>
4154CombinerHelper::findLoadOffsetsForLoadOrCombine(
4155 SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
4156 const SmallVector<Register, 8> &RegsToVisit,
4157 const unsigned MemSizeInBits) const {
4158
4159 // Each load found for the pattern. There should be one for each RegsToVisit.
4160 SmallSetVector<const MachineInstr *, 8> Loads;
4161
4162 // The lowest index used in any load. (The lowest "i" for each x[i].)
4163 int64_t LowestIdx = INT64_MAX;
4164
4165 // The load which uses the lowest index.
4166 GZExtLoad *LowestIdxLoad = nullptr;
4167
4168 // Keeps track of the load indices we see. We shouldn't see any indices twice.
4169 SmallSet<int64_t, 8> SeenIdx;
4170
4171 // Ensure each load is in the same MBB.
4172 // TODO: Support multiple MachineBasicBlocks.
4173 MachineBasicBlock *MBB = nullptr;
4174 const MachineMemOperand *MMO = nullptr;
4175
4176 // Earliest instruction-order load in the pattern.
4177 GZExtLoad *EarliestLoad = nullptr;
4178
4179 // Latest instruction-order load in the pattern.
4180 GZExtLoad *LatestLoad = nullptr;
4181
4182 // Base pointer which every load should share.
4183 Register BasePtr;
4184
4185 // We want to find a load for each register. Each load should have some
4186 // appropriate bit twiddling arithmetic. During this loop, we will also keep
4187 // track of the load which uses the lowest index. Later, we will check if we
4188 // can use its pointer in the final, combined load.
4189 for (auto Reg : RegsToVisit) {
4190 // Find the load, and find the position that it will end up in (e.g. a
4191 // shifted) value.
4192 auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI);
4193 if (!LoadAndPos)
4194 return std::nullopt;
4195 GZExtLoad *Load;
4196 int64_t DstPos;
4197 std::tie(args&: Load, args&: DstPos) = *LoadAndPos;
4198
4199 // TODO: Handle multiple MachineBasicBlocks. Currently not handled because
4200 // it is difficult to check for stores/calls/etc between loads.
4201 MachineBasicBlock *LoadMBB = Load->getParent();
4202 if (!MBB)
4203 MBB = LoadMBB;
4204 if (LoadMBB != MBB)
4205 return std::nullopt;
4206
4207 // Make sure that the MachineMemOperands of every seen load are compatible.
4208 auto &LoadMMO = Load->getMMO();
4209 if (!MMO)
4210 MMO = &LoadMMO;
4211 if (MMO->getAddrSpace() != LoadMMO.getAddrSpace())
4212 return std::nullopt;
4213
4214 // Find out what the base pointer and index for the load is.
4215 Register LoadPtr;
4216 int64_t Idx;
4217 if (!mi_match(R: Load->getOperand(i: 1).getReg(), MRI,
4218 P: m_GPtrAdd(L: m_Reg(R&: LoadPtr), R: m_ICst(Cst&: Idx)))) {
4219 LoadPtr = Load->getOperand(i: 1).getReg();
4220 Idx = 0;
4221 }
4222
4223 // Don't combine things like a[i], a[i] -> a bigger load.
4224 if (!SeenIdx.insert(V: Idx).second)
4225 return std::nullopt;
4226
4227 // Every load must share the same base pointer; don't combine things like:
4228 //
4229 // a[i], b[i + 1] -> a bigger load.
4230 if (!BasePtr.isValid())
4231 BasePtr = LoadPtr;
4232 if (BasePtr != LoadPtr)
4233 return std::nullopt;
4234
4235 if (Idx < LowestIdx) {
4236 LowestIdx = Idx;
4237 LowestIdxLoad = Load;
4238 }
4239
4240 // Keep track of the byte offset that this load ends up at. If we have seen
4241 // the byte offset, then stop here. We do not want to combine:
4242 //
4243 // a[i] << 16, a[i + k] << 16 -> a bigger load.
4244 if (!MemOffset2Idx.try_emplace(Key: DstPos, Args&: Idx).second)
4245 return std::nullopt;
4246 Loads.insert(X: Load);
4247
4248 // Keep track of the position of the earliest/latest loads in the pattern.
4249 // We will check that there are no load fold barriers between them later
4250 // on.
4251 //
4252 // FIXME: Is there a better way to check for load fold barriers?
4253 if (!EarliestLoad || dominates(DefMI: *Load, UseMI: *EarliestLoad))
4254 EarliestLoad = Load;
4255 if (!LatestLoad || dominates(DefMI: *LatestLoad, UseMI: *Load))
4256 LatestLoad = Load;
4257 }
4258
4259 // We found a load for each register. Let's check if each load satisfies the
4260 // pattern.
4261 assert(Loads.size() == RegsToVisit.size() &&
4262 "Expected to find a load for each register?");
4263 assert(EarliestLoad != LatestLoad && EarliestLoad &&
4264 LatestLoad && "Expected at least two loads?");
4265
4266 // Check if there are any stores, calls, etc. between any of the loads. If
4267 // there are, then we can't safely perform the combine.
4268 //
4269 // MaxIter is chosen based off the (worst case) number of iterations it
4270 // typically takes to succeed in the LLVM test suite plus some padding.
4271 //
4272 // FIXME: Is there a better way to check for load fold barriers?
4273 const unsigned MaxIter = 20;
4274 unsigned Iter = 0;
4275 for (const auto &MI : instructionsWithoutDebug(It: EarliestLoad->getIterator(),
4276 End: LatestLoad->getIterator())) {
4277 if (Loads.count(key: &MI))
4278 continue;
4279 if (MI.isLoadFoldBarrier())
4280 return std::nullopt;
4281 if (Iter++ == MaxIter)
4282 return std::nullopt;
4283 }
4284
4285 return std::make_tuple(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad);
4286}
4287
4288bool CombinerHelper::matchLoadOrCombine(
4289 MachineInstr &MI,
4290 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4291 assert(MI.getOpcode() == TargetOpcode::G_OR);
4292 MachineFunction &MF = *MI.getMF();
4293 // Assuming a little-endian target, transform:
4294 // s8 *a = ...
4295 // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
4296 // =>
4297 // s32 val = *((i32)a)
4298 //
4299 // s8 *a = ...
4300 // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
4301 // =>
4302 // s32 val = BSWAP(*((s32)a))
4303 Register Dst = MI.getOperand(i: 0).getReg();
4304 LLT Ty = MRI.getType(Reg: Dst);
4305 if (Ty.isVector())
4306 return false;
4307
4308 // We need to combine at least two loads into this type. Since the smallest
4309 // possible load is into a byte, we need at least a 16-bit wide type.
4310 const unsigned WideMemSizeInBits = Ty.getSizeInBits();
4311 if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0)
4312 return false;
4313
4314 // Match a collection of non-OR instructions in the pattern.
4315 auto RegsToVisit = findCandidatesForLoadOrCombine(Root: &MI);
4316 if (!RegsToVisit)
4317 return false;
4318
4319 // We have a collection of non-OR instructions. Figure out how wide each of
4320 // the small loads should be based off of the number of potential loads we
4321 // found.
4322 const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size();
4323 if (NarrowMemSizeInBits % 8 != 0)
4324 return false;
4325
4326 // Check if each register feeding into each OR is a load from the same
4327 // base pointer + some arithmetic.
4328 //
4329 // e.g. a[0], a[1] << 8, a[2] << 16, etc.
4330 //
4331 // Also verify that each of these ends up putting a[i] into the same memory
4332 // offset as a load into a wide type would.
4333 SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx;
4334 GZExtLoad *LowestIdxLoad, *LatestLoad;
4335 int64_t LowestIdx;
4336 auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine(
4337 MemOffset2Idx, RegsToVisit: *RegsToVisit, MemSizeInBits: NarrowMemSizeInBits);
4338 if (!MaybeLoadInfo)
4339 return false;
4340 std::tie(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad) = *MaybeLoadInfo;
4341
4342 // We have a bunch of loads being OR'd together. Using the addresses + offsets
4343 // we found before, check if this corresponds to a big or little endian byte
4344 // pattern. If it does, then we can represent it using a load + possibly a
4345 // BSWAP.
4346 bool IsBigEndianTarget = MF.getDataLayout().isBigEndian();
4347 std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx);
4348 if (!IsBigEndian)
4349 return false;
4350 bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian;
4351 if (NeedsBSwap && !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_BSWAP, {Ty}}))
4352 return false;
4353
4354 // Make sure that the load from the lowest index produces offset 0 in the
4355 // final value.
4356 //
4357 // This ensures that we won't combine something like this:
4358 //
4359 // load x[i] -> byte 2
4360 // load x[i+1] -> byte 0 ---> wide_load x[i]
4361 // load x[i+2] -> byte 1
4362 const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits;
4363 const unsigned ZeroByteOffset =
4364 *IsBigEndian
4365 ? bigEndianByteAt(ByteWidth: NumLoadsInTy, I: 0)
4366 : littleEndianByteAt(ByteWidth: NumLoadsInTy, I: 0);
4367 auto ZeroOffsetIdx = MemOffset2Idx.find(Val: ZeroByteOffset);
4368 if (ZeroOffsetIdx == MemOffset2Idx.end() ||
4369 ZeroOffsetIdx->second != LowestIdx)
4370 return false;
4371
4372 // We wil reuse the pointer from the load which ends up at byte offset 0. It
4373 // may not use index 0.
4374 Register Ptr = LowestIdxLoad->getPointerReg();
4375 const MachineMemOperand &MMO = LowestIdxLoad->getMMO();
4376 LegalityQuery::MemDesc MMDesc(MMO);
4377 MMDesc.MemoryTy = Ty;
4378 if (!isLegalOrBeforeLegalizer(
4379 Query: {TargetOpcode::G_LOAD, {Ty, MRI.getType(Reg: Ptr)}, {MMDesc}}))
4380 return false;
4381 auto PtrInfo = MMO.getPointerInfo();
4382 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: WideMemSizeInBits / 8);
4383
4384 // Load must be allowed and fast on the target.
4385 LLVMContext &C = MF.getFunction().getContext();
4386 auto &DL = MF.getDataLayout();
4387 unsigned Fast = 0;
4388 if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty, MMO: *NewMMO, Fast: &Fast) ||
4389 !Fast)
4390 return false;
4391
4392 MatchInfo = [=](MachineIRBuilder &MIB) {
4393 MIB.setInstrAndDebugLoc(*LatestLoad);
4394 Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(VReg: Dst) : Dst;
4395 MIB.buildLoad(Res: LoadDst, Addr: Ptr, MMO&: *NewMMO);
4396 if (NeedsBSwap)
4397 MIB.buildBSwap(Dst, Src0: LoadDst);
4398 };
4399 return true;
4400}
4401
4402bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI,
4403 MachineInstr *&ExtMI) const {
4404 auto &PHI = cast<GPhi>(Val&: MI);
4405 Register DstReg = PHI.getReg(Idx: 0);
4406
4407 // TODO: Extending a vector may be expensive, don't do this until heuristics
4408 // are better.
4409 if (MRI.getType(Reg: DstReg).isVector())
4410 return false;
4411
4412 // Try to match a phi, whose only use is an extend.
4413 if (!MRI.hasOneNonDBGUse(RegNo: DstReg))
4414 return false;
4415 ExtMI = &*MRI.use_instr_nodbg_begin(RegNo: DstReg);
4416 switch (ExtMI->getOpcode()) {
4417 case TargetOpcode::G_ANYEXT:
4418 return true; // G_ANYEXT is usually free.
4419 case TargetOpcode::G_ZEXT:
4420 case TargetOpcode::G_SEXT:
4421 break;
4422 default:
4423 return false;
4424 }
4425
4426 // If the target is likely to fold this extend away, don't propagate.
4427 if (Builder.getTII().isExtendLikelyToBeFolded(ExtMI&: *ExtMI, MRI))
4428 return false;
4429
4430 // We don't want to propagate the extends unless there's a good chance that
4431 // they'll be optimized in some way.
4432 // Collect the unique incoming values.
4433 SmallPtrSet<MachineInstr *, 4> InSrcs;
4434 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
4435 auto *DefMI = getDefIgnoringCopies(Reg: PHI.getIncomingValue(I), MRI);
4436 switch (DefMI->getOpcode()) {
4437 case TargetOpcode::G_LOAD:
4438 case TargetOpcode::G_TRUNC:
4439 case TargetOpcode::G_SEXT:
4440 case TargetOpcode::G_ZEXT:
4441 case TargetOpcode::G_ANYEXT:
4442 case TargetOpcode::G_CONSTANT:
4443 InSrcs.insert(Ptr: DefMI);
4444 // Don't try to propagate if there are too many places to create new
4445 // extends, chances are it'll increase code size.
4446 if (InSrcs.size() > 2)
4447 return false;
4448 break;
4449 default:
4450 return false;
4451 }
4452 }
4453 return true;
4454}
4455
4456void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI,
4457 MachineInstr *&ExtMI) const {
4458 auto &PHI = cast<GPhi>(Val&: MI);
4459 Register DstReg = ExtMI->getOperand(i: 0).getReg();
4460 LLT ExtTy = MRI.getType(Reg: DstReg);
4461
4462 // Propagate the extension into the block of each incoming reg's block.
4463 // Use a SetVector here because PHIs can have duplicate edges, and we want
4464 // deterministic iteration order.
4465 SmallSetVector<MachineInstr *, 8> SrcMIs;
4466 SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap;
4467 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
4468 auto SrcReg = PHI.getIncomingValue(I);
4469 auto *SrcMI = MRI.getVRegDef(Reg: SrcReg);
4470 if (!SrcMIs.insert(X: SrcMI))
4471 continue;
4472
4473 // Build an extend after each src inst.
4474 auto *MBB = SrcMI->getParent();
4475 MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator();
4476 if (InsertPt != MBB->end() && InsertPt->isPHI())
4477 InsertPt = MBB->getFirstNonPHI();
4478
4479 Builder.setInsertPt(MBB&: *SrcMI->getParent(), II: InsertPt);
4480 Builder.setDebugLoc(MI.getDebugLoc());
4481 auto NewExt = Builder.buildExtOrTrunc(ExtOpc: ExtMI->getOpcode(), Res: ExtTy, Op: SrcReg);
4482 OldToNewSrcMap[SrcMI] = NewExt;
4483 }
4484
4485 // Create a new phi with the extended inputs.
4486 Builder.setInstrAndDebugLoc(MI);
4487 auto NewPhi = Builder.buildInstrNoInsert(Opcode: TargetOpcode::G_PHI);
4488 NewPhi.addDef(RegNo: DstReg);
4489 for (const MachineOperand &MO : llvm::drop_begin(RangeOrContainer: MI.operands())) {
4490 if (!MO.isReg()) {
4491 NewPhi.addMBB(MBB: MO.getMBB());
4492 continue;
4493 }
4494 auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(Reg: MO.getReg())];
4495 NewPhi.addUse(RegNo: NewSrc->getOperand(i: 0).getReg());
4496 }
4497 Builder.insertInstr(MIB: NewPhi);
4498 ExtMI->eraseFromParent();
4499}
4500
4501bool CombinerHelper::matchExtractVecEltBuildVec(MachineInstr &MI,
4502 Register &Reg) const {
4503 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
4504 // If we have a constant index, look for a G_BUILD_VECTOR source
4505 // and find the source register that the index maps to.
4506 Register SrcVec = MI.getOperand(i: 1).getReg();
4507 LLT SrcTy = MRI.getType(Reg: SrcVec);
4508 if (SrcTy.isScalableVector())
4509 return false;
4510
4511 auto Cst = getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
4512 if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements())
4513 return false;
4514
4515 unsigned VecIdx = Cst->Value.getZExtValue();
4516
4517 // Check if we have a build_vector or build_vector_trunc with an optional
4518 // trunc in front.
4519 MachineInstr *SrcVecMI = MRI.getVRegDef(Reg: SrcVec);
4520 if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) {
4521 SrcVecMI = MRI.getVRegDef(Reg: SrcVecMI->getOperand(i: 1).getReg());
4522 }
4523
4524 if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR &&
4525 SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC)
4526 return false;
4527
4528 EVT Ty(getMVTForLLT(Ty: SrcTy));
4529 if (!MRI.hasOneNonDBGUse(RegNo: SrcVec) &&
4530 !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty))
4531 return false;
4532
4533 Reg = SrcVecMI->getOperand(i: VecIdx + 1).getReg();
4534 return true;
4535}
4536
4537void CombinerHelper::applyExtractVecEltBuildVec(MachineInstr &MI,
4538 Register &Reg) const {
4539 // Check the type of the register, since it may have come from a
4540 // G_BUILD_VECTOR_TRUNC.
4541 LLT ScalarTy = MRI.getType(Reg);
4542 Register DstReg = MI.getOperand(i: 0).getReg();
4543 LLT DstTy = MRI.getType(Reg: DstReg);
4544
4545 if (ScalarTy != DstTy) {
4546 assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits());
4547 Builder.buildTrunc(Res: DstReg, Op: Reg);
4548 MI.eraseFromParent();
4549 return;
4550 }
4551 replaceSingleDefInstWithReg(MI, Replacement: Reg);
4552}
4553
4554bool CombinerHelper::matchExtractAllEltsFromBuildVector(
4555 MachineInstr &MI,
4556 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const {
4557 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4558 // This combine tries to find build_vector's which have every source element
4559 // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like
4560 // the masked load scalarization is run late in the pipeline. There's already
4561 // a combine for a similar pattern starting from the extract, but that
4562 // doesn't attempt to do it if there are multiple uses of the build_vector,
4563 // which in this case is true. Starting the combine from the build_vector
4564 // feels more natural than trying to find sibling nodes of extracts.
4565 // E.g.
4566 // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4
4567 // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0
4568 // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1
4569 // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2
4570 // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3
4571 // ==>
4572 // replace ext{1,2,3,4} with %s{1,2,3,4}
4573
4574 Register DstReg = MI.getOperand(i: 0).getReg();
4575 LLT DstTy = MRI.getType(Reg: DstReg);
4576 unsigned NumElts = DstTy.getNumElements();
4577
4578 SmallBitVector ExtractedElts(NumElts);
4579 for (MachineInstr &II : MRI.use_nodbg_instructions(Reg: DstReg)) {
4580 if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT)
4581 return false;
4582 auto Cst = getIConstantVRegVal(VReg: II.getOperand(i: 2).getReg(), MRI);
4583 if (!Cst)
4584 return false;
4585 unsigned Idx = Cst->getZExtValue();
4586 if (Idx >= NumElts)
4587 return false; // Out of range.
4588 ExtractedElts.set(Idx);
4589 SrcDstPairs.emplace_back(
4590 Args: std::make_pair(x: MI.getOperand(i: Idx + 1).getReg(), y: &II));
4591 }
4592 // Match if every element was extracted.
4593 return ExtractedElts.all();
4594}
4595
4596void CombinerHelper::applyExtractAllEltsFromBuildVector(
4597 MachineInstr &MI,
4598 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const {
4599 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4600 for (auto &Pair : SrcDstPairs) {
4601 auto *ExtMI = Pair.second;
4602 replaceRegWith(MRI, FromReg: ExtMI->getOperand(i: 0).getReg(), ToReg: Pair.first);
4603 ExtMI->eraseFromParent();
4604 }
4605 MI.eraseFromParent();
4606}
4607
4608void CombinerHelper::applyBuildFn(
4609 MachineInstr &MI,
4610 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4611 applyBuildFnNoErase(MI, MatchInfo);
4612 MI.eraseFromParent();
4613}
4614
4615void CombinerHelper::applyBuildFnNoErase(
4616 MachineInstr &MI,
4617 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4618 MatchInfo(Builder);
4619}
4620
4621bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
4622 bool AllowScalarConstants,
4623 BuildFnTy &MatchInfo) const {
4624 assert(MI.getOpcode() == TargetOpcode::G_OR);
4625
4626 Register Dst = MI.getOperand(i: 0).getReg();
4627 LLT Ty = MRI.getType(Reg: Dst);
4628 unsigned BitWidth = Ty.getScalarSizeInBits();
4629
4630 Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt;
4631 unsigned FshOpc = 0;
4632
4633 // Match (or (shl ...), (lshr ...)).
4634 if (!mi_match(R: Dst, MRI,
4635 // m_GOr() handles the commuted version as well.
4636 P: m_GOr(L: m_GShl(L: m_Reg(R&: ShlSrc), R: m_Reg(R&: ShlAmt)),
4637 R: m_GLShr(L: m_Reg(R&: LShrSrc), R: m_Reg(R&: LShrAmt)))))
4638 return false;
4639
4640 // Given constants C0 and C1 such that C0 + C1 is bit-width:
4641 // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1)
4642 int64_t CstShlAmt = 0, CstLShrAmt;
4643 if (mi_match(R: ShlAmt, MRI, P: m_ICstOrSplat(Cst&: CstShlAmt)) &&
4644 mi_match(R: LShrAmt, MRI, P: m_ICstOrSplat(Cst&: CstLShrAmt)) &&
4645 CstShlAmt + CstLShrAmt == BitWidth) {
4646 FshOpc = TargetOpcode::G_FSHR;
4647 Amt = LShrAmt;
4648 } else if (mi_match(R: LShrAmt, MRI,
4649 P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) &&
4650 ShlAmt == Amt) {
4651 // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt)
4652 FshOpc = TargetOpcode::G_FSHL;
4653 } else if (mi_match(R: ShlAmt, MRI,
4654 P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) &&
4655 LShrAmt == Amt) {
4656 // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt)
4657 FshOpc = TargetOpcode::G_FSHR;
4658 } else {
4659 return false;
4660 }
4661
4662 LLT AmtTy = MRI.getType(Reg: Amt);
4663 if (!isLegalOrBeforeLegalizer(Query: {FshOpc, {Ty, AmtTy}}) &&
4664 (!AllowScalarConstants || CstShlAmt == 0 || !Ty.isScalar()))
4665 return false;
4666
4667 MatchInfo = [=](MachineIRBuilder &B) {
4668 B.buildInstr(Opc: FshOpc, DstOps: {Dst}, SrcOps: {ShlSrc, LShrSrc, Amt});
4669 };
4670 return true;
4671}
4672
4673/// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate.
4674bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) const {
4675 unsigned Opc = MI.getOpcode();
4676 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4677 Register X = MI.getOperand(i: 1).getReg();
4678 Register Y = MI.getOperand(i: 2).getReg();
4679 if (X != Y)
4680 return false;
4681 unsigned RotateOpc =
4682 Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR;
4683 return isLegalOrBeforeLegalizer(Query: {RotateOpc, {MRI.getType(Reg: X), MRI.getType(Reg: Y)}});
4684}
4685
4686void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) const {
4687 unsigned Opc = MI.getOpcode();
4688 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4689 bool IsFSHL = Opc == TargetOpcode::G_FSHL;
4690 Observer.changingInstr(MI);
4691 MI.setDesc(Builder.getTII().get(Opcode: IsFSHL ? TargetOpcode::G_ROTL
4692 : TargetOpcode::G_ROTR));
4693 MI.removeOperand(OpNo: 2);
4694 Observer.changedInstr(MI);
4695}
4696
4697// Fold (rot x, c) -> (rot x, c % BitSize)
4698bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) const {
4699 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4700 MI.getOpcode() == TargetOpcode::G_ROTR);
4701 unsigned Bitsize =
4702 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits();
4703 Register AmtReg = MI.getOperand(i: 2).getReg();
4704 bool OutOfRange = false;
4705 auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) {
4706 if (auto *CI = dyn_cast<ConstantInt>(Val: C))
4707 OutOfRange |= CI->getValue().uge(RHS: Bitsize);
4708 return true;
4709 };
4710 return matchUnaryPredicate(MRI, Reg: AmtReg, Match: MatchOutOfRange) && OutOfRange;
4711}
4712
4713void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) const {
4714 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4715 MI.getOpcode() == TargetOpcode::G_ROTR);
4716 unsigned Bitsize =
4717 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits();
4718 Register Amt = MI.getOperand(i: 2).getReg();
4719 LLT AmtTy = MRI.getType(Reg: Amt);
4720 auto Bits = Builder.buildConstant(Res: AmtTy, Val: Bitsize);
4721 Amt = Builder.buildURem(Dst: AmtTy, Src0: MI.getOperand(i: 2).getReg(), Src1: Bits).getReg(Idx: 0);
4722 Observer.changingInstr(MI);
4723 MI.getOperand(i: 2).setReg(Amt);
4724 Observer.changedInstr(MI);
4725}
4726
4727bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI,
4728 int64_t &MatchInfo) const {
4729 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4730 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
4731
4732 // We want to avoid calling KnownBits on the LHS if possible, as this combine
4733 // has no filter and runs on every G_ICMP instruction. We can avoid calling
4734 // KnownBits on the LHS in two cases:
4735 //
4736 // - The RHS is unknown: Constants are always on RHS. If the RHS is unknown
4737 // we cannot do any transforms so we can safely bail out early.
4738 // - The RHS is zero: we don't need to know the LHS to do unsigned <0 and
4739 // >=0.
4740 auto KnownRHS = VT->getKnownBits(R: MI.getOperand(i: 3).getReg());
4741 if (KnownRHS.isUnknown())
4742 return false;
4743
4744 std::optional<bool> KnownVal;
4745 if (KnownRHS.isZero()) {
4746 // ? uge 0 -> always true
4747 // ? ult 0 -> always false
4748 if (Pred == CmpInst::ICMP_UGE)
4749 KnownVal = true;
4750 else if (Pred == CmpInst::ICMP_ULT)
4751 KnownVal = false;
4752 }
4753
4754 if (!KnownVal) {
4755 auto KnownLHS = VT->getKnownBits(R: MI.getOperand(i: 2).getReg());
4756 KnownVal = ICmpInst::compare(LHS: KnownLHS, RHS: KnownRHS, Pred);
4757 }
4758
4759 if (!KnownVal)
4760 return false;
4761 MatchInfo =
4762 *KnownVal
4763 ? getICmpTrueVal(TLI: getTargetLowering(),
4764 /*IsVector = */
4765 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector(),
4766 /* IsFP = */ false)
4767 : 0;
4768 return true;
4769}
4770
4771bool CombinerHelper::matchICmpToLHSKnownBits(
4772 MachineInstr &MI,
4773 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4774 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4775 // Given:
4776 //
4777 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4778 // %cmp = G_ICMP ne %x, 0
4779 //
4780 // Or:
4781 //
4782 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4783 // %cmp = G_ICMP eq %x, 1
4784 //
4785 // We can replace %cmp with %x assuming true is 1 on the target.
4786 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
4787 if (!CmpInst::isEquality(pred: Pred))
4788 return false;
4789 Register Dst = MI.getOperand(i: 0).getReg();
4790 LLT DstTy = MRI.getType(Reg: Dst);
4791 if (getICmpTrueVal(TLI: getTargetLowering(), IsVector: DstTy.isVector(),
4792 /* IsFP = */ false) != 1)
4793 return false;
4794 int64_t OneOrZero = Pred == CmpInst::ICMP_EQ;
4795 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICst(RequestedValue: OneOrZero)))
4796 return false;
4797 Register LHS = MI.getOperand(i: 2).getReg();
4798 auto KnownLHS = VT->getKnownBits(R: LHS);
4799 if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1)
4800 return false;
4801 // Make sure replacing Dst with the LHS is a legal operation.
4802 LLT LHSTy = MRI.getType(Reg: LHS);
4803 unsigned LHSSize = LHSTy.getSizeInBits();
4804 unsigned DstSize = DstTy.getSizeInBits();
4805 unsigned Op = TargetOpcode::COPY;
4806 if (DstSize != LHSSize)
4807 Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT;
4808 if (!isLegalOrBeforeLegalizer(Query: {Op, {DstTy, LHSTy}}))
4809 return false;
4810 MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Opc: Op, DstOps: {Dst}, SrcOps: {LHS}); };
4811 return true;
4812}
4813
4814// Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0
4815bool CombinerHelper::matchAndOrDisjointMask(
4816 MachineInstr &MI,
4817 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4818 assert(MI.getOpcode() == TargetOpcode::G_AND);
4819
4820 // Ignore vector types to simplify matching the two constants.
4821 // TODO: do this for vectors and scalars via a demanded bits analysis.
4822 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
4823 if (Ty.isVector())
4824 return false;
4825
4826 Register Src;
4827 Register AndMaskReg;
4828 int64_t AndMaskBits;
4829 int64_t OrMaskBits;
4830 if (!mi_match(MI, MRI,
4831 P: m_GAnd(L: m_GOr(L: m_Reg(R&: Src), R: m_ICst(Cst&: OrMaskBits)),
4832 R: m_all_of(preds: m_ICst(Cst&: AndMaskBits), preds: m_Reg(R&: AndMaskReg)))))
4833 return false;
4834
4835 // Check if OrMask could turn on any bits in Src.
4836 if (AndMaskBits & OrMaskBits)
4837 return false;
4838
4839 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4840 Observer.changingInstr(MI);
4841 // Canonicalize the result to have the constant on the RHS.
4842 if (MI.getOperand(i: 1).getReg() == AndMaskReg)
4843 MI.getOperand(i: 2).setReg(AndMaskReg);
4844 MI.getOperand(i: 1).setReg(Src);
4845 Observer.changedInstr(MI);
4846 };
4847 return true;
4848}
4849
4850/// Form a G_SBFX from a G_SEXT_INREG fed by a right shift.
4851bool CombinerHelper::matchBitfieldExtractFromSExtInReg(
4852 MachineInstr &MI,
4853 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4854 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
4855 Register Dst = MI.getOperand(i: 0).getReg();
4856 Register Src = MI.getOperand(i: 1).getReg();
4857 LLT Ty = MRI.getType(Reg: Src);
4858 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4859 if (!LI || !LI->isLegalOrCustom(Query: {TargetOpcode::G_SBFX, {Ty, ExtractTy}}))
4860 return false;
4861 int64_t Width = MI.getOperand(i: 2).getImm();
4862 Register ShiftSrc;
4863 int64_t ShiftImm;
4864 if (!mi_match(
4865 R: Src, MRI,
4866 P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)),
4867 preds: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm))))))
4868 return false;
4869 if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits())
4870 return false;
4871
4872 MatchInfo = [=](MachineIRBuilder &B) {
4873 auto Cst1 = B.buildConstant(Res: ExtractTy, Val: ShiftImm);
4874 auto Cst2 = B.buildConstant(Res: ExtractTy, Val: Width);
4875 B.buildSbfx(Dst, Src: ShiftSrc, LSB: Cst1, Width: Cst2);
4876 };
4877 return true;
4878}
4879
4880/// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants.
4881bool CombinerHelper::matchBitfieldExtractFromAnd(MachineInstr &MI,
4882 BuildFnTy &MatchInfo) const {
4883 GAnd *And = cast<GAnd>(Val: &MI);
4884 Register Dst = And->getReg(Idx: 0);
4885 LLT Ty = MRI.getType(Reg: Dst);
4886 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4887 // Note that isLegalOrBeforeLegalizer is stricter and does not take custom
4888 // into account.
4889 if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4890 return false;
4891
4892 int64_t AndImm, LSBImm;
4893 Register ShiftSrc;
4894 const unsigned Size = Ty.getScalarSizeInBits();
4895 if (!mi_match(R: And->getReg(Idx: 0), MRI,
4896 P: m_GAnd(L: m_OneNonDBGUse(SP: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: LSBImm))),
4897 R: m_ICst(Cst&: AndImm))))
4898 return false;
4899
4900 // The mask is a mask of the low bits iff imm & (imm+1) == 0.
4901 auto MaybeMask = static_cast<uint64_t>(AndImm);
4902 if (MaybeMask & (MaybeMask + 1))
4903 return false;
4904
4905 // LSB must fit within the register.
4906 if (static_cast<uint64_t>(LSBImm) >= Size)
4907 return false;
4908
4909 uint64_t Width = APInt(Size, AndImm).countr_one();
4910 MatchInfo = [=](MachineIRBuilder &B) {
4911 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4912 auto LSBCst = B.buildConstant(Res: ExtractTy, Val: LSBImm);
4913 B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {ShiftSrc, LSBCst, WidthCst});
4914 };
4915 return true;
4916}
4917
4918bool CombinerHelper::matchBitfieldExtractFromShr(
4919 MachineInstr &MI,
4920 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4921 const unsigned Opcode = MI.getOpcode();
4922 assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR);
4923
4924 const Register Dst = MI.getOperand(i: 0).getReg();
4925
4926 const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR
4927 ? TargetOpcode::G_SBFX
4928 : TargetOpcode::G_UBFX;
4929
4930 // Check if the type we would use for the extract is legal
4931 LLT Ty = MRI.getType(Reg: Dst);
4932 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4933 if (!LI || !LI->isLegalOrCustom(Query: {ExtrOpcode, {Ty, ExtractTy}}))
4934 return false;
4935
4936 Register ShlSrc;
4937 int64_t ShrAmt;
4938 int64_t ShlAmt;
4939 const unsigned Size = Ty.getScalarSizeInBits();
4940
4941 // Try to match shr (shl x, c1), c2
4942 if (!mi_match(R: Dst, MRI,
4943 P: m_BinOp(Opcode,
4944 L: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: ShlSrc), R: m_ICst(Cst&: ShlAmt))),
4945 R: m_ICst(Cst&: ShrAmt))))
4946 return false;
4947
4948 // Make sure that the shift sizes can fit a bitfield extract
4949 if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size)
4950 return false;
4951
4952 // Skip this combine if the G_SEXT_INREG combine could handle it
4953 if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt)
4954 return false;
4955
4956 // Calculate start position and width of the extract
4957 const int64_t Pos = ShrAmt - ShlAmt;
4958 const int64_t Width = Size - ShrAmt;
4959
4960 MatchInfo = [=](MachineIRBuilder &B) {
4961 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4962 auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos);
4963 B.buildInstr(Opc: ExtrOpcode, DstOps: {Dst}, SrcOps: {ShlSrc, PosCst, WidthCst});
4964 };
4965 return true;
4966}
4967
4968bool CombinerHelper::matchBitfieldExtractFromShrAnd(
4969 MachineInstr &MI,
4970 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
4971 const unsigned Opcode = MI.getOpcode();
4972 assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR);
4973
4974 const Register Dst = MI.getOperand(i: 0).getReg();
4975 LLT Ty = MRI.getType(Reg: Dst);
4976 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4977 if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4978 return false;
4979
4980 // Try to match shr (and x, c1), c2
4981 Register AndSrc;
4982 int64_t ShrAmt;
4983 int64_t SMask;
4984 if (!mi_match(R: Dst, MRI,
4985 P: m_BinOp(Opcode,
4986 L: m_OneNonDBGUse(SP: m_GAnd(L: m_Reg(R&: AndSrc), R: m_ICst(Cst&: SMask))),
4987 R: m_ICst(Cst&: ShrAmt))))
4988 return false;
4989
4990 const unsigned Size = Ty.getScalarSizeInBits();
4991 if (ShrAmt < 0 || ShrAmt >= Size)
4992 return false;
4993
4994 // If the shift subsumes the mask, emit the 0 directly.
4995 if (0 == (SMask >> ShrAmt)) {
4996 MatchInfo = [=](MachineIRBuilder &B) {
4997 B.buildConstant(Res: Dst, Val: 0);
4998 };
4999 return true;
5000 }
5001
5002 // Check that ubfx can do the extraction, with no holes in the mask.
5003 uint64_t UMask = SMask;
5004 UMask |= maskTrailingOnes<uint64_t>(N: ShrAmt);
5005 UMask &= maskTrailingOnes<uint64_t>(N: Size);
5006 if (!isMask_64(Value: UMask))
5007 return false;
5008
5009 // Calculate start position and width of the extract.
5010 const int64_t Pos = ShrAmt;
5011 const int64_t Width = llvm::countr_one(Value: UMask) - ShrAmt;
5012
5013 // It's preferable to keep the shift, rather than form G_SBFX.
5014 // TODO: remove the G_AND via demanded bits analysis.
5015 if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size)
5016 return false;
5017
5018 MatchInfo = [=](MachineIRBuilder &B) {
5019 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
5020 auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos);
5021 B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {AndSrc, PosCst, WidthCst});
5022 };
5023 return true;
5024}
5025
5026bool CombinerHelper::reassociationCanBreakAddressingModePattern(
5027 MachineInstr &MI) const {
5028 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
5029
5030 Register Src1Reg = PtrAdd.getBaseReg();
5031 auto *Src1Def = getOpcodeDef<GPtrAdd>(Reg: Src1Reg, MRI);
5032 if (!Src1Def)
5033 return false;
5034
5035 Register Src2Reg = PtrAdd.getOffsetReg();
5036
5037 if (MRI.hasOneNonDBGUse(RegNo: Src1Reg))
5038 return false;
5039
5040 auto C1 = getIConstantVRegVal(VReg: Src1Def->getOffsetReg(), MRI);
5041 if (!C1)
5042 return false;
5043 auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI);
5044 if (!C2)
5045 return false;
5046
5047 const APInt &C1APIntVal = *C1;
5048 const APInt &C2APIntVal = *C2;
5049 const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue();
5050
5051 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: PtrAdd.getReg(Idx: 0))) {
5052 // This combine may end up running before ptrtoint/inttoptr combines
5053 // manage to eliminate redundant conversions, so try to look through them.
5054 MachineInstr *ConvUseMI = &UseMI;
5055 unsigned ConvUseOpc = ConvUseMI->getOpcode();
5056 while (ConvUseOpc == TargetOpcode::G_INTTOPTR ||
5057 ConvUseOpc == TargetOpcode::G_PTRTOINT) {
5058 Register DefReg = ConvUseMI->getOperand(i: 0).getReg();
5059 if (!MRI.hasOneNonDBGUse(RegNo: DefReg))
5060 break;
5061 ConvUseMI = &*MRI.use_instr_nodbg_begin(RegNo: DefReg);
5062 ConvUseOpc = ConvUseMI->getOpcode();
5063 }
5064 auto *LdStMI = dyn_cast<GLoadStore>(Val: ConvUseMI);
5065 if (!LdStMI)
5066 continue;
5067 // Is x[offset2] already not a legal addressing mode? If so then
5068 // reassociating the constants breaks nothing (we test offset2 because
5069 // that's the one we hope to fold into the load or store).
5070 TargetLoweringBase::AddrMode AM;
5071 AM.HasBaseReg = true;
5072 AM.BaseOffs = C2APIntVal.getSExtValue();
5073 unsigned AS = MRI.getType(Reg: LdStMI->getPointerReg()).getAddressSpace();
5074 Type *AccessTy = getTypeForLLT(Ty: LdStMI->getMMO().getMemoryType(),
5075 C&: PtrAdd.getMF()->getFunction().getContext());
5076 const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering();
5077 if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM,
5078 Ty: AccessTy, AddrSpace: AS))
5079 continue;
5080
5081 // Would x[offset1+offset2] still be a legal addressing mode?
5082 AM.BaseOffs = CombinedValue;
5083 if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM,
5084 Ty: AccessTy, AddrSpace: AS))
5085 return true;
5086 }
5087
5088 return false;
5089}
5090
5091bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI,
5092 MachineInstr *RHS,
5093 BuildFnTy &MatchInfo) const {
5094 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
5095 Register Src1Reg = MI.getOperand(i: 1).getReg();
5096 if (RHS->getOpcode() != TargetOpcode::G_ADD)
5097 return false;
5098 auto C2 = getIConstantVRegVal(VReg: RHS->getOperand(i: 2).getReg(), MRI);
5099 if (!C2)
5100 return false;
5101
5102 // If both additions are nuw, the reassociated additions are also nuw.
5103 // If the original G_PTR_ADD is additionally nusw, X and C are both not
5104 // negative, so BASE+X is between BASE and BASE+(X+C). The new G_PTR_ADDs are
5105 // therefore also nusw.
5106 // If the original G_PTR_ADD is additionally inbounds (which implies nusw),
5107 // the new G_PTR_ADDs are then also inbounds.
5108 unsigned PtrAddFlags = MI.getFlags();
5109 unsigned AddFlags = RHS->getFlags();
5110 bool IsNoUWrap = PtrAddFlags & AddFlags & MachineInstr::MIFlag::NoUWrap;
5111 bool IsNoUSWrap = IsNoUWrap && (PtrAddFlags & MachineInstr::MIFlag::NoUSWrap);
5112 bool IsInBounds = IsNoUWrap && (PtrAddFlags & MachineInstr::MIFlag::InBounds);
5113 unsigned Flags = 0;
5114 if (IsNoUWrap)
5115 Flags |= MachineInstr::MIFlag::NoUWrap;
5116 if (IsNoUSWrap)
5117 Flags |= MachineInstr::MIFlag::NoUSWrap;
5118 if (IsInBounds)
5119 Flags |= MachineInstr::MIFlag::InBounds;
5120
5121 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5122 LLT PtrTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5123
5124 auto NewBase =
5125 Builder.buildPtrAdd(Res: PtrTy, Op0: Src1Reg, Op1: RHS->getOperand(i: 1).getReg(), Flags);
5126 Observer.changingInstr(MI);
5127 MI.getOperand(i: 1).setReg(NewBase.getReg(Idx: 0));
5128 MI.getOperand(i: 2).setReg(RHS->getOperand(i: 2).getReg());
5129 MI.setFlags(Flags);
5130 Observer.changedInstr(MI);
5131 };
5132 return !reassociationCanBreakAddressingModePattern(MI);
5133}
5134
5135bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI,
5136 MachineInstr *LHS,
5137 MachineInstr *RHS,
5138 BuildFnTy &MatchInfo) const {
5139 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C)
5140 // if and only if (G_PTR_ADD X, C) has one use.
5141 Register LHSBase;
5142 std::optional<ValueAndVReg> LHSCstOff;
5143 if (!mi_match(R: MI.getBaseReg(), MRI,
5144 P: m_OneNonDBGUse(SP: m_GPtrAdd(L: m_Reg(R&: LHSBase), R: m_GCst(ValReg&: LHSCstOff)))))
5145 return false;
5146
5147 auto *LHSPtrAdd = cast<GPtrAdd>(Val: LHS);
5148
5149 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
5150 // nuw and inbounds (which implies nusw), the offsets are both non-negative,
5151 // so the new G_PTR_ADDs are also inbounds.
5152 unsigned PtrAddFlags = MI.getFlags();
5153 unsigned LHSPtrAddFlags = LHSPtrAdd->getFlags();
5154 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
5155 bool IsNoUSWrap = IsNoUWrap && (PtrAddFlags & LHSPtrAddFlags &
5156 MachineInstr::MIFlag::NoUSWrap);
5157 bool IsInBounds = IsNoUWrap && (PtrAddFlags & LHSPtrAddFlags &
5158 MachineInstr::MIFlag::InBounds);
5159 unsigned Flags = 0;
5160 if (IsNoUWrap)
5161 Flags |= MachineInstr::MIFlag::NoUWrap;
5162 if (IsNoUSWrap)
5163 Flags |= MachineInstr::MIFlag::NoUSWrap;
5164 if (IsInBounds)
5165 Flags |= MachineInstr::MIFlag::InBounds;
5166
5167 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5168 // When we change LHSPtrAdd's offset register we might cause it to use a reg
5169 // before its def. Sink the instruction so the outer PTR_ADD to ensure this
5170 // doesn't happen.
5171 LHSPtrAdd->moveBefore(MovePos: &MI);
5172 Register RHSReg = MI.getOffsetReg();
5173 // set VReg will cause type mismatch if it comes from extend/trunc
5174 auto NewCst = B.buildConstant(Res: MRI.getType(Reg: RHSReg), Val: LHSCstOff->Value);
5175 Observer.changingInstr(MI);
5176 MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0));
5177 MI.setFlags(Flags);
5178 Observer.changedInstr(MI);
5179 Observer.changingInstr(MI&: *LHSPtrAdd);
5180 LHSPtrAdd->getOperand(i: 2).setReg(RHSReg);
5181 LHSPtrAdd->setFlags(Flags);
5182 Observer.changedInstr(MI&: *LHSPtrAdd);
5183 };
5184 return !reassociationCanBreakAddressingModePattern(MI);
5185}
5186
5187bool CombinerHelper::matchReassocFoldConstantsInSubTree(
5188 GPtrAdd &MI, MachineInstr *LHS, MachineInstr *RHS,
5189 BuildFnTy &MatchInfo) const {
5190 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
5191 auto *LHSPtrAdd = dyn_cast<GPtrAdd>(Val: LHS);
5192 if (!LHSPtrAdd)
5193 return false;
5194
5195 Register Src2Reg = MI.getOperand(i: 2).getReg();
5196 Register LHSSrc1 = LHSPtrAdd->getBaseReg();
5197 Register LHSSrc2 = LHSPtrAdd->getOffsetReg();
5198 auto C1 = getIConstantVRegVal(VReg: LHSSrc2, MRI);
5199 if (!C1)
5200 return false;
5201 auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI);
5202 if (!C2)
5203 return false;
5204
5205 // Reassociating nuw additions preserves nuw. If both original G_PTR_ADDs are
5206 // inbounds, reaching the same result in one G_PTR_ADD is also inbounds.
5207 // The nusw constraints are satisfied because imm1+imm2 cannot exceed the
5208 // largest signed integer that fits into the index type, which is the maximum
5209 // size of allocated objects according to the IR Language Reference.
5210 unsigned PtrAddFlags = MI.getFlags();
5211 unsigned LHSPtrAddFlags = LHSPtrAdd->getFlags();
5212 bool IsNoUWrap = PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::NoUWrap;
5213 bool IsInBounds =
5214 PtrAddFlags & LHSPtrAddFlags & MachineInstr::MIFlag::InBounds;
5215 unsigned Flags = 0;
5216 if (IsNoUWrap)
5217 Flags |= MachineInstr::MIFlag::NoUWrap;
5218 if (IsInBounds) {
5219 Flags |= MachineInstr::MIFlag::InBounds;
5220 Flags |= MachineInstr::MIFlag::NoUSWrap;
5221 }
5222
5223 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5224 auto NewCst = B.buildConstant(Res: MRI.getType(Reg: Src2Reg), Val: *C1 + *C2);
5225 Observer.changingInstr(MI);
5226 MI.getOperand(i: 1).setReg(LHSSrc1);
5227 MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0));
5228 MI.setFlags(Flags);
5229 Observer.changedInstr(MI);
5230 };
5231 return !reassociationCanBreakAddressingModePattern(MI);
5232}
5233
5234bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI,
5235 BuildFnTy &MatchInfo) const {
5236 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
5237 // We're trying to match a few pointer computation patterns here for
5238 // re-association opportunities.
5239 // 1) Isolating a constant operand to be on the RHS, e.g.:
5240 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
5241 //
5242 // 2) Folding two constants in each sub-tree as long as such folding
5243 // doesn't break a legal addressing mode.
5244 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
5245 //
5246 // 3) Move a constant from the LHS of an inner op to the RHS of the outer.
5247 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C)
5248 // iif (G_PTR_ADD X, C) has one use.
5249 MachineInstr *LHS = MRI.getVRegDef(Reg: PtrAdd.getBaseReg());
5250 MachineInstr *RHS = MRI.getVRegDef(Reg: PtrAdd.getOffsetReg());
5251
5252 // Try to match example 2.
5253 if (matchReassocFoldConstantsInSubTree(MI&: PtrAdd, LHS, RHS, MatchInfo))
5254 return true;
5255
5256 // Try to match example 3.
5257 if (matchReassocConstantInnerLHS(MI&: PtrAdd, LHS, RHS, MatchInfo))
5258 return true;
5259
5260 // Try to match example 1.
5261 if (matchReassocConstantInnerRHS(MI&: PtrAdd, RHS, MatchInfo))
5262 return true;
5263
5264 return false;
5265}
5266bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg,
5267 Register OpLHS, Register OpRHS,
5268 BuildFnTy &MatchInfo) const {
5269 LLT OpRHSTy = MRI.getType(Reg: OpRHS);
5270 MachineInstr *OpLHSDef = MRI.getVRegDef(Reg: OpLHS);
5271
5272 if (OpLHSDef->getOpcode() != Opc)
5273 return false;
5274
5275 MachineInstr *OpRHSDef = MRI.getVRegDef(Reg: OpRHS);
5276 Register OpLHSLHS = OpLHSDef->getOperand(i: 1).getReg();
5277 Register OpLHSRHS = OpLHSDef->getOperand(i: 2).getReg();
5278
5279 // If the inner op is (X op C), pull the constant out so it can be folded with
5280 // other constants in the expression tree. Folding is not guaranteed so we
5281 // might have (C1 op C2). In that case do not pull a constant out because it
5282 // won't help and can lead to infinite loops.
5283 if (isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSRHS), MRI) &&
5284 !isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSLHS), MRI)) {
5285 if (isConstantOrConstantSplatVector(MI&: *OpRHSDef, MRI)) {
5286 // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2))
5287 MatchInfo = [=](MachineIRBuilder &B) {
5288 auto NewCst = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSRHS, OpRHS});
5289 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {OpLHSLHS, NewCst});
5290 };
5291 return true;
5292 }
5293 if (getTargetLowering().isReassocProfitable(MRI, N0: OpLHS, N1: OpRHS)) {
5294 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
5295 // iff (op x, c1) has one use
5296 MatchInfo = [=](MachineIRBuilder &B) {
5297 auto NewLHSLHS = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSLHS, OpRHS});
5298 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {NewLHSLHS, OpLHSRHS});
5299 };
5300 return true;
5301 }
5302 }
5303
5304 return false;
5305}
5306
5307bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI,
5308 BuildFnTy &MatchInfo) const {
5309 // We don't check if the reassociation will break a legal addressing mode
5310 // here since pointer arithmetic is handled by G_PTR_ADD.
5311 unsigned Opc = MI.getOpcode();
5312 Register DstReg = MI.getOperand(i: 0).getReg();
5313 Register LHSReg = MI.getOperand(i: 1).getReg();
5314 Register RHSReg = MI.getOperand(i: 2).getReg();
5315
5316 if (tryReassocBinOp(Opc, DstReg, OpLHS: LHSReg, OpRHS: RHSReg, MatchInfo))
5317 return true;
5318 if (tryReassocBinOp(Opc, DstReg, OpLHS: RHSReg, OpRHS: LHSReg, MatchInfo))
5319 return true;
5320 return false;
5321}
5322
5323bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI,
5324 APInt &MatchInfo) const {
5325 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5326 Register SrcOp = MI.getOperand(i: 1).getReg();
5327
5328 if (auto MaybeCst = ConstantFoldCastOp(Opcode: MI.getOpcode(), DstTy, Op0: SrcOp, MRI)) {
5329 MatchInfo = *MaybeCst;
5330 return true;
5331 }
5332
5333 return false;
5334}
5335
5336bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI,
5337 APInt &MatchInfo) const {
5338 Register Op1 = MI.getOperand(i: 1).getReg();
5339 Register Op2 = MI.getOperand(i: 2).getReg();
5340 auto MaybeCst = ConstantFoldBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI);
5341 if (!MaybeCst)
5342 return false;
5343 MatchInfo = *MaybeCst;
5344 return true;
5345}
5346
5347bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI,
5348 ConstantFP *&MatchInfo) const {
5349 Register Op1 = MI.getOperand(i: 1).getReg();
5350 Register Op2 = MI.getOperand(i: 2).getReg();
5351 auto MaybeCst = ConstantFoldFPBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI);
5352 if (!MaybeCst)
5353 return false;
5354 MatchInfo =
5355 ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: *MaybeCst);
5356 return true;
5357}
5358
5359bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI,
5360 ConstantFP *&MatchInfo) const {
5361 assert(MI.getOpcode() == TargetOpcode::G_FMA ||
5362 MI.getOpcode() == TargetOpcode::G_FMAD);
5363 auto [_, Op1, Op2, Op3] = MI.getFirst4Regs();
5364
5365 const ConstantFP *Op3Cst = getConstantFPVRegVal(VReg: Op3, MRI);
5366 if (!Op3Cst)
5367 return false;
5368
5369 const ConstantFP *Op2Cst = getConstantFPVRegVal(VReg: Op2, MRI);
5370 if (!Op2Cst)
5371 return false;
5372
5373 const ConstantFP *Op1Cst = getConstantFPVRegVal(VReg: Op1, MRI);
5374 if (!Op1Cst)
5375 return false;
5376
5377 APFloat Op1F = Op1Cst->getValueAPF();
5378 Op1F.fusedMultiplyAdd(Multiplicand: Op2Cst->getValueAPF(), Addend: Op3Cst->getValueAPF(),
5379 RM: APFloat::rmNearestTiesToEven);
5380 MatchInfo = ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: Op1F);
5381 return true;
5382}
5383
5384bool CombinerHelper::matchNarrowBinopFeedingAnd(
5385 MachineInstr &MI,
5386 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
5387 // Look for a binop feeding into an AND with a mask:
5388 //
5389 // %add = G_ADD %lhs, %rhs
5390 // %and = G_AND %add, 000...11111111
5391 //
5392 // Check if it's possible to perform the binop at a narrower width and zext
5393 // back to the original width like so:
5394 //
5395 // %narrow_lhs = G_TRUNC %lhs
5396 // %narrow_rhs = G_TRUNC %rhs
5397 // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs
5398 // %new_add = G_ZEXT %narrow_add
5399 // %and = G_AND %new_add, 000...11111111
5400 //
5401 // This can allow later combines to eliminate the G_AND if it turns out
5402 // that the mask is irrelevant.
5403 assert(MI.getOpcode() == TargetOpcode::G_AND);
5404 Register Dst = MI.getOperand(i: 0).getReg();
5405 Register AndLHS = MI.getOperand(i: 1).getReg();
5406 Register AndRHS = MI.getOperand(i: 2).getReg();
5407 LLT WideTy = MRI.getType(Reg: Dst);
5408
5409 // If the potential binop has more than one use, then it's possible that one
5410 // of those uses will need its full width.
5411 if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(RegNo: AndLHS))
5412 return false;
5413
5414 // Check if the LHS feeding the AND is impacted by the high bits that we're
5415 // masking out.
5416 //
5417 // e.g. for 64-bit x, y:
5418 //
5419 // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535
5420 MachineInstr *LHSInst = getDefIgnoringCopies(Reg: AndLHS, MRI);
5421 if (!LHSInst)
5422 return false;
5423 unsigned LHSOpc = LHSInst->getOpcode();
5424 switch (LHSOpc) {
5425 default:
5426 return false;
5427 case TargetOpcode::G_ADD:
5428 case TargetOpcode::G_SUB:
5429 case TargetOpcode::G_MUL:
5430 case TargetOpcode::G_AND:
5431 case TargetOpcode::G_OR:
5432 case TargetOpcode::G_XOR:
5433 break;
5434 }
5435
5436 // Find the mask on the RHS.
5437 auto Cst = getIConstantVRegValWithLookThrough(VReg: AndRHS, MRI);
5438 if (!Cst)
5439 return false;
5440 auto Mask = Cst->Value;
5441 if (!Mask.isMask())
5442 return false;
5443
5444 // No point in combining if there's nothing to truncate.
5445 unsigned NarrowWidth = Mask.countr_one();
5446 if (NarrowWidth == WideTy.getSizeInBits())
5447 return false;
5448 LLT NarrowTy = LLT::scalar(SizeInBits: NarrowWidth);
5449
5450 // Check if adding the zext + truncates could be harmful.
5451 auto &MF = *MI.getMF();
5452 const auto &TLI = getTargetLowering();
5453 LLVMContext &Ctx = MF.getFunction().getContext();
5454 if (!TLI.isTruncateFree(FromTy: WideTy, ToTy: NarrowTy, Ctx) ||
5455 !TLI.isZExtFree(FromTy: NarrowTy, ToTy: WideTy, Ctx))
5456 return false;
5457 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) ||
5458 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {WideTy, NarrowTy}}))
5459 return false;
5460 Register BinOpLHS = LHSInst->getOperand(i: 1).getReg();
5461 Register BinOpRHS = LHSInst->getOperand(i: 2).getReg();
5462 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5463 auto NarrowLHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpLHS);
5464 auto NarrowRHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpRHS);
5465 auto NarrowBinOp =
5466 Builder.buildInstr(Opc: LHSOpc, DstOps: {NarrowTy}, SrcOps: {NarrowLHS, NarrowRHS});
5467 auto Ext = Builder.buildZExt(Res: WideTy, Op: NarrowBinOp);
5468 Observer.changingInstr(MI);
5469 MI.getOperand(i: 1).setReg(Ext.getReg(Idx: 0));
5470 Observer.changedInstr(MI);
5471 };
5472 return true;
5473}
5474
5475bool CombinerHelper::matchMulOBy2(MachineInstr &MI,
5476 BuildFnTy &MatchInfo) const {
5477 unsigned Opc = MI.getOpcode();
5478 assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO);
5479
5480 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 2)))
5481 return false;
5482
5483 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5484 Observer.changingInstr(MI);
5485 unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO
5486 : TargetOpcode::G_SADDO;
5487 MI.setDesc(Builder.getTII().get(Opcode: NewOpc));
5488 MI.getOperand(i: 3).setReg(MI.getOperand(i: 2).getReg());
5489 Observer.changedInstr(MI);
5490 };
5491 return true;
5492}
5493
5494bool CombinerHelper::matchMulOBy0(MachineInstr &MI,
5495 BuildFnTy &MatchInfo) const {
5496 // (G_*MULO x, 0) -> 0 + no carry out
5497 assert(MI.getOpcode() == TargetOpcode::G_UMULO ||
5498 MI.getOpcode() == TargetOpcode::G_SMULO);
5499 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0)))
5500 return false;
5501 Register Dst = MI.getOperand(i: 0).getReg();
5502 Register Carry = MI.getOperand(i: 1).getReg();
5503 if (!isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Dst)) ||
5504 !isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Carry)))
5505 return false;
5506 MatchInfo = [=](MachineIRBuilder &B) {
5507 B.buildConstant(Res: Dst, Val: 0);
5508 B.buildConstant(Res: Carry, Val: 0);
5509 };
5510 return true;
5511}
5512
5513bool CombinerHelper::matchAddEToAddO(MachineInstr &MI,
5514 BuildFnTy &MatchInfo) const {
5515 // (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
5516 // (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
5517 assert(MI.getOpcode() == TargetOpcode::G_UADDE ||
5518 MI.getOpcode() == TargetOpcode::G_SADDE ||
5519 MI.getOpcode() == TargetOpcode::G_USUBE ||
5520 MI.getOpcode() == TargetOpcode::G_SSUBE);
5521 if (!mi_match(R: MI.getOperand(i: 4).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0)))
5522 return false;
5523 MatchInfo = [&](MachineIRBuilder &B) {
5524 unsigned NewOpcode;
5525 switch (MI.getOpcode()) {
5526 case TargetOpcode::G_UADDE:
5527 NewOpcode = TargetOpcode::G_UADDO;
5528 break;
5529 case TargetOpcode::G_SADDE:
5530 NewOpcode = TargetOpcode::G_SADDO;
5531 break;
5532 case TargetOpcode::G_USUBE:
5533 NewOpcode = TargetOpcode::G_USUBO;
5534 break;
5535 case TargetOpcode::G_SSUBE:
5536 NewOpcode = TargetOpcode::G_SSUBO;
5537 break;
5538 }
5539 Observer.changingInstr(MI);
5540 MI.setDesc(B.getTII().get(Opcode: NewOpcode));
5541 MI.removeOperand(OpNo: 4);
5542 Observer.changedInstr(MI);
5543 };
5544 return true;
5545}
5546
5547bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI,
5548 BuildFnTy &MatchInfo) const {
5549 assert(MI.getOpcode() == TargetOpcode::G_SUB);
5550 Register Dst = MI.getOperand(i: 0).getReg();
5551 // (x + y) - z -> x (if y == z)
5552 // (x + y) - z -> y (if x == z)
5553 Register X, Y, Z;
5554 if (mi_match(R: Dst, MRI, P: m_GSub(L: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: Y)), R: m_Reg(R&: Z)))) {
5555 Register ReplaceReg;
5556 int64_t CstX, CstY;
5557 if (Y == Z || (mi_match(R: Y, MRI, P: m_ICstOrSplat(Cst&: CstY)) &&
5558 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstY))))
5559 ReplaceReg = X;
5560 else if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5561 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5562 ReplaceReg = Y;
5563 if (ReplaceReg) {
5564 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: ReplaceReg); };
5565 return true;
5566 }
5567 }
5568
5569 // x - (y + z) -> 0 - y (if x == z)
5570 // x - (y + z) -> 0 - z (if x == y)
5571 if (mi_match(R: Dst, MRI, P: m_GSub(L: m_Reg(R&: X), R: m_GAdd(L: m_Reg(R&: Y), R: m_Reg(R&: Z))))) {
5572 Register ReplaceReg;
5573 int64_t CstX;
5574 if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5575 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5576 ReplaceReg = Y;
5577 else if (X == Y || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5578 mi_match(R: Y, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5579 ReplaceReg = Z;
5580 if (ReplaceReg) {
5581 MatchInfo = [=](MachineIRBuilder &B) {
5582 auto Zero = B.buildConstant(Res: MRI.getType(Reg: Dst), Val: 0);
5583 B.buildSub(Dst, Src0: Zero, Src1: ReplaceReg);
5584 };
5585 return true;
5586 }
5587 }
5588 return false;
5589}
5590
5591MachineInstr *CombinerHelper::buildUDivOrURemUsingMul(MachineInstr &MI) const {
5592 unsigned Opcode = MI.getOpcode();
5593 assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM);
5594 auto &UDivorRem = cast<GenericMachineInstr>(Val&: MI);
5595 Register Dst = UDivorRem.getReg(Idx: 0);
5596 Register LHS = UDivorRem.getReg(Idx: 1);
5597 Register RHS = UDivorRem.getReg(Idx: 2);
5598 LLT Ty = MRI.getType(Reg: Dst);
5599 LLT ScalarTy = Ty.getScalarType();
5600 const unsigned EltBits = ScalarTy.getScalarSizeInBits();
5601 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5602 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5603
5604 auto &MIB = Builder;
5605
5606 bool UseSRL = false;
5607 SmallVector<Register, 16> Shifts, Factors;
5608 auto *RHSDefInstr = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI));
5609 bool IsSplat = getIConstantSplatVal(MI: *RHSDefInstr, MRI).has_value();
5610
5611 auto BuildExactUDIVPattern = [&](const Constant *C) {
5612 // Don't recompute inverses for each splat element.
5613 if (IsSplat && !Factors.empty()) {
5614 Shifts.push_back(Elt: Shifts[0]);
5615 Factors.push_back(Elt: Factors[0]);
5616 return true;
5617 }
5618
5619 auto *CI = cast<ConstantInt>(Val: C);
5620 APInt Divisor = CI->getValue();
5621 unsigned Shift = Divisor.countr_zero();
5622 if (Shift) {
5623 Divisor.lshrInPlace(ShiftAmt: Shift);
5624 UseSRL = true;
5625 }
5626
5627 // Calculate the multiplicative inverse modulo BW.
5628 APInt Factor = Divisor.multiplicativeInverse();
5629 Shifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0));
5630 Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0));
5631 return true;
5632 };
5633
5634 if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5635 // Collect all magic values from the build vector.
5636 if (!matchUnaryPredicate(MRI, Reg: RHS, Match: BuildExactUDIVPattern))
5637 llvm_unreachable("Expected unary predicate match to succeed");
5638
5639 Register Shift, Factor;
5640 if (Ty.isVector()) {
5641 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0);
5642 Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0);
5643 } else {
5644 Shift = Shifts[0];
5645 Factor = Factors[0];
5646 }
5647
5648 Register Res = LHS;
5649
5650 if (UseSRL)
5651 Res = MIB.buildLShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0);
5652
5653 return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor);
5654 }
5655
5656 unsigned KnownLeadingZeros =
5657 VT ? VT->getKnownBits(R: LHS).countMinLeadingZeros() : 0;
5658
5659 bool UseNPQ = false;
5660 SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
5661 auto BuildUDIVPattern = [&](const Constant *C) {
5662 auto *CI = cast<ConstantInt>(Val: C);
5663 const APInt &Divisor = CI->getValue();
5664
5665 bool SelNPQ = false;
5666 APInt Magic(Divisor.getBitWidth(), 0);
5667 unsigned PreShift = 0, PostShift = 0;
5668
5669 // Magic algorithm doesn't work for division by 1. We need to emit a select
5670 // at the end.
5671 // TODO: Use undef values for divisor of 1.
5672 if (!Divisor.isOne()) {
5673
5674 // UnsignedDivisionByConstantInfo doesn't work correctly if leading zeros
5675 // in the dividend exceeds the leading zeros for the divisor.
5676 UnsignedDivisionByConstantInfo magics =
5677 UnsignedDivisionByConstantInfo::get(
5678 D: Divisor, LeadingZeros: std::min(a: KnownLeadingZeros, b: Divisor.countl_zero()));
5679
5680 Magic = std::move(magics.Magic);
5681
5682 assert(magics.PreShift < Divisor.getBitWidth() &&
5683 "We shouldn't generate an undefined shift!");
5684 assert(magics.PostShift < Divisor.getBitWidth() &&
5685 "We shouldn't generate an undefined shift!");
5686 assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift");
5687 PreShift = magics.PreShift;
5688 PostShift = magics.PostShift;
5689 SelNPQ = magics.IsAdd;
5690 }
5691
5692 PreShifts.push_back(
5693 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PreShift).getReg(Idx: 0));
5694 MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magic).getReg(Idx: 0));
5695 NPQFactors.push_back(
5696 Elt: MIB.buildConstant(Res: ScalarTy,
5697 Val: SelNPQ ? APInt::getOneBitSet(numBits: EltBits, BitNo: EltBits - 1)
5698 : APInt::getZero(numBits: EltBits))
5699 .getReg(Idx: 0));
5700 PostShifts.push_back(
5701 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PostShift).getReg(Idx: 0));
5702 UseNPQ |= SelNPQ;
5703 return true;
5704 };
5705
5706 // Collect the shifts/magic values from each element.
5707 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildUDIVPattern);
5708 (void)Matched;
5709 assert(Matched && "Expected unary predicate match to succeed");
5710
5711 Register PreShift, PostShift, MagicFactor, NPQFactor;
5712 auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI);
5713 if (RHSDef) {
5714 PreShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PreShifts).getReg(Idx: 0);
5715 MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0);
5716 NPQFactor = MIB.buildBuildVector(Res: Ty, Ops: NPQFactors).getReg(Idx: 0);
5717 PostShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PostShifts).getReg(Idx: 0);
5718 } else {
5719 assert(MRI.getType(RHS).isScalar() &&
5720 "Non-build_vector operation should have been a scalar");
5721 PreShift = PreShifts[0];
5722 MagicFactor = MagicFactors[0];
5723 PostShift = PostShifts[0];
5724 }
5725
5726 Register Q = LHS;
5727 Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PreShift).getReg(Idx: 0);
5728
5729 // Multiply the numerator (operand 0) by the magic value.
5730 Q = MIB.buildUMulH(Dst: Ty, Src0: Q, Src1: MagicFactor).getReg(Idx: 0);
5731
5732 if (UseNPQ) {
5733 Register NPQ = MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Q).getReg(Idx: 0);
5734
5735 // For vectors we might have a mix of non-NPQ/NPQ paths, so use
5736 // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero.
5737 if (Ty.isVector())
5738 NPQ = MIB.buildUMulH(Dst: Ty, Src0: NPQ, Src1: NPQFactor).getReg(Idx: 0);
5739 else
5740 NPQ = MIB.buildLShr(Dst: Ty, Src0: NPQ, Src1: MIB.buildConstant(Res: ShiftAmtTy, Val: 1)).getReg(Idx: 0);
5741
5742 Q = MIB.buildAdd(Dst: Ty, Src0: NPQ, Src1: Q).getReg(Idx: 0);
5743 }
5744
5745 Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PostShift).getReg(Idx: 0);
5746 auto One = MIB.buildConstant(Res: Ty, Val: 1);
5747 auto IsOne = MIB.buildICmp(
5748 Pred: CmpInst::Predicate::ICMP_EQ,
5749 Res: Ty.isScalar() ? LLT::scalar(SizeInBits: 1) : Ty.changeElementSize(NewEltSize: 1), Op0: RHS, Op1: One);
5750 auto ret = MIB.buildSelect(Res: Ty, Tst: IsOne, Op0: LHS, Op1: Q);
5751
5752 if (Opcode == TargetOpcode::G_UREM) {
5753 auto Prod = MIB.buildMul(Dst: Ty, Src0: ret, Src1: RHS);
5754 return MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Prod);
5755 }
5756 return ret;
5757}
5758
5759bool CombinerHelper::matchUDivOrURemByConst(MachineInstr &MI) const {
5760 unsigned Opcode = MI.getOpcode();
5761 assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM);
5762 Register Dst = MI.getOperand(i: 0).getReg();
5763 Register RHS = MI.getOperand(i: 2).getReg();
5764 LLT DstTy = MRI.getType(Reg: Dst);
5765
5766 auto &MF = *MI.getMF();
5767 AttributeList Attr = MF.getFunction().getAttributes();
5768 const auto &TLI = getTargetLowering();
5769 LLVMContext &Ctx = MF.getFunction().getContext();
5770 if (DstTy.getScalarSizeInBits() == 1 ||
5771 TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, Ctx), Attr))
5772 return false;
5773
5774 // Don't do this for minsize because the instruction sequence is usually
5775 // larger.
5776 if (MF.getFunction().hasMinSize())
5777 return false;
5778
5779 if (Opcode == TargetOpcode::G_UDIV &&
5780 MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5781 return matchUnaryPredicate(
5782 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5783 }
5784
5785 auto *RHSDef = MRI.getVRegDef(Reg: RHS);
5786 if (!isConstantOrConstantVector(MI&: *RHSDef, MRI))
5787 return false;
5788
5789 // Don't do this if the types are not going to be legal.
5790 if (LI) {
5791 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}}))
5792 return false;
5793 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMULH, {DstTy}}))
5794 return false;
5795 if (!isLegalOrBeforeLegalizer(
5796 Query: {TargetOpcode::G_ICMP,
5797 {DstTy.isVector() ? DstTy.changeElementSize(NewEltSize: 1) : LLT::scalar(SizeInBits: 1),
5798 DstTy}}))
5799 return false;
5800 if (Opcode == TargetOpcode::G_UREM &&
5801 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy, DstTy}}))
5802 return false;
5803 }
5804
5805 return matchUnaryPredicate(
5806 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5807}
5808
5809void CombinerHelper::applyUDivOrURemByConst(MachineInstr &MI) const {
5810 auto *NewMI = buildUDivOrURemUsingMul(MI);
5811 replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg());
5812}
5813
5814bool CombinerHelper::matchSDivOrSRemByConst(MachineInstr &MI) const {
5815 unsigned Opcode = MI.getOpcode();
5816 assert(Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM);
5817 Register Dst = MI.getOperand(i: 0).getReg();
5818 Register RHS = MI.getOperand(i: 2).getReg();
5819 LLT DstTy = MRI.getType(Reg: Dst);
5820 auto SizeInBits = DstTy.getScalarSizeInBits();
5821 LLT WideTy = DstTy.changeElementSize(NewEltSize: SizeInBits * 2);
5822
5823 auto &MF = *MI.getMF();
5824 AttributeList Attr = MF.getFunction().getAttributes();
5825 const auto &TLI = getTargetLowering();
5826 LLVMContext &Ctx = MF.getFunction().getContext();
5827 if (DstTy.getScalarSizeInBits() < 3 ||
5828 TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, Ctx), Attr))
5829 return false;
5830
5831 // Don't do this for minsize because the instruction sequence is usually
5832 // larger.
5833 if (MF.getFunction().hasMinSize())
5834 return false;
5835
5836 // If the sdiv has an 'exact' flag we can use a simpler lowering.
5837 if (Opcode == TargetOpcode::G_SDIV &&
5838 MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5839 return matchUnaryPredicate(
5840 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5841 }
5842
5843 auto *RHSDef = MRI.getVRegDef(Reg: RHS);
5844 if (!isConstantOrConstantVector(MI&: *RHSDef, MRI))
5845 return false;
5846
5847 // Don't do this if the types are not going to be legal.
5848 if (LI) {
5849 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}}))
5850 return false;
5851 if (!isLegal(Query: {TargetOpcode::G_SMULH, {DstTy}}) &&
5852 !isLegalOrHasWidenScalar(Query: {TargetOpcode::G_MUL, {WideTy, WideTy}}))
5853 return false;
5854 if (Opcode == TargetOpcode::G_SREM &&
5855 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy, DstTy}}))
5856 return false;
5857 }
5858
5859 return matchUnaryPredicate(
5860 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); });
5861}
5862
5863void CombinerHelper::applySDivOrSRemByConst(MachineInstr &MI) const {
5864 auto *NewMI = buildSDivOrSRemUsingMul(MI);
5865 replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg());
5866}
5867
5868MachineInstr *CombinerHelper::buildSDivOrSRemUsingMul(MachineInstr &MI) const {
5869 unsigned Opcode = MI.getOpcode();
5870 assert(MI.getOpcode() == TargetOpcode::G_SDIV ||
5871 Opcode == TargetOpcode::G_SREM);
5872 auto &SDivorRem = cast<GenericMachineInstr>(Val&: MI);
5873 Register Dst = SDivorRem.getReg(Idx: 0);
5874 Register LHS = SDivorRem.getReg(Idx: 1);
5875 Register RHS = SDivorRem.getReg(Idx: 2);
5876 LLT Ty = MRI.getType(Reg: Dst);
5877 LLT ScalarTy = Ty.getScalarType();
5878 const unsigned EltBits = ScalarTy.getScalarSizeInBits();
5879 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5880 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5881 auto &MIB = Builder;
5882
5883 bool UseSRA = false;
5884 SmallVector<Register, 16> ExactShifts, ExactFactors;
5885
5886 auto *RHSDefInstr = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI));
5887 bool IsSplat = getIConstantSplatVal(MI: *RHSDefInstr, MRI).has_value();
5888
5889 auto BuildExactSDIVPattern = [&](const Constant *C) {
5890 // Don't recompute inverses for each splat element.
5891 if (IsSplat && !ExactFactors.empty()) {
5892 ExactShifts.push_back(Elt: ExactShifts[0]);
5893 ExactFactors.push_back(Elt: ExactFactors[0]);
5894 return true;
5895 }
5896
5897 auto *CI = cast<ConstantInt>(Val: C);
5898 APInt Divisor = CI->getValue();
5899 unsigned Shift = Divisor.countr_zero();
5900 if (Shift) {
5901 Divisor.ashrInPlace(ShiftAmt: Shift);
5902 UseSRA = true;
5903 }
5904
5905 // Calculate the multiplicative inverse modulo BW.
5906 // 2^W requires W + 1 bits, so we have to extend and then truncate.
5907 APInt Factor = Divisor.multiplicativeInverse();
5908 ExactShifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0));
5909 ExactFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0));
5910 return true;
5911 };
5912
5913 if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5914 // Collect all magic values from the build vector.
5915 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildExactSDIVPattern);
5916 (void)Matched;
5917 assert(Matched && "Expected unary predicate match to succeed");
5918
5919 Register Shift, Factor;
5920 if (Ty.isVector()) {
5921 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: ExactShifts).getReg(Idx: 0);
5922 Factor = MIB.buildBuildVector(Res: Ty, Ops: ExactFactors).getReg(Idx: 0);
5923 } else {
5924 Shift = ExactShifts[0];
5925 Factor = ExactFactors[0];
5926 }
5927
5928 Register Res = LHS;
5929
5930 if (UseSRA)
5931 Res = MIB.buildAShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0);
5932
5933 return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor);
5934 }
5935
5936 SmallVector<Register, 16> MagicFactors, Factors, Shifts, ShiftMasks;
5937
5938 auto BuildSDIVPattern = [&](const Constant *C) {
5939 auto *CI = cast<ConstantInt>(Val: C);
5940 const APInt &Divisor = CI->getValue();
5941
5942 SignedDivisionByConstantInfo Magics =
5943 SignedDivisionByConstantInfo::get(D: Divisor);
5944 int NumeratorFactor = 0;
5945 int ShiftMask = -1;
5946
5947 if (Divisor.isOne() || Divisor.isAllOnes()) {
5948 // If d is +1/-1, we just multiply the numerator by +1/-1.
5949 NumeratorFactor = Divisor.getSExtValue();
5950 Magics.Magic = 0;
5951 Magics.ShiftAmount = 0;
5952 ShiftMask = 0;
5953 } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) {
5954 // If d > 0 and m < 0, add the numerator.
5955 NumeratorFactor = 1;
5956 } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) {
5957 // If d < 0 and m > 0, subtract the numerator.
5958 NumeratorFactor = -1;
5959 }
5960
5961 MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magics.Magic).getReg(Idx: 0));
5962 Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: NumeratorFactor).getReg(Idx: 0));
5963 Shifts.push_back(
5964 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Magics.ShiftAmount).getReg(Idx: 0));
5965 ShiftMasks.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: ShiftMask).getReg(Idx: 0));
5966
5967 return true;
5968 };
5969
5970 // Collect the shifts/magic values from each element.
5971 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildSDIVPattern);
5972 (void)Matched;
5973 assert(Matched && "Expected unary predicate match to succeed");
5974
5975 Register MagicFactor, Factor, Shift, ShiftMask;
5976 auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI);
5977 if (RHSDef) {
5978 MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0);
5979 Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0);
5980 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0);
5981 ShiftMask = MIB.buildBuildVector(Res: Ty, Ops: ShiftMasks).getReg(Idx: 0);
5982 } else {
5983 assert(MRI.getType(RHS).isScalar() &&
5984 "Non-build_vector operation should have been a scalar");
5985 MagicFactor = MagicFactors[0];
5986 Factor = Factors[0];
5987 Shift = Shifts[0];
5988 ShiftMask = ShiftMasks[0];
5989 }
5990
5991 Register Q = LHS;
5992 Q = MIB.buildSMulH(Dst: Ty, Src0: LHS, Src1: MagicFactor).getReg(Idx: 0);
5993
5994 // (Optionally) Add/subtract the numerator using Factor.
5995 Factor = MIB.buildMul(Dst: Ty, Src0: LHS, Src1: Factor).getReg(Idx: 0);
5996 Q = MIB.buildAdd(Dst: Ty, Src0: Q, Src1: Factor).getReg(Idx: 0);
5997
5998 // Shift right algebraic by shift value.
5999 Q = MIB.buildAShr(Dst: Ty, Src0: Q, Src1: Shift).getReg(Idx: 0);
6000
6001 // Extract the sign bit, mask it and add it to the quotient.
6002 auto SignShift = MIB.buildConstant(Res: ShiftAmtTy, Val: EltBits - 1);
6003 auto T = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: SignShift);
6004 T = MIB.buildAnd(Dst: Ty, Src0: T, Src1: ShiftMask);
6005 auto ret = MIB.buildAdd(Dst: Ty, Src0: Q, Src1: T);
6006
6007 if (Opcode == TargetOpcode::G_SREM) {
6008 auto Prod = MIB.buildMul(Dst: Ty, Src0: ret, Src1: RHS);
6009 return MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Prod);
6010 }
6011 return ret;
6012}
6013
6014bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const {
6015 assert((MI.getOpcode() == TargetOpcode::G_SDIV ||
6016 MI.getOpcode() == TargetOpcode::G_UDIV) &&
6017 "Expected SDIV or UDIV");
6018 auto &Div = cast<GenericMachineInstr>(Val&: MI);
6019 Register RHS = Div.getReg(Idx: 2);
6020 auto MatchPow2 = [&](const Constant *C) {
6021 auto *CI = dyn_cast<ConstantInt>(Val: C);
6022 return CI && (CI->getValue().isPowerOf2() ||
6023 (IsSigned && CI->getValue().isNegatedPowerOf2()));
6024 };
6025 return matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2, /*AllowUndefs=*/false);
6026}
6027
6028void CombinerHelper::applySDivByPow2(MachineInstr &MI) const {
6029 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
6030 auto &SDiv = cast<GenericMachineInstr>(Val&: MI);
6031 Register Dst = SDiv.getReg(Idx: 0);
6032 Register LHS = SDiv.getReg(Idx: 1);
6033 Register RHS = SDiv.getReg(Idx: 2);
6034 LLT Ty = MRI.getType(Reg: Dst);
6035 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6036 LLT CCVT =
6037 Ty.isVector() ? LLT::vector(EC: Ty.getElementCount(), ScalarSizeInBits: 1) : LLT::scalar(SizeInBits: 1);
6038
6039 // Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2,
6040 // to the following version:
6041 //
6042 // %c1 = G_CTTZ %rhs
6043 // %inexact = G_SUB $bitwidth, %c1
6044 // %sign = %G_ASHR %lhs, $(bitwidth - 1)
6045 // %lshr = G_LSHR %sign, %inexact
6046 // %add = G_ADD %lhs, %lshr
6047 // %ashr = G_ASHR %add, %c1
6048 // %ashr = G_SELECT, %isoneorallones, %lhs, %ashr
6049 // %zero = G_CONSTANT $0
6050 // %neg = G_NEG %ashr
6051 // %isneg = G_ICMP SLT %rhs, %zero
6052 // %res = G_SELECT %isneg, %neg, %ashr
6053
6054 unsigned BitWidth = Ty.getScalarSizeInBits();
6055 auto Zero = Builder.buildConstant(Res: Ty, Val: 0);
6056
6057 auto Bits = Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth);
6058 auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS);
6059 auto Inexact = Builder.buildSub(Dst: ShiftAmtTy, Src0: Bits, Src1: C1);
6060 // Splat the sign bit into the register
6061 auto Sign = Builder.buildAShr(
6062 Dst: Ty, Src0: LHS, Src1: Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth - 1));
6063
6064 // Add (LHS < 0) ? abs2 - 1 : 0;
6065 auto LSrl = Builder.buildLShr(Dst: Ty, Src0: Sign, Src1: Inexact);
6066 auto Add = Builder.buildAdd(Dst: Ty, Src0: LHS, Src1: LSrl);
6067 auto AShr = Builder.buildAShr(Dst: Ty, Src0: Add, Src1: C1);
6068
6069 // Special case: (sdiv X, 1) -> X
6070 // Special Case: (sdiv X, -1) -> 0-X
6071 auto One = Builder.buildConstant(Res: Ty, Val: 1);
6072 auto MinusOne = Builder.buildConstant(Res: Ty, Val: -1);
6073 auto IsOne = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: One);
6074 auto IsMinusOne =
6075 Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: MinusOne);
6076 auto IsOneOrMinusOne = Builder.buildOr(Dst: CCVT, Src0: IsOne, Src1: IsMinusOne);
6077 AShr = Builder.buildSelect(Res: Ty, Tst: IsOneOrMinusOne, Op0: LHS, Op1: AShr);
6078
6079 // If divided by a positive value, we're done. Otherwise, the result must be
6080 // negated.
6081 auto Neg = Builder.buildNeg(Dst: Ty, Src0: AShr);
6082 auto IsNeg = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_SLT, Res: CCVT, Op0: RHS, Op1: Zero);
6083 Builder.buildSelect(Res: MI.getOperand(i: 0).getReg(), Tst: IsNeg, Op0: Neg, Op1: AShr);
6084 MI.eraseFromParent();
6085}
6086
6087void CombinerHelper::applyUDivByPow2(MachineInstr &MI) const {
6088 assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV");
6089 auto &UDiv = cast<GenericMachineInstr>(Val&: MI);
6090 Register Dst = UDiv.getReg(Idx: 0);
6091 Register LHS = UDiv.getReg(Idx: 1);
6092 Register RHS = UDiv.getReg(Idx: 2);
6093 LLT Ty = MRI.getType(Reg: Dst);
6094 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6095
6096 auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS);
6097 Builder.buildLShr(Dst: MI.getOperand(i: 0).getReg(), Src0: LHS, Src1: C1);
6098 MI.eraseFromParent();
6099}
6100
6101bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) const {
6102 assert(MI.getOpcode() == TargetOpcode::G_UMULH);
6103 Register RHS = MI.getOperand(i: 2).getReg();
6104 Register Dst = MI.getOperand(i: 0).getReg();
6105 LLT Ty = MRI.getType(Reg: Dst);
6106 LLT RHSTy = MRI.getType(Reg: RHS);
6107 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6108 auto MatchPow2ExceptOne = [&](const Constant *C) {
6109 if (auto *CI = dyn_cast<ConstantInt>(Val: C))
6110 return CI->getValue().isPowerOf2() && !CI->getValue().isOne();
6111 return false;
6112 };
6113 if (!matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2ExceptOne, AllowUndefs: false))
6114 return false;
6115 // We need to check both G_LSHR and G_CTLZ because the combine uses G_CTLZ to
6116 // get log base 2, and it is not always legal for on a target.
6117 return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}}) &&
6118 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CTLZ, {RHSTy, RHSTy}});
6119}
6120
6121void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) const {
6122 Register LHS = MI.getOperand(i: 1).getReg();
6123 Register RHS = MI.getOperand(i: 2).getReg();
6124 Register Dst = MI.getOperand(i: 0).getReg();
6125 LLT Ty = MRI.getType(Reg: Dst);
6126 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
6127 unsigned NumEltBits = Ty.getScalarSizeInBits();
6128
6129 auto LogBase2 = buildLogBase2(V: RHS, MIB&: Builder);
6130 auto ShiftAmt =
6131 Builder.buildSub(Dst: Ty, Src0: Builder.buildConstant(Res: Ty, Val: NumEltBits), Src1: LogBase2);
6132 auto Trunc = Builder.buildZExtOrTrunc(Res: ShiftAmtTy, Op: ShiftAmt);
6133 Builder.buildLShr(Dst, Src0: LHS, Src1: Trunc);
6134 MI.eraseFromParent();
6135}
6136
6137bool CombinerHelper::matchTruncSSatS(MachineInstr &MI,
6138 Register &MatchInfo) const {
6139 Register Dst = MI.getOperand(i: 0).getReg();
6140 Register Src = MI.getOperand(i: 1).getReg();
6141 LLT DstTy = MRI.getType(Reg: Dst);
6142 LLT SrcTy = MRI.getType(Reg: Src);
6143 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6144 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6145 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6146
6147 if (!LI || !isLegalOrHasFewerElements(
6148 Query: {TargetOpcode::G_TRUNC_SSAT_S, {DstTy, SrcTy}}))
6149 return false;
6150
6151 APInt SignedMax = APInt::getSignedMaxValue(numBits: NumDstBits).sext(width: NumSrcBits);
6152 APInt SignedMin = APInt::getSignedMinValue(numBits: NumDstBits).sext(width: NumSrcBits);
6153 return mi_match(R: Src, MRI,
6154 P: m_GSMin(L: m_GSMax(L: m_Reg(R&: MatchInfo),
6155 R: m_SpecificICstOrSplat(RequestedValue: SignedMin)),
6156 R: m_SpecificICstOrSplat(RequestedValue: SignedMax))) ||
6157 mi_match(R: Src, MRI,
6158 P: m_GSMax(L: m_GSMin(L: m_Reg(R&: MatchInfo),
6159 R: m_SpecificICstOrSplat(RequestedValue: SignedMax)),
6160 R: m_SpecificICstOrSplat(RequestedValue: SignedMin)));
6161}
6162
6163void CombinerHelper::applyTruncSSatS(MachineInstr &MI,
6164 Register &MatchInfo) const {
6165 Register Dst = MI.getOperand(i: 0).getReg();
6166 Builder.buildTruncSSatS(Res: Dst, Op: MatchInfo);
6167 MI.eraseFromParent();
6168}
6169
6170bool CombinerHelper::matchTruncSSatU(MachineInstr &MI,
6171 Register &MatchInfo) const {
6172 Register Dst = MI.getOperand(i: 0).getReg();
6173 Register Src = MI.getOperand(i: 1).getReg();
6174 LLT DstTy = MRI.getType(Reg: Dst);
6175 LLT SrcTy = MRI.getType(Reg: Src);
6176 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6177 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6178 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6179
6180 if (!LI || !isLegalOrHasFewerElements(
6181 Query: {TargetOpcode::G_TRUNC_SSAT_U, {DstTy, SrcTy}}))
6182 return false;
6183 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
6184 return mi_match(R: Src, MRI,
6185 P: m_GSMin(L: m_GSMax(L: m_Reg(R&: MatchInfo), R: m_SpecificICstOrSplat(RequestedValue: 0)),
6186 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax))) ||
6187 mi_match(R: Src, MRI,
6188 P: m_GSMax(L: m_GSMin(L: m_Reg(R&: MatchInfo),
6189 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)),
6190 R: m_SpecificICstOrSplat(RequestedValue: 0))) ||
6191 mi_match(R: Src, MRI,
6192 P: m_GUMin(L: m_GSMax(L: m_Reg(R&: MatchInfo), R: m_SpecificICstOrSplat(RequestedValue: 0)),
6193 R: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)));
6194}
6195
6196void CombinerHelper::applyTruncSSatU(MachineInstr &MI,
6197 Register &MatchInfo) const {
6198 Register Dst = MI.getOperand(i: 0).getReg();
6199 Builder.buildTruncSSatU(Res: Dst, Op: MatchInfo);
6200 MI.eraseFromParent();
6201}
6202
6203bool CombinerHelper::matchTruncUSatU(MachineInstr &MI,
6204 MachineInstr &MinMI) const {
6205 Register Min = MinMI.getOperand(i: 2).getReg();
6206 Register Val = MinMI.getOperand(i: 1).getReg();
6207 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6208 LLT SrcTy = MRI.getType(Reg: Val);
6209 unsigned NumDstBits = DstTy.getScalarSizeInBits();
6210 unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
6211 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
6212
6213 if (!LI || !isLegalOrHasFewerElements(
6214 Query: {TargetOpcode::G_TRUNC_SSAT_U, {DstTy, SrcTy}}))
6215 return false;
6216 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
6217 return mi_match(R: Min, MRI, P: m_SpecificICstOrSplat(RequestedValue: UnsignedMax)) &&
6218 !mi_match(R: Val, MRI, P: m_GSMax(L: m_Reg(), R: m_Reg()));
6219}
6220
6221bool CombinerHelper::matchTruncUSatUToFPTOUISat(MachineInstr &MI,
6222 MachineInstr &SrcMI) const {
6223 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6224 LLT SrcTy = MRI.getType(Reg: SrcMI.getOperand(i: 1).getReg());
6225
6226 return LI &&
6227 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FPTOUI_SAT, {DstTy, SrcTy}});
6228}
6229
6230bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI,
6231 BuildFnTy &MatchInfo) const {
6232 unsigned Opc = MI.getOpcode();
6233 assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB ||
6234 Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
6235 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA);
6236
6237 Register Dst = MI.getOperand(i: 0).getReg();
6238 Register X = MI.getOperand(i: 1).getReg();
6239 Register Y = MI.getOperand(i: 2).getReg();
6240 LLT Type = MRI.getType(Reg: Dst);
6241
6242 // fold (fadd x, fneg(y)) -> (fsub x, y)
6243 // fold (fadd fneg(y), x) -> (fsub x, y)
6244 // G_ADD is commutative so both cases are checked by m_GFAdd
6245 if (mi_match(R: Dst, MRI, P: m_GFAdd(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) &&
6246 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FSUB, {Type}})) {
6247 Opc = TargetOpcode::G_FSUB;
6248 }
6249 /// fold (fsub x, fneg(y)) -> (fadd x, y)
6250 else if (mi_match(R: Dst, MRI, P: m_GFSub(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) &&
6251 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FADD, {Type}})) {
6252 Opc = TargetOpcode::G_FADD;
6253 }
6254 // fold (fmul fneg(x), fneg(y)) -> (fmul x, y)
6255 // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y)
6256 // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z)
6257 // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z)
6258 else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
6259 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) &&
6260 mi_match(R: X, MRI, P: m_GFNeg(Src: m_Reg(R&: X))) &&
6261 mi_match(R: Y, MRI, P: m_GFNeg(Src: m_Reg(R&: Y)))) {
6262 // no opcode change
6263 } else
6264 return false;
6265
6266 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6267 Observer.changingInstr(MI);
6268 MI.setDesc(B.getTII().get(Opcode: Opc));
6269 MI.getOperand(i: 1).setReg(X);
6270 MI.getOperand(i: 2).setReg(Y);
6271 Observer.changedInstr(MI);
6272 };
6273 return true;
6274}
6275
6276bool CombinerHelper::matchFsubToFneg(MachineInstr &MI,
6277 Register &MatchInfo) const {
6278 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6279
6280 Register LHS = MI.getOperand(i: 1).getReg();
6281 MatchInfo = MI.getOperand(i: 2).getReg();
6282 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6283
6284 const auto LHSCst = Ty.isVector()
6285 ? getFConstantSplat(VReg: LHS, MRI, /* allowUndef */ AllowUndef: true)
6286 : getFConstantVRegValWithLookThrough(VReg: LHS, MRI);
6287 if (!LHSCst)
6288 return false;
6289
6290 // -0.0 is always allowed
6291 if (LHSCst->Value.isNegZero())
6292 return true;
6293
6294 // +0.0 is only allowed if nsz is set.
6295 if (LHSCst->Value.isPosZero())
6296 return MI.getFlag(Flag: MachineInstr::FmNsz);
6297
6298 return false;
6299}
6300
6301void CombinerHelper::applyFsubToFneg(MachineInstr &MI,
6302 Register &MatchInfo) const {
6303 Register Dst = MI.getOperand(i: 0).getReg();
6304 Builder.buildFNeg(
6305 Dst, Src0: Builder.buildFCanonicalize(Dst: MRI.getType(Reg: Dst), Src0: MatchInfo).getReg(Idx: 0));
6306 eraseInst(MI);
6307}
6308
6309/// Checks if \p MI is TargetOpcode::G_FMUL and contractable either
6310/// due to global flags or MachineInstr flags.
6311static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) {
6312 if (MI.getOpcode() != TargetOpcode::G_FMUL)
6313 return false;
6314 return AllowFusionGlobally || MI.getFlag(Flag: MachineInstr::MIFlag::FmContract);
6315}
6316
6317static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1,
6318 const MachineRegisterInfo &MRI) {
6319 return std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI0.getOperand(i: 0).getReg()),
6320 last: MRI.use_instr_nodbg_end()) >
6321 std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI1.getOperand(i: 0).getReg()),
6322 last: MRI.use_instr_nodbg_end());
6323}
6324
6325bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI,
6326 bool &AllowFusionGlobally,
6327 bool &HasFMAD, bool &Aggressive,
6328 bool CanReassociate) const {
6329
6330 auto *MF = MI.getMF();
6331 const auto &TLI = *MF->getSubtarget().getTargetLowering();
6332 const TargetOptions &Options = MF->getTarget().Options;
6333 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6334
6335 if (CanReassociate && !MI.getFlag(Flag: MachineInstr::MIFlag::FmReassoc))
6336 return false;
6337
6338 // Floating-point multiply-add with intermediate rounding.
6339 HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, Ty: DstType));
6340 // Floating-point multiply-add without intermediate rounding.
6341 bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(MF: *MF, DstType) &&
6342 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FMA, {DstType}});
6343 // No valid opcode, do not combine.
6344 if (!HasFMAD && !HasFMA)
6345 return false;
6346
6347 AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD;
6348 // If the addition is not contractable, do not combine.
6349 if (!AllowFusionGlobally && !MI.getFlag(Flag: MachineInstr::MIFlag::FmContract))
6350 return false;
6351
6352 Aggressive = TLI.enableAggressiveFMAFusion(Ty: DstType);
6353 return true;
6354}
6355
6356bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA(
6357 MachineInstr &MI,
6358 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6359 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6360
6361 bool AllowFusionGlobally, HasFMAD, Aggressive;
6362 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6363 return false;
6364
6365 Register Op1 = MI.getOperand(i: 1).getReg();
6366 Register Op2 = MI.getOperand(i: 2).getReg();
6367 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6368 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6369 unsigned PreferredFusedOpcode =
6370 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6371
6372 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6373 // prefer to fold the multiply with fewer uses.
6374 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6375 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6376 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6377 std::swap(a&: LHS, b&: RHS);
6378 }
6379
6380 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
6381 if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6382 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg))) {
6383 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6384 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6385 SrcOps: {LHS.MI->getOperand(i: 1).getReg(),
6386 LHS.MI->getOperand(i: 2).getReg(), RHS.Reg});
6387 };
6388 return true;
6389 }
6390
6391 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
6392 if (isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6393 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg))) {
6394 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6395 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6396 SrcOps: {RHS.MI->getOperand(i: 1).getReg(),
6397 RHS.MI->getOperand(i: 2).getReg(), LHS.Reg});
6398 };
6399 return true;
6400 }
6401
6402 return false;
6403}
6404
6405bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA(
6406 MachineInstr &MI,
6407 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6408 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6409
6410 bool AllowFusionGlobally, HasFMAD, Aggressive;
6411 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6412 return false;
6413
6414 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6415 Register Op1 = MI.getOperand(i: 1).getReg();
6416 Register Op2 = MI.getOperand(i: 2).getReg();
6417 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6418 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6419 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6420
6421 unsigned PreferredFusedOpcode =
6422 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6423
6424 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6425 // prefer to fold the multiply with fewer uses.
6426 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6427 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6428 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6429 std::swap(a&: LHS, b&: RHS);
6430 }
6431
6432 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
6433 MachineInstr *FpExtSrc;
6434 if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) &&
6435 isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) &&
6436 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6437 SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) {
6438 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6439 auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg());
6440 auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg());
6441 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6442 SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), RHS.Reg});
6443 };
6444 return true;
6445 }
6446
6447 // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z)
6448 // Note: Commutes FADD operands.
6449 if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) &&
6450 isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) &&
6451 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6452 SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) {
6453 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6454 auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg());
6455 auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg());
6456 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6457 SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), LHS.Reg});
6458 };
6459 return true;
6460 }
6461
6462 return false;
6463}
6464
6465bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA(
6466 MachineInstr &MI,
6467 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6468 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6469
6470 bool AllowFusionGlobally, HasFMAD, Aggressive;
6471 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, CanReassociate: true))
6472 return false;
6473
6474 Register Op1 = MI.getOperand(i: 1).getReg();
6475 Register Op2 = MI.getOperand(i: 2).getReg();
6476 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6477 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6478 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6479
6480 unsigned PreferredFusedOpcode =
6481 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6482
6483 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6484 // prefer to fold the multiply with fewer uses.
6485 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6486 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6487 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6488 std::swap(a&: LHS, b&: RHS);
6489 }
6490
6491 MachineInstr *FMA = nullptr;
6492 Register Z;
6493 // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z))
6494 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
6495 (MRI.getVRegDef(Reg: LHS.MI->getOperand(i: 3).getReg())->getOpcode() ==
6496 TargetOpcode::G_FMUL) &&
6497 MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 0).getReg()) &&
6498 MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 3).getReg())) {
6499 FMA = LHS.MI;
6500 Z = RHS.Reg;
6501 }
6502 // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z))
6503 else if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
6504 (MRI.getVRegDef(Reg: RHS.MI->getOperand(i: 3).getReg())->getOpcode() ==
6505 TargetOpcode::G_FMUL) &&
6506 MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 0).getReg()) &&
6507 MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 3).getReg())) {
6508 Z = LHS.Reg;
6509 FMA = RHS.MI;
6510 }
6511
6512 if (FMA) {
6513 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMA->getOperand(i: 3).getReg());
6514 Register X = FMA->getOperand(i: 1).getReg();
6515 Register Y = FMA->getOperand(i: 2).getReg();
6516 Register U = FMulMI->getOperand(i: 1).getReg();
6517 Register V = FMulMI->getOperand(i: 2).getReg();
6518
6519 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6520 Register InnerFMA = MRI.createGenericVirtualRegister(Ty: DstTy);
6521 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {InnerFMA}, SrcOps: {U, V, Z});
6522 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6523 SrcOps: {X, Y, InnerFMA});
6524 };
6525 return true;
6526 }
6527
6528 return false;
6529}
6530
6531bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
6532 MachineInstr &MI,
6533 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6534 assert(MI.getOpcode() == TargetOpcode::G_FADD);
6535
6536 bool AllowFusionGlobally, HasFMAD, Aggressive;
6537 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6538 return false;
6539
6540 if (!Aggressive)
6541 return false;
6542
6543 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6544 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6545 Register Op1 = MI.getOperand(i: 1).getReg();
6546 Register Op2 = MI.getOperand(i: 2).getReg();
6547 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6548 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6549
6550 unsigned PreferredFusedOpcode =
6551 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6552
6553 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6554 // prefer to fold the multiply with fewer uses.
6555 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6556 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
6557 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6558 std::swap(a&: LHS, b&: RHS);
6559 }
6560
6561 // Builds: (fma x, y, (fma (fpext u), (fpext v), z))
6562 auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X,
6563 Register Y, MachineIRBuilder &B) {
6564 Register FpExtU = B.buildFPExt(Res: DstType, Op: U).getReg(Idx: 0);
6565 Register FpExtV = B.buildFPExt(Res: DstType, Op: V).getReg(Idx: 0);
6566 Register InnerFMA =
6567 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {DstType}, SrcOps: {FpExtU, FpExtV, Z})
6568 .getReg(Idx: 0);
6569 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6570 SrcOps: {X, Y, InnerFMA});
6571 };
6572
6573 MachineInstr *FMulMI, *FMAMI;
6574 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
6575 // -> (fma x, y, (fma (fpext u), (fpext v), z))
6576 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
6577 mi_match(R: LHS.MI->getOperand(i: 3).getReg(), MRI,
6578 P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6579 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6580 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6581 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6582 MatchInfo = [=](MachineIRBuilder &B) {
6583 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6584 FMulMI->getOperand(i: 2).getReg(), RHS.Reg,
6585 LHS.MI->getOperand(i: 1).getReg(),
6586 LHS.MI->getOperand(i: 2).getReg(), B);
6587 };
6588 return true;
6589 }
6590
6591 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
6592 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
6593 // FIXME: This turns two single-precision and one double-precision
6594 // operation into two double-precision operations, which might not be
6595 // interesting for all targets, especially GPUs.
6596 if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) &&
6597 FMAMI->getOpcode() == PreferredFusedOpcode) {
6598 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg());
6599 if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6600 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6601 SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) {
6602 MatchInfo = [=](MachineIRBuilder &B) {
6603 Register X = FMAMI->getOperand(i: 1).getReg();
6604 Register Y = FMAMI->getOperand(i: 2).getReg();
6605 X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0);
6606 Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0);
6607 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6608 FMulMI->getOperand(i: 2).getReg(), RHS.Reg, X, Y, B);
6609 };
6610
6611 return true;
6612 }
6613 }
6614
6615 // fold (fadd z, (fma x, y, (fpext (fmul u, v)))
6616 // -> (fma x, y, (fma (fpext u), (fpext v), z))
6617 if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
6618 mi_match(R: RHS.MI->getOperand(i: 3).getReg(), MRI,
6619 P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6620 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6621 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6622 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6623 MatchInfo = [=](MachineIRBuilder &B) {
6624 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6625 FMulMI->getOperand(i: 2).getReg(), LHS.Reg,
6626 RHS.MI->getOperand(i: 1).getReg(),
6627 RHS.MI->getOperand(i: 2).getReg(), B);
6628 };
6629 return true;
6630 }
6631
6632 // fold (fadd z, (fpext (fma x, y, (fmul u, v)))
6633 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
6634 // FIXME: This turns two single-precision and one double-precision
6635 // operation into two double-precision operations, which might not be
6636 // interesting for all targets, especially GPUs.
6637 if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) &&
6638 FMAMI->getOpcode() == PreferredFusedOpcode) {
6639 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg());
6640 if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6641 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
6642 SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) {
6643 MatchInfo = [=](MachineIRBuilder &B) {
6644 Register X = FMAMI->getOperand(i: 1).getReg();
6645 Register Y = FMAMI->getOperand(i: 2).getReg();
6646 X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0);
6647 Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0);
6648 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
6649 FMulMI->getOperand(i: 2).getReg(), LHS.Reg, X, Y, B);
6650 };
6651 return true;
6652 }
6653 }
6654
6655 return false;
6656}
6657
6658bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA(
6659 MachineInstr &MI,
6660 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6661 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6662
6663 bool AllowFusionGlobally, HasFMAD, Aggressive;
6664 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6665 return false;
6666
6667 Register Op1 = MI.getOperand(i: 1).getReg();
6668 Register Op2 = MI.getOperand(i: 2).getReg();
6669 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
6670 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
6671 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6672
6673 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
6674 // prefer to fold the multiply with fewer uses.
6675 int FirstMulHasFewerUses = true;
6676 if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6677 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6678 hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
6679 FirstMulHasFewerUses = false;
6680
6681 unsigned PreferredFusedOpcode =
6682 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6683
6684 // fold (fsub (fmul x, y), z) -> (fma x, y, -z)
6685 if (FirstMulHasFewerUses &&
6686 (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
6687 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg)))) {
6688 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6689 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHS.Reg).getReg(Idx: 0);
6690 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6691 SrcOps: {LHS.MI->getOperand(i: 1).getReg(),
6692 LHS.MI->getOperand(i: 2).getReg(), NegZ});
6693 };
6694 return true;
6695 }
6696 // fold (fsub x, (fmul y, z)) -> (fma -y, z, x)
6697 else if ((isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
6698 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg)))) {
6699 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6700 Register NegY =
6701 B.buildFNeg(Dst: DstTy, Src0: RHS.MI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6702 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6703 SrcOps: {NegY, RHS.MI->getOperand(i: 2).getReg(), LHS.Reg});
6704 };
6705 return true;
6706 }
6707
6708 return false;
6709}
6710
6711bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA(
6712 MachineInstr &MI,
6713 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6714 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6715
6716 bool AllowFusionGlobally, HasFMAD, Aggressive;
6717 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6718 return false;
6719
6720 Register LHSReg = MI.getOperand(i: 1).getReg();
6721 Register RHSReg = MI.getOperand(i: 2).getReg();
6722 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6723
6724 unsigned PreferredFusedOpcode =
6725 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6726
6727 MachineInstr *FMulMI;
6728 // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z))
6729 if (mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) &&
6730 (Aggressive || (MRI.hasOneNonDBGUse(RegNo: LHSReg) &&
6731 MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) &&
6732 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) {
6733 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6734 Register NegX =
6735 B.buildFNeg(Dst: DstTy, Src0: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6736 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0);
6737 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6738 SrcOps: {NegX, FMulMI->getOperand(i: 2).getReg(), NegZ});
6739 };
6740 return true;
6741 }
6742
6743 // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x)
6744 if (mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) &&
6745 (Aggressive || (MRI.hasOneNonDBGUse(RegNo: RHSReg) &&
6746 MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) &&
6747 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) {
6748 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6749 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6750 SrcOps: {FMulMI->getOperand(i: 1).getReg(),
6751 FMulMI->getOperand(i: 2).getReg(), LHSReg});
6752 };
6753 return true;
6754 }
6755
6756 return false;
6757}
6758
6759bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA(
6760 MachineInstr &MI,
6761 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6762 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6763
6764 bool AllowFusionGlobally, HasFMAD, Aggressive;
6765 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6766 return false;
6767
6768 Register LHSReg = MI.getOperand(i: 1).getReg();
6769 Register RHSReg = MI.getOperand(i: 2).getReg();
6770 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6771
6772 unsigned PreferredFusedOpcode =
6773 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6774
6775 MachineInstr *FMulMI;
6776 // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
6777 if (mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6778 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6779 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHSReg))) {
6780 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6781 Register FpExtX =
6782 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6783 Register FpExtY =
6784 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0);
6785 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0);
6786 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6787 SrcOps: {FpExtX, FpExtY, NegZ});
6788 };
6789 return true;
6790 }
6791
6792 // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
6793 if (mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
6794 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6795 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHSReg))) {
6796 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6797 Register FpExtY =
6798 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
6799 Register NegY = B.buildFNeg(Dst: DstTy, Src0: FpExtY).getReg(Idx: 0);
6800 Register FpExtZ =
6801 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0);
6802 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
6803 SrcOps: {NegY, FpExtZ, LHSReg});
6804 };
6805 return true;
6806 }
6807
6808 return false;
6809}
6810
6811bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
6812 MachineInstr &MI,
6813 std::function<void(MachineIRBuilder &)> &MatchInfo) const {
6814 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
6815
6816 bool AllowFusionGlobally, HasFMAD, Aggressive;
6817 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
6818 return false;
6819
6820 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
6821 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6822 Register LHSReg = MI.getOperand(i: 1).getReg();
6823 Register RHSReg = MI.getOperand(i: 2).getReg();
6824
6825 unsigned PreferredFusedOpcode =
6826 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
6827
6828 auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z,
6829 MachineIRBuilder &B) {
6830 Register FpExtX = B.buildFPExt(Res: DstTy, Op: X).getReg(Idx: 0);
6831 Register FpExtY = B.buildFPExt(Res: DstTy, Op: Y).getReg(Idx: 0);
6832 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {Dst}, SrcOps: {FpExtX, FpExtY, Z});
6833 };
6834
6835 MachineInstr *FMulMI;
6836 // fold (fsub (fpext (fneg (fmul x, y))), z) ->
6837 // (fneg (fma (fpext x), (fpext y), z))
6838 // fold (fsub (fneg (fpext (fmul x, y))), z) ->
6839 // (fneg (fma (fpext x), (fpext y), z))
6840 if ((mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) ||
6841 mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) &&
6842 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6843 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy,
6844 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6845 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6846 Register FMAReg = MRI.createGenericVirtualRegister(Ty: DstTy);
6847 buildMatchInfo(FMAReg, FMulMI->getOperand(i: 1).getReg(),
6848 FMulMI->getOperand(i: 2).getReg(), RHSReg, B);
6849 B.buildFNeg(Dst: MI.getOperand(i: 0).getReg(), Src0: FMAReg);
6850 };
6851 return true;
6852 }
6853
6854 // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
6855 // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
6856 if ((mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) ||
6857 mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) &&
6858 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
6859 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy,
6860 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
6861 MatchInfo = [=, &MI](MachineIRBuilder &B) {
6862 buildMatchInfo(MI.getOperand(i: 0).getReg(), FMulMI->getOperand(i: 1).getReg(),
6863 FMulMI->getOperand(i: 2).getReg(), LHSReg, B);
6864 };
6865 return true;
6866 }
6867
6868 return false;
6869}
6870
6871bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,
6872 unsigned &IdxToPropagate) const {
6873 bool PropagateNaN;
6874 switch (MI.getOpcode()) {
6875 default:
6876 return false;
6877 case TargetOpcode::G_FMINNUM:
6878 case TargetOpcode::G_FMAXNUM:
6879 PropagateNaN = false;
6880 break;
6881 case TargetOpcode::G_FMINIMUM:
6882 case TargetOpcode::G_FMAXIMUM:
6883 PropagateNaN = true;
6884 break;
6885 }
6886
6887 auto MatchNaN = [&](unsigned Idx) {
6888 Register MaybeNaNReg = MI.getOperand(i: Idx).getReg();
6889 const ConstantFP *MaybeCst = getConstantFPVRegVal(VReg: MaybeNaNReg, MRI);
6890 if (!MaybeCst || !MaybeCst->getValueAPF().isNaN())
6891 return false;
6892 IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1);
6893 return true;
6894 };
6895
6896 return MatchNaN(1) || MatchNaN(2);
6897}
6898
6899// Combine multiple FDIVs with the same divisor into multiple FMULs by the
6900// reciprocal.
6901// E.g., (a / Y; b / Y;) -> (recip = 1.0 / Y; a * recip; b * recip)
6902bool CombinerHelper::matchRepeatedFPDivisor(
6903 MachineInstr &MI, SmallVector<MachineInstr *> &MatchInfo) const {
6904 assert(MI.getOpcode() == TargetOpcode::G_FDIV);
6905
6906 Register X = MI.getOperand(i: 1).getReg();
6907 Register Y = MI.getOperand(i: 2).getReg();
6908
6909 if (!MI.getFlag(Flag: MachineInstr::MIFlag::FmArcp))
6910 return false;
6911
6912 // Skip if current node is a reciprocal/fneg-reciprocal.
6913 auto N0CFP = isConstantOrConstantSplatVectorFP(MI&: *MRI.getVRegDef(Reg: X), MRI);
6914 if (N0CFP && (N0CFP->isExactlyValue(V: 1.0) || N0CFP->isExactlyValue(V: -1.0)))
6915 return false;
6916
6917 // Exit early if the target does not want this transform or if there can't
6918 // possibly be enough uses of the divisor to make the transform worthwhile.
6919 unsigned MinUses = getTargetLowering().combineRepeatedFPDivisors();
6920 if (!MinUses)
6921 return false;
6922
6923 // Find all FDIV users of the same divisor. For the moment we limit all
6924 // instructions to a single BB and use the first Instr in MatchInfo as the
6925 // dominating position.
6926 MatchInfo.push_back(Elt: &MI);
6927 for (auto &U : MRI.use_nodbg_instructions(Reg: Y)) {
6928 if (&U == &MI || U.getParent() != MI.getParent())
6929 continue;
6930 if (U.getOpcode() == TargetOpcode::G_FDIV &&
6931 U.getOperand(i: 2).getReg() == Y && U.getOperand(i: 1).getReg() != Y) {
6932 // This division is eligible for optimization only if global unsafe math
6933 // is enabled or if this division allows reciprocal formation.
6934 if (U.getFlag(Flag: MachineInstr::MIFlag::FmArcp)) {
6935 MatchInfo.push_back(Elt: &U);
6936 if (dominates(DefMI: U, UseMI: *MatchInfo[0]))
6937 std::swap(a&: MatchInfo[0], b&: MatchInfo.back());
6938 }
6939 }
6940 }
6941
6942 // Now that we have the actual number of divisor uses, make sure it meets
6943 // the minimum threshold specified by the target.
6944 return MatchInfo.size() >= MinUses;
6945}
6946
6947void CombinerHelper::applyRepeatedFPDivisor(
6948 SmallVector<MachineInstr *> &MatchInfo) const {
6949 // Generate the new div at the position of the first instruction, that we have
6950 // ensured will dominate all other instructions.
6951 Builder.setInsertPt(MBB&: *MatchInfo[0]->getParent(), II: MatchInfo[0]);
6952 LLT Ty = MRI.getType(Reg: MatchInfo[0]->getOperand(i: 0).getReg());
6953 auto Div = Builder.buildFDiv(Dst: Ty, Src0: Builder.buildFConstant(Res: Ty, Val: 1.0),
6954 Src1: MatchInfo[0]->getOperand(i: 2).getReg(),
6955 Flags: MatchInfo[0]->getFlags());
6956
6957 // Replace all found div's with fmul instructions.
6958 for (MachineInstr *MI : MatchInfo) {
6959 Builder.setInsertPt(MBB&: *MI->getParent(), II: MI);
6960 Builder.buildFMul(Dst: MI->getOperand(i: 0).getReg(), Src0: MI->getOperand(i: 1).getReg(),
6961 Src1: Div->getOperand(i: 0).getReg(), Flags: MI->getFlags());
6962 MI->eraseFromParent();
6963 }
6964}
6965
6966bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) const {
6967 assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD");
6968 Register LHS = MI.getOperand(i: 1).getReg();
6969 Register RHS = MI.getOperand(i: 2).getReg();
6970
6971 // Helper lambda to check for opportunities for
6972 // A + (B - A) -> B
6973 // (B - A) + A -> B
6974 auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) {
6975 Register Reg;
6976 return mi_match(R: MaybeSub, MRI, P: m_GSub(L: m_Reg(R&: Src), R: m_Reg(R&: Reg))) &&
6977 Reg == MaybeSameReg;
6978 };
6979 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
6980}
6981
6982bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI,
6983 Register &MatchInfo) const {
6984 // This combine folds the following patterns:
6985 //
6986 // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k))
6987 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k)))
6988 // into
6989 // x
6990 // if
6991 // k == sizeof(VecEltTy)/2
6992 // type(x) == type(dst)
6993 //
6994 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef)
6995 // into
6996 // x
6997 // if
6998 // type(x) == type(dst)
6999
7000 LLT DstVecTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
7001 LLT DstEltTy = DstVecTy.getElementType();
7002
7003 Register Lo, Hi;
7004
7005 if (mi_match(
7006 MI, MRI,
7007 P: m_GBuildVector(L: m_GTrunc(Src: m_GBitcast(Src: m_Reg(R&: Lo))), R: m_GImplicitDef()))) {
7008 MatchInfo = Lo;
7009 return MRI.getType(Reg: MatchInfo) == DstVecTy;
7010 }
7011
7012 std::optional<ValueAndVReg> ShiftAmount;
7013 const auto LoPattern = m_GBitcast(Src: m_Reg(R&: Lo));
7014 const auto HiPattern = m_GLShr(L: m_GBitcast(Src: m_Reg(R&: Hi)), R: m_GCst(ValReg&: ShiftAmount));
7015 if (mi_match(
7016 MI, MRI,
7017 P: m_any_of(preds: m_GBuildVectorTrunc(L: LoPattern, R: HiPattern),
7018 preds: m_GBuildVector(L: m_GTrunc(Src: LoPattern), R: m_GTrunc(Src: HiPattern))))) {
7019 if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) {
7020 MatchInfo = Lo;
7021 return MRI.getType(Reg: MatchInfo) == DstVecTy;
7022 }
7023 }
7024
7025 return false;
7026}
7027
7028bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI,
7029 Register &MatchInfo) const {
7030 // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x
7031 // if type(x) == type(G_TRUNC)
7032 if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI,
7033 P: m_GBitcast(Src: m_GBuildVector(L: m_Reg(R&: MatchInfo), R: m_Reg()))))
7034 return false;
7035
7036 return MRI.getType(Reg: MatchInfo) == MRI.getType(Reg: MI.getOperand(i: 0).getReg());
7037}
7038
7039bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI,
7040 Register &MatchInfo) const {
7041 // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with
7042 // y if K == size of vector element type
7043 std::optional<ValueAndVReg> ShiftAmt;
7044 if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI,
7045 P: m_GLShr(L: m_GBitcast(Src: m_GBuildVector(L: m_Reg(), R: m_Reg(R&: MatchInfo))),
7046 R: m_GCst(ValReg&: ShiftAmt))))
7047 return false;
7048
7049 LLT MatchTy = MRI.getType(Reg: MatchInfo);
7050 return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() &&
7051 MatchTy == MRI.getType(Reg: MI.getOperand(i: 0).getReg());
7052}
7053
7054unsigned CombinerHelper::getFPMinMaxOpcForSelect(
7055 CmpInst::Predicate Pred, LLT DstTy,
7056 SelectPatternNaNBehaviour VsNaNRetVal) const {
7057 assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE &&
7058 "Expected a NaN behaviour?");
7059 // Choose an opcode based off of legality or the behaviour when one of the
7060 // LHS/RHS may be NaN.
7061 switch (Pred) {
7062 default:
7063 return 0;
7064 case CmpInst::FCMP_UGT:
7065 case CmpInst::FCMP_UGE:
7066 case CmpInst::FCMP_OGT:
7067 case CmpInst::FCMP_OGE:
7068 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
7069 return TargetOpcode::G_FMAXNUM;
7070 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
7071 return TargetOpcode::G_FMAXIMUM;
7072 if (isLegal(Query: {TargetOpcode::G_FMAXNUM, {DstTy}}))
7073 return TargetOpcode::G_FMAXNUM;
7074 if (isLegal(Query: {TargetOpcode::G_FMAXIMUM, {DstTy}}))
7075 return TargetOpcode::G_FMAXIMUM;
7076 return 0;
7077 case CmpInst::FCMP_ULT:
7078 case CmpInst::FCMP_ULE:
7079 case CmpInst::FCMP_OLT:
7080 case CmpInst::FCMP_OLE:
7081 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
7082 return TargetOpcode::G_FMINNUM;
7083 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
7084 return TargetOpcode::G_FMINIMUM;
7085 if (isLegal(Query: {TargetOpcode::G_FMINNUM, {DstTy}}))
7086 return TargetOpcode::G_FMINNUM;
7087 if (!isLegal(Query: {TargetOpcode::G_FMINIMUM, {DstTy}}))
7088 return 0;
7089 return TargetOpcode::G_FMINIMUM;
7090 }
7091}
7092
7093CombinerHelper::SelectPatternNaNBehaviour
7094CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS,
7095 bool IsOrderedComparison) const {
7096 bool LHSSafe = isKnownNeverNaN(Val: LHS, MRI);
7097 bool RHSSafe = isKnownNeverNaN(Val: RHS, MRI);
7098 // Completely unsafe.
7099 if (!LHSSafe && !RHSSafe)
7100 return SelectPatternNaNBehaviour::NOT_APPLICABLE;
7101 if (LHSSafe && RHSSafe)
7102 return SelectPatternNaNBehaviour::RETURNS_ANY;
7103 // An ordered comparison will return false when given a NaN, so it
7104 // returns the RHS.
7105 if (IsOrderedComparison)
7106 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN
7107 : SelectPatternNaNBehaviour::RETURNS_OTHER;
7108 // An unordered comparison will return true when given a NaN, so it
7109 // returns the LHS.
7110 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER
7111 : SelectPatternNaNBehaviour::RETURNS_NAN;
7112}
7113
7114bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond,
7115 Register TrueVal, Register FalseVal,
7116 BuildFnTy &MatchInfo) const {
7117 // Match: select (fcmp cond x, y) x, y
7118 // select (fcmp cond x, y) y, x
7119 // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition.
7120 LLT DstTy = MRI.getType(Reg: Dst);
7121 // Bail out early on pointers, since we'll never want to fold to a min/max.
7122 if (DstTy.isPointer())
7123 return false;
7124 // Match a floating point compare with a less-than/greater-than predicate.
7125 // TODO: Allow multiple users of the compare if they are all selects.
7126 CmpInst::Predicate Pred;
7127 Register CmpLHS, CmpRHS;
7128 if (!mi_match(R: Cond, MRI,
7129 P: m_OneNonDBGUse(
7130 SP: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: CmpLHS), R: m_Reg(R&: CmpRHS)))) ||
7131 CmpInst::isEquality(pred: Pred))
7132 return false;
7133 SelectPatternNaNBehaviour ResWithKnownNaNInfo =
7134 computeRetValAgainstNaN(LHS: CmpLHS, RHS: CmpRHS, IsOrderedComparison: CmpInst::isOrdered(predicate: Pred));
7135 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE)
7136 return false;
7137 if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
7138 std::swap(a&: CmpLHS, b&: CmpRHS);
7139 Pred = CmpInst::getSwappedPredicate(pred: Pred);
7140 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN)
7141 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER;
7142 else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER)
7143 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN;
7144 }
7145 if (TrueVal != CmpLHS || FalseVal != CmpRHS)
7146 return false;
7147 // Decide what type of max/min this should be based off of the predicate.
7148 unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, VsNaNRetVal: ResWithKnownNaNInfo);
7149 if (!Opc || !isLegal(Query: {Opc, {DstTy}}))
7150 return false;
7151 // Comparisons between signed zero and zero may have different results...
7152 // unless we have fmaximum/fminimum. In that case, we know -0 < 0.
7153 if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) {
7154 // We don't know if a comparison between two 0s will give us a consistent
7155 // result. Be conservative and only proceed if at least one side is
7156 // non-zero.
7157 auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpLHS, MRI);
7158 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) {
7159 KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpRHS, MRI);
7160 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero())
7161 return false;
7162 }
7163 }
7164 MatchInfo = [=](MachineIRBuilder &B) {
7165 B.buildInstr(Opc, DstOps: {Dst}, SrcOps: {CmpLHS, CmpRHS});
7166 };
7167 return true;
7168}
7169
7170bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI,
7171 BuildFnTy &MatchInfo) const {
7172 // TODO: Handle integer cases.
7173 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
7174 // Condition may be fed by a truncated compare.
7175 Register Cond = MI.getOperand(i: 1).getReg();
7176 Register MaybeTrunc;
7177 if (mi_match(R: Cond, MRI, P: m_OneNonDBGUse(SP: m_GTrunc(Src: m_Reg(R&: MaybeTrunc)))))
7178 Cond = MaybeTrunc;
7179 Register Dst = MI.getOperand(i: 0).getReg();
7180 Register TrueVal = MI.getOperand(i: 2).getReg();
7181 Register FalseVal = MI.getOperand(i: 3).getReg();
7182 return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo);
7183}
7184
7185bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI,
7186 BuildFnTy &MatchInfo) const {
7187 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
7188 // (X + Y) == X --> Y == 0
7189 // (X + Y) != X --> Y != 0
7190 // (X - Y) == X --> Y == 0
7191 // (X - Y) != X --> Y != 0
7192 // (X ^ Y) == X --> Y == 0
7193 // (X ^ Y) != X --> Y != 0
7194 Register Dst = MI.getOperand(i: 0).getReg();
7195 CmpInst::Predicate Pred;
7196 Register X, Y, OpLHS, OpRHS;
7197 bool MatchedSub = mi_match(
7198 R: Dst, MRI,
7199 P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), R: m_GSub(L: m_Reg(R&: OpLHS), R: m_Reg(R&: Y))));
7200 if (MatchedSub && X != OpLHS)
7201 return false;
7202 if (!MatchedSub) {
7203 if (!mi_match(R: Dst, MRI,
7204 P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X),
7205 R: m_any_of(preds: m_GAdd(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)),
7206 preds: m_GXor(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS))))))
7207 return false;
7208 Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register();
7209 }
7210 MatchInfo = [=](MachineIRBuilder &B) {
7211 auto Zero = B.buildConstant(Res: MRI.getType(Reg: Y), Val: 0);
7212 B.buildICmp(Pred, Res: Dst, Op0: Y, Op1: Zero);
7213 };
7214 return CmpInst::isEquality(pred: Pred) && Y.isValid();
7215}
7216
7217/// Return the minimum useless shift amount that results in complete loss of the
7218/// source value. Return std::nullopt when it cannot determine a value.
7219static std::optional<unsigned>
7220getMinUselessShift(KnownBits ValueKB, unsigned Opcode,
7221 std::optional<int64_t> &Result) {
7222 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR ||
7223 Opcode == TargetOpcode::G_ASHR) &&
7224 "Expect G_SHL, G_LSHR or G_ASHR.");
7225 auto SignificantBits = 0;
7226 switch (Opcode) {
7227 case TargetOpcode::G_SHL:
7228 SignificantBits = ValueKB.countMinTrailingZeros();
7229 Result = 0;
7230 break;
7231 case TargetOpcode::G_LSHR:
7232 Result = 0;
7233 SignificantBits = ValueKB.countMinLeadingZeros();
7234 break;
7235 case TargetOpcode::G_ASHR:
7236 if (ValueKB.isNonNegative()) {
7237 SignificantBits = ValueKB.countMinLeadingZeros();
7238 Result = 0;
7239 } else if (ValueKB.isNegative()) {
7240 SignificantBits = ValueKB.countMinLeadingOnes();
7241 Result = -1;
7242 } else {
7243 // Cannot determine shift result.
7244 Result = std::nullopt;
7245 }
7246 break;
7247 default:
7248 break;
7249 }
7250 return ValueKB.getBitWidth() - SignificantBits;
7251}
7252
7253bool CombinerHelper::matchShiftsTooBig(
7254 MachineInstr &MI, std::optional<int64_t> &MatchInfo) const {
7255 Register ShiftVal = MI.getOperand(i: 1).getReg();
7256 Register ShiftReg = MI.getOperand(i: 2).getReg();
7257 LLT ResTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
7258 auto IsShiftTooBig = [&](const Constant *C) {
7259 auto *CI = dyn_cast<ConstantInt>(Val: C);
7260 if (!CI)
7261 return false;
7262 if (CI->uge(Num: ResTy.getScalarSizeInBits())) {
7263 MatchInfo = std::nullopt;
7264 return true;
7265 }
7266 auto OptMaxUsefulShift = getMinUselessShift(ValueKB: VT->getKnownBits(R: ShiftVal),
7267 Opcode: MI.getOpcode(), Result&: MatchInfo);
7268 return OptMaxUsefulShift && CI->uge(Num: *OptMaxUsefulShift);
7269 };
7270 return matchUnaryPredicate(MRI, Reg: ShiftReg, Match: IsShiftTooBig);
7271}
7272
7273bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) const {
7274 unsigned LHSOpndIdx = 1;
7275 unsigned RHSOpndIdx = 2;
7276 switch (MI.getOpcode()) {
7277 case TargetOpcode::G_UADDO:
7278 case TargetOpcode::G_SADDO:
7279 case TargetOpcode::G_UMULO:
7280 case TargetOpcode::G_SMULO:
7281 LHSOpndIdx = 2;
7282 RHSOpndIdx = 3;
7283 break;
7284 default:
7285 break;
7286 }
7287 Register LHS = MI.getOperand(i: LHSOpndIdx).getReg();
7288 Register RHS = MI.getOperand(i: RHSOpndIdx).getReg();
7289 if (!getIConstantVRegVal(VReg: LHS, MRI)) {
7290 // Skip commuting if LHS is not a constant. But, LHS may be a
7291 // G_CONSTANT_FOLD_BARRIER. If so we commute as long as we don't already
7292 // have a constant on the RHS.
7293 if (MRI.getVRegDef(Reg: LHS)->getOpcode() !=
7294 TargetOpcode::G_CONSTANT_FOLD_BARRIER)
7295 return false;
7296 }
7297 // Commute as long as RHS is not a constant or G_CONSTANT_FOLD_BARRIER.
7298 return MRI.getVRegDef(Reg: RHS)->getOpcode() !=
7299 TargetOpcode::G_CONSTANT_FOLD_BARRIER &&
7300 !getIConstantVRegVal(VReg: RHS, MRI);
7301}
7302
7303bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) const {
7304 Register LHS = MI.getOperand(i: 1).getReg();
7305 Register RHS = MI.getOperand(i: 2).getReg();
7306 std::optional<FPValueAndVReg> ValAndVReg;
7307 if (!mi_match(R: LHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg)))
7308 return false;
7309 return !mi_match(R: RHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg));
7310}
7311
7312void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) const {
7313 Observer.changingInstr(MI);
7314 unsigned LHSOpndIdx = 1;
7315 unsigned RHSOpndIdx = 2;
7316 switch (MI.getOpcode()) {
7317 case TargetOpcode::G_UADDO:
7318 case TargetOpcode::G_SADDO:
7319 case TargetOpcode::G_UMULO:
7320 case TargetOpcode::G_SMULO:
7321 LHSOpndIdx = 2;
7322 RHSOpndIdx = 3;
7323 break;
7324 default:
7325 break;
7326 }
7327 Register LHSReg = MI.getOperand(i: LHSOpndIdx).getReg();
7328 Register RHSReg = MI.getOperand(i: RHSOpndIdx).getReg();
7329 MI.getOperand(i: LHSOpndIdx).setReg(RHSReg);
7330 MI.getOperand(i: RHSOpndIdx).setReg(LHSReg);
7331 Observer.changedInstr(MI);
7332}
7333
7334bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) const {
7335 LLT SrcTy = MRI.getType(Reg: Src);
7336 if (SrcTy.isFixedVector())
7337 return isConstantSplatVector(Src, SplatValue: 1, AllowUndefs);
7338 if (SrcTy.isScalar()) {
7339 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr)
7340 return true;
7341 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7342 return IConstant && IConstant->Value == 1;
7343 }
7344 return false; // scalable vector
7345}
7346
7347bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) const {
7348 LLT SrcTy = MRI.getType(Reg: Src);
7349 if (SrcTy.isFixedVector())
7350 return isConstantSplatVector(Src, SplatValue: 0, AllowUndefs);
7351 if (SrcTy.isScalar()) {
7352 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr)
7353 return true;
7354 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7355 return IConstant && IConstant->Value == 0;
7356 }
7357 return false; // scalable vector
7358}
7359
7360// Ignores COPYs during conformance checks.
7361// FIXME scalable vectors.
7362bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
7363 bool AllowUndefs) const {
7364 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7365 if (!BuildVector)
7366 return false;
7367 unsigned NumSources = BuildVector->getNumSources();
7368
7369 for (unsigned I = 0; I < NumSources; ++I) {
7370 GImplicitDef *ImplicitDef =
7371 getOpcodeDef<GImplicitDef>(Reg: BuildVector->getSourceReg(I), MRI);
7372 if (ImplicitDef && AllowUndefs)
7373 continue;
7374 if (ImplicitDef && !AllowUndefs)
7375 return false;
7376 std::optional<ValueAndVReg> IConstant =
7377 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7378 if (IConstant && IConstant->Value == SplatValue)
7379 continue;
7380 return false;
7381 }
7382 return true;
7383}
7384
7385// Ignores COPYs during lookups.
7386// FIXME scalable vectors
7387std::optional<APInt>
7388CombinerHelper::getConstantOrConstantSplatVector(Register Src) const {
7389 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7390 if (IConstant)
7391 return IConstant->Value;
7392
7393 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7394 if (!BuildVector)
7395 return std::nullopt;
7396 unsigned NumSources = BuildVector->getNumSources();
7397
7398 std::optional<APInt> Value = std::nullopt;
7399 for (unsigned I = 0; I < NumSources; ++I) {
7400 std::optional<ValueAndVReg> IConstant =
7401 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7402 if (!IConstant)
7403 return std::nullopt;
7404 if (!Value)
7405 Value = IConstant->Value;
7406 else if (*Value != IConstant->Value)
7407 return std::nullopt;
7408 }
7409 return Value;
7410}
7411
7412// FIXME G_SPLAT_VECTOR
7413bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const {
7414 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
7415 if (IConstant)
7416 return true;
7417
7418 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
7419 if (!BuildVector)
7420 return false;
7421
7422 unsigned NumSources = BuildVector->getNumSources();
7423 for (unsigned I = 0; I < NumSources; ++I) {
7424 std::optional<ValueAndVReg> IConstant =
7425 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
7426 if (!IConstant)
7427 return false;
7428 }
7429 return true;
7430}
7431
7432// TODO: use knownbits to determine zeros
7433bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
7434 BuildFnTy &MatchInfo) const {
7435 uint32_t Flags = Select->getFlags();
7436 Register Dest = Select->getReg(Idx: 0);
7437 Register Cond = Select->getCondReg();
7438 Register True = Select->getTrueReg();
7439 Register False = Select->getFalseReg();
7440 LLT CondTy = MRI.getType(Reg: Select->getCondReg());
7441 LLT TrueTy = MRI.getType(Reg: Select->getTrueReg());
7442
7443 // We only do this combine for scalar boolean conditions.
7444 if (CondTy != LLT::scalar(SizeInBits: 1))
7445 return false;
7446
7447 if (TrueTy.isPointer())
7448 return false;
7449
7450 // Both are scalars.
7451 std::optional<ValueAndVReg> TrueOpt =
7452 getIConstantVRegValWithLookThrough(VReg: True, MRI);
7453 std::optional<ValueAndVReg> FalseOpt =
7454 getIConstantVRegValWithLookThrough(VReg: False, MRI);
7455
7456 if (!TrueOpt || !FalseOpt)
7457 return false;
7458
7459 APInt TrueValue = TrueOpt->Value;
7460 APInt FalseValue = FalseOpt->Value;
7461
7462 // select Cond, 1, 0 --> zext (Cond)
7463 if (TrueValue.isOne() && FalseValue.isZero()) {
7464 MatchInfo = [=](MachineIRBuilder &B) {
7465 B.setInstrAndDebugLoc(*Select);
7466 B.buildZExtOrTrunc(Res: Dest, Op: Cond);
7467 };
7468 return true;
7469 }
7470
7471 // select Cond, -1, 0 --> sext (Cond)
7472 if (TrueValue.isAllOnes() && FalseValue.isZero()) {
7473 MatchInfo = [=](MachineIRBuilder &B) {
7474 B.setInstrAndDebugLoc(*Select);
7475 B.buildSExtOrTrunc(Res: Dest, Op: Cond);
7476 };
7477 return true;
7478 }
7479
7480 // select Cond, 0, 1 --> zext (!Cond)
7481 if (TrueValue.isZero() && FalseValue.isOne()) {
7482 MatchInfo = [=](MachineIRBuilder &B) {
7483 B.setInstrAndDebugLoc(*Select);
7484 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7485 B.buildNot(Dst: Inner, Src0: Cond);
7486 B.buildZExtOrTrunc(Res: Dest, Op: Inner);
7487 };
7488 return true;
7489 }
7490
7491 // select Cond, 0, -1 --> sext (!Cond)
7492 if (TrueValue.isZero() && FalseValue.isAllOnes()) {
7493 MatchInfo = [=](MachineIRBuilder &B) {
7494 B.setInstrAndDebugLoc(*Select);
7495 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7496 B.buildNot(Dst: Inner, Src0: Cond);
7497 B.buildSExtOrTrunc(Res: Dest, Op: Inner);
7498 };
7499 return true;
7500 }
7501
7502 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
7503 if (TrueValue - 1 == FalseValue) {
7504 MatchInfo = [=](MachineIRBuilder &B) {
7505 B.setInstrAndDebugLoc(*Select);
7506 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7507 B.buildZExtOrTrunc(Res: Inner, Op: Cond);
7508 B.buildAdd(Dst: Dest, Src0: Inner, Src1: False);
7509 };
7510 return true;
7511 }
7512
7513 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
7514 if (TrueValue + 1 == FalseValue) {
7515 MatchInfo = [=](MachineIRBuilder &B) {
7516 B.setInstrAndDebugLoc(*Select);
7517 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7518 B.buildSExtOrTrunc(Res: Inner, Op: Cond);
7519 B.buildAdd(Dst: Dest, Src0: Inner, Src1: False);
7520 };
7521 return true;
7522 }
7523
7524 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
7525 if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
7526 MatchInfo = [=](MachineIRBuilder &B) {
7527 B.setInstrAndDebugLoc(*Select);
7528 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7529 B.buildZExtOrTrunc(Res: Inner, Op: Cond);
7530 // The shift amount must be scalar.
7531 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
7532 auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: TrueValue.exactLogBase2());
7533 B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags);
7534 };
7535 return true;
7536 }
7537
7538 // select Cond, 0, Pow2 --> (zext (!Cond)) << log2(Pow2)
7539 if (FalseValue.isPowerOf2() && TrueValue.isZero()) {
7540 MatchInfo = [=](MachineIRBuilder &B) {
7541 B.setInstrAndDebugLoc(*Select);
7542 Register Not = MRI.createGenericVirtualRegister(Ty: CondTy);
7543 B.buildNot(Dst: Not, Src0: Cond);
7544 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7545 B.buildZExtOrTrunc(Res: Inner, Op: Not);
7546 // The shift amount must be scalar.
7547 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
7548 auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: FalseValue.exactLogBase2());
7549 B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags);
7550 };
7551 return true;
7552 }
7553
7554 // select Cond, -1, C --> or (sext Cond), C
7555 if (TrueValue.isAllOnes()) {
7556 MatchInfo = [=](MachineIRBuilder &B) {
7557 B.setInstrAndDebugLoc(*Select);
7558 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7559 B.buildSExtOrTrunc(Res: Inner, Op: Cond);
7560 B.buildOr(Dst: Dest, Src0: Inner, Src1: False, Flags);
7561 };
7562 return true;
7563 }
7564
7565 // select Cond, C, -1 --> or (sext (not Cond)), C
7566 if (FalseValue.isAllOnes()) {
7567 MatchInfo = [=](MachineIRBuilder &B) {
7568 B.setInstrAndDebugLoc(*Select);
7569 Register Not = MRI.createGenericVirtualRegister(Ty: CondTy);
7570 B.buildNot(Dst: Not, Src0: Cond);
7571 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
7572 B.buildSExtOrTrunc(Res: Inner, Op: Not);
7573 B.buildOr(Dst: Dest, Src0: Inner, Src1: True, Flags);
7574 };
7575 return true;
7576 }
7577
7578 return false;
7579}
7580
7581// TODO: use knownbits to determine zeros
7582bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
7583 BuildFnTy &MatchInfo) const {
7584 uint32_t Flags = Select->getFlags();
7585 Register DstReg = Select->getReg(Idx: 0);
7586 Register Cond = Select->getCondReg();
7587 Register True = Select->getTrueReg();
7588 Register False = Select->getFalseReg();
7589 LLT CondTy = MRI.getType(Reg: Select->getCondReg());
7590 LLT TrueTy = MRI.getType(Reg: Select->getTrueReg());
7591
7592 // Boolean or fixed vector of booleans.
7593 if (CondTy.isScalableVector() ||
7594 (CondTy.isFixedVector() &&
7595 CondTy.getElementType().getScalarSizeInBits() != 1) ||
7596 CondTy.getScalarSizeInBits() != 1)
7597 return false;
7598
7599 if (CondTy != TrueTy)
7600 return false;
7601
7602 // select Cond, Cond, F --> or Cond, F
7603 // select Cond, 1, F --> or Cond, F
7604 if ((Cond == True) || isOneOrOneSplat(Src: True, /* AllowUndefs */ true)) {
7605 MatchInfo = [=](MachineIRBuilder &B) {
7606 B.setInstrAndDebugLoc(*Select);
7607 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7608 B.buildZExtOrTrunc(Res: Ext, Op: Cond);
7609 auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False);
7610 B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeFalse, Flags);
7611 };
7612 return true;
7613 }
7614
7615 // select Cond, T, Cond --> and Cond, T
7616 // select Cond, T, 0 --> and Cond, T
7617 if ((Cond == False) || isZeroOrZeroSplat(Src: False, /* AllowUndefs */ true)) {
7618 MatchInfo = [=](MachineIRBuilder &B) {
7619 B.setInstrAndDebugLoc(*Select);
7620 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7621 B.buildZExtOrTrunc(Res: Ext, Op: Cond);
7622 auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True);
7623 B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeTrue);
7624 };
7625 return true;
7626 }
7627
7628 // select Cond, T, 1 --> or (not Cond), T
7629 if (isOneOrOneSplat(Src: False, /* AllowUndefs */ true)) {
7630 MatchInfo = [=](MachineIRBuilder &B) {
7631 B.setInstrAndDebugLoc(*Select);
7632 // First the not.
7633 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7634 B.buildNot(Dst: Inner, Src0: Cond);
7635 // Then an ext to match the destination register.
7636 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7637 B.buildZExtOrTrunc(Res: Ext, Op: Inner);
7638 auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True);
7639 B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeTrue, Flags);
7640 };
7641 return true;
7642 }
7643
7644 // select Cond, 0, F --> and (not Cond), F
7645 if (isZeroOrZeroSplat(Src: True, /* AllowUndefs */ true)) {
7646 MatchInfo = [=](MachineIRBuilder &B) {
7647 B.setInstrAndDebugLoc(*Select);
7648 // First the not.
7649 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
7650 B.buildNot(Dst: Inner, Src0: Cond);
7651 // Then an ext to match the destination register.
7652 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
7653 B.buildZExtOrTrunc(Res: Ext, Op: Inner);
7654 auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False);
7655 B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeFalse);
7656 };
7657 return true;
7658 }
7659
7660 return false;
7661}
7662
7663bool CombinerHelper::matchSelectIMinMax(const MachineOperand &MO,
7664 BuildFnTy &MatchInfo) const {
7665 GSelect *Select = cast<GSelect>(Val: MRI.getVRegDef(Reg: MO.getReg()));
7666 GICmp *Cmp = cast<GICmp>(Val: MRI.getVRegDef(Reg: Select->getCondReg()));
7667
7668 Register DstReg = Select->getReg(Idx: 0);
7669 Register True = Select->getTrueReg();
7670 Register False = Select->getFalseReg();
7671 LLT DstTy = MRI.getType(Reg: DstReg);
7672
7673 if (DstTy.isPointerOrPointerVector())
7674 return false;
7675
7676 // We want to fold the icmp and replace the select.
7677 if (!MRI.hasOneNonDBGUse(RegNo: Cmp->getReg(Idx: 0)))
7678 return false;
7679
7680 CmpInst::Predicate Pred = Cmp->getCond();
7681 // We need a larger or smaller predicate for
7682 // canonicalization.
7683 if (CmpInst::isEquality(pred: Pred))
7684 return false;
7685
7686 Register CmpLHS = Cmp->getLHSReg();
7687 Register CmpRHS = Cmp->getRHSReg();
7688
7689 // We can swap CmpLHS and CmpRHS for higher hitrate.
7690 if (True == CmpRHS && False == CmpLHS) {
7691 std::swap(a&: CmpLHS, b&: CmpRHS);
7692 Pred = CmpInst::getSwappedPredicate(pred: Pred);
7693 }
7694
7695 // (icmp X, Y) ? X : Y -> integer minmax.
7696 // see matchSelectPattern in ValueTracking.
7697 // Legality between G_SELECT and integer minmax can differ.
7698 if (True != CmpLHS || False != CmpRHS)
7699 return false;
7700
7701 switch (Pred) {
7702 case ICmpInst::ICMP_UGT:
7703 case ICmpInst::ICMP_UGE: {
7704 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMAX, DstTy}))
7705 return false;
7706 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMax(Dst: DstReg, Src0: True, Src1: False); };
7707 return true;
7708 }
7709 case ICmpInst::ICMP_SGT:
7710 case ICmpInst::ICMP_SGE: {
7711 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMAX, DstTy}))
7712 return false;
7713 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMax(Dst: DstReg, Src0: True, Src1: False); };
7714 return true;
7715 }
7716 case ICmpInst::ICMP_ULT:
7717 case ICmpInst::ICMP_ULE: {
7718 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMIN, DstTy}))
7719 return false;
7720 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMin(Dst: DstReg, Src0: True, Src1: False); };
7721 return true;
7722 }
7723 case ICmpInst::ICMP_SLT:
7724 case ICmpInst::ICMP_SLE: {
7725 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMIN, DstTy}))
7726 return false;
7727 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMin(Dst: DstReg, Src0: True, Src1: False); };
7728 return true;
7729 }
7730 default:
7731 return false;
7732 }
7733}
7734
7735// (neg (min/max x, (neg x))) --> (max/min x, (neg x))
7736bool CombinerHelper::matchSimplifyNegMinMax(MachineInstr &MI,
7737 BuildFnTy &MatchInfo) const {
7738 assert(MI.getOpcode() == TargetOpcode::G_SUB);
7739 Register DestReg = MI.getOperand(i: 0).getReg();
7740 LLT DestTy = MRI.getType(Reg: DestReg);
7741
7742 Register X;
7743 Register Sub0;
7744 auto NegPattern = m_all_of(preds: m_Neg(Src: m_DeferredReg(R&: X)), preds: m_Reg(R&: Sub0));
7745 if (mi_match(R: DestReg, MRI,
7746 P: m_Neg(Src: m_OneUse(SP: m_any_of(preds: m_GSMin(L: m_Reg(R&: X), R: NegPattern),
7747 preds: m_GSMax(L: m_Reg(R&: X), R: NegPattern),
7748 preds: m_GUMin(L: m_Reg(R&: X), R: NegPattern),
7749 preds: m_GUMax(L: m_Reg(R&: X), R: NegPattern)))))) {
7750 MachineInstr *MinMaxMI = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
7751 unsigned NewOpc = getInverseGMinMaxOpcode(MinMaxOpc: MinMaxMI->getOpcode());
7752 if (isLegal(Query: {NewOpc, {DestTy}})) {
7753 MatchInfo = [=](MachineIRBuilder &B) {
7754 B.buildInstr(Opc: NewOpc, DstOps: {DestReg}, SrcOps: {X, Sub0});
7755 };
7756 return true;
7757 }
7758 }
7759
7760 return false;
7761}
7762
7763bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) const {
7764 GSelect *Select = cast<GSelect>(Val: &MI);
7765
7766 if (tryFoldSelectOfConstants(Select, MatchInfo))
7767 return true;
7768
7769 if (tryFoldBoolSelectToLogic(Select, MatchInfo))
7770 return true;
7771
7772 return false;
7773}
7774
7775/// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
7776/// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
7777/// into a single comparison using range-based reasoning.
7778/// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges.
7779bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(
7780 GLogicalBinOp *Logic, BuildFnTy &MatchInfo) const {
7781 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor");
7782 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
7783 Register DstReg = Logic->getReg(Idx: 0);
7784 Register LHS = Logic->getLHSReg();
7785 Register RHS = Logic->getRHSReg();
7786 unsigned Flags = Logic->getFlags();
7787
7788 // We need an G_ICMP on the LHS register.
7789 GICmp *Cmp1 = getOpcodeDef<GICmp>(Reg: LHS, MRI);
7790 if (!Cmp1)
7791 return false;
7792
7793 // We need an G_ICMP on the RHS register.
7794 GICmp *Cmp2 = getOpcodeDef<GICmp>(Reg: RHS, MRI);
7795 if (!Cmp2)
7796 return false;
7797
7798 // We want to fold the icmps.
7799 if (!MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) ||
7800 !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)))
7801 return false;
7802
7803 APInt C1;
7804 APInt C2;
7805 std::optional<ValueAndVReg> MaybeC1 =
7806 getIConstantVRegValWithLookThrough(VReg: Cmp1->getRHSReg(), MRI);
7807 if (!MaybeC1)
7808 return false;
7809 C1 = MaybeC1->Value;
7810
7811 std::optional<ValueAndVReg> MaybeC2 =
7812 getIConstantVRegValWithLookThrough(VReg: Cmp2->getRHSReg(), MRI);
7813 if (!MaybeC2)
7814 return false;
7815 C2 = MaybeC2->Value;
7816
7817 Register R1 = Cmp1->getLHSReg();
7818 Register R2 = Cmp2->getLHSReg();
7819 CmpInst::Predicate Pred1 = Cmp1->getCond();
7820 CmpInst::Predicate Pred2 = Cmp2->getCond();
7821 LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0));
7822 LLT CmpOperandTy = MRI.getType(Reg: R1);
7823
7824 if (CmpOperandTy.isPointer())
7825 return false;
7826
7827 // We build ands, adds, and constants of type CmpOperandTy.
7828 // They must be legal to build.
7829 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_AND, CmpOperandTy}) ||
7830 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, CmpOperandTy}) ||
7831 !isConstantLegalOrBeforeLegalizer(Ty: CmpOperandTy))
7832 return false;
7833
7834 // Look through add of a constant offset on R1, R2, or both operands. This
7835 // allows us to interpret the R + C' < C'' range idiom into a proper range.
7836 std::optional<APInt> Offset1;
7837 std::optional<APInt> Offset2;
7838 if (R1 != R2) {
7839 if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R1, MRI)) {
7840 std::optional<ValueAndVReg> MaybeOffset1 =
7841 getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI);
7842 if (MaybeOffset1) {
7843 R1 = Add->getLHSReg();
7844 Offset1 = MaybeOffset1->Value;
7845 }
7846 }
7847 if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R2, MRI)) {
7848 std::optional<ValueAndVReg> MaybeOffset2 =
7849 getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI);
7850 if (MaybeOffset2) {
7851 R2 = Add->getLHSReg();
7852 Offset2 = MaybeOffset2->Value;
7853 }
7854 }
7855 }
7856
7857 if (R1 != R2)
7858 return false;
7859
7860 // We calculate the icmp ranges including maybe offsets.
7861 ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
7862 Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred1) : Pred1, Other: C1);
7863 if (Offset1)
7864 CR1 = CR1.subtract(CI: *Offset1);
7865
7866 ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
7867 Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred2) : Pred2, Other: C2);
7868 if (Offset2)
7869 CR2 = CR2.subtract(CI: *Offset2);
7870
7871 bool CreateMask = false;
7872 APInt LowerDiff;
7873 std::optional<ConstantRange> CR = CR1.exactUnionWith(CR: CR2);
7874 if (!CR) {
7875 // We need non-wrapping ranges.
7876 if (CR1.isWrappedSet() || CR2.isWrappedSet())
7877 return false;
7878
7879 // Check whether we have equal-size ranges that only differ by one bit.
7880 // In that case we can apply a mask to map one range onto the other.
7881 LowerDiff = CR1.getLower() ^ CR2.getLower();
7882 APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1);
7883 APInt CR1Size = CR1.getUpper() - CR1.getLower();
7884 if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff ||
7885 CR1Size != CR2.getUpper() - CR2.getLower())
7886 return false;
7887
7888 CR = CR1.getLower().ult(RHS: CR2.getLower()) ? CR1 : CR2;
7889 CreateMask = true;
7890 }
7891
7892 if (IsAnd)
7893 CR = CR->inverse();
7894
7895 CmpInst::Predicate NewPred;
7896 APInt NewC, Offset;
7897 CR->getEquivalentICmp(Pred&: NewPred, RHS&: NewC, Offset);
7898
7899 // We take the result type of one of the original icmps, CmpTy, for
7900 // the to be build icmp. The operand type, CmpOperandTy, is used for
7901 // the other instructions and constants to be build. The types of
7902 // the parameters and output are the same for add and and. CmpTy
7903 // and the type of DstReg might differ. That is why we zext or trunc
7904 // the icmp into the destination register.
7905
7906 MatchInfo = [=](MachineIRBuilder &B) {
7907 if (CreateMask && Offset != 0) {
7908 auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff);
7909 auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask.
7910 auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset);
7911 auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: And, Src1: OffsetC, Flags);
7912 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7913 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon);
7914 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7915 } else if (CreateMask && Offset == 0) {
7916 auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff);
7917 auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask.
7918 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7919 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: And, Op1: NewCon);
7920 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7921 } else if (!CreateMask && Offset != 0) {
7922 auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset);
7923 auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: R1, Src1: OffsetC, Flags);
7924 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7925 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon);
7926 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7927 } else if (!CreateMask && Offset == 0) {
7928 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
7929 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: R1, Op1: NewCon);
7930 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
7931 } else {
7932 llvm_unreachable("unexpected configuration of CreateMask and Offset");
7933 }
7934 };
7935 return true;
7936}
7937
7938bool CombinerHelper::tryFoldLogicOfFCmps(GLogicalBinOp *Logic,
7939 BuildFnTy &MatchInfo) const {
7940 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpecte xor");
7941 Register DestReg = Logic->getReg(Idx: 0);
7942 Register LHS = Logic->getLHSReg();
7943 Register RHS = Logic->getRHSReg();
7944 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
7945
7946 // We need a compare on the LHS register.
7947 GFCmp *Cmp1 = getOpcodeDef<GFCmp>(Reg: LHS, MRI);
7948 if (!Cmp1)
7949 return false;
7950
7951 // We need a compare on the RHS register.
7952 GFCmp *Cmp2 = getOpcodeDef<GFCmp>(Reg: RHS, MRI);
7953 if (!Cmp2)
7954 return false;
7955
7956 LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0));
7957 LLT CmpOperandTy = MRI.getType(Reg: Cmp1->getLHSReg());
7958
7959 // We build one fcmp, want to fold the fcmps, replace the logic op,
7960 // and the fcmps must have the same shape.
7961 if (!isLegalOrBeforeLegalizer(
7962 Query: {TargetOpcode::G_FCMP, {CmpTy, CmpOperandTy}}) ||
7963 !MRI.hasOneNonDBGUse(RegNo: Logic->getReg(Idx: 0)) ||
7964 !MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) ||
7965 !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)) ||
7966 MRI.getType(Reg: Cmp1->getLHSReg()) != MRI.getType(Reg: Cmp2->getLHSReg()))
7967 return false;
7968
7969 CmpInst::Predicate PredL = Cmp1->getCond();
7970 CmpInst::Predicate PredR = Cmp2->getCond();
7971 Register LHS0 = Cmp1->getLHSReg();
7972 Register LHS1 = Cmp1->getRHSReg();
7973 Register RHS0 = Cmp2->getLHSReg();
7974 Register RHS1 = Cmp2->getRHSReg();
7975
7976 if (LHS0 == RHS1 && LHS1 == RHS0) {
7977 // Swap RHS operands to match LHS.
7978 PredR = CmpInst::getSwappedPredicate(pred: PredR);
7979 std::swap(a&: RHS0, b&: RHS1);
7980 }
7981
7982 if (LHS0 == RHS0 && LHS1 == RHS1) {
7983 // We determine the new predicate.
7984 unsigned CmpCodeL = getFCmpCode(CC: PredL);
7985 unsigned CmpCodeR = getFCmpCode(CC: PredR);
7986 unsigned NewPred = IsAnd ? CmpCodeL & CmpCodeR : CmpCodeL | CmpCodeR;
7987 unsigned Flags = Cmp1->getFlags() | Cmp2->getFlags();
7988 MatchInfo = [=](MachineIRBuilder &B) {
7989 // The fcmp predicates fill the lower part of the enum.
7990 FCmpInst::Predicate Pred = static_cast<FCmpInst::Predicate>(NewPred);
7991 if (Pred == FCmpInst::FCMP_FALSE &&
7992 isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) {
7993 auto False = B.buildConstant(Res: CmpTy, Val: 0);
7994 B.buildZExtOrTrunc(Res: DestReg, Op: False);
7995 } else if (Pred == FCmpInst::FCMP_TRUE &&
7996 isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) {
7997 auto True =
7998 B.buildConstant(Res: CmpTy, Val: getICmpTrueVal(TLI: getTargetLowering(),
7999 IsVector: CmpTy.isVector() /*isVector*/,
8000 IsFP: true /*isFP*/));
8001 B.buildZExtOrTrunc(Res: DestReg, Op: True);
8002 } else { // We take the predicate without predicate optimizations.
8003 auto Cmp = B.buildFCmp(Pred, Res: CmpTy, Op0: LHS0, Op1: LHS1, Flags);
8004 B.buildZExtOrTrunc(Res: DestReg, Op: Cmp);
8005 }
8006 };
8007 return true;
8008 }
8009
8010 return false;
8011}
8012
8013bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) const {
8014 GAnd *And = cast<GAnd>(Val: &MI);
8015
8016 if (tryFoldAndOrOrICmpsUsingRanges(Logic: And, MatchInfo))
8017 return true;
8018
8019 if (tryFoldLogicOfFCmps(Logic: And, MatchInfo))
8020 return true;
8021
8022 return false;
8023}
8024
8025bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) const {
8026 GOr *Or = cast<GOr>(Val: &MI);
8027
8028 if (tryFoldAndOrOrICmpsUsingRanges(Logic: Or, MatchInfo))
8029 return true;
8030
8031 if (tryFoldLogicOfFCmps(Logic: Or, MatchInfo))
8032 return true;
8033
8034 return false;
8035}
8036
8037bool CombinerHelper::matchAddOverflow(MachineInstr &MI,
8038 BuildFnTy &MatchInfo) const {
8039 GAddCarryOut *Add = cast<GAddCarryOut>(Val: &MI);
8040
8041 // Addo has no flags
8042 Register Dst = Add->getReg(Idx: 0);
8043 Register Carry = Add->getReg(Idx: 1);
8044 Register LHS = Add->getLHSReg();
8045 Register RHS = Add->getRHSReg();
8046 bool IsSigned = Add->isSigned();
8047 LLT DstTy = MRI.getType(Reg: Dst);
8048 LLT CarryTy = MRI.getType(Reg: Carry);
8049
8050 // Fold addo, if the carry is dead -> add, undef.
8051 if (MRI.use_nodbg_empty(RegNo: Carry) &&
8052 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}})) {
8053 MatchInfo = [=](MachineIRBuilder &B) {
8054 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
8055 B.buildUndef(Res: Carry);
8056 };
8057 return true;
8058 }
8059
8060 // Canonicalize constant to RHS.
8061 if (isConstantOrConstantVectorI(Src: LHS) && !isConstantOrConstantVectorI(Src: RHS)) {
8062 if (IsSigned) {
8063 MatchInfo = [=](MachineIRBuilder &B) {
8064 B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS);
8065 };
8066 return true;
8067 }
8068 // !IsSigned
8069 MatchInfo = [=](MachineIRBuilder &B) {
8070 B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS);
8071 };
8072 return true;
8073 }
8074
8075 std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(Src: LHS);
8076 std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(Src: RHS);
8077
8078 // Fold addo(c1, c2) -> c3, carry.
8079 if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(Ty: DstTy) &&
8080 isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) {
8081 bool Overflow;
8082 APInt Result = IsSigned ? MaybeLHS->sadd_ov(RHS: *MaybeRHS, Overflow)
8083 : MaybeLHS->uadd_ov(RHS: *MaybeRHS, Overflow);
8084 MatchInfo = [=](MachineIRBuilder &B) {
8085 B.buildConstant(Res: Dst, Val: Result);
8086 B.buildConstant(Res: Carry, Val: Overflow);
8087 };
8088 return true;
8089 }
8090
8091 // Fold (addo x, 0) -> x, no carry
8092 if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) {
8093 MatchInfo = [=](MachineIRBuilder &B) {
8094 B.buildCopy(Res: Dst, Op: LHS);
8095 B.buildConstant(Res: Carry, Val: 0);
8096 };
8097 return true;
8098 }
8099
8100 // Given 2 constant operands whose sum does not overflow:
8101 // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1
8102 // saddo (X +nsw C0), C1 -> saddo X, C0 + C1
8103 GAdd *AddLHS = getOpcodeDef<GAdd>(Reg: LHS, MRI);
8104 if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)) &&
8105 ((IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoSWrap)) ||
8106 (!IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoUWrap)))) {
8107 std::optional<APInt> MaybeAddRHS =
8108 getConstantOrConstantSplatVector(Src: AddLHS->getRHSReg());
8109 if (MaybeAddRHS) {
8110 bool Overflow;
8111 APInt NewC = IsSigned ? MaybeAddRHS->sadd_ov(RHS: *MaybeRHS, Overflow)
8112 : MaybeAddRHS->uadd_ov(RHS: *MaybeRHS, Overflow);
8113 if (!Overflow && isConstantLegalOrBeforeLegalizer(Ty: DstTy)) {
8114 if (IsSigned) {
8115 MatchInfo = [=](MachineIRBuilder &B) {
8116 auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC);
8117 B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS);
8118 };
8119 return true;
8120 }
8121 // !IsSigned
8122 MatchInfo = [=](MachineIRBuilder &B) {
8123 auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC);
8124 B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS);
8125 };
8126 return true;
8127 }
8128 }
8129 };
8130
8131 // We try to combine addo to non-overflowing add.
8132 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}}) ||
8133 !isConstantLegalOrBeforeLegalizer(Ty: CarryTy))
8134 return false;
8135
8136 // We try to combine uaddo to non-overflowing add.
8137 if (!IsSigned) {
8138 ConstantRange CRLHS =
8139 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS), /*IsSigned=*/false);
8140 ConstantRange CRRHS =
8141 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS), /*IsSigned=*/false);
8142
8143 switch (CRLHS.unsignedAddMayOverflow(Other: CRRHS)) {
8144 case ConstantRange::OverflowResult::MayOverflow:
8145 return false;
8146 case ConstantRange::OverflowResult::NeverOverflows: {
8147 MatchInfo = [=](MachineIRBuilder &B) {
8148 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap);
8149 B.buildConstant(Res: Carry, Val: 0);
8150 };
8151 return true;
8152 }
8153 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8154 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8155 MatchInfo = [=](MachineIRBuilder &B) {
8156 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
8157 B.buildConstant(Res: Carry, Val: 1);
8158 };
8159 return true;
8160 }
8161 }
8162 return false;
8163 }
8164
8165 // We try to combine saddo to non-overflowing add.
8166
8167 // If LHS and RHS each have at least two sign bits, then there is no signed
8168 // overflow.
8169 if (VT->computeNumSignBits(R: RHS) > 1 && VT->computeNumSignBits(R: LHS) > 1) {
8170 MatchInfo = [=](MachineIRBuilder &B) {
8171 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8172 B.buildConstant(Res: Carry, Val: 0);
8173 };
8174 return true;
8175 }
8176
8177 ConstantRange CRLHS =
8178 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS), /*IsSigned=*/true);
8179 ConstantRange CRRHS =
8180 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS), /*IsSigned=*/true);
8181
8182 switch (CRLHS.signedAddMayOverflow(Other: CRRHS)) {
8183 case ConstantRange::OverflowResult::MayOverflow:
8184 return false;
8185 case ConstantRange::OverflowResult::NeverOverflows: {
8186 MatchInfo = [=](MachineIRBuilder &B) {
8187 B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8188 B.buildConstant(Res: Carry, Val: 0);
8189 };
8190 return true;
8191 }
8192 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8193 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8194 MatchInfo = [=](MachineIRBuilder &B) {
8195 B.buildAdd(Dst, Src0: LHS, Src1: RHS);
8196 B.buildConstant(Res: Carry, Val: 1);
8197 };
8198 return true;
8199 }
8200 }
8201
8202 return false;
8203}
8204
8205void CombinerHelper::applyBuildFnMO(const MachineOperand &MO,
8206 BuildFnTy &MatchInfo) const {
8207 MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI);
8208 MatchInfo(Builder);
8209 Root->eraseFromParent();
8210}
8211
8212bool CombinerHelper::matchFPowIExpansion(MachineInstr &MI,
8213 int64_t Exponent) const {
8214 bool OptForSize = MI.getMF()->getFunction().hasOptSize();
8215 return getTargetLowering().isBeneficialToExpandPowI(Exponent, OptForSize);
8216}
8217
8218void CombinerHelper::applyExpandFPowI(MachineInstr &MI,
8219 int64_t Exponent) const {
8220 auto [Dst, Base] = MI.getFirst2Regs();
8221 LLT Ty = MRI.getType(Reg: Dst);
8222 int64_t ExpVal = Exponent;
8223
8224 if (ExpVal == 0) {
8225 Builder.buildFConstant(Res: Dst, Val: 1.0);
8226 MI.removeFromParent();
8227 return;
8228 }
8229
8230 if (ExpVal < 0)
8231 ExpVal = -ExpVal;
8232
8233 // We use the simple binary decomposition method from SelectionDAG ExpandPowI
8234 // to generate the multiply sequence. There are more optimal ways to do this
8235 // (for example, powi(x,15) generates one more multiply than it should), but
8236 // this has the benefit of being both really simple and much better than a
8237 // libcall.
8238 std::optional<SrcOp> Res;
8239 SrcOp CurSquare = Base;
8240 while (ExpVal > 0) {
8241 if (ExpVal & 1) {
8242 if (!Res)
8243 Res = CurSquare;
8244 else
8245 Res = Builder.buildFMul(Dst: Ty, Src0: *Res, Src1: CurSquare);
8246 }
8247
8248 CurSquare = Builder.buildFMul(Dst: Ty, Src0: CurSquare, Src1: CurSquare);
8249 ExpVal >>= 1;
8250 }
8251
8252 // If the original exponent was negative, invert the result, producing
8253 // 1/(x*x*x).
8254 if (Exponent < 0)
8255 Res = Builder.buildFDiv(Dst: Ty, Src0: Builder.buildFConstant(Res: Ty, Val: 1.0), Src1: *Res,
8256 Flags: MI.getFlags());
8257
8258 Builder.buildCopy(Res: Dst, Op: *Res);
8259 MI.eraseFromParent();
8260}
8261
8262bool CombinerHelper::matchFoldAPlusC1MinusC2(const MachineInstr &MI,
8263 BuildFnTy &MatchInfo) const {
8264 // fold (A+C1)-C2 -> A+(C1-C2)
8265 const GSub *Sub = cast<GSub>(Val: &MI);
8266 GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: Sub->getLHSReg()));
8267
8268 if (!MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)))
8269 return false;
8270
8271 APInt C2 = getIConstantFromReg(VReg: Sub->getRHSReg(), MRI);
8272 APInt C1 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8273
8274 Register Dst = Sub->getReg(Idx: 0);
8275 LLT DstTy = MRI.getType(Reg: Dst);
8276
8277 MatchInfo = [=](MachineIRBuilder &B) {
8278 auto Const = B.buildConstant(Res: DstTy, Val: C1 - C2);
8279 B.buildAdd(Dst, Src0: Add->getLHSReg(), Src1: Const);
8280 };
8281
8282 return true;
8283}
8284
8285bool CombinerHelper::matchFoldC2MinusAPlusC1(const MachineInstr &MI,
8286 BuildFnTy &MatchInfo) const {
8287 // fold C2-(A+C1) -> (C2-C1)-A
8288 const GSub *Sub = cast<GSub>(Val: &MI);
8289 GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: Sub->getRHSReg()));
8290
8291 if (!MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)))
8292 return false;
8293
8294 APInt C2 = getIConstantFromReg(VReg: Sub->getLHSReg(), MRI);
8295 APInt C1 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8296
8297 Register Dst = Sub->getReg(Idx: 0);
8298 LLT DstTy = MRI.getType(Reg: Dst);
8299
8300 MatchInfo = [=](MachineIRBuilder &B) {
8301 auto Const = B.buildConstant(Res: DstTy, Val: C2 - C1);
8302 B.buildSub(Dst, Src0: Const, Src1: Add->getLHSReg());
8303 };
8304
8305 return true;
8306}
8307
8308bool CombinerHelper::matchFoldAMinusC1MinusC2(const MachineInstr &MI,
8309 BuildFnTy &MatchInfo) const {
8310 // fold (A-C1)-C2 -> A-(C1+C2)
8311 const GSub *Sub1 = cast<GSub>(Val: &MI);
8312 GSub *Sub2 = cast<GSub>(Val: MRI.getVRegDef(Reg: Sub1->getLHSReg()));
8313
8314 if (!MRI.hasOneNonDBGUse(RegNo: Sub2->getReg(Idx: 0)))
8315 return false;
8316
8317 APInt C2 = getIConstantFromReg(VReg: Sub1->getRHSReg(), MRI);
8318 APInt C1 = getIConstantFromReg(VReg: Sub2->getRHSReg(), MRI);
8319
8320 Register Dst = Sub1->getReg(Idx: 0);
8321 LLT DstTy = MRI.getType(Reg: Dst);
8322
8323 MatchInfo = [=](MachineIRBuilder &B) {
8324 auto Const = B.buildConstant(Res: DstTy, Val: C1 + C2);
8325 B.buildSub(Dst, Src0: Sub2->getLHSReg(), Src1: Const);
8326 };
8327
8328 return true;
8329}
8330
8331bool CombinerHelper::matchFoldC1Minus2MinusC2(const MachineInstr &MI,
8332 BuildFnTy &MatchInfo) const {
8333 // fold (C1-A)-C2 -> (C1-C2)-A
8334 const GSub *Sub1 = cast<GSub>(Val: &MI);
8335 GSub *Sub2 = cast<GSub>(Val: MRI.getVRegDef(Reg: Sub1->getLHSReg()));
8336
8337 if (!MRI.hasOneNonDBGUse(RegNo: Sub2->getReg(Idx: 0)))
8338 return false;
8339
8340 APInt C2 = getIConstantFromReg(VReg: Sub1->getRHSReg(), MRI);
8341 APInt C1 = getIConstantFromReg(VReg: Sub2->getLHSReg(), MRI);
8342
8343 Register Dst = Sub1->getReg(Idx: 0);
8344 LLT DstTy = MRI.getType(Reg: Dst);
8345
8346 MatchInfo = [=](MachineIRBuilder &B) {
8347 auto Const = B.buildConstant(Res: DstTy, Val: C1 - C2);
8348 B.buildSub(Dst, Src0: Const, Src1: Sub2->getRHSReg());
8349 };
8350
8351 return true;
8352}
8353
8354bool CombinerHelper::matchFoldAMinusC1PlusC2(const MachineInstr &MI,
8355 BuildFnTy &MatchInfo) const {
8356 // fold ((A-C1)+C2) -> (A+(C2-C1))
8357 const GAdd *Add = cast<GAdd>(Val: &MI);
8358 GSub *Sub = cast<GSub>(Val: MRI.getVRegDef(Reg: Add->getLHSReg()));
8359
8360 if (!MRI.hasOneNonDBGUse(RegNo: Sub->getReg(Idx: 0)))
8361 return false;
8362
8363 APInt C2 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI);
8364 APInt C1 = getIConstantFromReg(VReg: Sub->getRHSReg(), MRI);
8365
8366 Register Dst = Add->getReg(Idx: 0);
8367 LLT DstTy = MRI.getType(Reg: Dst);
8368
8369 MatchInfo = [=](MachineIRBuilder &B) {
8370 auto Const = B.buildConstant(Res: DstTy, Val: C2 - C1);
8371 B.buildAdd(Dst, Src0: Sub->getLHSReg(), Src1: Const);
8372 };
8373
8374 return true;
8375}
8376
8377bool CombinerHelper::matchUnmergeValuesAnyExtBuildVector(
8378 const MachineInstr &MI, BuildFnTy &MatchInfo) const {
8379 const GUnmerge *Unmerge = cast<GUnmerge>(Val: &MI);
8380
8381 if (!MRI.hasOneNonDBGUse(RegNo: Unmerge->getSourceReg()))
8382 return false;
8383
8384 const MachineInstr *Source = MRI.getVRegDef(Reg: Unmerge->getSourceReg());
8385
8386 LLT DstTy = MRI.getType(Reg: Unmerge->getReg(Idx: 0));
8387
8388 // $bv:_(<8 x s8>) = G_BUILD_VECTOR ....
8389 // $any:_(<8 x s16>) = G_ANYEXT $bv
8390 // $uv:_(<4 x s16>), $uv1:_(<4 x s16>) = G_UNMERGE_VALUES $any
8391 //
8392 // ->
8393 //
8394 // $any:_(s16) = G_ANYEXT $bv[0]
8395 // $any1:_(s16) = G_ANYEXT $bv[1]
8396 // $any2:_(s16) = G_ANYEXT $bv[2]
8397 // $any3:_(s16) = G_ANYEXT $bv[3]
8398 // $any4:_(s16) = G_ANYEXT $bv[4]
8399 // $any5:_(s16) = G_ANYEXT $bv[5]
8400 // $any6:_(s16) = G_ANYEXT $bv[6]
8401 // $any7:_(s16) = G_ANYEXT $bv[7]
8402 // $uv:_(<4 x s16>) = G_BUILD_VECTOR $any, $any1, $any2, $any3
8403 // $uv1:_(<4 x s16>) = G_BUILD_VECTOR $any4, $any5, $any6, $any7
8404
8405 // We want to unmerge into vectors.
8406 if (!DstTy.isFixedVector())
8407 return false;
8408
8409 const GAnyExt *Any = dyn_cast<GAnyExt>(Val: Source);
8410 if (!Any)
8411 return false;
8412
8413 const MachineInstr *NextSource = MRI.getVRegDef(Reg: Any->getSrcReg());
8414
8415 if (const GBuildVector *BV = dyn_cast<GBuildVector>(Val: NextSource)) {
8416 // G_UNMERGE_VALUES G_ANYEXT G_BUILD_VECTOR
8417
8418 if (!MRI.hasOneNonDBGUse(RegNo: BV->getReg(Idx: 0)))
8419 return false;
8420
8421 // FIXME: check element types?
8422 if (BV->getNumSources() % Unmerge->getNumDefs() != 0)
8423 return false;
8424
8425 LLT BigBvTy = MRI.getType(Reg: BV->getReg(Idx: 0));
8426 LLT SmallBvTy = DstTy;
8427 LLT SmallBvElemenTy = SmallBvTy.getElementType();
8428
8429 if (!isLegalOrBeforeLegalizer(
8430 Query: {TargetOpcode::G_BUILD_VECTOR, {SmallBvTy, SmallBvElemenTy}}))
8431 return false;
8432
8433 // We check the legality of scalar anyext.
8434 if (!isLegalOrBeforeLegalizer(
8435 Query: {TargetOpcode::G_ANYEXT,
8436 {SmallBvElemenTy, BigBvTy.getElementType()}}))
8437 return false;
8438
8439 MatchInfo = [=](MachineIRBuilder &B) {
8440 // Build into each G_UNMERGE_VALUES def
8441 // a small build vector with anyext from the source build vector.
8442 for (unsigned I = 0; I < Unmerge->getNumDefs(); ++I) {
8443 SmallVector<Register> Ops;
8444 for (unsigned J = 0; J < SmallBvTy.getNumElements(); ++J) {
8445 Register SourceArray =
8446 BV->getSourceReg(I: I * SmallBvTy.getNumElements() + J);
8447 auto AnyExt = B.buildAnyExt(Res: SmallBvElemenTy, Op: SourceArray);
8448 Ops.push_back(Elt: AnyExt.getReg(Idx: 0));
8449 }
8450 B.buildBuildVector(Res: Unmerge->getOperand(i: I).getReg(), Ops);
8451 };
8452 };
8453 return true;
8454 };
8455
8456 return false;
8457}
8458
8459bool CombinerHelper::matchShuffleUndefRHS(MachineInstr &MI,
8460 BuildFnTy &MatchInfo) const {
8461
8462 bool Changed = false;
8463 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
8464 ArrayRef<int> OrigMask = Shuffle.getMask();
8465 SmallVector<int, 16> NewMask;
8466 const LLT SrcTy = MRI.getType(Reg: Shuffle.getSrc1Reg());
8467 const unsigned NumSrcElems = SrcTy.isVector() ? SrcTy.getNumElements() : 1;
8468 const unsigned NumDstElts = OrigMask.size();
8469 for (unsigned i = 0; i != NumDstElts; ++i) {
8470 int Idx = OrigMask[i];
8471 if (Idx >= (int)NumSrcElems) {
8472 Idx = -1;
8473 Changed = true;
8474 }
8475 NewMask.push_back(Elt: Idx);
8476 }
8477
8478 if (!Changed)
8479 return false;
8480
8481 MatchInfo = [&, NewMask = std::move(NewMask)](MachineIRBuilder &B) {
8482 B.buildShuffleVector(Res: MI.getOperand(i: 0), Src1: MI.getOperand(i: 1), Src2: MI.getOperand(i: 2),
8483 Mask: std::move(NewMask));
8484 };
8485
8486 return true;
8487}
8488
8489static void commuteMask(MutableArrayRef<int> Mask, const unsigned NumElems) {
8490 const unsigned MaskSize = Mask.size();
8491 for (unsigned I = 0; I < MaskSize; ++I) {
8492 int Idx = Mask[I];
8493 if (Idx < 0)
8494 continue;
8495
8496 if (Idx < (int)NumElems)
8497 Mask[I] = Idx + NumElems;
8498 else
8499 Mask[I] = Idx - NumElems;
8500 }
8501}
8502
8503bool CombinerHelper::matchShuffleDisjointMask(MachineInstr &MI,
8504 BuildFnTy &MatchInfo) const {
8505
8506 auto &Shuffle = cast<GShuffleVector>(Val&: MI);
8507 // If any of the two inputs is already undef, don't check the mask again to
8508 // prevent infinite loop
8509 if (getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: Shuffle.getSrc1Reg(), MRI))
8510 return false;
8511
8512 if (getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: Shuffle.getSrc2Reg(), MRI))
8513 return false;
8514
8515 const LLT DstTy = MRI.getType(Reg: Shuffle.getReg(Idx: 0));
8516 const LLT Src1Ty = MRI.getType(Reg: Shuffle.getSrc1Reg());
8517 if (!isLegalOrBeforeLegalizer(
8518 Query: {TargetOpcode::G_SHUFFLE_VECTOR, {DstTy, Src1Ty}}))
8519 return false;
8520
8521 ArrayRef<int> Mask = Shuffle.getMask();
8522 const unsigned NumSrcElems = Src1Ty.getNumElements();
8523
8524 bool TouchesSrc1 = false;
8525 bool TouchesSrc2 = false;
8526 const unsigned NumElems = Mask.size();
8527 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
8528 if (Mask[Idx] < 0)
8529 continue;
8530
8531 if (Mask[Idx] < (int)NumSrcElems)
8532 TouchesSrc1 = true;
8533 else
8534 TouchesSrc2 = true;
8535 }
8536
8537 if (TouchesSrc1 == TouchesSrc2)
8538 return false;
8539
8540 Register NewSrc1 = Shuffle.getSrc1Reg();
8541 SmallVector<int, 16> NewMask(Mask);
8542 if (TouchesSrc2) {
8543 NewSrc1 = Shuffle.getSrc2Reg();
8544 commuteMask(Mask: NewMask, NumElems: NumSrcElems);
8545 }
8546
8547 MatchInfo = [=, &Shuffle](MachineIRBuilder &B) {
8548 auto Undef = B.buildUndef(Res: Src1Ty);
8549 B.buildShuffleVector(Res: Shuffle.getReg(Idx: 0), Src1: NewSrc1, Src2: Undef, Mask: NewMask);
8550 };
8551
8552 return true;
8553}
8554
8555bool CombinerHelper::matchSuboCarryOut(const MachineInstr &MI,
8556 BuildFnTy &MatchInfo) const {
8557 const GSubCarryOut *Subo = cast<GSubCarryOut>(Val: &MI);
8558
8559 Register Dst = Subo->getReg(Idx: 0);
8560 Register LHS = Subo->getLHSReg();
8561 Register RHS = Subo->getRHSReg();
8562 Register Carry = Subo->getCarryOutReg();
8563 LLT DstTy = MRI.getType(Reg: Dst);
8564 LLT CarryTy = MRI.getType(Reg: Carry);
8565
8566 // Check legality before known bits.
8567 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy}}) ||
8568 !isConstantLegalOrBeforeLegalizer(Ty: CarryTy))
8569 return false;
8570
8571 ConstantRange KBLHS =
8572 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS),
8573 /* IsSigned=*/Subo->isSigned());
8574 ConstantRange KBRHS =
8575 ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS),
8576 /* IsSigned=*/Subo->isSigned());
8577
8578 if (Subo->isSigned()) {
8579 // G_SSUBO
8580 switch (KBLHS.signedSubMayOverflow(Other: KBRHS)) {
8581 case ConstantRange::OverflowResult::MayOverflow:
8582 return false;
8583 case ConstantRange::OverflowResult::NeverOverflows: {
8584 MatchInfo = [=](MachineIRBuilder &B) {
8585 B.buildSub(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap);
8586 B.buildConstant(Res: Carry, Val: 0);
8587 };
8588 return true;
8589 }
8590 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8591 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8592 MatchInfo = [=](MachineIRBuilder &B) {
8593 B.buildSub(Dst, Src0: LHS, Src1: RHS);
8594 B.buildConstant(Res: Carry, Val: getICmpTrueVal(TLI: getTargetLowering(),
8595 /*isVector=*/IsVector: CarryTy.isVector(),
8596 /*isFP=*/IsFP: false));
8597 };
8598 return true;
8599 }
8600 }
8601 return false;
8602 }
8603
8604 // G_USUBO
8605 switch (KBLHS.unsignedSubMayOverflow(Other: KBRHS)) {
8606 case ConstantRange::OverflowResult::MayOverflow:
8607 return false;
8608 case ConstantRange::OverflowResult::NeverOverflows: {
8609 MatchInfo = [=](MachineIRBuilder &B) {
8610 B.buildSub(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap);
8611 B.buildConstant(Res: Carry, Val: 0);
8612 };
8613 return true;
8614 }
8615 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
8616 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
8617 MatchInfo = [=](MachineIRBuilder &B) {
8618 B.buildSub(Dst, Src0: LHS, Src1: RHS);
8619 B.buildConstant(Res: Carry, Val: getICmpTrueVal(TLI: getTargetLowering(),
8620 /*isVector=*/IsVector: CarryTy.isVector(),
8621 /*isFP=*/IsFP: false));
8622 };
8623 return true;
8624 }
8625 }
8626
8627 return false;
8628}
8629
8630// Fold (ctlz (xor x, (sra x, bitwidth-1))) -> (add (ctls x), 1).
8631// Fold (ctlz (or (shl (xor x, (sra x, bitwidth-1)), 1), 1) -> (ctls x)
8632bool CombinerHelper::matchCtls(MachineInstr &CtlzMI,
8633 BuildFnTy &MatchInfo) const {
8634 assert((CtlzMI.getOpcode() == TargetOpcode::G_CTLZ ||
8635 CtlzMI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) &&
8636 "Expected G_CTLZ variant");
8637
8638 const Register Dst = CtlzMI.getOperand(i: 0).getReg();
8639 Register Src = CtlzMI.getOperand(i: 1).getReg();
8640
8641 LLT Ty = MRI.getType(Reg: Dst);
8642 LLT SrcTy = MRI.getType(Reg: Src);
8643
8644 if (!(Ty.isValid() && Ty.isScalar()))
8645 return false;
8646
8647 if (!LI)
8648 return false;
8649
8650 SmallVector<LLT, 2> QueryTypes = {Ty, SrcTy};
8651 LegalityQuery Query(TargetOpcode::G_CTLS, QueryTypes);
8652
8653 switch (LI->getAction(Query).Action) {
8654 default:
8655 return false;
8656 case LegalizeActions::Legal:
8657 case LegalizeActions::Custom:
8658 case LegalizeActions::WidenScalar:
8659 break;
8660 }
8661
8662 // Src = or(shl(V, 1), 1) -> Src=V; NeedAdd = False
8663 Register V;
8664 bool NeedAdd = true;
8665 if (mi_match(R: Src, MRI,
8666 P: m_OneUse(SP: m_GOr(L: m_OneUse(SP: m_GShl(L: m_Reg(R&: V), R: m_SpecificICst(RequestedValue: 1))),
8667 R: m_SpecificICst(RequestedValue: 1))))) {
8668 NeedAdd = false;
8669 Src = V;
8670 }
8671
8672 unsigned BitWidth = Ty.getScalarSizeInBits();
8673
8674 Register X;
8675 if (!mi_match(R: Src, MRI,
8676 P: m_OneUse(SP: m_GXor(L: m_Reg(R&: X), R: m_OneUse(SP: m_GAShr(
8677 L: m_DeferredReg(R&: X),
8678 R: m_SpecificICst(RequestedValue: BitWidth - 1)))))))
8679 return false;
8680
8681 MatchInfo = [=](MachineIRBuilder &B) {
8682 if (!NeedAdd) {
8683 B.buildCTLS(Dst, Src0: X);
8684 return;
8685 }
8686
8687 auto Ctls = B.buildCTLS(Dst: Ty, Src0: X);
8688 auto One = B.buildConstant(Res: Ty, Val: 1);
8689
8690 B.buildAdd(Dst, Src0: Ctls, Src1: One);
8691 };
8692
8693 return true;
8694}
8695