1 | //===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" |
9 | #include "llvm/ADT/APFloat.h" |
10 | #include "llvm/ADT/STLExtras.h" |
11 | #include "llvm/ADT/SetVector.h" |
12 | #include "llvm/ADT/SmallBitVector.h" |
13 | #include "llvm/Analysis/CmpInstAnalysis.h" |
14 | #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" |
15 | #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" |
16 | #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" |
17 | #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" |
18 | #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" |
19 | #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" |
20 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
21 | #include "llvm/CodeGen/GlobalISel/Utils.h" |
22 | #include "llvm/CodeGen/LowLevelTypeUtils.h" |
23 | #include "llvm/CodeGen/MachineBasicBlock.h" |
24 | #include "llvm/CodeGen/MachineDominators.h" |
25 | #include "llvm/CodeGen/MachineInstr.h" |
26 | #include "llvm/CodeGen/MachineMemOperand.h" |
27 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
28 | #include "llvm/CodeGen/RegisterBankInfo.h" |
29 | #include "llvm/CodeGen/TargetInstrInfo.h" |
30 | #include "llvm/CodeGen/TargetLowering.h" |
31 | #include "llvm/CodeGen/TargetOpcodes.h" |
32 | #include "llvm/IR/ConstantRange.h" |
33 | #include "llvm/IR/DataLayout.h" |
34 | #include "llvm/IR/InstrTypes.h" |
35 | #include "llvm/Support/Casting.h" |
36 | #include "llvm/Support/DivisionByConstantInfo.h" |
37 | #include "llvm/Support/ErrorHandling.h" |
38 | #include "llvm/Support/MathExtras.h" |
39 | #include "llvm/Target/TargetMachine.h" |
40 | #include <cmath> |
41 | #include <optional> |
42 | #include <tuple> |
43 | |
44 | #define DEBUG_TYPE "gi-combiner" |
45 | |
46 | using namespace llvm; |
47 | using namespace MIPatternMatch; |
48 | |
49 | // Option to allow testing of the combiner while no targets know about indexed |
50 | // addressing. |
51 | static cl::opt<bool> |
52 | ForceLegalIndexing("force-legal-indexing" , cl::Hidden, cl::init(Val: false), |
53 | cl::desc("Force all indexed operations to be " |
54 | "legal for the GlobalISel combiner" )); |
55 | |
56 | CombinerHelper::CombinerHelper(GISelChangeObserver &Observer, |
57 | MachineIRBuilder &B, bool IsPreLegalize, |
58 | GISelKnownBits *KB, MachineDominatorTree *MDT, |
59 | const LegalizerInfo *LI) |
60 | : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), KB(KB), |
61 | MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI), |
62 | RBI(Builder.getMF().getSubtarget().getRegBankInfo()), |
63 | TRI(Builder.getMF().getSubtarget().getRegisterInfo()) { |
64 | (void)this->KB; |
65 | } |
66 | |
67 | const TargetLowering &CombinerHelper::getTargetLowering() const { |
68 | return *Builder.getMF().getSubtarget().getTargetLowering(); |
69 | } |
70 | |
71 | /// \returns The little endian in-memory byte position of byte \p I in a |
72 | /// \p ByteWidth bytes wide type. |
73 | /// |
74 | /// E.g. Given a 4-byte type x, x[0] -> byte 0 |
75 | static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) { |
76 | assert(I < ByteWidth && "I must be in [0, ByteWidth)" ); |
77 | return I; |
78 | } |
79 | |
80 | /// Determines the LogBase2 value for a non-null input value using the |
81 | /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V). |
82 | static Register buildLogBase2(Register V, MachineIRBuilder &MIB) { |
83 | auto &MRI = *MIB.getMRI(); |
84 | LLT Ty = MRI.getType(Reg: V); |
85 | auto Ctlz = MIB.buildCTLZ(Dst: Ty, Src0: V); |
86 | auto Base = MIB.buildConstant(Res: Ty, Val: Ty.getScalarSizeInBits() - 1); |
87 | return MIB.buildSub(Dst: Ty, Src0: Base, Src1: Ctlz).getReg(Idx: 0); |
88 | } |
89 | |
90 | /// \returns The big endian in-memory byte position of byte \p I in a |
91 | /// \p ByteWidth bytes wide type. |
92 | /// |
93 | /// E.g. Given a 4-byte type x, x[0] -> byte 3 |
94 | static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) { |
95 | assert(I < ByteWidth && "I must be in [0, ByteWidth)" ); |
96 | return ByteWidth - I - 1; |
97 | } |
98 | |
99 | /// Given a map from byte offsets in memory to indices in a load/store, |
100 | /// determine if that map corresponds to a little or big endian byte pattern. |
101 | /// |
102 | /// \param MemOffset2Idx maps memory offsets to address offsets. |
103 | /// \param LowestIdx is the lowest index in \p MemOffset2Idx. |
104 | /// |
105 | /// \returns true if the map corresponds to a big endian byte pattern, false if |
106 | /// it corresponds to a little endian byte pattern, and std::nullopt otherwise. |
107 | /// |
108 | /// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns |
109 | /// are as follows: |
110 | /// |
111 | /// AddrOffset Little endian Big endian |
112 | /// 0 0 3 |
113 | /// 1 1 2 |
114 | /// 2 2 1 |
115 | /// 3 3 0 |
116 | static std::optional<bool> |
117 | isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, |
118 | int64_t LowestIdx) { |
119 | // Need at least two byte positions to decide on endianness. |
120 | unsigned Width = MemOffset2Idx.size(); |
121 | if (Width < 2) |
122 | return std::nullopt; |
123 | bool BigEndian = true, LittleEndian = true; |
124 | for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) { |
125 | auto MemOffsetAndIdx = MemOffset2Idx.find(Val: MemOffset); |
126 | if (MemOffsetAndIdx == MemOffset2Idx.end()) |
127 | return std::nullopt; |
128 | const int64_t Idx = MemOffsetAndIdx->second - LowestIdx; |
129 | assert(Idx >= 0 && "Expected non-negative byte offset?" ); |
130 | LittleEndian &= Idx == littleEndianByteAt(ByteWidth: Width, I: MemOffset); |
131 | BigEndian &= Idx == bigEndianByteAt(ByteWidth: Width, I: MemOffset); |
132 | if (!BigEndian && !LittleEndian) |
133 | return std::nullopt; |
134 | } |
135 | |
136 | assert((BigEndian != LittleEndian) && |
137 | "Pattern cannot be both big and little endian!" ); |
138 | return BigEndian; |
139 | } |
140 | |
141 | bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; } |
142 | |
143 | bool CombinerHelper::isLegal(const LegalityQuery &Query) const { |
144 | assert(LI && "Must have LegalizerInfo to query isLegal!" ); |
145 | return LI->getAction(Query).Action == LegalizeActions::Legal; |
146 | } |
147 | |
148 | bool CombinerHelper::isLegalOrBeforeLegalizer( |
149 | const LegalityQuery &Query) const { |
150 | return isPreLegalize() || isLegal(Query); |
151 | } |
152 | |
153 | bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const { |
154 | if (!Ty.isVector()) |
155 | return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CONSTANT, {Ty}}); |
156 | // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs. |
157 | if (isPreLegalize()) |
158 | return true; |
159 | LLT EltTy = Ty.getElementType(); |
160 | return isLegal(Query: {TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) && |
161 | isLegal(Query: {TargetOpcode::G_CONSTANT, {EltTy}}); |
162 | } |
163 | |
164 | void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg, |
165 | Register ToReg) const { |
166 | Observer.changingAllUsesOfReg(MRI, Reg: FromReg); |
167 | |
168 | if (MRI.constrainRegAttrs(Reg: ToReg, ConstrainingReg: FromReg)) |
169 | MRI.replaceRegWith(FromReg, ToReg); |
170 | else |
171 | Builder.buildCopy(Res: ToReg, Op: FromReg); |
172 | |
173 | Observer.finishedChangingAllUsesOfReg(); |
174 | } |
175 | |
176 | void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI, |
177 | MachineOperand &FromRegOp, |
178 | Register ToReg) const { |
179 | assert(FromRegOp.getParent() && "Expected an operand in an MI" ); |
180 | Observer.changingInstr(MI&: *FromRegOp.getParent()); |
181 | |
182 | FromRegOp.setReg(ToReg); |
183 | |
184 | Observer.changedInstr(MI&: *FromRegOp.getParent()); |
185 | } |
186 | |
187 | void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI, |
188 | unsigned ToOpcode) const { |
189 | Observer.changingInstr(MI&: FromMI); |
190 | |
191 | FromMI.setDesc(Builder.getTII().get(Opcode: ToOpcode)); |
192 | |
193 | Observer.changedInstr(MI&: FromMI); |
194 | } |
195 | |
196 | const RegisterBank *CombinerHelper::getRegBank(Register Reg) const { |
197 | return RBI->getRegBank(Reg, MRI, TRI: *TRI); |
198 | } |
199 | |
200 | void CombinerHelper::setRegBank(Register Reg, const RegisterBank *RegBank) { |
201 | if (RegBank) |
202 | MRI.setRegBank(Reg, RegBank: *RegBank); |
203 | } |
204 | |
205 | bool CombinerHelper::tryCombineCopy(MachineInstr &MI) { |
206 | if (matchCombineCopy(MI)) { |
207 | applyCombineCopy(MI); |
208 | return true; |
209 | } |
210 | return false; |
211 | } |
212 | bool CombinerHelper::matchCombineCopy(MachineInstr &MI) { |
213 | if (MI.getOpcode() != TargetOpcode::COPY) |
214 | return false; |
215 | Register DstReg = MI.getOperand(i: 0).getReg(); |
216 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
217 | return canReplaceReg(DstReg, SrcReg, MRI); |
218 | } |
219 | void CombinerHelper::applyCombineCopy(MachineInstr &MI) { |
220 | Register DstReg = MI.getOperand(i: 0).getReg(); |
221 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
222 | MI.eraseFromParent(); |
223 | replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg); |
224 | } |
225 | |
226 | bool CombinerHelper::matchFreezeOfSingleMaybePoisonOperand( |
227 | MachineInstr &MI, BuildFnTy &MatchInfo) { |
228 | // Ported from InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating. |
229 | Register DstOp = MI.getOperand(i: 0).getReg(); |
230 | Register OrigOp = MI.getOperand(i: 1).getReg(); |
231 | |
232 | if (!MRI.hasOneNonDBGUse(RegNo: OrigOp)) |
233 | return false; |
234 | |
235 | MachineInstr *OrigDef = MRI.getUniqueVRegDef(Reg: OrigOp); |
236 | // Even if only a single operand of the PHI is not guaranteed non-poison, |
237 | // moving freeze() backwards across a PHI can cause optimization issues for |
238 | // other users of that operand. |
239 | // |
240 | // Moving freeze() from one of the output registers of a G_UNMERGE_VALUES to |
241 | // the source register is unprofitable because it makes the freeze() more |
242 | // strict than is necessary (it would affect the whole register instead of |
243 | // just the subreg being frozen). |
244 | if (OrigDef->isPHI() || isa<GUnmerge>(Val: OrigDef)) |
245 | return false; |
246 | |
247 | if (canCreateUndefOrPoison(Reg: OrigOp, MRI, |
248 | /*ConsiderFlagsAndMetadata=*/false)) |
249 | return false; |
250 | |
251 | std::optional<MachineOperand> MaybePoisonOperand; |
252 | for (MachineOperand &Operand : OrigDef->uses()) { |
253 | if (!Operand.isReg()) |
254 | return false; |
255 | |
256 | if (isGuaranteedNotToBeUndefOrPoison(Reg: Operand.getReg(), MRI)) |
257 | continue; |
258 | |
259 | if (!MaybePoisonOperand) |
260 | MaybePoisonOperand = Operand; |
261 | else { |
262 | // We have more than one maybe-poison operand. Moving the freeze is |
263 | // unsafe. |
264 | return false; |
265 | } |
266 | } |
267 | |
268 | // Eliminate freeze if all operands are guaranteed non-poison. |
269 | if (!MaybePoisonOperand) { |
270 | MatchInfo = [=](MachineIRBuilder &B) { |
271 | Observer.changingInstr(MI&: *OrigDef); |
272 | cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags(); |
273 | Observer.changedInstr(MI&: *OrigDef); |
274 | B.buildCopy(Res: DstOp, Op: OrigOp); |
275 | }; |
276 | return true; |
277 | } |
278 | |
279 | Register MaybePoisonOperandReg = MaybePoisonOperand->getReg(); |
280 | LLT MaybePoisonOperandRegTy = MRI.getType(Reg: MaybePoisonOperandReg); |
281 | |
282 | MatchInfo = [=](MachineIRBuilder &B) mutable { |
283 | Observer.changingInstr(MI&: *OrigDef); |
284 | cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags(); |
285 | Observer.changedInstr(MI&: *OrigDef); |
286 | B.setInsertPt(MBB&: *OrigDef->getParent(), II: OrigDef->getIterator()); |
287 | auto Freeze = B.buildFreeze(Dst: MaybePoisonOperandRegTy, Src: MaybePoisonOperandReg); |
288 | replaceRegOpWith( |
289 | MRI, FromRegOp&: *OrigDef->findRegisterUseOperand(Reg: MaybePoisonOperandReg, TRI), |
290 | ToReg: Freeze.getReg(Idx: 0)); |
291 | replaceRegWith(MRI, FromReg: DstOp, ToReg: OrigOp); |
292 | }; |
293 | return true; |
294 | } |
295 | |
296 | bool CombinerHelper::matchCombineConcatVectors(MachineInstr &MI, |
297 | SmallVector<Register> &Ops) { |
298 | assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS && |
299 | "Invalid instruction" ); |
300 | bool IsUndef = true; |
301 | MachineInstr *Undef = nullptr; |
302 | |
303 | // Walk over all the operands of concat vectors and check if they are |
304 | // build_vector themselves or undef. |
305 | // Then collect their operands in Ops. |
306 | for (const MachineOperand &MO : MI.uses()) { |
307 | Register Reg = MO.getReg(); |
308 | MachineInstr *Def = MRI.getVRegDef(Reg); |
309 | assert(Def && "Operand not defined" ); |
310 | if (!MRI.hasOneNonDBGUse(RegNo: Reg)) |
311 | return false; |
312 | switch (Def->getOpcode()) { |
313 | case TargetOpcode::G_BUILD_VECTOR: |
314 | IsUndef = false; |
315 | // Remember the operands of the build_vector to fold |
316 | // them into the yet-to-build flattened concat vectors. |
317 | for (const MachineOperand &BuildVecMO : Def->uses()) |
318 | Ops.push_back(Elt: BuildVecMO.getReg()); |
319 | break; |
320 | case TargetOpcode::G_IMPLICIT_DEF: { |
321 | LLT OpType = MRI.getType(Reg); |
322 | // Keep one undef value for all the undef operands. |
323 | if (!Undef) { |
324 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
325 | Undef = Builder.buildUndef(Res: OpType.getScalarType()); |
326 | } |
327 | assert(MRI.getType(Undef->getOperand(0).getReg()) == |
328 | OpType.getScalarType() && |
329 | "All undefs should have the same type" ); |
330 | // Break the undef vector in as many scalar elements as needed |
331 | // for the flattening. |
332 | for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements(); |
333 | EltIdx != EltEnd; ++EltIdx) |
334 | Ops.push_back(Elt: Undef->getOperand(i: 0).getReg()); |
335 | break; |
336 | } |
337 | default: |
338 | return false; |
339 | } |
340 | } |
341 | |
342 | // Check if the combine is illegal |
343 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
344 | if (!isLegalOrBeforeLegalizer( |
345 | Query: {TargetOpcode::G_BUILD_VECTOR, {DstTy, MRI.getType(Reg: Ops[0])}})) { |
346 | return false; |
347 | } |
348 | |
349 | if (IsUndef) |
350 | Ops.clear(); |
351 | |
352 | return true; |
353 | } |
354 | void CombinerHelper::applyCombineConcatVectors(MachineInstr &MI, |
355 | SmallVector<Register> &Ops) { |
356 | // We determined that the concat_vectors can be flatten. |
357 | // Generate the flattened build_vector. |
358 | Register DstReg = MI.getOperand(i: 0).getReg(); |
359 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
360 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg); |
361 | |
362 | // Note: IsUndef is sort of redundant. We could have determine it by |
363 | // checking that at all Ops are undef. Alternatively, we could have |
364 | // generate a build_vector of undefs and rely on another combine to |
365 | // clean that up. For now, given we already gather this information |
366 | // in matchCombineConcatVectors, just save compile time and issue the |
367 | // right thing. |
368 | if (Ops.empty()) |
369 | Builder.buildUndef(Res: NewDstReg); |
370 | else |
371 | Builder.buildBuildVector(Res: NewDstReg, Ops); |
372 | MI.eraseFromParent(); |
373 | replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg); |
374 | } |
375 | |
376 | bool CombinerHelper::matchCombineShuffleConcat(MachineInstr &MI, |
377 | SmallVector<Register> &Ops) { |
378 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
379 | auto ConcatMI1 = |
380 | dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg())); |
381 | auto ConcatMI2 = |
382 | dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg())); |
383 | if (!ConcatMI1 || !ConcatMI2) |
384 | return false; |
385 | |
386 | // Check that the sources of the Concat instructions have the same type |
387 | if (MRI.getType(Reg: ConcatMI1->getSourceReg(I: 0)) != |
388 | MRI.getType(Reg: ConcatMI2->getSourceReg(I: 0))) |
389 | return false; |
390 | |
391 | LLT ConcatSrcTy = MRI.getType(Reg: ConcatMI1->getReg(Idx: 1)); |
392 | LLT ShuffleSrcTy1 = MRI.getType(Reg: MI.getOperand(i: 1).getReg()); |
393 | unsigned ConcatSrcNumElt = ConcatSrcTy.getNumElements(); |
394 | for (unsigned i = 0; i < Mask.size(); i += ConcatSrcNumElt) { |
395 | // Check if the index takes a whole source register from G_CONCAT_VECTORS |
396 | // Assumes that all Sources of G_CONCAT_VECTORS are the same type |
397 | if (Mask[i] == -1) { |
398 | for (unsigned j = 1; j < ConcatSrcNumElt; j++) { |
399 | if (i + j >= Mask.size()) |
400 | return false; |
401 | if (Mask[i + j] != -1) |
402 | return false; |
403 | } |
404 | if (!isLegalOrBeforeLegalizer( |
405 | Query: {TargetOpcode::G_IMPLICIT_DEF, {ConcatSrcTy}})) |
406 | return false; |
407 | Ops.push_back(Elt: 0); |
408 | } else if (Mask[i] % ConcatSrcNumElt == 0) { |
409 | for (unsigned j = 1; j < ConcatSrcNumElt; j++) { |
410 | if (i + j >= Mask.size()) |
411 | return false; |
412 | if (Mask[i + j] != Mask[i] + static_cast<int>(j)) |
413 | return false; |
414 | } |
415 | // Retrieve the source register from its respective G_CONCAT_VECTORS |
416 | // instruction |
417 | if (Mask[i] < ShuffleSrcTy1.getNumElements()) { |
418 | Ops.push_back(Elt: ConcatMI1->getSourceReg(I: Mask[i] / ConcatSrcNumElt)); |
419 | } else { |
420 | Ops.push_back(Elt: ConcatMI2->getSourceReg(I: Mask[i] / ConcatSrcNumElt - |
421 | ConcatMI1->getNumSources())); |
422 | } |
423 | } else { |
424 | return false; |
425 | } |
426 | } |
427 | |
428 | if (!isLegalOrBeforeLegalizer( |
429 | Query: {TargetOpcode::G_CONCAT_VECTORS, |
430 | {MRI.getType(Reg: MI.getOperand(i: 0).getReg()), ConcatSrcTy}})) |
431 | return false; |
432 | |
433 | return !Ops.empty(); |
434 | } |
435 | |
436 | void CombinerHelper::applyCombineShuffleConcat(MachineInstr &MI, |
437 | SmallVector<Register> &Ops) { |
438 | LLT SrcTy = MRI.getType(Reg: Ops[0]); |
439 | Register UndefReg = 0; |
440 | |
441 | for (Register &Reg : Ops) { |
442 | if (Reg == 0) { |
443 | if (UndefReg == 0) |
444 | UndefReg = Builder.buildUndef(Res: SrcTy).getReg(Idx: 0); |
445 | Reg = UndefReg; |
446 | } |
447 | } |
448 | |
449 | if (Ops.size() > 1) |
450 | Builder.buildConcatVectors(Res: MI.getOperand(i: 0).getReg(), Ops); |
451 | else |
452 | Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: Ops[0]); |
453 | MI.eraseFromParent(); |
454 | } |
455 | |
456 | bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) { |
457 | SmallVector<Register, 4> Ops; |
458 | if (matchCombineShuffleVector(MI, Ops)) { |
459 | applyCombineShuffleVector(MI, Ops); |
460 | return true; |
461 | } |
462 | return false; |
463 | } |
464 | |
465 | bool CombinerHelper::matchCombineShuffleVector(MachineInstr &MI, |
466 | SmallVectorImpl<Register> &Ops) { |
467 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && |
468 | "Invalid instruction kind" ); |
469 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
470 | Register Src1 = MI.getOperand(i: 1).getReg(); |
471 | LLT SrcType = MRI.getType(Reg: Src1); |
472 | // As bizarre as it may look, shuffle vector can actually produce |
473 | // scalar! This is because at the IR level a <1 x ty> shuffle |
474 | // vector is perfectly valid. |
475 | unsigned DstNumElts = DstType.isVector() ? DstType.getNumElements() : 1; |
476 | unsigned SrcNumElts = SrcType.isVector() ? SrcType.getNumElements() : 1; |
477 | |
478 | // If the resulting vector is smaller than the size of the source |
479 | // vectors being concatenated, we won't be able to replace the |
480 | // shuffle vector into a concat_vectors. |
481 | // |
482 | // Note: We may still be able to produce a concat_vectors fed by |
483 | // extract_vector_elt and so on. It is less clear that would |
484 | // be better though, so don't bother for now. |
485 | // |
486 | // If the destination is a scalar, the size of the sources doesn't |
487 | // matter. we will lower the shuffle to a plain copy. This will |
488 | // work only if the source and destination have the same size. But |
489 | // that's covered by the next condition. |
490 | // |
491 | // TODO: If the size between the source and destination don't match |
492 | // we could still emit an extract vector element in that case. |
493 | if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1) |
494 | return false; |
495 | |
496 | // Check that the shuffle mask can be broken evenly between the |
497 | // different sources. |
498 | if (DstNumElts % SrcNumElts != 0) |
499 | return false; |
500 | |
501 | // Mask length is a multiple of the source vector length. |
502 | // Check if the shuffle is some kind of concatenation of the input |
503 | // vectors. |
504 | unsigned NumConcat = DstNumElts / SrcNumElts; |
505 | SmallVector<int, 8> ConcatSrcs(NumConcat, -1); |
506 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
507 | for (unsigned i = 0; i != DstNumElts; ++i) { |
508 | int Idx = Mask[i]; |
509 | // Undef value. |
510 | if (Idx < 0) |
511 | continue; |
512 | // Ensure the indices in each SrcType sized piece are sequential and that |
513 | // the same source is used for the whole piece. |
514 | if ((Idx % SrcNumElts != (i % SrcNumElts)) || |
515 | (ConcatSrcs[i / SrcNumElts] >= 0 && |
516 | ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts))) |
517 | return false; |
518 | // Remember which source this index came from. |
519 | ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts; |
520 | } |
521 | |
522 | // The shuffle is concatenating multiple vectors together. |
523 | // Collect the different operands for that. |
524 | Register UndefReg; |
525 | Register Src2 = MI.getOperand(i: 2).getReg(); |
526 | for (auto Src : ConcatSrcs) { |
527 | if (Src < 0) { |
528 | if (!UndefReg) { |
529 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
530 | UndefReg = Builder.buildUndef(Res: SrcType).getReg(Idx: 0); |
531 | } |
532 | Ops.push_back(Elt: UndefReg); |
533 | } else if (Src == 0) |
534 | Ops.push_back(Elt: Src1); |
535 | else |
536 | Ops.push_back(Elt: Src2); |
537 | } |
538 | return true; |
539 | } |
540 | |
541 | void CombinerHelper::applyCombineShuffleVector(MachineInstr &MI, |
542 | const ArrayRef<Register> Ops) { |
543 | Register DstReg = MI.getOperand(i: 0).getReg(); |
544 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
545 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg); |
546 | |
547 | if (Ops.size() == 1) |
548 | Builder.buildCopy(Res: NewDstReg, Op: Ops[0]); |
549 | else |
550 | Builder.buildMergeLikeInstr(Res: NewDstReg, Ops); |
551 | |
552 | MI.eraseFromParent(); |
553 | replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg); |
554 | } |
555 | |
556 | bool CombinerHelper::(MachineInstr &MI) { |
557 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && |
558 | "Invalid instruction kind" ); |
559 | |
560 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
561 | return Mask.size() == 1; |
562 | } |
563 | |
564 | void CombinerHelper::(MachineInstr &MI) { |
565 | Register DstReg = MI.getOperand(i: 0).getReg(); |
566 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
567 | |
568 | int I = MI.getOperand(i: 3).getShuffleMask()[0]; |
569 | Register Src1 = MI.getOperand(i: 1).getReg(); |
570 | LLT Src1Ty = MRI.getType(Reg: Src1); |
571 | int Src1NumElts = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1; |
572 | Register SrcReg; |
573 | if (I >= Src1NumElts) { |
574 | SrcReg = MI.getOperand(i: 2).getReg(); |
575 | I -= Src1NumElts; |
576 | } else if (I >= 0) |
577 | SrcReg = Src1; |
578 | |
579 | if (I < 0) |
580 | Builder.buildUndef(Res: DstReg); |
581 | else if (!MRI.getType(Reg: SrcReg).isVector()) |
582 | Builder.buildCopy(Res: DstReg, Op: SrcReg); |
583 | else |
584 | Builder.buildExtractVectorElementConstant(Res: DstReg, Val: SrcReg, Idx: I); |
585 | |
586 | MI.eraseFromParent(); |
587 | } |
588 | |
589 | namespace { |
590 | |
591 | /// Select a preference between two uses. CurrentUse is the current preference |
592 | /// while *ForCandidate is attributes of the candidate under consideration. |
593 | PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI, |
594 | PreferredTuple &CurrentUse, |
595 | const LLT TyForCandidate, |
596 | unsigned OpcodeForCandidate, |
597 | MachineInstr *MIForCandidate) { |
598 | if (!CurrentUse.Ty.isValid()) { |
599 | if (CurrentUse.ExtendOpcode == OpcodeForCandidate || |
600 | CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT) |
601 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
602 | return CurrentUse; |
603 | } |
604 | |
605 | // We permit the extend to hoist through basic blocks but this is only |
606 | // sensible if the target has extending loads. If you end up lowering back |
607 | // into a load and extend during the legalizer then the end result is |
608 | // hoisting the extend up to the load. |
609 | |
610 | // Prefer defined extensions to undefined extensions as these are more |
611 | // likely to reduce the number of instructions. |
612 | if (OpcodeForCandidate == TargetOpcode::G_ANYEXT && |
613 | CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT) |
614 | return CurrentUse; |
615 | else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT && |
616 | OpcodeForCandidate != TargetOpcode::G_ANYEXT) |
617 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
618 | |
619 | // Prefer sign extensions to zero extensions as sign-extensions tend to be |
620 | // more expensive. Don't do this if the load is already a zero-extend load |
621 | // though, otherwise we'll rewrite a zero-extend load into a sign-extend |
622 | // later. |
623 | if (!isa<GZExtLoad>(Val: LoadMI) && CurrentUse.Ty == TyForCandidate) { |
624 | if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT && |
625 | OpcodeForCandidate == TargetOpcode::G_ZEXT) |
626 | return CurrentUse; |
627 | else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT && |
628 | OpcodeForCandidate == TargetOpcode::G_SEXT) |
629 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
630 | } |
631 | |
632 | // This is potentially target specific. We've chosen the largest type |
633 | // because G_TRUNC is usually free. One potential catch with this is that |
634 | // some targets have a reduced number of larger registers than smaller |
635 | // registers and this choice potentially increases the live-range for the |
636 | // larger value. |
637 | if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) { |
638 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
639 | } |
640 | return CurrentUse; |
641 | } |
642 | |
643 | /// Find a suitable place to insert some instructions and insert them. This |
644 | /// function accounts for special cases like inserting before a PHI node. |
645 | /// The current strategy for inserting before PHI's is to duplicate the |
646 | /// instructions for each predecessor. However, while that's ok for G_TRUNC |
647 | /// on most targets since it generally requires no code, other targets/cases may |
648 | /// want to try harder to find a dominating block. |
649 | static void InsertInsnsWithoutSideEffectsBeforeUse( |
650 | MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO, |
651 | std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator, |
652 | MachineOperand &UseMO)> |
653 | Inserter) { |
654 | MachineInstr &UseMI = *UseMO.getParent(); |
655 | |
656 | MachineBasicBlock *InsertBB = UseMI.getParent(); |
657 | |
658 | // If the use is a PHI then we want the predecessor block instead. |
659 | if (UseMI.isPHI()) { |
660 | MachineOperand *PredBB = std::next(x: &UseMO); |
661 | InsertBB = PredBB->getMBB(); |
662 | } |
663 | |
664 | // If the block is the same block as the def then we want to insert just after |
665 | // the def instead of at the start of the block. |
666 | if (InsertBB == DefMI.getParent()) { |
667 | MachineBasicBlock::iterator InsertPt = &DefMI; |
668 | Inserter(InsertBB, std::next(x: InsertPt), UseMO); |
669 | return; |
670 | } |
671 | |
672 | // Otherwise we want the start of the BB |
673 | Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO); |
674 | } |
675 | } // end anonymous namespace |
676 | |
677 | bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) { |
678 | PreferredTuple Preferred; |
679 | if (matchCombineExtendingLoads(MI, MatchInfo&: Preferred)) { |
680 | applyCombineExtendingLoads(MI, MatchInfo&: Preferred); |
681 | return true; |
682 | } |
683 | return false; |
684 | } |
685 | |
686 | static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) { |
687 | unsigned CandidateLoadOpc; |
688 | switch (ExtOpc) { |
689 | case TargetOpcode::G_ANYEXT: |
690 | CandidateLoadOpc = TargetOpcode::G_LOAD; |
691 | break; |
692 | case TargetOpcode::G_SEXT: |
693 | CandidateLoadOpc = TargetOpcode::G_SEXTLOAD; |
694 | break; |
695 | case TargetOpcode::G_ZEXT: |
696 | CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD; |
697 | break; |
698 | default: |
699 | llvm_unreachable("Unexpected extend opc" ); |
700 | } |
701 | return CandidateLoadOpc; |
702 | } |
703 | |
704 | bool CombinerHelper::matchCombineExtendingLoads(MachineInstr &MI, |
705 | PreferredTuple &Preferred) { |
706 | // We match the loads and follow the uses to the extend instead of matching |
707 | // the extends and following the def to the load. This is because the load |
708 | // must remain in the same position for correctness (unless we also add code |
709 | // to find a safe place to sink it) whereas the extend is freely movable. |
710 | // It also prevents us from duplicating the load for the volatile case or just |
711 | // for performance. |
712 | GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: &MI); |
713 | if (!LoadMI) |
714 | return false; |
715 | |
716 | Register LoadReg = LoadMI->getDstReg(); |
717 | |
718 | LLT LoadValueTy = MRI.getType(Reg: LoadReg); |
719 | if (!LoadValueTy.isScalar()) |
720 | return false; |
721 | |
722 | // Most architectures are going to legalize <s8 loads into at least a 1 byte |
723 | // load, and the MMOs can only describe memory accesses in multiples of bytes. |
724 | // If we try to perform extload combining on those, we can end up with |
725 | // %a(s8) = extload %ptr (load 1 byte from %ptr) |
726 | // ... which is an illegal extload instruction. |
727 | if (LoadValueTy.getSizeInBits() < 8) |
728 | return false; |
729 | |
730 | // For non power-of-2 types, they will very likely be legalized into multiple |
731 | // loads. Don't bother trying to match them into extending loads. |
732 | if (!llvm::has_single_bit<uint32_t>(Value: LoadValueTy.getSizeInBits())) |
733 | return false; |
734 | |
735 | // Find the preferred type aside from the any-extends (unless it's the only |
736 | // one) and non-extending ops. We'll emit an extending load to that type and |
737 | // and emit a variant of (extend (trunc X)) for the others according to the |
738 | // relative type sizes. At the same time, pick an extend to use based on the |
739 | // extend involved in the chosen type. |
740 | unsigned PreferredOpcode = |
741 | isa<GLoad>(Val: &MI) |
742 | ? TargetOpcode::G_ANYEXT |
743 | : isa<GSExtLoad>(Val: &MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT; |
744 | Preferred = {.Ty: LLT(), .ExtendOpcode: PreferredOpcode, .MI: nullptr}; |
745 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: LoadReg)) { |
746 | if (UseMI.getOpcode() == TargetOpcode::G_SEXT || |
747 | UseMI.getOpcode() == TargetOpcode::G_ZEXT || |
748 | (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) { |
749 | const auto &MMO = LoadMI->getMMO(); |
750 | // Don't do anything for atomics. |
751 | if (MMO.isAtomic()) |
752 | continue; |
753 | // Check for legality. |
754 | if (!isPreLegalize()) { |
755 | LegalityQuery::MemDesc MMDesc(MMO); |
756 | unsigned CandidateLoadOpc = getExtLoadOpcForExtend(ExtOpc: UseMI.getOpcode()); |
757 | LLT UseTy = MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()); |
758 | LLT SrcTy = MRI.getType(Reg: LoadMI->getPointerReg()); |
759 | if (LI->getAction(Query: {CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}}) |
760 | .Action != LegalizeActions::Legal) |
761 | continue; |
762 | } |
763 | Preferred = ChoosePreferredUse(LoadMI&: MI, CurrentUse&: Preferred, |
764 | TyForCandidate: MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()), |
765 | OpcodeForCandidate: UseMI.getOpcode(), MIForCandidate: &UseMI); |
766 | } |
767 | } |
768 | |
769 | // There were no extends |
770 | if (!Preferred.MI) |
771 | return false; |
772 | // It should be impossible to chose an extend without selecting a different |
773 | // type since by definition the result of an extend is larger. |
774 | assert(Preferred.Ty != LoadValueTy && "Extending to same type?" ); |
775 | |
776 | LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI); |
777 | return true; |
778 | } |
779 | |
780 | void CombinerHelper::applyCombineExtendingLoads(MachineInstr &MI, |
781 | PreferredTuple &Preferred) { |
782 | // Rewrite the load to the chosen extending load. |
783 | Register ChosenDstReg = Preferred.MI->getOperand(i: 0).getReg(); |
784 | |
785 | // Inserter to insert a truncate back to the original type at a given point |
786 | // with some basic CSE to limit truncate duplication to one per BB. |
787 | DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns; |
788 | auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB, |
789 | MachineBasicBlock::iterator InsertBefore, |
790 | MachineOperand &UseMO) { |
791 | MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(Val: InsertIntoBB); |
792 | if (PreviouslyEmitted) { |
793 | Observer.changingInstr(MI&: *UseMO.getParent()); |
794 | UseMO.setReg(PreviouslyEmitted->getOperand(i: 0).getReg()); |
795 | Observer.changedInstr(MI&: *UseMO.getParent()); |
796 | return; |
797 | } |
798 | |
799 | Builder.setInsertPt(MBB&: *InsertIntoBB, II: InsertBefore); |
800 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: MI.getOperand(i: 0).getReg()); |
801 | MachineInstr *NewMI = Builder.buildTrunc(Res: NewDstReg, Op: ChosenDstReg); |
802 | EmittedInsns[InsertIntoBB] = NewMI; |
803 | replaceRegOpWith(MRI, FromRegOp&: UseMO, ToReg: NewDstReg); |
804 | }; |
805 | |
806 | Observer.changingInstr(MI); |
807 | unsigned LoadOpc = getExtLoadOpcForExtend(ExtOpc: Preferred.ExtendOpcode); |
808 | MI.setDesc(Builder.getTII().get(Opcode: LoadOpc)); |
809 | |
810 | // Rewrite all the uses to fix up the types. |
811 | auto &LoadValue = MI.getOperand(i: 0); |
812 | SmallVector<MachineOperand *, 4> Uses; |
813 | for (auto &UseMO : MRI.use_operands(Reg: LoadValue.getReg())) |
814 | Uses.push_back(Elt: &UseMO); |
815 | |
816 | for (auto *UseMO : Uses) { |
817 | MachineInstr *UseMI = UseMO->getParent(); |
818 | |
819 | // If the extend is compatible with the preferred extend then we should fix |
820 | // up the type and extend so that it uses the preferred use. |
821 | if (UseMI->getOpcode() == Preferred.ExtendOpcode || |
822 | UseMI->getOpcode() == TargetOpcode::G_ANYEXT) { |
823 | Register UseDstReg = UseMI->getOperand(i: 0).getReg(); |
824 | MachineOperand &UseSrcMO = UseMI->getOperand(i: 1); |
825 | const LLT UseDstTy = MRI.getType(Reg: UseDstReg); |
826 | if (UseDstReg != ChosenDstReg) { |
827 | if (Preferred.Ty == UseDstTy) { |
828 | // If the use has the same type as the preferred use, then merge |
829 | // the vregs and erase the extend. For example: |
830 | // %1:_(s8) = G_LOAD ... |
831 | // %2:_(s32) = G_SEXT %1(s8) |
832 | // %3:_(s32) = G_ANYEXT %1(s8) |
833 | // ... = ... %3(s32) |
834 | // rewrites to: |
835 | // %2:_(s32) = G_SEXTLOAD ... |
836 | // ... = ... %2(s32) |
837 | replaceRegWith(MRI, FromReg: UseDstReg, ToReg: ChosenDstReg); |
838 | Observer.erasingInstr(MI&: *UseMO->getParent()); |
839 | UseMO->getParent()->eraseFromParent(); |
840 | } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) { |
841 | // If the preferred size is smaller, then keep the extend but extend |
842 | // from the result of the extending load. For example: |
843 | // %1:_(s8) = G_LOAD ... |
844 | // %2:_(s32) = G_SEXT %1(s8) |
845 | // %3:_(s64) = G_ANYEXT %1(s8) |
846 | // ... = ... %3(s64) |
847 | /// rewrites to: |
848 | // %2:_(s32) = G_SEXTLOAD ... |
849 | // %3:_(s64) = G_ANYEXT %2:_(s32) |
850 | // ... = ... %3(s64) |
851 | replaceRegOpWith(MRI, FromRegOp&: UseSrcMO, ToReg: ChosenDstReg); |
852 | } else { |
853 | // If the preferred size is large, then insert a truncate. For |
854 | // example: |
855 | // %1:_(s8) = G_LOAD ... |
856 | // %2:_(s64) = G_SEXT %1(s8) |
857 | // %3:_(s32) = G_ZEXT %1(s8) |
858 | // ... = ... %3(s32) |
859 | /// rewrites to: |
860 | // %2:_(s64) = G_SEXTLOAD ... |
861 | // %4:_(s8) = G_TRUNC %2:_(s32) |
862 | // %3:_(s64) = G_ZEXT %2:_(s8) |
863 | // ... = ... %3(s64) |
864 | InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, |
865 | Inserter: InsertTruncAt); |
866 | } |
867 | continue; |
868 | } |
869 | // The use is (one of) the uses of the preferred use we chose earlier. |
870 | // We're going to update the load to def this value later so just erase |
871 | // the old extend. |
872 | Observer.erasingInstr(MI&: *UseMO->getParent()); |
873 | UseMO->getParent()->eraseFromParent(); |
874 | continue; |
875 | } |
876 | |
877 | // The use isn't an extend. Truncate back to the type we originally loaded. |
878 | // This is free on many targets. |
879 | InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, Inserter: InsertTruncAt); |
880 | } |
881 | |
882 | MI.getOperand(i: 0).setReg(ChosenDstReg); |
883 | Observer.changedInstr(MI); |
884 | } |
885 | |
886 | bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI, |
887 | BuildFnTy &MatchInfo) { |
888 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
889 | |
890 | // If we have the following code: |
891 | // %mask = G_CONSTANT 255 |
892 | // %ld = G_LOAD %ptr, (load s16) |
893 | // %and = G_AND %ld, %mask |
894 | // |
895 | // Try to fold it into |
896 | // %ld = G_ZEXTLOAD %ptr, (load s8) |
897 | |
898 | Register Dst = MI.getOperand(i: 0).getReg(); |
899 | if (MRI.getType(Reg: Dst).isVector()) |
900 | return false; |
901 | |
902 | auto MaybeMask = |
903 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
904 | if (!MaybeMask) |
905 | return false; |
906 | |
907 | APInt MaskVal = MaybeMask->Value; |
908 | |
909 | if (!MaskVal.isMask()) |
910 | return false; |
911 | |
912 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
913 | // Don't use getOpcodeDef() here since intermediate instructions may have |
914 | // multiple users. |
915 | GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: MRI.getVRegDef(Reg: SrcReg)); |
916 | if (!LoadMI || !MRI.hasOneNonDBGUse(RegNo: LoadMI->getDstReg())) |
917 | return false; |
918 | |
919 | Register LoadReg = LoadMI->getDstReg(); |
920 | LLT RegTy = MRI.getType(Reg: LoadReg); |
921 | Register PtrReg = LoadMI->getPointerReg(); |
922 | unsigned RegSize = RegTy.getSizeInBits(); |
923 | LocationSize LoadSizeBits = LoadMI->getMemSizeInBits(); |
924 | unsigned MaskSizeBits = MaskVal.countr_one(); |
925 | |
926 | // The mask may not be larger than the in-memory type, as it might cover sign |
927 | // extended bits |
928 | if (MaskSizeBits > LoadSizeBits.getValue()) |
929 | return false; |
930 | |
931 | // If the mask covers the whole destination register, there's nothing to |
932 | // extend |
933 | if (MaskSizeBits >= RegSize) |
934 | return false; |
935 | |
936 | // Most targets cannot deal with loads of size < 8 and need to re-legalize to |
937 | // at least byte loads. Avoid creating such loads here |
938 | if (MaskSizeBits < 8 || !isPowerOf2_32(Value: MaskSizeBits)) |
939 | return false; |
940 | |
941 | const MachineMemOperand &MMO = LoadMI->getMMO(); |
942 | LegalityQuery::MemDesc MemDesc(MMO); |
943 | |
944 | // Don't modify the memory access size if this is atomic/volatile, but we can |
945 | // still adjust the opcode to indicate the high bit behavior. |
946 | if (LoadMI->isSimple()) |
947 | MemDesc.MemoryTy = LLT::scalar(SizeInBits: MaskSizeBits); |
948 | else if (LoadSizeBits.getValue() > MaskSizeBits || |
949 | LoadSizeBits.getValue() == RegSize) |
950 | return false; |
951 | |
952 | // TODO: Could check if it's legal with the reduced or original memory size. |
953 | if (!isLegalOrBeforeLegalizer( |
954 | Query: {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(Reg: PtrReg)}, {MemDesc}})) |
955 | return false; |
956 | |
957 | MatchInfo = [=](MachineIRBuilder &B) { |
958 | B.setInstrAndDebugLoc(*LoadMI); |
959 | auto &MF = B.getMF(); |
960 | auto PtrInfo = MMO.getPointerInfo(); |
961 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: MemDesc.MemoryTy); |
962 | B.buildLoadInstr(Opcode: TargetOpcode::G_ZEXTLOAD, Res: Dst, Addr: PtrReg, MMO&: *NewMMO); |
963 | LoadMI->eraseFromParent(); |
964 | }; |
965 | return true; |
966 | } |
967 | |
968 | bool CombinerHelper::isPredecessor(const MachineInstr &DefMI, |
969 | const MachineInstr &UseMI) { |
970 | assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && |
971 | "shouldn't consider debug uses" ); |
972 | assert(DefMI.getParent() == UseMI.getParent()); |
973 | if (&DefMI == &UseMI) |
974 | return true; |
975 | const MachineBasicBlock &MBB = *DefMI.getParent(); |
976 | auto DefOrUse = find_if(Range: MBB, P: [&DefMI, &UseMI](const MachineInstr &MI) { |
977 | return &MI == &DefMI || &MI == &UseMI; |
978 | }); |
979 | if (DefOrUse == MBB.end()) |
980 | llvm_unreachable("Block must contain both DefMI and UseMI!" ); |
981 | return &*DefOrUse == &DefMI; |
982 | } |
983 | |
984 | bool CombinerHelper::dominates(const MachineInstr &DefMI, |
985 | const MachineInstr &UseMI) { |
986 | assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && |
987 | "shouldn't consider debug uses" ); |
988 | if (MDT) |
989 | return MDT->dominates(A: &DefMI, B: &UseMI); |
990 | else if (DefMI.getParent() != UseMI.getParent()) |
991 | return false; |
992 | |
993 | return isPredecessor(DefMI, UseMI); |
994 | } |
995 | |
996 | bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) { |
997 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
998 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
999 | Register LoadUser = SrcReg; |
1000 | |
1001 | if (MRI.getType(Reg: SrcReg).isVector()) |
1002 | return false; |
1003 | |
1004 | Register TruncSrc; |
1005 | if (mi_match(R: SrcReg, MRI, P: m_GTrunc(Src: m_Reg(R&: TruncSrc)))) |
1006 | LoadUser = TruncSrc; |
1007 | |
1008 | uint64_t SizeInBits = MI.getOperand(i: 2).getImm(); |
1009 | // If the source is a G_SEXTLOAD from the same bit width, then we don't |
1010 | // need any extend at all, just a truncate. |
1011 | if (auto *LoadMI = getOpcodeDef<GSExtLoad>(Reg: LoadUser, MRI)) { |
1012 | // If truncating more than the original extended value, abort. |
1013 | auto LoadSizeBits = LoadMI->getMemSizeInBits(); |
1014 | if (TruncSrc && |
1015 | MRI.getType(Reg: TruncSrc).getSizeInBits() < LoadSizeBits.getValue()) |
1016 | return false; |
1017 | if (LoadSizeBits == SizeInBits) |
1018 | return true; |
1019 | } |
1020 | return false; |
1021 | } |
1022 | |
1023 | void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) { |
1024 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
1025 | Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: MI.getOperand(i: 1).getReg()); |
1026 | MI.eraseFromParent(); |
1027 | } |
1028 | |
1029 | bool CombinerHelper::matchSextInRegOfLoad( |
1030 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
1031 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
1032 | |
1033 | Register DstReg = MI.getOperand(i: 0).getReg(); |
1034 | LLT RegTy = MRI.getType(Reg: DstReg); |
1035 | |
1036 | // Only supports scalars for now. |
1037 | if (RegTy.isVector()) |
1038 | return false; |
1039 | |
1040 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
1041 | auto *LoadDef = getOpcodeDef<GLoad>(Reg: SrcReg, MRI); |
1042 | if (!LoadDef || !MRI.hasOneNonDBGUse(RegNo: DstReg)) |
1043 | return false; |
1044 | |
1045 | uint64_t MemBits = LoadDef->getMemSizeInBits().getValue(); |
1046 | |
1047 | // If the sign extend extends from a narrower width than the load's width, |
1048 | // then we can narrow the load width when we combine to a G_SEXTLOAD. |
1049 | // Avoid widening the load at all. |
1050 | unsigned NewSizeBits = std::min(a: (uint64_t)MI.getOperand(i: 2).getImm(), b: MemBits); |
1051 | |
1052 | // Don't generate G_SEXTLOADs with a < 1 byte width. |
1053 | if (NewSizeBits < 8) |
1054 | return false; |
1055 | // Don't bother creating a non-power-2 sextload, it will likely be broken up |
1056 | // anyway for most targets. |
1057 | if (!isPowerOf2_32(Value: NewSizeBits)) |
1058 | return false; |
1059 | |
1060 | const MachineMemOperand &MMO = LoadDef->getMMO(); |
1061 | LegalityQuery::MemDesc MMDesc(MMO); |
1062 | |
1063 | // Don't modify the memory access size if this is atomic/volatile, but we can |
1064 | // still adjust the opcode to indicate the high bit behavior. |
1065 | if (LoadDef->isSimple()) |
1066 | MMDesc.MemoryTy = LLT::scalar(SizeInBits: NewSizeBits); |
1067 | else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits()) |
1068 | return false; |
1069 | |
1070 | // TODO: Could check if it's legal with the reduced or original memory size. |
1071 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXTLOAD, |
1072 | {MRI.getType(Reg: LoadDef->getDstReg()), |
1073 | MRI.getType(Reg: LoadDef->getPointerReg())}, |
1074 | {MMDesc}})) |
1075 | return false; |
1076 | |
1077 | MatchInfo = std::make_tuple(args: LoadDef->getDstReg(), args&: NewSizeBits); |
1078 | return true; |
1079 | } |
1080 | |
1081 | void CombinerHelper::applySextInRegOfLoad( |
1082 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
1083 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
1084 | Register LoadReg; |
1085 | unsigned ScalarSizeBits; |
1086 | std::tie(args&: LoadReg, args&: ScalarSizeBits) = MatchInfo; |
1087 | GLoad *LoadDef = cast<GLoad>(Val: MRI.getVRegDef(Reg: LoadReg)); |
1088 | |
1089 | // If we have the following: |
1090 | // %ld = G_LOAD %ptr, (load 2) |
1091 | // %ext = G_SEXT_INREG %ld, 8 |
1092 | // ==> |
1093 | // %ld = G_SEXTLOAD %ptr (load 1) |
1094 | |
1095 | auto &MMO = LoadDef->getMMO(); |
1096 | Builder.setInstrAndDebugLoc(*LoadDef); |
1097 | auto &MF = Builder.getMF(); |
1098 | auto PtrInfo = MMO.getPointerInfo(); |
1099 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: ScalarSizeBits / 8); |
1100 | Builder.buildLoadInstr(Opcode: TargetOpcode::G_SEXTLOAD, Res: MI.getOperand(i: 0).getReg(), |
1101 | Addr: LoadDef->getPointerReg(), MMO&: *NewMMO); |
1102 | MI.eraseFromParent(); |
1103 | } |
1104 | |
1105 | /// Return true if 'MI' is a load or a store that may be fold it's address |
1106 | /// operand into the load / store addressing mode. |
1107 | static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI, |
1108 | MachineRegisterInfo &MRI) { |
1109 | TargetLowering::AddrMode AM; |
1110 | auto *MF = MI->getMF(); |
1111 | auto *Addr = getOpcodeDef<GPtrAdd>(Reg: MI->getPointerReg(), MRI); |
1112 | if (!Addr) |
1113 | return false; |
1114 | |
1115 | AM.HasBaseReg = true; |
1116 | if (auto CstOff = getIConstantVRegVal(VReg: Addr->getOffsetReg(), MRI)) |
1117 | AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm] |
1118 | else |
1119 | AM.Scale = 1; // [reg +/- reg] |
1120 | |
1121 | return TLI.isLegalAddressingMode( |
1122 | DL: MF->getDataLayout(), AM, |
1123 | Ty: getTypeForLLT(Ty: MI->getMMO().getMemoryType(), |
1124 | C&: MF->getFunction().getContext()), |
1125 | AddrSpace: MI->getMMO().getAddrSpace()); |
1126 | } |
1127 | |
1128 | static unsigned getIndexedOpc(unsigned LdStOpc) { |
1129 | switch (LdStOpc) { |
1130 | case TargetOpcode::G_LOAD: |
1131 | return TargetOpcode::G_INDEXED_LOAD; |
1132 | case TargetOpcode::G_STORE: |
1133 | return TargetOpcode::G_INDEXED_STORE; |
1134 | case TargetOpcode::G_ZEXTLOAD: |
1135 | return TargetOpcode::G_INDEXED_ZEXTLOAD; |
1136 | case TargetOpcode::G_SEXTLOAD: |
1137 | return TargetOpcode::G_INDEXED_SEXTLOAD; |
1138 | default: |
1139 | llvm_unreachable("Unexpected opcode" ); |
1140 | } |
1141 | } |
1142 | |
1143 | bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const { |
1144 | // Check for legality. |
1145 | LLT PtrTy = MRI.getType(Reg: LdSt.getPointerReg()); |
1146 | LLT Ty = MRI.getType(Reg: LdSt.getReg(Idx: 0)); |
1147 | LLT MemTy = LdSt.getMMO().getMemoryType(); |
1148 | SmallVector<LegalityQuery::MemDesc, 2> MemDescrs( |
1149 | {{MemTy, MemTy.getSizeInBits().getKnownMinValue(), |
1150 | AtomicOrdering::NotAtomic}}); |
1151 | unsigned IndexedOpc = getIndexedOpc(LdStOpc: LdSt.getOpcode()); |
1152 | SmallVector<LLT> OpTys; |
1153 | if (IndexedOpc == TargetOpcode::G_INDEXED_STORE) |
1154 | OpTys = {PtrTy, Ty, Ty}; |
1155 | else |
1156 | OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD |
1157 | |
1158 | LegalityQuery Q(IndexedOpc, OpTys, MemDescrs); |
1159 | return isLegal(Query: Q); |
1160 | } |
1161 | |
1162 | static cl::opt<unsigned> PostIndexUseThreshold( |
1163 | "post-index-use-threshold" , cl::Hidden, cl::init(Val: 32), |
1164 | cl::desc("Number of uses of a base pointer to check before it is no longer " |
1165 | "considered for post-indexing." )); |
1166 | |
1167 | bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr, |
1168 | Register &Base, Register &Offset, |
1169 | bool &RematOffset) { |
1170 | // We're looking for the following pattern, for either load or store: |
1171 | // %baseptr:_(p0) = ... |
1172 | // G_STORE %val(s64), %baseptr(p0) |
1173 | // %offset:_(s64) = G_CONSTANT i64 -256 |
1174 | // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64) |
1175 | const auto &TLI = getTargetLowering(); |
1176 | |
1177 | Register Ptr = LdSt.getPointerReg(); |
1178 | // If the store is the only use, don't bother. |
1179 | if (MRI.hasOneNonDBGUse(RegNo: Ptr)) |
1180 | return false; |
1181 | |
1182 | if (!isIndexedLoadStoreLegal(LdSt)) |
1183 | return false; |
1184 | |
1185 | if (getOpcodeDef(Opcode: TargetOpcode::G_FRAME_INDEX, Reg: Ptr, MRI)) |
1186 | return false; |
1187 | |
1188 | MachineInstr *StoredValDef = getDefIgnoringCopies(Reg: LdSt.getReg(Idx: 0), MRI); |
1189 | auto *PtrDef = MRI.getVRegDef(Reg: Ptr); |
1190 | |
1191 | unsigned NumUsesChecked = 0; |
1192 | for (auto &Use : MRI.use_nodbg_instructions(Reg: Ptr)) { |
1193 | if (++NumUsesChecked > PostIndexUseThreshold) |
1194 | return false; // Try to avoid exploding compile time. |
1195 | |
1196 | auto *PtrAdd = dyn_cast<GPtrAdd>(Val: &Use); |
1197 | // The use itself might be dead. This can happen during combines if DCE |
1198 | // hasn't had a chance to run yet. Don't allow it to form an indexed op. |
1199 | if (!PtrAdd || MRI.use_nodbg_empty(RegNo: PtrAdd->getReg(Idx: 0))) |
1200 | continue; |
1201 | |
1202 | // Check the user of this isn't the store, otherwise we'd be generate a |
1203 | // indexed store defining its own use. |
1204 | if (StoredValDef == &Use) |
1205 | continue; |
1206 | |
1207 | Offset = PtrAdd->getOffsetReg(); |
1208 | if (!ForceLegalIndexing && |
1209 | !TLI.isIndexingLegal(MI&: LdSt, Base: PtrAdd->getBaseReg(), Offset, |
1210 | /*IsPre*/ false, MRI)) |
1211 | continue; |
1212 | |
1213 | // Make sure the offset calculation is before the potentially indexed op. |
1214 | MachineInstr *OffsetDef = MRI.getVRegDef(Reg: Offset); |
1215 | RematOffset = false; |
1216 | if (!dominates(DefMI: *OffsetDef, UseMI: LdSt)) { |
1217 | // If the offset however is just a G_CONSTANT, we can always just |
1218 | // rematerialize it where we need it. |
1219 | if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT) |
1220 | continue; |
1221 | RematOffset = true; |
1222 | } |
1223 | |
1224 | for (auto &BasePtrUse : MRI.use_nodbg_instructions(Reg: PtrAdd->getBaseReg())) { |
1225 | if (&BasePtrUse == PtrDef) |
1226 | continue; |
1227 | |
1228 | // If the user is a later load/store that can be post-indexed, then don't |
1229 | // combine this one. |
1230 | auto *BasePtrLdSt = dyn_cast<GLoadStore>(Val: &BasePtrUse); |
1231 | if (BasePtrLdSt && BasePtrLdSt != &LdSt && |
1232 | dominates(DefMI: LdSt, UseMI: *BasePtrLdSt) && |
1233 | isIndexedLoadStoreLegal(LdSt&: *BasePtrLdSt)) |
1234 | return false; |
1235 | |
1236 | // Now we're looking for the key G_PTR_ADD instruction, which contains |
1237 | // the offset add that we want to fold. |
1238 | if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(Val: &BasePtrUse)) { |
1239 | Register PtrAddDefReg = BasePtrUseDef->getReg(Idx: 0); |
1240 | for (auto &BaseUseUse : MRI.use_nodbg_instructions(Reg: PtrAddDefReg)) { |
1241 | // If the use is in a different block, then we may produce worse code |
1242 | // due to the extra register pressure. |
1243 | if (BaseUseUse.getParent() != LdSt.getParent()) |
1244 | return false; |
1245 | |
1246 | if (auto *UseUseLdSt = dyn_cast<GLoadStore>(Val: &BaseUseUse)) |
1247 | if (canFoldInAddressingMode(MI: UseUseLdSt, TLI, MRI)) |
1248 | return false; |
1249 | } |
1250 | if (!dominates(DefMI: LdSt, UseMI: BasePtrUse)) |
1251 | return false; // All use must be dominated by the load/store. |
1252 | } |
1253 | } |
1254 | |
1255 | Addr = PtrAdd->getReg(Idx: 0); |
1256 | Base = PtrAdd->getBaseReg(); |
1257 | return true; |
1258 | } |
1259 | |
1260 | return false; |
1261 | } |
1262 | |
1263 | bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr, |
1264 | Register &Base, Register &Offset) { |
1265 | auto &MF = *LdSt.getParent()->getParent(); |
1266 | const auto &TLI = *MF.getSubtarget().getTargetLowering(); |
1267 | |
1268 | Addr = LdSt.getPointerReg(); |
1269 | if (!mi_match(R: Addr, MRI, P: m_GPtrAdd(L: m_Reg(R&: Base), R: m_Reg(R&: Offset))) || |
1270 | MRI.hasOneNonDBGUse(RegNo: Addr)) |
1271 | return false; |
1272 | |
1273 | if (!ForceLegalIndexing && |
1274 | !TLI.isIndexingLegal(MI&: LdSt, Base, Offset, /*IsPre*/ true, MRI)) |
1275 | return false; |
1276 | |
1277 | if (!isIndexedLoadStoreLegal(LdSt)) |
1278 | return false; |
1279 | |
1280 | MachineInstr *BaseDef = getDefIgnoringCopies(Reg: Base, MRI); |
1281 | if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX) |
1282 | return false; |
1283 | |
1284 | if (auto *St = dyn_cast<GStore>(Val: &LdSt)) { |
1285 | // Would require a copy. |
1286 | if (Base == St->getValueReg()) |
1287 | return false; |
1288 | |
1289 | // We're expecting one use of Addr in MI, but it could also be the |
1290 | // value stored, which isn't actually dominated by the instruction. |
1291 | if (St->getValueReg() == Addr) |
1292 | return false; |
1293 | } |
1294 | |
1295 | // Avoid increasing cross-block register pressure. |
1296 | for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) |
1297 | if (AddrUse.getParent() != LdSt.getParent()) |
1298 | return false; |
1299 | |
1300 | // FIXME: check whether all uses of the base pointer are constant PtrAdds. |
1301 | // That might allow us to end base's liveness here by adjusting the constant. |
1302 | bool RealUse = false; |
1303 | for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) { |
1304 | if (!dominates(DefMI: LdSt, UseMI: AddrUse)) |
1305 | return false; // All use must be dominated by the load/store. |
1306 | |
1307 | // If Ptr may be folded in addressing mode of other use, then it's |
1308 | // not profitable to do this transformation. |
1309 | if (auto *UseLdSt = dyn_cast<GLoadStore>(Val: &AddrUse)) { |
1310 | if (!canFoldInAddressingMode(MI: UseLdSt, TLI, MRI)) |
1311 | RealUse = true; |
1312 | } else { |
1313 | RealUse = true; |
1314 | } |
1315 | } |
1316 | return RealUse; |
1317 | } |
1318 | |
1319 | bool CombinerHelper::(MachineInstr &MI, |
1320 | BuildFnTy &MatchInfo) { |
1321 | assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); |
1322 | |
1323 | // Check if there is a load that defines the vector being extracted from. |
1324 | auto *LoadMI = getOpcodeDef<GLoad>(Reg: MI.getOperand(i: 1).getReg(), MRI); |
1325 | if (!LoadMI) |
1326 | return false; |
1327 | |
1328 | Register Vector = MI.getOperand(i: 1).getReg(); |
1329 | LLT VecEltTy = MRI.getType(Reg: Vector).getElementType(); |
1330 | |
1331 | assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy); |
1332 | |
1333 | // Checking whether we should reduce the load width. |
1334 | if (!MRI.hasOneNonDBGUse(RegNo: Vector)) |
1335 | return false; |
1336 | |
1337 | // Check if the defining load is simple. |
1338 | if (!LoadMI->isSimple()) |
1339 | return false; |
1340 | |
1341 | // If the vector element type is not a multiple of a byte then we are unable |
1342 | // to correctly compute an address to load only the extracted element as a |
1343 | // scalar. |
1344 | if (!VecEltTy.isByteSized()) |
1345 | return false; |
1346 | |
1347 | // Check for load fold barriers between the extraction and the load. |
1348 | if (MI.getParent() != LoadMI->getParent()) |
1349 | return false; |
1350 | const unsigned MaxIter = 20; |
1351 | unsigned Iter = 0; |
1352 | for (auto II = LoadMI->getIterator(), IE = MI.getIterator(); II != IE; ++II) { |
1353 | if (II->isLoadFoldBarrier()) |
1354 | return false; |
1355 | if (Iter++ == MaxIter) |
1356 | return false; |
1357 | } |
1358 | |
1359 | // Check if the new load that we are going to create is legal |
1360 | // if we are in the post-legalization phase. |
1361 | MachineMemOperand MMO = LoadMI->getMMO(); |
1362 | Align Alignment = MMO.getAlign(); |
1363 | MachinePointerInfo PtrInfo; |
1364 | uint64_t Offset; |
1365 | |
1366 | // Finding the appropriate PtrInfo if offset is a known constant. |
1367 | // This is required to create the memory operand for the narrowed load. |
1368 | // This machine memory operand object helps us infer about legality |
1369 | // before we proceed to combine the instruction. |
1370 | if (auto CVal = getIConstantVRegVal(VReg: Vector, MRI)) { |
1371 | int Elt = CVal->getZExtValue(); |
1372 | // FIXME: should be (ABI size)*Elt. |
1373 | Offset = VecEltTy.getSizeInBits() * Elt / 8; |
1374 | PtrInfo = MMO.getPointerInfo().getWithOffset(O: Offset); |
1375 | } else { |
1376 | // Discard the pointer info except the address space because the memory |
1377 | // operand can't represent this new access since the offset is variable. |
1378 | Offset = VecEltTy.getSizeInBits() / 8; |
1379 | PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace()); |
1380 | } |
1381 | |
1382 | Alignment = commonAlignment(A: Alignment, Offset); |
1383 | |
1384 | Register VecPtr = LoadMI->getPointerReg(); |
1385 | LLT PtrTy = MRI.getType(Reg: VecPtr); |
1386 | |
1387 | MachineFunction &MF = *MI.getMF(); |
1388 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: VecEltTy); |
1389 | |
1390 | LegalityQuery::MemDesc MMDesc(*NewMMO); |
1391 | |
1392 | LegalityQuery Q = {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}}; |
1393 | |
1394 | if (!isLegalOrBeforeLegalizer(Query: Q)) |
1395 | return false; |
1396 | |
1397 | // Load must be allowed and fast on the target. |
1398 | LLVMContext &C = MF.getFunction().getContext(); |
1399 | auto &DL = MF.getDataLayout(); |
1400 | unsigned Fast = 0; |
1401 | if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty: VecEltTy, MMO: *NewMMO, |
1402 | Fast: &Fast) || |
1403 | !Fast) |
1404 | return false; |
1405 | |
1406 | Register Result = MI.getOperand(i: 0).getReg(); |
1407 | Register Index = MI.getOperand(i: 2).getReg(); |
1408 | |
1409 | MatchInfo = [=](MachineIRBuilder &B) { |
1410 | GISelObserverWrapper DummyObserver; |
1411 | LegalizerHelper Helper(B.getMF(), DummyObserver, B); |
1412 | //// Get pointer to the vector element. |
1413 | Register finalPtr = Helper.getVectorElementPointer( |
1414 | VecPtr: LoadMI->getPointerReg(), VecTy: MRI.getType(Reg: LoadMI->getOperand(i: 0).getReg()), |
1415 | Index); |
1416 | // New G_LOAD instruction. |
1417 | B.buildLoad(Res: Result, Addr: finalPtr, PtrInfo, Alignment); |
1418 | // Remove original GLOAD instruction. |
1419 | LoadMI->eraseFromParent(); |
1420 | }; |
1421 | |
1422 | return true; |
1423 | } |
1424 | |
1425 | bool CombinerHelper::matchCombineIndexedLoadStore( |
1426 | MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) { |
1427 | auto &LdSt = cast<GLoadStore>(Val&: MI); |
1428 | |
1429 | if (LdSt.isAtomic()) |
1430 | return false; |
1431 | |
1432 | MatchInfo.IsPre = findPreIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base, |
1433 | Offset&: MatchInfo.Offset); |
1434 | if (!MatchInfo.IsPre && |
1435 | !findPostIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base, |
1436 | Offset&: MatchInfo.Offset, RematOffset&: MatchInfo.RematOffset)) |
1437 | return false; |
1438 | |
1439 | return true; |
1440 | } |
1441 | |
1442 | void CombinerHelper::applyCombineIndexedLoadStore( |
1443 | MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) { |
1444 | MachineInstr &AddrDef = *MRI.getUniqueVRegDef(Reg: MatchInfo.Addr); |
1445 | unsigned Opcode = MI.getOpcode(); |
1446 | bool IsStore = Opcode == TargetOpcode::G_STORE; |
1447 | unsigned NewOpcode = getIndexedOpc(LdStOpc: Opcode); |
1448 | |
1449 | // If the offset constant didn't happen to dominate the load/store, we can |
1450 | // just clone it as needed. |
1451 | if (MatchInfo.RematOffset) { |
1452 | auto *OldCst = MRI.getVRegDef(Reg: MatchInfo.Offset); |
1453 | auto NewCst = Builder.buildConstant(Res: MRI.getType(Reg: MatchInfo.Offset), |
1454 | Val: *OldCst->getOperand(i: 1).getCImm()); |
1455 | MatchInfo.Offset = NewCst.getReg(Idx: 0); |
1456 | } |
1457 | |
1458 | auto MIB = Builder.buildInstr(Opcode: NewOpcode); |
1459 | if (IsStore) { |
1460 | MIB.addDef(RegNo: MatchInfo.Addr); |
1461 | MIB.addUse(RegNo: MI.getOperand(i: 0).getReg()); |
1462 | } else { |
1463 | MIB.addDef(RegNo: MI.getOperand(i: 0).getReg()); |
1464 | MIB.addDef(RegNo: MatchInfo.Addr); |
1465 | } |
1466 | |
1467 | MIB.addUse(RegNo: MatchInfo.Base); |
1468 | MIB.addUse(RegNo: MatchInfo.Offset); |
1469 | MIB.addImm(Val: MatchInfo.IsPre); |
1470 | MIB->cloneMemRefs(MF&: *MI.getMF(), MI); |
1471 | MI.eraseFromParent(); |
1472 | AddrDef.eraseFromParent(); |
1473 | |
1474 | LLVM_DEBUG(dbgs() << " Combinined to indexed operation" ); |
1475 | } |
1476 | |
1477 | bool CombinerHelper::matchCombineDivRem(MachineInstr &MI, |
1478 | MachineInstr *&OtherMI) { |
1479 | unsigned Opcode = MI.getOpcode(); |
1480 | bool IsDiv, IsSigned; |
1481 | |
1482 | switch (Opcode) { |
1483 | default: |
1484 | llvm_unreachable("Unexpected opcode!" ); |
1485 | case TargetOpcode::G_SDIV: |
1486 | case TargetOpcode::G_UDIV: { |
1487 | IsDiv = true; |
1488 | IsSigned = Opcode == TargetOpcode::G_SDIV; |
1489 | break; |
1490 | } |
1491 | case TargetOpcode::G_SREM: |
1492 | case TargetOpcode::G_UREM: { |
1493 | IsDiv = false; |
1494 | IsSigned = Opcode == TargetOpcode::G_SREM; |
1495 | break; |
1496 | } |
1497 | } |
1498 | |
1499 | Register Src1 = MI.getOperand(i: 1).getReg(); |
1500 | unsigned DivOpcode, RemOpcode, DivremOpcode; |
1501 | if (IsSigned) { |
1502 | DivOpcode = TargetOpcode::G_SDIV; |
1503 | RemOpcode = TargetOpcode::G_SREM; |
1504 | DivremOpcode = TargetOpcode::G_SDIVREM; |
1505 | } else { |
1506 | DivOpcode = TargetOpcode::G_UDIV; |
1507 | RemOpcode = TargetOpcode::G_UREM; |
1508 | DivremOpcode = TargetOpcode::G_UDIVREM; |
1509 | } |
1510 | |
1511 | if (!isLegalOrBeforeLegalizer(Query: {DivremOpcode, {MRI.getType(Reg: Src1)}})) |
1512 | return false; |
1513 | |
1514 | // Combine: |
1515 | // %div:_ = G_[SU]DIV %src1:_, %src2:_ |
1516 | // %rem:_ = G_[SU]REM %src1:_, %src2:_ |
1517 | // into: |
1518 | // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ |
1519 | |
1520 | // Combine: |
1521 | // %rem:_ = G_[SU]REM %src1:_, %src2:_ |
1522 | // %div:_ = G_[SU]DIV %src1:_, %src2:_ |
1523 | // into: |
1524 | // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ |
1525 | |
1526 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: Src1)) { |
1527 | if (MI.getParent() == UseMI.getParent() && |
1528 | ((IsDiv && UseMI.getOpcode() == RemOpcode) || |
1529 | (!IsDiv && UseMI.getOpcode() == DivOpcode)) && |
1530 | matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: UseMI.getOperand(i: 2)) && |
1531 | matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: UseMI.getOperand(i: 1))) { |
1532 | OtherMI = &UseMI; |
1533 | return true; |
1534 | } |
1535 | } |
1536 | |
1537 | return false; |
1538 | } |
1539 | |
1540 | void CombinerHelper::applyCombineDivRem(MachineInstr &MI, |
1541 | MachineInstr *&OtherMI) { |
1542 | unsigned Opcode = MI.getOpcode(); |
1543 | assert(OtherMI && "OtherMI shouldn't be empty." ); |
1544 | |
1545 | Register DestDivReg, DestRemReg; |
1546 | if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) { |
1547 | DestDivReg = MI.getOperand(i: 0).getReg(); |
1548 | DestRemReg = OtherMI->getOperand(i: 0).getReg(); |
1549 | } else { |
1550 | DestDivReg = OtherMI->getOperand(i: 0).getReg(); |
1551 | DestRemReg = MI.getOperand(i: 0).getReg(); |
1552 | } |
1553 | |
1554 | bool IsSigned = |
1555 | Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM; |
1556 | |
1557 | // Check which instruction is first in the block so we don't break def-use |
1558 | // deps by "moving" the instruction incorrectly. Also keep track of which |
1559 | // instruction is first so we pick it's operands, avoiding use-before-def |
1560 | // bugs. |
1561 | MachineInstr *FirstInst = dominates(DefMI: MI, UseMI: *OtherMI) ? &MI : OtherMI; |
1562 | Builder.setInstrAndDebugLoc(*FirstInst); |
1563 | |
1564 | Builder.buildInstr(Opc: IsSigned ? TargetOpcode::G_SDIVREM |
1565 | : TargetOpcode::G_UDIVREM, |
1566 | DstOps: {DestDivReg, DestRemReg}, |
1567 | SrcOps: { FirstInst->getOperand(i: 1), FirstInst->getOperand(i: 2) }); |
1568 | MI.eraseFromParent(); |
1569 | OtherMI->eraseFromParent(); |
1570 | } |
1571 | |
1572 | bool CombinerHelper::matchOptBrCondByInvertingCond(MachineInstr &MI, |
1573 | MachineInstr *&BrCond) { |
1574 | assert(MI.getOpcode() == TargetOpcode::G_BR); |
1575 | |
1576 | // Try to match the following: |
1577 | // bb1: |
1578 | // G_BRCOND %c1, %bb2 |
1579 | // G_BR %bb3 |
1580 | // bb2: |
1581 | // ... |
1582 | // bb3: |
1583 | |
1584 | // The above pattern does not have a fall through to the successor bb2, always |
1585 | // resulting in a branch no matter which path is taken. Here we try to find |
1586 | // and replace that pattern with conditional branch to bb3 and otherwise |
1587 | // fallthrough to bb2. This is generally better for branch predictors. |
1588 | |
1589 | MachineBasicBlock *MBB = MI.getParent(); |
1590 | MachineBasicBlock::iterator BrIt(MI); |
1591 | if (BrIt == MBB->begin()) |
1592 | return false; |
1593 | assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator" ); |
1594 | |
1595 | BrCond = &*std::prev(x: BrIt); |
1596 | if (BrCond->getOpcode() != TargetOpcode::G_BRCOND) |
1597 | return false; |
1598 | |
1599 | // Check that the next block is the conditional branch target. Also make sure |
1600 | // that it isn't the same as the G_BR's target (otherwise, this will loop.) |
1601 | MachineBasicBlock *BrCondTarget = BrCond->getOperand(i: 1).getMBB(); |
1602 | return BrCondTarget != MI.getOperand(i: 0).getMBB() && |
1603 | MBB->isLayoutSuccessor(MBB: BrCondTarget); |
1604 | } |
1605 | |
1606 | void CombinerHelper::applyOptBrCondByInvertingCond(MachineInstr &MI, |
1607 | MachineInstr *&BrCond) { |
1608 | MachineBasicBlock *BrTarget = MI.getOperand(i: 0).getMBB(); |
1609 | Builder.setInstrAndDebugLoc(*BrCond); |
1610 | LLT Ty = MRI.getType(Reg: BrCond->getOperand(i: 0).getReg()); |
1611 | // FIXME: Does int/fp matter for this? If so, we might need to restrict |
1612 | // this to i1 only since we might not know for sure what kind of |
1613 | // compare generated the condition value. |
1614 | auto True = Builder.buildConstant( |
1615 | Res: Ty, Val: getICmpTrueVal(TLI: getTargetLowering(), IsVector: false, IsFP: false)); |
1616 | auto Xor = Builder.buildXor(Dst: Ty, Src0: BrCond->getOperand(i: 0), Src1: True); |
1617 | |
1618 | auto *FallthroughBB = BrCond->getOperand(i: 1).getMBB(); |
1619 | Observer.changingInstr(MI); |
1620 | MI.getOperand(i: 0).setMBB(FallthroughBB); |
1621 | Observer.changedInstr(MI); |
1622 | |
1623 | // Change the conditional branch to use the inverted condition and |
1624 | // new target block. |
1625 | Observer.changingInstr(MI&: *BrCond); |
1626 | BrCond->getOperand(i: 0).setReg(Xor.getReg(Idx: 0)); |
1627 | BrCond->getOperand(i: 1).setMBB(BrTarget); |
1628 | Observer.changedInstr(MI&: *BrCond); |
1629 | } |
1630 | |
1631 | |
1632 | bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) { |
1633 | MachineIRBuilder HelperBuilder(MI); |
1634 | GISelObserverWrapper DummyObserver; |
1635 | LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); |
1636 | return Helper.lowerMemcpyInline(MI) == |
1637 | LegalizerHelper::LegalizeResult::Legalized; |
1638 | } |
1639 | |
1640 | bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI, unsigned MaxLen) { |
1641 | MachineIRBuilder HelperBuilder(MI); |
1642 | GISelObserverWrapper DummyObserver; |
1643 | LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); |
1644 | return Helper.lowerMemCpyFamily(MI, MaxLen) == |
1645 | LegalizerHelper::LegalizeResult::Legalized; |
1646 | } |
1647 | |
1648 | static APFloat constantFoldFpUnary(const MachineInstr &MI, |
1649 | const MachineRegisterInfo &MRI, |
1650 | const APFloat &Val) { |
1651 | APFloat Result(Val); |
1652 | switch (MI.getOpcode()) { |
1653 | default: |
1654 | llvm_unreachable("Unexpected opcode!" ); |
1655 | case TargetOpcode::G_FNEG: { |
1656 | Result.changeSign(); |
1657 | return Result; |
1658 | } |
1659 | case TargetOpcode::G_FABS: { |
1660 | Result.clearSign(); |
1661 | return Result; |
1662 | } |
1663 | case TargetOpcode::G_FPTRUNC: { |
1664 | bool Unused; |
1665 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
1666 | Result.convert(ToSemantics: getFltSemanticForLLT(Ty: DstTy), RM: APFloat::rmNearestTiesToEven, |
1667 | losesInfo: &Unused); |
1668 | return Result; |
1669 | } |
1670 | case TargetOpcode::G_FSQRT: { |
1671 | bool Unused; |
1672 | Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, |
1673 | losesInfo: &Unused); |
1674 | Result = APFloat(sqrt(x: Result.convertToDouble())); |
1675 | break; |
1676 | } |
1677 | case TargetOpcode::G_FLOG2: { |
1678 | bool Unused; |
1679 | Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, |
1680 | losesInfo: &Unused); |
1681 | Result = APFloat(log2(x: Result.convertToDouble())); |
1682 | break; |
1683 | } |
1684 | } |
1685 | // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise, |
1686 | // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and |
1687 | // `G_FLOG2` reach here. |
1688 | bool Unused; |
1689 | Result.convert(ToSemantics: Val.getSemantics(), RM: APFloat::rmNearestTiesToEven, losesInfo: &Unused); |
1690 | return Result; |
1691 | } |
1692 | |
1693 | void CombinerHelper::applyCombineConstantFoldFpUnary(MachineInstr &MI, |
1694 | const ConstantFP *Cst) { |
1695 | APFloat Folded = constantFoldFpUnary(MI, MRI, Val: Cst->getValue()); |
1696 | const ConstantFP *NewCst = ConstantFP::get(Context&: Builder.getContext(), V: Folded); |
1697 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: *NewCst); |
1698 | MI.eraseFromParent(); |
1699 | } |
1700 | |
1701 | bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI, |
1702 | PtrAddChain &MatchInfo) { |
1703 | // We're trying to match the following pattern: |
1704 | // %t1 = G_PTR_ADD %base, G_CONSTANT imm1 |
1705 | // %root = G_PTR_ADD %t1, G_CONSTANT imm2 |
1706 | // --> |
1707 | // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2) |
1708 | |
1709 | if (MI.getOpcode() != TargetOpcode::G_PTR_ADD) |
1710 | return false; |
1711 | |
1712 | Register Add2 = MI.getOperand(i: 1).getReg(); |
1713 | Register Imm1 = MI.getOperand(i: 2).getReg(); |
1714 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI); |
1715 | if (!MaybeImmVal) |
1716 | return false; |
1717 | |
1718 | MachineInstr *Add2Def = MRI.getVRegDef(Reg: Add2); |
1719 | if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD) |
1720 | return false; |
1721 | |
1722 | Register Base = Add2Def->getOperand(i: 1).getReg(); |
1723 | Register Imm2 = Add2Def->getOperand(i: 2).getReg(); |
1724 | auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI); |
1725 | if (!MaybeImm2Val) |
1726 | return false; |
1727 | |
1728 | // Check if the new combined immediate forms an illegal addressing mode. |
1729 | // Do not combine if it was legal before but would get illegal. |
1730 | // To do so, we need to find a load/store user of the pointer to get |
1731 | // the access type. |
1732 | Type *AccessTy = nullptr; |
1733 | auto &MF = *MI.getMF(); |
1734 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: MI.getOperand(i: 0).getReg())) { |
1735 | if (auto *LdSt = dyn_cast<GLoadStore>(Val: &UseMI)) { |
1736 | AccessTy = getTypeForLLT(Ty: MRI.getType(Reg: LdSt->getReg(Idx: 0)), |
1737 | C&: MF.getFunction().getContext()); |
1738 | break; |
1739 | } |
1740 | } |
1741 | TargetLoweringBase::AddrMode AMNew; |
1742 | APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value; |
1743 | AMNew.BaseOffs = CombinedImm.getSExtValue(); |
1744 | if (AccessTy) { |
1745 | AMNew.HasBaseReg = true; |
1746 | TargetLoweringBase::AddrMode AMOld; |
1747 | AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue(); |
1748 | AMOld.HasBaseReg = true; |
1749 | unsigned AS = MRI.getType(Reg: Add2).getAddressSpace(); |
1750 | const auto &TLI = *MF.getSubtarget().getTargetLowering(); |
1751 | if (TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMOld, Ty: AccessTy, AddrSpace: AS) && |
1752 | !TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMNew, Ty: AccessTy, AddrSpace: AS)) |
1753 | return false; |
1754 | } |
1755 | |
1756 | // Pass the combined immediate to the apply function. |
1757 | MatchInfo.Imm = AMNew.BaseOffs; |
1758 | MatchInfo.Base = Base; |
1759 | MatchInfo.Bank = getRegBank(Reg: Imm2); |
1760 | return true; |
1761 | } |
1762 | |
1763 | void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI, |
1764 | PtrAddChain &MatchInfo) { |
1765 | assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD" ); |
1766 | MachineIRBuilder MIB(MI); |
1767 | LLT OffsetTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
1768 | auto NewOffset = MIB.buildConstant(Res: OffsetTy, Val: MatchInfo.Imm); |
1769 | setRegBank(Reg: NewOffset.getReg(Idx: 0), RegBank: MatchInfo.Bank); |
1770 | Observer.changingInstr(MI); |
1771 | MI.getOperand(i: 1).setReg(MatchInfo.Base); |
1772 | MI.getOperand(i: 2).setReg(NewOffset.getReg(Idx: 0)); |
1773 | Observer.changedInstr(MI); |
1774 | } |
1775 | |
1776 | bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI, |
1777 | RegisterImmPair &MatchInfo) { |
1778 | // We're trying to match the following pattern with any of |
1779 | // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions: |
1780 | // %t1 = SHIFT %base, G_CONSTANT imm1 |
1781 | // %root = SHIFT %t1, G_CONSTANT imm2 |
1782 | // --> |
1783 | // %root = SHIFT %base, G_CONSTANT (imm1 + imm2) |
1784 | |
1785 | unsigned Opcode = MI.getOpcode(); |
1786 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
1787 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || |
1788 | Opcode == TargetOpcode::G_USHLSAT) && |
1789 | "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT" ); |
1790 | |
1791 | Register Shl2 = MI.getOperand(i: 1).getReg(); |
1792 | Register Imm1 = MI.getOperand(i: 2).getReg(); |
1793 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI); |
1794 | if (!MaybeImmVal) |
1795 | return false; |
1796 | |
1797 | MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Reg: Shl2); |
1798 | if (Shl2Def->getOpcode() != Opcode) |
1799 | return false; |
1800 | |
1801 | Register Base = Shl2Def->getOperand(i: 1).getReg(); |
1802 | Register Imm2 = Shl2Def->getOperand(i: 2).getReg(); |
1803 | auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI); |
1804 | if (!MaybeImm2Val) |
1805 | return false; |
1806 | |
1807 | // Pass the combined immediate to the apply function. |
1808 | MatchInfo.Imm = |
1809 | (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue(); |
1810 | MatchInfo.Reg = Base; |
1811 | |
1812 | // There is no simple replacement for a saturating unsigned left shift that |
1813 | // exceeds the scalar size. |
1814 | if (Opcode == TargetOpcode::G_USHLSAT && |
1815 | MatchInfo.Imm >= MRI.getType(Reg: Shl2).getScalarSizeInBits()) |
1816 | return false; |
1817 | |
1818 | return true; |
1819 | } |
1820 | |
1821 | void CombinerHelper::applyShiftImmedChain(MachineInstr &MI, |
1822 | RegisterImmPair &MatchInfo) { |
1823 | unsigned Opcode = MI.getOpcode(); |
1824 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
1825 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || |
1826 | Opcode == TargetOpcode::G_USHLSAT) && |
1827 | "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT" ); |
1828 | |
1829 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 1).getReg()); |
1830 | unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits(); |
1831 | auto Imm = MatchInfo.Imm; |
1832 | |
1833 | if (Imm >= ScalarSizeInBits) { |
1834 | // Any logical shift that exceeds scalar size will produce zero. |
1835 | if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) { |
1836 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: 0); |
1837 | MI.eraseFromParent(); |
1838 | return; |
1839 | } |
1840 | // Arithmetic shift and saturating signed left shift have no effect beyond |
1841 | // scalar size. |
1842 | Imm = ScalarSizeInBits - 1; |
1843 | } |
1844 | |
1845 | LLT ImmTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
1846 | Register NewImm = Builder.buildConstant(Res: ImmTy, Val: Imm).getReg(Idx: 0); |
1847 | Observer.changingInstr(MI); |
1848 | MI.getOperand(i: 1).setReg(MatchInfo.Reg); |
1849 | MI.getOperand(i: 2).setReg(NewImm); |
1850 | Observer.changedInstr(MI); |
1851 | } |
1852 | |
1853 | bool CombinerHelper::matchShiftOfShiftedLogic(MachineInstr &MI, |
1854 | ShiftOfShiftedLogic &MatchInfo) { |
1855 | // We're trying to match the following pattern with any of |
1856 | // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination |
1857 | // with any of G_AND/G_OR/G_XOR logic instructions. |
1858 | // %t1 = SHIFT %X, G_CONSTANT C0 |
1859 | // %t2 = LOGIC %t1, %Y |
1860 | // %root = SHIFT %t2, G_CONSTANT C1 |
1861 | // --> |
1862 | // %t3 = SHIFT %X, G_CONSTANT (C0+C1) |
1863 | // %t4 = SHIFT %Y, G_CONSTANT C1 |
1864 | // %root = LOGIC %t3, %t4 |
1865 | unsigned ShiftOpcode = MI.getOpcode(); |
1866 | assert((ShiftOpcode == TargetOpcode::G_SHL || |
1867 | ShiftOpcode == TargetOpcode::G_ASHR || |
1868 | ShiftOpcode == TargetOpcode::G_LSHR || |
1869 | ShiftOpcode == TargetOpcode::G_USHLSAT || |
1870 | ShiftOpcode == TargetOpcode::G_SSHLSAT) && |
1871 | "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT" ); |
1872 | |
1873 | // Match a one-use bitwise logic op. |
1874 | Register LogicDest = MI.getOperand(i: 1).getReg(); |
1875 | if (!MRI.hasOneNonDBGUse(RegNo: LogicDest)) |
1876 | return false; |
1877 | |
1878 | MachineInstr *LogicMI = MRI.getUniqueVRegDef(Reg: LogicDest); |
1879 | unsigned LogicOpcode = LogicMI->getOpcode(); |
1880 | if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR && |
1881 | LogicOpcode != TargetOpcode::G_XOR) |
1882 | return false; |
1883 | |
1884 | // Find a matching one-use shift by constant. |
1885 | const Register C1 = MI.getOperand(i: 2).getReg(); |
1886 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: C1, MRI); |
1887 | if (!MaybeImmVal || MaybeImmVal->Value == 0) |
1888 | return false; |
1889 | |
1890 | const uint64_t C1Val = MaybeImmVal->Value.getZExtValue(); |
1891 | |
1892 | auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) { |
1893 | // Shift should match previous one and should be a one-use. |
1894 | if (MI->getOpcode() != ShiftOpcode || |
1895 | !MRI.hasOneNonDBGUse(RegNo: MI->getOperand(i: 0).getReg())) |
1896 | return false; |
1897 | |
1898 | // Must be a constant. |
1899 | auto MaybeImmVal = |
1900 | getIConstantVRegValWithLookThrough(VReg: MI->getOperand(i: 2).getReg(), MRI); |
1901 | if (!MaybeImmVal) |
1902 | return false; |
1903 | |
1904 | ShiftVal = MaybeImmVal->Value.getSExtValue(); |
1905 | return true; |
1906 | }; |
1907 | |
1908 | // Logic ops are commutative, so check each operand for a match. |
1909 | Register LogicMIReg1 = LogicMI->getOperand(i: 1).getReg(); |
1910 | MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(Reg: LogicMIReg1); |
1911 | Register LogicMIReg2 = LogicMI->getOperand(i: 2).getReg(); |
1912 | MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(Reg: LogicMIReg2); |
1913 | uint64_t C0Val; |
1914 | |
1915 | if (matchFirstShift(LogicMIOp1, C0Val)) { |
1916 | MatchInfo.LogicNonShiftReg = LogicMIReg2; |
1917 | MatchInfo.Shift2 = LogicMIOp1; |
1918 | } else if (matchFirstShift(LogicMIOp2, C0Val)) { |
1919 | MatchInfo.LogicNonShiftReg = LogicMIReg1; |
1920 | MatchInfo.Shift2 = LogicMIOp2; |
1921 | } else |
1922 | return false; |
1923 | |
1924 | MatchInfo.ValSum = C0Val + C1Val; |
1925 | |
1926 | // The fold is not valid if the sum of the shift values exceeds bitwidth. |
1927 | if (MatchInfo.ValSum >= MRI.getType(Reg: LogicDest).getScalarSizeInBits()) |
1928 | return false; |
1929 | |
1930 | MatchInfo.Logic = LogicMI; |
1931 | return true; |
1932 | } |
1933 | |
1934 | void CombinerHelper::applyShiftOfShiftedLogic(MachineInstr &MI, |
1935 | ShiftOfShiftedLogic &MatchInfo) { |
1936 | unsigned Opcode = MI.getOpcode(); |
1937 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
1938 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT || |
1939 | Opcode == TargetOpcode::G_SSHLSAT) && |
1940 | "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT" ); |
1941 | |
1942 | LLT ShlType = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
1943 | LLT DestType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
1944 | |
1945 | Register Const = Builder.buildConstant(Res: ShlType, Val: MatchInfo.ValSum).getReg(Idx: 0); |
1946 | |
1947 | Register Shift1Base = MatchInfo.Shift2->getOperand(i: 1).getReg(); |
1948 | Register Shift1 = |
1949 | Builder.buildInstr(Opc: Opcode, DstOps: {DestType}, SrcOps: {Shift1Base, Const}).getReg(Idx: 0); |
1950 | |
1951 | // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same |
1952 | // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when |
1953 | // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we |
1954 | // remove old shift1. And it will cause crash later. So erase it earlier to |
1955 | // avoid the crash. |
1956 | MatchInfo.Shift2->eraseFromParent(); |
1957 | |
1958 | Register Shift2Const = MI.getOperand(i: 2).getReg(); |
1959 | Register Shift2 = Builder |
1960 | .buildInstr(Opc: Opcode, DstOps: {DestType}, |
1961 | SrcOps: {MatchInfo.LogicNonShiftReg, Shift2Const}) |
1962 | .getReg(Idx: 0); |
1963 | |
1964 | Register Dest = MI.getOperand(i: 0).getReg(); |
1965 | Builder.buildInstr(Opc: MatchInfo.Logic->getOpcode(), DstOps: {Dest}, SrcOps: {Shift1, Shift2}); |
1966 | |
1967 | // This was one use so it's safe to remove it. |
1968 | MatchInfo.Logic->eraseFromParent(); |
1969 | |
1970 | MI.eraseFromParent(); |
1971 | } |
1972 | |
1973 | bool CombinerHelper::matchCommuteShift(MachineInstr &MI, BuildFnTy &MatchInfo) { |
1974 | assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL" ); |
1975 | // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2) |
1976 | // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2) |
1977 | auto &Shl = cast<GenericMachineInstr>(Val&: MI); |
1978 | Register DstReg = Shl.getReg(Idx: 0); |
1979 | Register SrcReg = Shl.getReg(Idx: 1); |
1980 | Register ShiftReg = Shl.getReg(Idx: 2); |
1981 | Register X, C1; |
1982 | |
1983 | if (!getTargetLowering().isDesirableToCommuteWithShift(MI, IsAfterLegal: !isPreLegalize())) |
1984 | return false; |
1985 | |
1986 | if (!mi_match(R: SrcReg, MRI, |
1987 | P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: C1)), |
1988 | preds: m_GOr(L: m_Reg(R&: X), R: m_Reg(R&: C1)))))) |
1989 | return false; |
1990 | |
1991 | APInt C1Val, C2Val; |
1992 | if (!mi_match(R: C1, MRI, P: m_ICstOrSplat(Cst&: C1Val)) || |
1993 | !mi_match(R: ShiftReg, MRI, P: m_ICstOrSplat(Cst&: C2Val))) |
1994 | return false; |
1995 | |
1996 | auto *SrcDef = MRI.getVRegDef(Reg: SrcReg); |
1997 | assert((SrcDef->getOpcode() == TargetOpcode::G_ADD || |
1998 | SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op" ); |
1999 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2000 | MatchInfo = [=](MachineIRBuilder &B) { |
2001 | auto S1 = B.buildShl(Dst: SrcTy, Src0: X, Src1: ShiftReg); |
2002 | auto S2 = B.buildShl(Dst: SrcTy, Src0: C1, Src1: ShiftReg); |
2003 | B.buildInstr(Opc: SrcDef->getOpcode(), DstOps: {DstReg}, SrcOps: {S1, S2}); |
2004 | }; |
2005 | return true; |
2006 | } |
2007 | |
2008 | bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI, |
2009 | unsigned &ShiftVal) { |
2010 | assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL" ); |
2011 | auto MaybeImmVal = |
2012 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
2013 | if (!MaybeImmVal) |
2014 | return false; |
2015 | |
2016 | ShiftVal = MaybeImmVal->Value.exactLogBase2(); |
2017 | return (static_cast<int32_t>(ShiftVal) != -1); |
2018 | } |
2019 | |
2020 | void CombinerHelper::applyCombineMulToShl(MachineInstr &MI, |
2021 | unsigned &ShiftVal) { |
2022 | assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL" ); |
2023 | MachineIRBuilder MIB(MI); |
2024 | LLT ShiftTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2025 | auto ShiftCst = MIB.buildConstant(Res: ShiftTy, Val: ShiftVal); |
2026 | Observer.changingInstr(MI); |
2027 | MI.setDesc(MIB.getTII().get(Opcode: TargetOpcode::G_SHL)); |
2028 | MI.getOperand(i: 2).setReg(ShiftCst.getReg(Idx: 0)); |
2029 | Observer.changedInstr(MI); |
2030 | } |
2031 | |
2032 | // shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source |
2033 | bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI, |
2034 | RegisterImmPair &MatchData) { |
2035 | assert(MI.getOpcode() == TargetOpcode::G_SHL && KB); |
2036 | if (!getTargetLowering().isDesirableToPullExtFromShl(MI)) |
2037 | return false; |
2038 | |
2039 | Register LHS = MI.getOperand(i: 1).getReg(); |
2040 | |
2041 | Register ExtSrc; |
2042 | if (!mi_match(R: LHS, MRI, P: m_GAnyExt(Src: m_Reg(R&: ExtSrc))) && |
2043 | !mi_match(R: LHS, MRI, P: m_GZExt(Src: m_Reg(R&: ExtSrc))) && |
2044 | !mi_match(R: LHS, MRI, P: m_GSExt(Src: m_Reg(R&: ExtSrc)))) |
2045 | return false; |
2046 | |
2047 | Register RHS = MI.getOperand(i: 2).getReg(); |
2048 | MachineInstr *MIShiftAmt = MRI.getVRegDef(Reg: RHS); |
2049 | auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(MI&: *MIShiftAmt, MRI); |
2050 | if (!MaybeShiftAmtVal) |
2051 | return false; |
2052 | |
2053 | if (LI) { |
2054 | LLT SrcTy = MRI.getType(Reg: ExtSrc); |
2055 | |
2056 | // We only really care about the legality with the shifted value. We can |
2057 | // pick any type the constant shift amount, so ask the target what to |
2058 | // use. Otherwise we would have to guess and hope it is reported as legal. |
2059 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: SrcTy); |
2060 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}})) |
2061 | return false; |
2062 | } |
2063 | |
2064 | int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue(); |
2065 | MatchData.Reg = ExtSrc; |
2066 | MatchData.Imm = ShiftAmt; |
2067 | |
2068 | unsigned MinLeadingZeros = KB->getKnownZeroes(R: ExtSrc).countl_one(); |
2069 | unsigned SrcTySize = MRI.getType(Reg: ExtSrc).getScalarSizeInBits(); |
2070 | return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize; |
2071 | } |
2072 | |
2073 | void CombinerHelper::applyCombineShlOfExtend(MachineInstr &MI, |
2074 | const RegisterImmPair &MatchData) { |
2075 | Register ExtSrcReg = MatchData.Reg; |
2076 | int64_t ShiftAmtVal = MatchData.Imm; |
2077 | |
2078 | LLT ExtSrcTy = MRI.getType(Reg: ExtSrcReg); |
2079 | auto ShiftAmt = Builder.buildConstant(Res: ExtSrcTy, Val: ShiftAmtVal); |
2080 | auto NarrowShift = |
2081 | Builder.buildShl(Dst: ExtSrcTy, Src0: ExtSrcReg, Src1: ShiftAmt, Flags: MI.getFlags()); |
2082 | Builder.buildZExt(Res: MI.getOperand(i: 0), Op: NarrowShift); |
2083 | MI.eraseFromParent(); |
2084 | } |
2085 | |
2086 | bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI, |
2087 | Register &MatchInfo) { |
2088 | GMerge &Merge = cast<GMerge>(Val&: MI); |
2089 | SmallVector<Register, 16> MergedValues; |
2090 | for (unsigned I = 0; I < Merge.getNumSources(); ++I) |
2091 | MergedValues.emplace_back(Args: Merge.getSourceReg(I)); |
2092 | |
2093 | auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: MergedValues[0], MRI); |
2094 | if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources()) |
2095 | return false; |
2096 | |
2097 | for (unsigned I = 0; I < MergedValues.size(); ++I) |
2098 | if (MergedValues[I] != Unmerge->getReg(Idx: I)) |
2099 | return false; |
2100 | |
2101 | MatchInfo = Unmerge->getSourceReg(); |
2102 | return true; |
2103 | } |
2104 | |
2105 | static Register peekThroughBitcast(Register Reg, |
2106 | const MachineRegisterInfo &MRI) { |
2107 | while (mi_match(R: Reg, MRI, P: m_GBitcast(Src: m_Reg(R&: Reg)))) |
2108 | ; |
2109 | |
2110 | return Reg; |
2111 | } |
2112 | |
2113 | bool CombinerHelper::matchCombineUnmergeMergeToPlainValues( |
2114 | MachineInstr &MI, SmallVectorImpl<Register> &Operands) { |
2115 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2116 | "Expected an unmerge" ); |
2117 | auto &Unmerge = cast<GUnmerge>(Val&: MI); |
2118 | Register SrcReg = peekThroughBitcast(Reg: Unmerge.getSourceReg(), MRI); |
2119 | |
2120 | auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(Reg: SrcReg, MRI); |
2121 | if (!SrcInstr) |
2122 | return false; |
2123 | |
2124 | // Check the source type of the merge. |
2125 | LLT SrcMergeTy = MRI.getType(Reg: SrcInstr->getSourceReg(I: 0)); |
2126 | LLT Dst0Ty = MRI.getType(Reg: Unmerge.getReg(Idx: 0)); |
2127 | bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits(); |
2128 | if (SrcMergeTy != Dst0Ty && !SameSize) |
2129 | return false; |
2130 | // They are the same now (modulo a bitcast). |
2131 | // We can collect all the src registers. |
2132 | for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx) |
2133 | Operands.push_back(Elt: SrcInstr->getSourceReg(I: Idx)); |
2134 | return true; |
2135 | } |
2136 | |
2137 | void CombinerHelper::applyCombineUnmergeMergeToPlainValues( |
2138 | MachineInstr &MI, SmallVectorImpl<Register> &Operands) { |
2139 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2140 | "Expected an unmerge" ); |
2141 | assert((MI.getNumOperands() - 1 == Operands.size()) && |
2142 | "Not enough operands to replace all defs" ); |
2143 | unsigned NumElems = MI.getNumOperands() - 1; |
2144 | |
2145 | LLT SrcTy = MRI.getType(Reg: Operands[0]); |
2146 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2147 | bool CanReuseInputDirectly = DstTy == SrcTy; |
2148 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2149 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2150 | Register SrcReg = Operands[Idx]; |
2151 | |
2152 | // This combine may run after RegBankSelect, so we need to be aware of |
2153 | // register banks. |
2154 | const auto &DstCB = MRI.getRegClassOrRegBank(Reg: DstReg); |
2155 | if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(Reg: SrcReg)) { |
2156 | SrcReg = Builder.buildCopy(Res: MRI.getType(Reg: SrcReg), Op: SrcReg).getReg(Idx: 0); |
2157 | MRI.setRegClassOrRegBank(Reg: SrcReg, RCOrRB: DstCB); |
2158 | } |
2159 | |
2160 | if (CanReuseInputDirectly) |
2161 | replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg); |
2162 | else |
2163 | Builder.buildCast(Dst: DstReg, Src: SrcReg); |
2164 | } |
2165 | MI.eraseFromParent(); |
2166 | } |
2167 | |
2168 | bool CombinerHelper::matchCombineUnmergeConstant(MachineInstr &MI, |
2169 | SmallVectorImpl<APInt> &Csts) { |
2170 | unsigned SrcIdx = MI.getNumOperands() - 1; |
2171 | Register SrcReg = MI.getOperand(i: SrcIdx).getReg(); |
2172 | MachineInstr *SrcInstr = MRI.getVRegDef(Reg: SrcReg); |
2173 | if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT && |
2174 | SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT) |
2175 | return false; |
2176 | // Break down the big constant in smaller ones. |
2177 | const MachineOperand &CstVal = SrcInstr->getOperand(i: 1); |
2178 | APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT |
2179 | ? CstVal.getCImm()->getValue() |
2180 | : CstVal.getFPImm()->getValueAPF().bitcastToAPInt(); |
2181 | |
2182 | LLT Dst0Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2183 | unsigned ShiftAmt = Dst0Ty.getSizeInBits(); |
2184 | // Unmerge a constant. |
2185 | for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) { |
2186 | Csts.emplace_back(Args: Val.trunc(width: ShiftAmt)); |
2187 | Val = Val.lshr(shiftAmt: ShiftAmt); |
2188 | } |
2189 | |
2190 | return true; |
2191 | } |
2192 | |
2193 | void CombinerHelper::applyCombineUnmergeConstant(MachineInstr &MI, |
2194 | SmallVectorImpl<APInt> &Csts) { |
2195 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2196 | "Expected an unmerge" ); |
2197 | assert((MI.getNumOperands() - 1 == Csts.size()) && |
2198 | "Not enough operands to replace all defs" ); |
2199 | unsigned NumElems = MI.getNumOperands() - 1; |
2200 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2201 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2202 | Builder.buildConstant(Res: DstReg, Val: Csts[Idx]); |
2203 | } |
2204 | |
2205 | MI.eraseFromParent(); |
2206 | } |
2207 | |
2208 | bool CombinerHelper::matchCombineUnmergeUndef( |
2209 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
2210 | unsigned SrcIdx = MI.getNumOperands() - 1; |
2211 | Register SrcReg = MI.getOperand(i: SrcIdx).getReg(); |
2212 | MatchInfo = [&MI](MachineIRBuilder &B) { |
2213 | unsigned NumElems = MI.getNumOperands() - 1; |
2214 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2215 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2216 | B.buildUndef(Res: DstReg); |
2217 | } |
2218 | }; |
2219 | return isa<GImplicitDef>(Val: MRI.getVRegDef(Reg: SrcReg)); |
2220 | } |
2221 | |
2222 | bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) { |
2223 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2224 | "Expected an unmerge" ); |
2225 | if (MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector() || |
2226 | MRI.getType(Reg: MI.getOperand(i: MI.getNumDefs()).getReg()).isVector()) |
2227 | return false; |
2228 | // Check that all the lanes are dead except the first one. |
2229 | for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { |
2230 | if (!MRI.use_nodbg_empty(RegNo: MI.getOperand(i: Idx).getReg())) |
2231 | return false; |
2232 | } |
2233 | return true; |
2234 | } |
2235 | |
2236 | void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) { |
2237 | Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg(); |
2238 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2239 | Builder.buildTrunc(Res: Dst0Reg, Op: SrcReg); |
2240 | MI.eraseFromParent(); |
2241 | } |
2242 | |
2243 | bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) { |
2244 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2245 | "Expected an unmerge" ); |
2246 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2247 | LLT Dst0Ty = MRI.getType(Reg: Dst0Reg); |
2248 | // G_ZEXT on vector applies to each lane, so it will |
2249 | // affect all destinations. Therefore we won't be able |
2250 | // to simplify the unmerge to just the first definition. |
2251 | if (Dst0Ty.isVector()) |
2252 | return false; |
2253 | Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg(); |
2254 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2255 | if (SrcTy.isVector()) |
2256 | return false; |
2257 | |
2258 | Register ZExtSrcReg; |
2259 | if (!mi_match(R: SrcReg, MRI, P: m_GZExt(Src: m_Reg(R&: ZExtSrcReg)))) |
2260 | return false; |
2261 | |
2262 | // Finally we can replace the first definition with |
2263 | // a zext of the source if the definition is big enough to hold |
2264 | // all of ZExtSrc bits. |
2265 | LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg); |
2266 | return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits(); |
2267 | } |
2268 | |
2269 | void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) { |
2270 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2271 | "Expected an unmerge" ); |
2272 | |
2273 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2274 | |
2275 | MachineInstr *ZExtInstr = |
2276 | MRI.getVRegDef(Reg: MI.getOperand(i: MI.getNumDefs()).getReg()); |
2277 | assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT && |
2278 | "Expecting a G_ZEXT" ); |
2279 | |
2280 | Register ZExtSrcReg = ZExtInstr->getOperand(i: 1).getReg(); |
2281 | LLT Dst0Ty = MRI.getType(Reg: Dst0Reg); |
2282 | LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg); |
2283 | |
2284 | if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) { |
2285 | Builder.buildZExt(Res: Dst0Reg, Op: ZExtSrcReg); |
2286 | } else { |
2287 | assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() && |
2288 | "ZExt src doesn't fit in destination" ); |
2289 | replaceRegWith(MRI, FromReg: Dst0Reg, ToReg: ZExtSrcReg); |
2290 | } |
2291 | |
2292 | Register ZeroReg; |
2293 | for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { |
2294 | if (!ZeroReg) |
2295 | ZeroReg = Builder.buildConstant(Res: Dst0Ty, Val: 0).getReg(Idx: 0); |
2296 | replaceRegWith(MRI, FromReg: MI.getOperand(i: Idx).getReg(), ToReg: ZeroReg); |
2297 | } |
2298 | MI.eraseFromParent(); |
2299 | } |
2300 | |
2301 | bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI, |
2302 | unsigned TargetShiftSize, |
2303 | unsigned &ShiftVal) { |
2304 | assert((MI.getOpcode() == TargetOpcode::G_SHL || |
2305 | MI.getOpcode() == TargetOpcode::G_LSHR || |
2306 | MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift" ); |
2307 | |
2308 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2309 | if (Ty.isVector()) // TODO: |
2310 | return false; |
2311 | |
2312 | // Don't narrow further than the requested size. |
2313 | unsigned Size = Ty.getSizeInBits(); |
2314 | if (Size <= TargetShiftSize) |
2315 | return false; |
2316 | |
2317 | auto MaybeImmVal = |
2318 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
2319 | if (!MaybeImmVal) |
2320 | return false; |
2321 | |
2322 | ShiftVal = MaybeImmVal->Value.getSExtValue(); |
2323 | return ShiftVal >= Size / 2 && ShiftVal < Size; |
2324 | } |
2325 | |
2326 | void CombinerHelper::applyCombineShiftToUnmerge(MachineInstr &MI, |
2327 | const unsigned &ShiftVal) { |
2328 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2329 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2330 | LLT Ty = MRI.getType(Reg: SrcReg); |
2331 | unsigned Size = Ty.getSizeInBits(); |
2332 | unsigned HalfSize = Size / 2; |
2333 | assert(ShiftVal >= HalfSize); |
2334 | |
2335 | LLT HalfTy = LLT::scalar(SizeInBits: HalfSize); |
2336 | |
2337 | auto Unmerge = Builder.buildUnmerge(Res: HalfTy, Op: SrcReg); |
2338 | unsigned NarrowShiftAmt = ShiftVal - HalfSize; |
2339 | |
2340 | if (MI.getOpcode() == TargetOpcode::G_LSHR) { |
2341 | Register Narrowed = Unmerge.getReg(Idx: 1); |
2342 | |
2343 | // dst = G_LSHR s64:x, C for C >= 32 |
2344 | // => |
2345 | // lo, hi = G_UNMERGE_VALUES x |
2346 | // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0 |
2347 | |
2348 | if (NarrowShiftAmt != 0) { |
2349 | Narrowed = Builder.buildLShr(Dst: HalfTy, Src0: Narrowed, |
2350 | Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0); |
2351 | } |
2352 | |
2353 | auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0); |
2354 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Narrowed, Zero}); |
2355 | } else if (MI.getOpcode() == TargetOpcode::G_SHL) { |
2356 | Register Narrowed = Unmerge.getReg(Idx: 0); |
2357 | // dst = G_SHL s64:x, C for C >= 32 |
2358 | // => |
2359 | // lo, hi = G_UNMERGE_VALUES x |
2360 | // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32) |
2361 | if (NarrowShiftAmt != 0) { |
2362 | Narrowed = Builder.buildShl(Dst: HalfTy, Src0: Narrowed, |
2363 | Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0); |
2364 | } |
2365 | |
2366 | auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0); |
2367 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Zero, Narrowed}); |
2368 | } else { |
2369 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
2370 | auto Hi = Builder.buildAShr( |
2371 | Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1), |
2372 | Src1: Builder.buildConstant(Res: HalfTy, Val: HalfSize - 1)); |
2373 | |
2374 | if (ShiftVal == HalfSize) { |
2375 | // (G_ASHR i64:x, 32) -> |
2376 | // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31) |
2377 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Unmerge.getReg(Idx: 1), Hi}); |
2378 | } else if (ShiftVal == Size - 1) { |
2379 | // Don't need a second shift. |
2380 | // (G_ASHR i64:x, 63) -> |
2381 | // %narrowed = (G_ASHR hi_32(x), 31) |
2382 | // G_MERGE_VALUES %narrowed, %narrowed |
2383 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Hi, Hi}); |
2384 | } else { |
2385 | auto Lo = Builder.buildAShr( |
2386 | Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1), |
2387 | Src1: Builder.buildConstant(Res: HalfTy, Val: ShiftVal - HalfSize)); |
2388 | |
2389 | // (G_ASHR i64:x, C) ->, for C >= 32 |
2390 | // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31) |
2391 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Lo, Hi}); |
2392 | } |
2393 | } |
2394 | |
2395 | MI.eraseFromParent(); |
2396 | } |
2397 | |
2398 | bool CombinerHelper::tryCombineShiftToUnmerge(MachineInstr &MI, |
2399 | unsigned TargetShiftAmount) { |
2400 | unsigned ShiftAmt; |
2401 | if (matchCombineShiftToUnmerge(MI, TargetShiftSize: TargetShiftAmount, ShiftVal&: ShiftAmt)) { |
2402 | applyCombineShiftToUnmerge(MI, ShiftVal: ShiftAmt); |
2403 | return true; |
2404 | } |
2405 | |
2406 | return false; |
2407 | } |
2408 | |
2409 | bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI, Register &Reg) { |
2410 | assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR" ); |
2411 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2412 | LLT DstTy = MRI.getType(Reg: DstReg); |
2413 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2414 | return mi_match(R: SrcReg, MRI, |
2415 | P: m_GPtrToInt(Src: m_all_of(preds: m_SpecificType(Ty: DstTy), preds: m_Reg(R&: Reg)))); |
2416 | } |
2417 | |
2418 | void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI, Register &Reg) { |
2419 | assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR" ); |
2420 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2421 | Builder.buildCopy(Res: DstReg, Op: Reg); |
2422 | MI.eraseFromParent(); |
2423 | } |
2424 | |
2425 | void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI, Register &Reg) { |
2426 | assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT" ); |
2427 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2428 | Builder.buildZExtOrTrunc(Res: DstReg, Op: Reg); |
2429 | MI.eraseFromParent(); |
2430 | } |
2431 | |
2432 | bool CombinerHelper::matchCombineAddP2IToPtrAdd( |
2433 | MachineInstr &MI, std::pair<Register, bool> &PtrReg) { |
2434 | assert(MI.getOpcode() == TargetOpcode::G_ADD); |
2435 | Register LHS = MI.getOperand(i: 1).getReg(); |
2436 | Register RHS = MI.getOperand(i: 2).getReg(); |
2437 | LLT IntTy = MRI.getType(Reg: LHS); |
2438 | |
2439 | // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the |
2440 | // instruction. |
2441 | PtrReg.second = false; |
2442 | for (Register SrcReg : {LHS, RHS}) { |
2443 | if (mi_match(R: SrcReg, MRI, P: m_GPtrToInt(Src: m_Reg(R&: PtrReg.first)))) { |
2444 | // Don't handle cases where the integer is implicitly converted to the |
2445 | // pointer width. |
2446 | LLT PtrTy = MRI.getType(Reg: PtrReg.first); |
2447 | if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits()) |
2448 | return true; |
2449 | } |
2450 | |
2451 | PtrReg.second = true; |
2452 | } |
2453 | |
2454 | return false; |
2455 | } |
2456 | |
2457 | void CombinerHelper::applyCombineAddP2IToPtrAdd( |
2458 | MachineInstr &MI, std::pair<Register, bool> &PtrReg) { |
2459 | Register Dst = MI.getOperand(i: 0).getReg(); |
2460 | Register LHS = MI.getOperand(i: 1).getReg(); |
2461 | Register RHS = MI.getOperand(i: 2).getReg(); |
2462 | |
2463 | const bool DoCommute = PtrReg.second; |
2464 | if (DoCommute) |
2465 | std::swap(a&: LHS, b&: RHS); |
2466 | LHS = PtrReg.first; |
2467 | |
2468 | LLT PtrTy = MRI.getType(Reg: LHS); |
2469 | |
2470 | auto PtrAdd = Builder.buildPtrAdd(Res: PtrTy, Op0: LHS, Op1: RHS); |
2471 | Builder.buildPtrToInt(Dst, Src: PtrAdd); |
2472 | MI.eraseFromParent(); |
2473 | } |
2474 | |
2475 | bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI, |
2476 | APInt &NewCst) { |
2477 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
2478 | Register LHS = PtrAdd.getBaseReg(); |
2479 | Register RHS = PtrAdd.getOffsetReg(); |
2480 | MachineRegisterInfo &MRI = Builder.getMF().getRegInfo(); |
2481 | |
2482 | if (auto RHSCst = getIConstantVRegVal(VReg: RHS, MRI)) { |
2483 | APInt Cst; |
2484 | if (mi_match(R: LHS, MRI, P: m_GIntToPtr(Src: m_ICst(Cst)))) { |
2485 | auto DstTy = MRI.getType(Reg: PtrAdd.getReg(Idx: 0)); |
2486 | // G_INTTOPTR uses zero-extension |
2487 | NewCst = Cst.zextOrTrunc(width: DstTy.getSizeInBits()); |
2488 | NewCst += RHSCst->sextOrTrunc(width: DstTy.getSizeInBits()); |
2489 | return true; |
2490 | } |
2491 | } |
2492 | |
2493 | return false; |
2494 | } |
2495 | |
2496 | void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI, |
2497 | APInt &NewCst) { |
2498 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
2499 | Register Dst = PtrAdd.getReg(Idx: 0); |
2500 | |
2501 | Builder.buildConstant(Res: Dst, Val: NewCst); |
2502 | PtrAdd.eraseFromParent(); |
2503 | } |
2504 | |
2505 | bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI, Register &Reg) { |
2506 | assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT" ); |
2507 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2508 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2509 | Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI); |
2510 | if (OriginalSrcReg.isValid()) |
2511 | SrcReg = OriginalSrcReg; |
2512 | LLT DstTy = MRI.getType(Reg: DstReg); |
2513 | return mi_match(R: SrcReg, MRI, |
2514 | P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))); |
2515 | } |
2516 | |
2517 | bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI, Register &Reg) { |
2518 | assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT" ); |
2519 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2520 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2521 | LLT DstTy = MRI.getType(Reg: DstReg); |
2522 | if (mi_match(R: SrcReg, MRI, |
2523 | P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy))))) { |
2524 | unsigned DstSize = DstTy.getScalarSizeInBits(); |
2525 | unsigned SrcSize = MRI.getType(Reg: SrcReg).getScalarSizeInBits(); |
2526 | return KB->getKnownBits(R: Reg).countMinLeadingZeros() >= DstSize - SrcSize; |
2527 | } |
2528 | return false; |
2529 | } |
2530 | |
2531 | bool CombinerHelper::matchCombineExtOfExt( |
2532 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
2533 | assert((MI.getOpcode() == TargetOpcode::G_ANYEXT || |
2534 | MI.getOpcode() == TargetOpcode::G_SEXT || |
2535 | MI.getOpcode() == TargetOpcode::G_ZEXT) && |
2536 | "Expected a G_[ASZ]EXT" ); |
2537 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2538 | Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI); |
2539 | if (OriginalSrcReg.isValid()) |
2540 | SrcReg = OriginalSrcReg; |
2541 | MachineInstr *SrcMI = MRI.getVRegDef(Reg: SrcReg); |
2542 | // Match exts with the same opcode, anyext([sz]ext) and sext(zext). |
2543 | unsigned Opc = MI.getOpcode(); |
2544 | unsigned SrcOpc = SrcMI->getOpcode(); |
2545 | if (Opc == SrcOpc || |
2546 | (Opc == TargetOpcode::G_ANYEXT && |
2547 | (SrcOpc == TargetOpcode::G_SEXT || SrcOpc == TargetOpcode::G_ZEXT)) || |
2548 | (Opc == TargetOpcode::G_SEXT && SrcOpc == TargetOpcode::G_ZEXT)) { |
2549 | MatchInfo = std::make_tuple(args: SrcMI->getOperand(i: 1).getReg(), args&: SrcOpc); |
2550 | return true; |
2551 | } |
2552 | return false; |
2553 | } |
2554 | |
2555 | void CombinerHelper::applyCombineExtOfExt( |
2556 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
2557 | assert((MI.getOpcode() == TargetOpcode::G_ANYEXT || |
2558 | MI.getOpcode() == TargetOpcode::G_SEXT || |
2559 | MI.getOpcode() == TargetOpcode::G_ZEXT) && |
2560 | "Expected a G_[ASZ]EXT" ); |
2561 | |
2562 | Register Reg = std::get<0>(t&: MatchInfo); |
2563 | unsigned SrcExtOp = std::get<1>(t&: MatchInfo); |
2564 | |
2565 | // Combine exts with the same opcode. |
2566 | if (MI.getOpcode() == SrcExtOp) { |
2567 | Observer.changingInstr(MI); |
2568 | MI.getOperand(i: 1).setReg(Reg); |
2569 | Observer.changedInstr(MI); |
2570 | return; |
2571 | } |
2572 | |
2573 | // Combine: |
2574 | // - anyext([sz]ext x) to [sz]ext x |
2575 | // - sext(zext x) to zext x |
2576 | if (MI.getOpcode() == TargetOpcode::G_ANYEXT || |
2577 | (MI.getOpcode() == TargetOpcode::G_SEXT && |
2578 | SrcExtOp == TargetOpcode::G_ZEXT)) { |
2579 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2580 | Builder.buildInstr(Opc: SrcExtOp, DstOps: {DstReg}, SrcOps: {Reg}); |
2581 | MI.eraseFromParent(); |
2582 | } |
2583 | } |
2584 | |
2585 | bool CombinerHelper::matchCombineTruncOfExt( |
2586 | MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) { |
2587 | assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC" ); |
2588 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2589 | MachineInstr *SrcMI = MRI.getVRegDef(Reg: SrcReg); |
2590 | unsigned SrcOpc = SrcMI->getOpcode(); |
2591 | if (SrcOpc == TargetOpcode::G_ANYEXT || SrcOpc == TargetOpcode::G_SEXT || |
2592 | SrcOpc == TargetOpcode::G_ZEXT) { |
2593 | MatchInfo = std::make_pair(x: SrcMI->getOperand(i: 1).getReg(), y&: SrcOpc); |
2594 | return true; |
2595 | } |
2596 | return false; |
2597 | } |
2598 | |
2599 | void CombinerHelper::applyCombineTruncOfExt( |
2600 | MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) { |
2601 | assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC" ); |
2602 | Register SrcReg = MatchInfo.first; |
2603 | unsigned SrcExtOp = MatchInfo.second; |
2604 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2605 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2606 | LLT DstTy = MRI.getType(Reg: DstReg); |
2607 | if (SrcTy == DstTy) { |
2608 | MI.eraseFromParent(); |
2609 | replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg); |
2610 | return; |
2611 | } |
2612 | if (SrcTy.getSizeInBits() < DstTy.getSizeInBits()) |
2613 | Builder.buildInstr(Opc: SrcExtOp, DstOps: {DstReg}, SrcOps: {SrcReg}); |
2614 | else |
2615 | Builder.buildTrunc(Res: DstReg, Op: SrcReg); |
2616 | MI.eraseFromParent(); |
2617 | } |
2618 | |
2619 | static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) { |
2620 | const unsigned ShiftSize = ShiftTy.getScalarSizeInBits(); |
2621 | const unsigned TruncSize = TruncTy.getScalarSizeInBits(); |
2622 | |
2623 | // ShiftTy > 32 > TruncTy -> 32 |
2624 | if (ShiftSize > 32 && TruncSize < 32) |
2625 | return ShiftTy.changeElementSize(NewEltSize: 32); |
2626 | |
2627 | // TODO: We could also reduce to 16 bits, but that's more target-dependent. |
2628 | // Some targets like it, some don't, some only like it under certain |
2629 | // conditions/processor versions, etc. |
2630 | // A TL hook might be needed for this. |
2631 | |
2632 | // Don't combine |
2633 | return ShiftTy; |
2634 | } |
2635 | |
2636 | bool CombinerHelper::matchCombineTruncOfShift( |
2637 | MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) { |
2638 | assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC" ); |
2639 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2640 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2641 | |
2642 | if (!MRI.hasOneNonDBGUse(RegNo: SrcReg)) |
2643 | return false; |
2644 | |
2645 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2646 | LLT DstTy = MRI.getType(Reg: DstReg); |
2647 | |
2648 | MachineInstr *SrcMI = getDefIgnoringCopies(Reg: SrcReg, MRI); |
2649 | const auto &TL = getTargetLowering(); |
2650 | |
2651 | LLT NewShiftTy; |
2652 | switch (SrcMI->getOpcode()) { |
2653 | default: |
2654 | return false; |
2655 | case TargetOpcode::G_SHL: { |
2656 | NewShiftTy = DstTy; |
2657 | |
2658 | // Make sure new shift amount is legal. |
2659 | KnownBits Known = KB->getKnownBits(R: SrcMI->getOperand(i: 2).getReg()); |
2660 | if (Known.getMaxValue().uge(RHS: NewShiftTy.getScalarSizeInBits())) |
2661 | return false; |
2662 | break; |
2663 | } |
2664 | case TargetOpcode::G_LSHR: |
2665 | case TargetOpcode::G_ASHR: { |
2666 | // For right shifts, we conservatively do not do the transform if the TRUNC |
2667 | // has any STORE users. The reason is that if we change the type of the |
2668 | // shift, we may break the truncstore combine. |
2669 | // |
2670 | // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)). |
2671 | for (auto &User : MRI.use_instructions(Reg: DstReg)) |
2672 | if (User.getOpcode() == TargetOpcode::G_STORE) |
2673 | return false; |
2674 | |
2675 | NewShiftTy = getMidVTForTruncRightShiftCombine(ShiftTy: SrcTy, TruncTy: DstTy); |
2676 | if (NewShiftTy == SrcTy) |
2677 | return false; |
2678 | |
2679 | // Make sure we won't lose information by truncating the high bits. |
2680 | KnownBits Known = KB->getKnownBits(R: SrcMI->getOperand(i: 2).getReg()); |
2681 | if (Known.getMaxValue().ugt(RHS: NewShiftTy.getScalarSizeInBits() - |
2682 | DstTy.getScalarSizeInBits())) |
2683 | return false; |
2684 | break; |
2685 | } |
2686 | } |
2687 | |
2688 | if (!isLegalOrBeforeLegalizer( |
2689 | Query: {SrcMI->getOpcode(), |
2690 | {NewShiftTy, TL.getPreferredShiftAmountTy(ShiftValueTy: NewShiftTy)}})) |
2691 | return false; |
2692 | |
2693 | MatchInfo = std::make_pair(x&: SrcMI, y&: NewShiftTy); |
2694 | return true; |
2695 | } |
2696 | |
2697 | void CombinerHelper::applyCombineTruncOfShift( |
2698 | MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) { |
2699 | MachineInstr *ShiftMI = MatchInfo.first; |
2700 | LLT NewShiftTy = MatchInfo.second; |
2701 | |
2702 | Register Dst = MI.getOperand(i: 0).getReg(); |
2703 | LLT DstTy = MRI.getType(Reg: Dst); |
2704 | |
2705 | Register ShiftAmt = ShiftMI->getOperand(i: 2).getReg(); |
2706 | Register ShiftSrc = ShiftMI->getOperand(i: 1).getReg(); |
2707 | ShiftSrc = Builder.buildTrunc(Res: NewShiftTy, Op: ShiftSrc).getReg(Idx: 0); |
2708 | |
2709 | Register NewShift = |
2710 | Builder |
2711 | .buildInstr(Opc: ShiftMI->getOpcode(), DstOps: {NewShiftTy}, SrcOps: {ShiftSrc, ShiftAmt}) |
2712 | .getReg(Idx: 0); |
2713 | |
2714 | if (NewShiftTy == DstTy) |
2715 | replaceRegWith(MRI, FromReg: Dst, ToReg: NewShift); |
2716 | else |
2717 | Builder.buildTrunc(Res: Dst, Op: NewShift); |
2718 | |
2719 | eraseInst(MI); |
2720 | } |
2721 | |
2722 | bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) { |
2723 | return any_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) { |
2724 | return MO.isReg() && |
2725 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
2726 | }); |
2727 | } |
2728 | |
2729 | bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) { |
2730 | return all_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) { |
2731 | return !MO.isReg() || |
2732 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
2733 | }); |
2734 | } |
2735 | |
2736 | bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) { |
2737 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR); |
2738 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
2739 | return all_of(Range&: Mask, P: [](int Elt) { return Elt < 0; }); |
2740 | } |
2741 | |
2742 | bool CombinerHelper::matchUndefStore(MachineInstr &MI) { |
2743 | assert(MI.getOpcode() == TargetOpcode::G_STORE); |
2744 | return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 0).getReg(), |
2745 | MRI); |
2746 | } |
2747 | |
2748 | bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) { |
2749 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
2750 | return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 1).getReg(), |
2751 | MRI); |
2752 | } |
2753 | |
2754 | bool CombinerHelper::(MachineInstr &MI) { |
2755 | assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT || |
2756 | MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) && |
2757 | "Expected an insert/extract element op" ); |
2758 | LLT VecTy = MRI.getType(Reg: MI.getOperand(i: 1).getReg()); |
2759 | unsigned IdxIdx = |
2760 | MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3; |
2761 | auto Idx = getIConstantVRegVal(VReg: MI.getOperand(i: IdxIdx).getReg(), MRI); |
2762 | if (!Idx) |
2763 | return false; |
2764 | return Idx->getZExtValue() >= VecTy.getNumElements(); |
2765 | } |
2766 | |
2767 | bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI, unsigned &OpIdx) { |
2768 | GSelect &SelMI = cast<GSelect>(Val&: MI); |
2769 | auto Cst = |
2770 | isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: SelMI.getCondReg()), MRI); |
2771 | if (!Cst) |
2772 | return false; |
2773 | OpIdx = Cst->isZero() ? 3 : 2; |
2774 | return true; |
2775 | } |
2776 | |
2777 | void CombinerHelper::eraseInst(MachineInstr &MI) { MI.eraseFromParent(); } |
2778 | |
2779 | bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1, |
2780 | const MachineOperand &MOP2) { |
2781 | if (!MOP1.isReg() || !MOP2.isReg()) |
2782 | return false; |
2783 | auto InstAndDef1 = getDefSrcRegIgnoringCopies(Reg: MOP1.getReg(), MRI); |
2784 | if (!InstAndDef1) |
2785 | return false; |
2786 | auto InstAndDef2 = getDefSrcRegIgnoringCopies(Reg: MOP2.getReg(), MRI); |
2787 | if (!InstAndDef2) |
2788 | return false; |
2789 | MachineInstr *I1 = InstAndDef1->MI; |
2790 | MachineInstr *I2 = InstAndDef2->MI; |
2791 | |
2792 | // Handle a case like this: |
2793 | // |
2794 | // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>) |
2795 | // |
2796 | // Even though %0 and %1 are produced by the same instruction they are not |
2797 | // the same values. |
2798 | if (I1 == I2) |
2799 | return MOP1.getReg() == MOP2.getReg(); |
2800 | |
2801 | // If we have an instruction which loads or stores, we can't guarantee that |
2802 | // it is identical. |
2803 | // |
2804 | // For example, we may have |
2805 | // |
2806 | // %x1 = G_LOAD %addr (load N from @somewhere) |
2807 | // ... |
2808 | // call @foo |
2809 | // ... |
2810 | // %x2 = G_LOAD %addr (load N from @somewhere) |
2811 | // ... |
2812 | // %or = G_OR %x1, %x2 |
2813 | // |
2814 | // It's possible that @foo will modify whatever lives at the address we're |
2815 | // loading from. To be safe, let's just assume that all loads and stores |
2816 | // are different (unless we have something which is guaranteed to not |
2817 | // change.) |
2818 | if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad()) |
2819 | return false; |
2820 | |
2821 | // If both instructions are loads or stores, they are equal only if both |
2822 | // are dereferenceable invariant loads with the same number of bits. |
2823 | if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) { |
2824 | GLoadStore *LS1 = dyn_cast<GLoadStore>(Val: I1); |
2825 | GLoadStore *LS2 = dyn_cast<GLoadStore>(Val: I2); |
2826 | if (!LS1 || !LS2) |
2827 | return false; |
2828 | |
2829 | if (!I2->isDereferenceableInvariantLoad() || |
2830 | (LS1->getMemSizeInBits() != LS2->getMemSizeInBits())) |
2831 | return false; |
2832 | } |
2833 | |
2834 | // Check for physical registers on the instructions first to avoid cases |
2835 | // like this: |
2836 | // |
2837 | // %a = COPY $physreg |
2838 | // ... |
2839 | // SOMETHING implicit-def $physreg |
2840 | // ... |
2841 | // %b = COPY $physreg |
2842 | // |
2843 | // These copies are not equivalent. |
2844 | if (any_of(Range: I1->uses(), P: [](const MachineOperand &MO) { |
2845 | return MO.isReg() && MO.getReg().isPhysical(); |
2846 | })) { |
2847 | // Check if we have a case like this: |
2848 | // |
2849 | // %a = COPY $physreg |
2850 | // %b = COPY %a |
2851 | // |
2852 | // In this case, I1 and I2 will both be equal to %a = COPY $physreg. |
2853 | // From that, we know that they must have the same value, since they must |
2854 | // have come from the same COPY. |
2855 | return I1->isIdenticalTo(Other: *I2); |
2856 | } |
2857 | |
2858 | // We don't have any physical registers, so we don't necessarily need the |
2859 | // same vreg defs. |
2860 | // |
2861 | // On the off-chance that there's some target instruction feeding into the |
2862 | // instruction, let's use produceSameValue instead of isIdenticalTo. |
2863 | if (Builder.getTII().produceSameValue(MI0: *I1, MI1: *I2, MRI: &MRI)) { |
2864 | // Handle instructions with multiple defs that produce same values. Values |
2865 | // are same for operands with same index. |
2866 | // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) |
2867 | // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) |
2868 | // I1 and I2 are different instructions but produce same values, |
2869 | // %1 and %6 are same, %1 and %7 are not the same value. |
2870 | return I1->findRegisterDefOperandIdx(Reg: InstAndDef1->Reg, /*TRI=*/nullptr) == |
2871 | I2->findRegisterDefOperandIdx(Reg: InstAndDef2->Reg, /*TRI=*/nullptr); |
2872 | } |
2873 | return false; |
2874 | } |
2875 | |
2876 | bool CombinerHelper::matchConstantOp(const MachineOperand &MOP, int64_t C) { |
2877 | if (!MOP.isReg()) |
2878 | return false; |
2879 | auto *MI = MRI.getVRegDef(Reg: MOP.getReg()); |
2880 | auto MaybeCst = isConstantOrConstantSplatVector(MI&: *MI, MRI); |
2881 | return MaybeCst && MaybeCst->getBitWidth() <= 64 && |
2882 | MaybeCst->getSExtValue() == C; |
2883 | } |
2884 | |
2885 | bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP, double C) { |
2886 | if (!MOP.isReg()) |
2887 | return false; |
2888 | std::optional<FPValueAndVReg> MaybeCst; |
2889 | if (!mi_match(R: MOP.getReg(), MRI, P: m_GFCstOrSplat(FPValReg&: MaybeCst))) |
2890 | return false; |
2891 | |
2892 | return MaybeCst->Value.isExactlyValue(V: C); |
2893 | } |
2894 | |
2895 | void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI, |
2896 | unsigned OpIdx) { |
2897 | assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?" ); |
2898 | Register OldReg = MI.getOperand(i: 0).getReg(); |
2899 | Register Replacement = MI.getOperand(i: OpIdx).getReg(); |
2900 | assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?" ); |
2901 | MI.eraseFromParent(); |
2902 | replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement); |
2903 | } |
2904 | |
2905 | void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI, |
2906 | Register Replacement) { |
2907 | assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?" ); |
2908 | Register OldReg = MI.getOperand(i: 0).getReg(); |
2909 | assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?" ); |
2910 | MI.eraseFromParent(); |
2911 | replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement); |
2912 | } |
2913 | |
2914 | bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI, |
2915 | unsigned ConstIdx) { |
2916 | Register ConstReg = MI.getOperand(i: ConstIdx).getReg(); |
2917 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2918 | |
2919 | // Get the shift amount |
2920 | auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI); |
2921 | if (!VRegAndVal) |
2922 | return false; |
2923 | |
2924 | // Return true of shift amount >= Bitwidth |
2925 | return (VRegAndVal->Value.uge(RHS: DstTy.getSizeInBits())); |
2926 | } |
2927 | |
2928 | void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) { |
2929 | assert((MI.getOpcode() == TargetOpcode::G_FSHL || |
2930 | MI.getOpcode() == TargetOpcode::G_FSHR) && |
2931 | "This is not a funnel shift operation" ); |
2932 | |
2933 | Register ConstReg = MI.getOperand(i: 3).getReg(); |
2934 | LLT ConstTy = MRI.getType(Reg: ConstReg); |
2935 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2936 | |
2937 | auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI); |
2938 | assert((VRegAndVal) && "Value is not a constant" ); |
2939 | |
2940 | // Calculate the new Shift Amount = Old Shift Amount % BitWidth |
2941 | APInt NewConst = VRegAndVal->Value.urem( |
2942 | RHS: APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits())); |
2943 | |
2944 | auto NewConstInstr = Builder.buildConstant(Res: ConstTy, Val: NewConst.getZExtValue()); |
2945 | Builder.buildInstr( |
2946 | Opc: MI.getOpcode(), DstOps: {MI.getOperand(i: 0)}, |
2947 | SrcOps: {MI.getOperand(i: 1), MI.getOperand(i: 2), NewConstInstr.getReg(Idx: 0)}); |
2948 | |
2949 | MI.eraseFromParent(); |
2950 | } |
2951 | |
2952 | bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) { |
2953 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
2954 | // Match (cond ? x : x) |
2955 | return matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: MI.getOperand(i: 3)) && |
2956 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 2).getReg(), |
2957 | MRI); |
2958 | } |
2959 | |
2960 | bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) { |
2961 | return matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: MI.getOperand(i: 2)) && |
2962 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 1).getReg(), |
2963 | MRI); |
2964 | } |
2965 | |
2966 | bool CombinerHelper::matchOperandIsZero(MachineInstr &MI, unsigned OpIdx) { |
2967 | return matchConstantOp(MOP: MI.getOperand(i: OpIdx), C: 0) && |
2968 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: OpIdx).getReg(), |
2969 | MRI); |
2970 | } |
2971 | |
2972 | bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI, unsigned OpIdx) { |
2973 | MachineOperand &MO = MI.getOperand(i: OpIdx); |
2974 | return MO.isReg() && |
2975 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
2976 | } |
2977 | |
2978 | bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI, |
2979 | unsigned OpIdx) { |
2980 | MachineOperand &MO = MI.getOperand(i: OpIdx); |
2981 | return isKnownToBeAPowerOfTwo(Val: MO.getReg(), MRI, KnownBits: KB); |
2982 | } |
2983 | |
2984 | void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, double C) { |
2985 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2986 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: C); |
2987 | MI.eraseFromParent(); |
2988 | } |
2989 | |
2990 | void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, int64_t C) { |
2991 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2992 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C); |
2993 | MI.eraseFromParent(); |
2994 | } |
2995 | |
2996 | void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) { |
2997 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2998 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C); |
2999 | MI.eraseFromParent(); |
3000 | } |
3001 | |
3002 | void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, |
3003 | ConstantFP *CFP) { |
3004 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
3005 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: CFP->getValueAPF()); |
3006 | MI.eraseFromParent(); |
3007 | } |
3008 | |
3009 | void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) { |
3010 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
3011 | Builder.buildUndef(Res: MI.getOperand(i: 0)); |
3012 | MI.eraseFromParent(); |
3013 | } |
3014 | |
3015 | bool CombinerHelper::matchSimplifyAddToSub( |
3016 | MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) { |
3017 | Register LHS = MI.getOperand(i: 1).getReg(); |
3018 | Register RHS = MI.getOperand(i: 2).getReg(); |
3019 | Register &NewLHS = std::get<0>(t&: MatchInfo); |
3020 | Register &NewRHS = std::get<1>(t&: MatchInfo); |
3021 | |
3022 | // Helper lambda to check for opportunities for |
3023 | // ((0-A) + B) -> B - A |
3024 | // (A + (0-B)) -> A - B |
3025 | auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) { |
3026 | if (!mi_match(R: MaybeSub, MRI, P: m_Neg(Src: m_Reg(R&: NewRHS)))) |
3027 | return false; |
3028 | NewLHS = MaybeNewLHS; |
3029 | return true; |
3030 | }; |
3031 | |
3032 | return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); |
3033 | } |
3034 | |
3035 | bool CombinerHelper::matchCombineInsertVecElts( |
3036 | MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) { |
3037 | assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT && |
3038 | "Invalid opcode" ); |
3039 | Register DstReg = MI.getOperand(i: 0).getReg(); |
3040 | LLT DstTy = MRI.getType(Reg: DstReg); |
3041 | assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?" ); |
3042 | unsigned NumElts = DstTy.getNumElements(); |
3043 | // If this MI is part of a sequence of insert_vec_elts, then |
3044 | // don't do the combine in the middle of the sequence. |
3045 | if (MRI.hasOneUse(RegNo: DstReg) && MRI.use_instr_begin(RegNo: DstReg)->getOpcode() == |
3046 | TargetOpcode::G_INSERT_VECTOR_ELT) |
3047 | return false; |
3048 | MachineInstr *CurrInst = &MI; |
3049 | MachineInstr *TmpInst; |
3050 | int64_t IntImm; |
3051 | Register TmpReg; |
3052 | MatchInfo.resize(N: NumElts); |
3053 | while (mi_match( |
3054 | R: CurrInst->getOperand(i: 0).getReg(), MRI, |
3055 | P: m_GInsertVecElt(Src0: m_MInstr(MI&: TmpInst), Src1: m_Reg(R&: TmpReg), Src2: m_ICst(Cst&: IntImm)))) { |
3056 | if (IntImm >= NumElts || IntImm < 0) |
3057 | return false; |
3058 | if (!MatchInfo[IntImm]) |
3059 | MatchInfo[IntImm] = TmpReg; |
3060 | CurrInst = TmpInst; |
3061 | } |
3062 | // Variable index. |
3063 | if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT) |
3064 | return false; |
3065 | if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) { |
3066 | for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) { |
3067 | if (!MatchInfo[I - 1].isValid()) |
3068 | MatchInfo[I - 1] = TmpInst->getOperand(i: I).getReg(); |
3069 | } |
3070 | return true; |
3071 | } |
3072 | // If we didn't end in a G_IMPLICIT_DEF and the source is not fully |
3073 | // overwritten, bail out. |
3074 | return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF || |
3075 | all_of(Range&: MatchInfo, P: [](Register Reg) { return !!Reg; }); |
3076 | } |
3077 | |
3078 | void CombinerHelper::applyCombineInsertVecElts( |
3079 | MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) { |
3080 | Register UndefReg; |
3081 | auto GetUndef = [&]() { |
3082 | if (UndefReg) |
3083 | return UndefReg; |
3084 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
3085 | UndefReg = Builder.buildUndef(Res: DstTy.getScalarType()).getReg(Idx: 0); |
3086 | return UndefReg; |
3087 | }; |
3088 | for (Register &Reg : MatchInfo) { |
3089 | if (!Reg) |
3090 | Reg = GetUndef(); |
3091 | } |
3092 | Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: MatchInfo); |
3093 | MI.eraseFromParent(); |
3094 | } |
3095 | |
3096 | void CombinerHelper::applySimplifyAddToSub( |
3097 | MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) { |
3098 | Register SubLHS, SubRHS; |
3099 | std::tie(args&: SubLHS, args&: SubRHS) = MatchInfo; |
3100 | Builder.buildSub(Dst: MI.getOperand(i: 0).getReg(), Src0: SubLHS, Src1: SubRHS); |
3101 | MI.eraseFromParent(); |
3102 | } |
3103 | |
3104 | bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands( |
3105 | MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) { |
3106 | // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ... |
3107 | // |
3108 | // Creates the new hand + logic instruction (but does not insert them.) |
3109 | // |
3110 | // On success, MatchInfo is populated with the new instructions. These are |
3111 | // inserted in applyHoistLogicOpWithSameOpcodeHands. |
3112 | unsigned LogicOpcode = MI.getOpcode(); |
3113 | assert(LogicOpcode == TargetOpcode::G_AND || |
3114 | LogicOpcode == TargetOpcode::G_OR || |
3115 | LogicOpcode == TargetOpcode::G_XOR); |
3116 | MachineIRBuilder MIB(MI); |
3117 | Register Dst = MI.getOperand(i: 0).getReg(); |
3118 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
3119 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
3120 | |
3121 | // Don't recompute anything. |
3122 | if (!MRI.hasOneNonDBGUse(RegNo: LHSReg) || !MRI.hasOneNonDBGUse(RegNo: RHSReg)) |
3123 | return false; |
3124 | |
3125 | // Make sure we have (hand x, ...), (hand y, ...) |
3126 | MachineInstr *LeftHandInst = getDefIgnoringCopies(Reg: LHSReg, MRI); |
3127 | MachineInstr *RightHandInst = getDefIgnoringCopies(Reg: RHSReg, MRI); |
3128 | if (!LeftHandInst || !RightHandInst) |
3129 | return false; |
3130 | unsigned HandOpcode = LeftHandInst->getOpcode(); |
3131 | if (HandOpcode != RightHandInst->getOpcode()) |
3132 | return false; |
3133 | if (!LeftHandInst->getOperand(i: 1).isReg() || |
3134 | !RightHandInst->getOperand(i: 1).isReg()) |
3135 | return false; |
3136 | |
3137 | // Make sure the types match up, and if we're doing this post-legalization, |
3138 | // we end up with legal types. |
3139 | Register X = LeftHandInst->getOperand(i: 1).getReg(); |
3140 | Register Y = RightHandInst->getOperand(i: 1).getReg(); |
3141 | LLT XTy = MRI.getType(Reg: X); |
3142 | LLT YTy = MRI.getType(Reg: Y); |
3143 | if (!XTy.isValid() || XTy != YTy) |
3144 | return false; |
3145 | |
3146 | // Optional extra source register. |
3147 | Register ExtraHandOpSrcReg; |
3148 | switch (HandOpcode) { |
3149 | default: |
3150 | return false; |
3151 | case TargetOpcode::G_ANYEXT: |
3152 | case TargetOpcode::G_SEXT: |
3153 | case TargetOpcode::G_ZEXT: { |
3154 | // Match: logic (ext X), (ext Y) --> ext (logic X, Y) |
3155 | break; |
3156 | } |
3157 | case TargetOpcode::G_TRUNC: { |
3158 | // Match: logic (trunc X), (trunc Y) -> trunc (logic X, Y) |
3159 | const MachineFunction *MF = MI.getMF(); |
3160 | const DataLayout &DL = MF->getDataLayout(); |
3161 | LLVMContext &Ctx = MF->getFunction().getContext(); |
3162 | |
3163 | LLT DstTy = MRI.getType(Reg: Dst); |
3164 | const TargetLowering &TLI = getTargetLowering(); |
3165 | |
3166 | // Be extra careful sinking truncate. If it's free, there's no benefit in |
3167 | // widening a binop. |
3168 | if (TLI.isZExtFree(FromTy: DstTy, ToTy: XTy, DL, Ctx) && |
3169 | TLI.isTruncateFree(FromTy: XTy, ToTy: DstTy, DL, Ctx)) |
3170 | return false; |
3171 | break; |
3172 | } |
3173 | case TargetOpcode::G_AND: |
3174 | case TargetOpcode::G_ASHR: |
3175 | case TargetOpcode::G_LSHR: |
3176 | case TargetOpcode::G_SHL: { |
3177 | // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z |
3178 | MachineOperand &ZOp = LeftHandInst->getOperand(i: 2); |
3179 | if (!matchEqualDefs(MOP1: ZOp, MOP2: RightHandInst->getOperand(i: 2))) |
3180 | return false; |
3181 | ExtraHandOpSrcReg = ZOp.getReg(); |
3182 | break; |
3183 | } |
3184 | } |
3185 | |
3186 | if (!isLegalOrBeforeLegalizer(Query: {LogicOpcode, {XTy, YTy}})) |
3187 | return false; |
3188 | |
3189 | // Record the steps to build the new instructions. |
3190 | // |
3191 | // Steps to build (logic x, y) |
3192 | auto NewLogicDst = MRI.createGenericVirtualRegister(Ty: XTy); |
3193 | OperandBuildSteps LogicBuildSteps = { |
3194 | [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: NewLogicDst); }, |
3195 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: X); }, |
3196 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: Y); }}; |
3197 | InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps); |
3198 | |
3199 | // Steps to build hand (logic x, y), ...z |
3200 | OperandBuildSteps HandBuildSteps = { |
3201 | [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: Dst); }, |
3202 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: NewLogicDst); }}; |
3203 | if (ExtraHandOpSrcReg.isValid()) |
3204 | HandBuildSteps.push_back( |
3205 | Elt: [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: ExtraHandOpSrcReg); }); |
3206 | InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps); |
3207 | |
3208 | MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps}); |
3209 | return true; |
3210 | } |
3211 | |
3212 | void CombinerHelper::applyBuildInstructionSteps( |
3213 | MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) { |
3214 | assert(MatchInfo.InstrsToBuild.size() && |
3215 | "Expected at least one instr to build?" ); |
3216 | for (auto &InstrToBuild : MatchInfo.InstrsToBuild) { |
3217 | assert(InstrToBuild.Opcode && "Expected a valid opcode?" ); |
3218 | assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?" ); |
3219 | MachineInstrBuilder Instr = Builder.buildInstr(Opcode: InstrToBuild.Opcode); |
3220 | for (auto &OperandFn : InstrToBuild.OperandFns) |
3221 | OperandFn(Instr); |
3222 | } |
3223 | MI.eraseFromParent(); |
3224 | } |
3225 | |
3226 | bool CombinerHelper::matchAshrShlToSextInreg( |
3227 | MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) { |
3228 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
3229 | int64_t ShlCst, AshrCst; |
3230 | Register Src; |
3231 | if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI, |
3232 | P: m_GAShr(L: m_GShl(L: m_Reg(R&: Src), R: m_ICstOrSplat(Cst&: ShlCst)), |
3233 | R: m_ICstOrSplat(Cst&: AshrCst)))) |
3234 | return false; |
3235 | if (ShlCst != AshrCst) |
3236 | return false; |
3237 | if (!isLegalOrBeforeLegalizer( |
3238 | Query: {TargetOpcode::G_SEXT_INREG, {MRI.getType(Reg: Src)}})) |
3239 | return false; |
3240 | MatchInfo = std::make_tuple(args&: Src, args&: ShlCst); |
3241 | return true; |
3242 | } |
3243 | |
3244 | void CombinerHelper::applyAshShlToSextInreg( |
3245 | MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) { |
3246 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
3247 | Register Src; |
3248 | int64_t ShiftAmt; |
3249 | std::tie(args&: Src, args&: ShiftAmt) = MatchInfo; |
3250 | unsigned Size = MRI.getType(Reg: Src).getScalarSizeInBits(); |
3251 | Builder.buildSExtInReg(Res: MI.getOperand(i: 0).getReg(), Op: Src, ImmOp: Size - ShiftAmt); |
3252 | MI.eraseFromParent(); |
3253 | } |
3254 | |
3255 | /// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0 |
3256 | bool CombinerHelper::matchOverlappingAnd( |
3257 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
3258 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
3259 | |
3260 | Register Dst = MI.getOperand(i: 0).getReg(); |
3261 | LLT Ty = MRI.getType(Reg: Dst); |
3262 | |
3263 | Register R; |
3264 | int64_t C1; |
3265 | int64_t C2; |
3266 | if (!mi_match( |
3267 | R: Dst, MRI, |
3268 | P: m_GAnd(L: m_GAnd(L: m_Reg(R), R: m_ICst(Cst&: C1)), R: m_ICst(Cst&: C2)))) |
3269 | return false; |
3270 | |
3271 | MatchInfo = [=](MachineIRBuilder &B) { |
3272 | if (C1 & C2) { |
3273 | B.buildAnd(Dst, Src0: R, Src1: B.buildConstant(Res: Ty, Val: C1 & C2)); |
3274 | return; |
3275 | } |
3276 | auto Zero = B.buildConstant(Res: Ty, Val: 0); |
3277 | replaceRegWith(MRI, FromReg: Dst, ToReg: Zero->getOperand(i: 0).getReg()); |
3278 | }; |
3279 | return true; |
3280 | } |
3281 | |
3282 | bool CombinerHelper::matchRedundantAnd(MachineInstr &MI, |
3283 | Register &Replacement) { |
3284 | // Given |
3285 | // |
3286 | // %y:_(sN) = G_SOMETHING |
3287 | // %x:_(sN) = G_SOMETHING |
3288 | // %res:_(sN) = G_AND %x, %y |
3289 | // |
3290 | // Eliminate the G_AND when it is known that x & y == x or x & y == y. |
3291 | // |
3292 | // Patterns like this can appear as a result of legalization. E.g. |
3293 | // |
3294 | // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y |
3295 | // %one:_(s32) = G_CONSTANT i32 1 |
3296 | // %and:_(s32) = G_AND %cmp, %one |
3297 | // |
3298 | // In this case, G_ICMP only produces a single bit, so x & 1 == x. |
3299 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
3300 | if (!KB) |
3301 | return false; |
3302 | |
3303 | Register AndDst = MI.getOperand(i: 0).getReg(); |
3304 | Register LHS = MI.getOperand(i: 1).getReg(); |
3305 | Register RHS = MI.getOperand(i: 2).getReg(); |
3306 | |
3307 | // Check the RHS (maybe a constant) first, and if we have no KnownBits there, |
3308 | // we can't do anything. If we do, then it depends on whether we have |
3309 | // KnownBits on the LHS. |
3310 | KnownBits RHSBits = KB->getKnownBits(R: RHS); |
3311 | if (RHSBits.isUnknown()) |
3312 | return false; |
3313 | |
3314 | KnownBits LHSBits = KB->getKnownBits(R: LHS); |
3315 | |
3316 | // Check that x & Mask == x. |
3317 | // x & 1 == x, always |
3318 | // x & 0 == x, only if x is also 0 |
3319 | // Meaning Mask has no effect if every bit is either one in Mask or zero in x. |
3320 | // |
3321 | // Check if we can replace AndDst with the LHS of the G_AND |
3322 | if (canReplaceReg(DstReg: AndDst, SrcReg: LHS, MRI) && |
3323 | (LHSBits.Zero | RHSBits.One).isAllOnes()) { |
3324 | Replacement = LHS; |
3325 | return true; |
3326 | } |
3327 | |
3328 | // Check if we can replace AndDst with the RHS of the G_AND |
3329 | if (canReplaceReg(DstReg: AndDst, SrcReg: RHS, MRI) && |
3330 | (LHSBits.One | RHSBits.Zero).isAllOnes()) { |
3331 | Replacement = RHS; |
3332 | return true; |
3333 | } |
3334 | |
3335 | return false; |
3336 | } |
3337 | |
3338 | bool CombinerHelper::matchRedundantOr(MachineInstr &MI, Register &Replacement) { |
3339 | // Given |
3340 | // |
3341 | // %y:_(sN) = G_SOMETHING |
3342 | // %x:_(sN) = G_SOMETHING |
3343 | // %res:_(sN) = G_OR %x, %y |
3344 | // |
3345 | // Eliminate the G_OR when it is known that x | y == x or x | y == y. |
3346 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
3347 | if (!KB) |
3348 | return false; |
3349 | |
3350 | Register OrDst = MI.getOperand(i: 0).getReg(); |
3351 | Register LHS = MI.getOperand(i: 1).getReg(); |
3352 | Register RHS = MI.getOperand(i: 2).getReg(); |
3353 | |
3354 | KnownBits LHSBits = KB->getKnownBits(R: LHS); |
3355 | KnownBits RHSBits = KB->getKnownBits(R: RHS); |
3356 | |
3357 | // Check that x | Mask == x. |
3358 | // x | 0 == x, always |
3359 | // x | 1 == x, only if x is also 1 |
3360 | // Meaning Mask has no effect if every bit is either zero in Mask or one in x. |
3361 | // |
3362 | // Check if we can replace OrDst with the LHS of the G_OR |
3363 | if (canReplaceReg(DstReg: OrDst, SrcReg: LHS, MRI) && |
3364 | (LHSBits.One | RHSBits.Zero).isAllOnes()) { |
3365 | Replacement = LHS; |
3366 | return true; |
3367 | } |
3368 | |
3369 | // Check if we can replace OrDst with the RHS of the G_OR |
3370 | if (canReplaceReg(DstReg: OrDst, SrcReg: RHS, MRI) && |
3371 | (LHSBits.Zero | RHSBits.One).isAllOnes()) { |
3372 | Replacement = RHS; |
3373 | return true; |
3374 | } |
3375 | |
3376 | return false; |
3377 | } |
3378 | |
3379 | bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) { |
3380 | // If the input is already sign extended, just drop the extension. |
3381 | Register Src = MI.getOperand(i: 1).getReg(); |
3382 | unsigned ExtBits = MI.getOperand(i: 2).getImm(); |
3383 | unsigned TypeSize = MRI.getType(Reg: Src).getScalarSizeInBits(); |
3384 | return KB->computeNumSignBits(R: Src) >= (TypeSize - ExtBits + 1); |
3385 | } |
3386 | |
3387 | static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits, |
3388 | int64_t Cst, bool IsVector, bool IsFP) { |
3389 | // For i1, Cst will always be -1 regardless of boolean contents. |
3390 | return (ScalarSizeBits == 1 && Cst == -1) || |
3391 | isConstTrueVal(TLI, Val: Cst, IsVector, IsFP); |
3392 | } |
3393 | |
3394 | bool CombinerHelper::matchNotCmp(MachineInstr &MI, |
3395 | SmallVectorImpl<Register> &RegsToNegate) { |
3396 | assert(MI.getOpcode() == TargetOpcode::G_XOR); |
3397 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
3398 | const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering(); |
3399 | Register XorSrc; |
3400 | Register CstReg; |
3401 | // We match xor(src, true) here. |
3402 | if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI, |
3403 | P: m_GXor(L: m_Reg(R&: XorSrc), R: m_Reg(R&: CstReg)))) |
3404 | return false; |
3405 | |
3406 | if (!MRI.hasOneNonDBGUse(RegNo: XorSrc)) |
3407 | return false; |
3408 | |
3409 | // Check that XorSrc is the root of a tree of comparisons combined with ANDs |
3410 | // and ORs. The suffix of RegsToNegate starting from index I is used a work |
3411 | // list of tree nodes to visit. |
3412 | RegsToNegate.push_back(Elt: XorSrc); |
3413 | // Remember whether the comparisons are all integer or all floating point. |
3414 | bool IsInt = false; |
3415 | bool IsFP = false; |
3416 | for (unsigned I = 0; I < RegsToNegate.size(); ++I) { |
3417 | Register Reg = RegsToNegate[I]; |
3418 | if (!MRI.hasOneNonDBGUse(RegNo: Reg)) |
3419 | return false; |
3420 | MachineInstr *Def = MRI.getVRegDef(Reg); |
3421 | switch (Def->getOpcode()) { |
3422 | default: |
3423 | // Don't match if the tree contains anything other than ANDs, ORs and |
3424 | // comparisons. |
3425 | return false; |
3426 | case TargetOpcode::G_ICMP: |
3427 | if (IsFP) |
3428 | return false; |
3429 | IsInt = true; |
3430 | // When we apply the combine we will invert the predicate. |
3431 | break; |
3432 | case TargetOpcode::G_FCMP: |
3433 | if (IsInt) |
3434 | return false; |
3435 | IsFP = true; |
3436 | // When we apply the combine we will invert the predicate. |
3437 | break; |
3438 | case TargetOpcode::G_AND: |
3439 | case TargetOpcode::G_OR: |
3440 | // Implement De Morgan's laws: |
3441 | // ~(x & y) -> ~x | ~y |
3442 | // ~(x | y) -> ~x & ~y |
3443 | // When we apply the combine we will change the opcode and recursively |
3444 | // negate the operands. |
3445 | RegsToNegate.push_back(Elt: Def->getOperand(i: 1).getReg()); |
3446 | RegsToNegate.push_back(Elt: Def->getOperand(i: 2).getReg()); |
3447 | break; |
3448 | } |
3449 | } |
3450 | |
3451 | // Now we know whether the comparisons are integer or floating point, check |
3452 | // the constant in the xor. |
3453 | int64_t Cst; |
3454 | if (Ty.isVector()) { |
3455 | MachineInstr *CstDef = MRI.getVRegDef(Reg: CstReg); |
3456 | auto MaybeCst = getIConstantSplatSExtVal(MI: *CstDef, MRI); |
3457 | if (!MaybeCst) |
3458 | return false; |
3459 | if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getScalarSizeInBits(), Cst: *MaybeCst, IsVector: true, IsFP)) |
3460 | return false; |
3461 | } else { |
3462 | if (!mi_match(R: CstReg, MRI, P: m_ICst(Cst))) |
3463 | return false; |
3464 | if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getSizeInBits(), Cst, IsVector: false, IsFP)) |
3465 | return false; |
3466 | } |
3467 | |
3468 | return true; |
3469 | } |
3470 | |
3471 | void CombinerHelper::applyNotCmp(MachineInstr &MI, |
3472 | SmallVectorImpl<Register> &RegsToNegate) { |
3473 | for (Register Reg : RegsToNegate) { |
3474 | MachineInstr *Def = MRI.getVRegDef(Reg); |
3475 | Observer.changingInstr(MI&: *Def); |
3476 | // For each comparison, invert the opcode. For each AND and OR, change the |
3477 | // opcode. |
3478 | switch (Def->getOpcode()) { |
3479 | default: |
3480 | llvm_unreachable("Unexpected opcode" ); |
3481 | case TargetOpcode::G_ICMP: |
3482 | case TargetOpcode::G_FCMP: { |
3483 | MachineOperand &PredOp = Def->getOperand(i: 1); |
3484 | CmpInst::Predicate NewP = CmpInst::getInversePredicate( |
3485 | pred: (CmpInst::Predicate)PredOp.getPredicate()); |
3486 | PredOp.setPredicate(NewP); |
3487 | break; |
3488 | } |
3489 | case TargetOpcode::G_AND: |
3490 | Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_OR)); |
3491 | break; |
3492 | case TargetOpcode::G_OR: |
3493 | Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND)); |
3494 | break; |
3495 | } |
3496 | Observer.changedInstr(MI&: *Def); |
3497 | } |
3498 | |
3499 | replaceRegWith(MRI, FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 1).getReg()); |
3500 | MI.eraseFromParent(); |
3501 | } |
3502 | |
3503 | bool CombinerHelper::matchXorOfAndWithSameReg( |
3504 | MachineInstr &MI, std::pair<Register, Register> &MatchInfo) { |
3505 | // Match (xor (and x, y), y) (or any of its commuted cases) |
3506 | assert(MI.getOpcode() == TargetOpcode::G_XOR); |
3507 | Register &X = MatchInfo.first; |
3508 | Register &Y = MatchInfo.second; |
3509 | Register AndReg = MI.getOperand(i: 1).getReg(); |
3510 | Register SharedReg = MI.getOperand(i: 2).getReg(); |
3511 | |
3512 | // Find a G_AND on either side of the G_XOR. |
3513 | // Look for one of |
3514 | // |
3515 | // (xor (and x, y), SharedReg) |
3516 | // (xor SharedReg, (and x, y)) |
3517 | if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) { |
3518 | std::swap(a&: AndReg, b&: SharedReg); |
3519 | if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) |
3520 | return false; |
3521 | } |
3522 | |
3523 | // Only do this if we'll eliminate the G_AND. |
3524 | if (!MRI.hasOneNonDBGUse(RegNo: AndReg)) |
3525 | return false; |
3526 | |
3527 | // We can combine if SharedReg is the same as either the LHS or RHS of the |
3528 | // G_AND. |
3529 | if (Y != SharedReg) |
3530 | std::swap(a&: X, b&: Y); |
3531 | return Y == SharedReg; |
3532 | } |
3533 | |
3534 | void CombinerHelper::applyXorOfAndWithSameReg( |
3535 | MachineInstr &MI, std::pair<Register, Register> &MatchInfo) { |
3536 | // Fold (xor (and x, y), y) -> (and (not x), y) |
3537 | Register X, Y; |
3538 | std::tie(args&: X, args&: Y) = MatchInfo; |
3539 | auto Not = Builder.buildNot(Dst: MRI.getType(Reg: X), Src0: X); |
3540 | Observer.changingInstr(MI); |
3541 | MI.setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND)); |
3542 | MI.getOperand(i: 1).setReg(Not->getOperand(i: 0).getReg()); |
3543 | MI.getOperand(i: 2).setReg(Y); |
3544 | Observer.changedInstr(MI); |
3545 | } |
3546 | |
3547 | bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) { |
3548 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
3549 | Register DstReg = PtrAdd.getReg(Idx: 0); |
3550 | LLT Ty = MRI.getType(Reg: DstReg); |
3551 | const DataLayout &DL = Builder.getMF().getDataLayout(); |
3552 | |
3553 | if (DL.isNonIntegralAddressSpace(AddrSpace: Ty.getScalarType().getAddressSpace())) |
3554 | return false; |
3555 | |
3556 | if (Ty.isPointer()) { |
3557 | auto ConstVal = getIConstantVRegVal(VReg: PtrAdd.getBaseReg(), MRI); |
3558 | return ConstVal && *ConstVal == 0; |
3559 | } |
3560 | |
3561 | assert(Ty.isVector() && "Expecting a vector type" ); |
3562 | const MachineInstr *VecMI = MRI.getVRegDef(Reg: PtrAdd.getBaseReg()); |
3563 | return isBuildVectorAllZeros(MI: *VecMI, MRI); |
3564 | } |
3565 | |
3566 | void CombinerHelper::applyPtrAddZero(MachineInstr &MI) { |
3567 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
3568 | Builder.buildIntToPtr(Dst: PtrAdd.getReg(Idx: 0), Src: PtrAdd.getOffsetReg()); |
3569 | PtrAdd.eraseFromParent(); |
3570 | } |
3571 | |
3572 | /// The second source operand is known to be a power of 2. |
3573 | void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) { |
3574 | Register DstReg = MI.getOperand(i: 0).getReg(); |
3575 | Register Src0 = MI.getOperand(i: 1).getReg(); |
3576 | Register Pow2Src1 = MI.getOperand(i: 2).getReg(); |
3577 | LLT Ty = MRI.getType(Reg: DstReg); |
3578 | |
3579 | // Fold (urem x, pow2) -> (and x, pow2-1) |
3580 | auto NegOne = Builder.buildConstant(Res: Ty, Val: -1); |
3581 | auto Add = Builder.buildAdd(Dst: Ty, Src0: Pow2Src1, Src1: NegOne); |
3582 | Builder.buildAnd(Dst: DstReg, Src0, Src1: Add); |
3583 | MI.eraseFromParent(); |
3584 | } |
3585 | |
3586 | bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI, |
3587 | unsigned &SelectOpNo) { |
3588 | Register LHS = MI.getOperand(i: 1).getReg(); |
3589 | Register RHS = MI.getOperand(i: 2).getReg(); |
3590 | |
3591 | Register OtherOperandReg = RHS; |
3592 | SelectOpNo = 1; |
3593 | MachineInstr *Select = MRI.getVRegDef(Reg: LHS); |
3594 | |
3595 | // Don't do this unless the old select is going away. We want to eliminate the |
3596 | // binary operator, not replace a binop with a select. |
3597 | if (Select->getOpcode() != TargetOpcode::G_SELECT || |
3598 | !MRI.hasOneNonDBGUse(RegNo: LHS)) { |
3599 | OtherOperandReg = LHS; |
3600 | SelectOpNo = 2; |
3601 | Select = MRI.getVRegDef(Reg: RHS); |
3602 | if (Select->getOpcode() != TargetOpcode::G_SELECT || |
3603 | !MRI.hasOneNonDBGUse(RegNo: RHS)) |
3604 | return false; |
3605 | } |
3606 | |
3607 | MachineInstr *SelectLHS = MRI.getVRegDef(Reg: Select->getOperand(i: 2).getReg()); |
3608 | MachineInstr *SelectRHS = MRI.getVRegDef(Reg: Select->getOperand(i: 3).getReg()); |
3609 | |
3610 | if (!isConstantOrConstantVector(MI: *SelectLHS, MRI, |
3611 | /*AllowFP*/ true, |
3612 | /*AllowOpaqueConstants*/ false)) |
3613 | return false; |
3614 | if (!isConstantOrConstantVector(MI: *SelectRHS, MRI, |
3615 | /*AllowFP*/ true, |
3616 | /*AllowOpaqueConstants*/ false)) |
3617 | return false; |
3618 | |
3619 | unsigned BinOpcode = MI.getOpcode(); |
3620 | |
3621 | // We know that one of the operands is a select of constants. Now verify that |
3622 | // the other binary operator operand is either a constant, or we can handle a |
3623 | // variable. |
3624 | bool CanFoldNonConst = |
3625 | (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) && |
3626 | (isNullOrNullSplat(MI: *SelectLHS, MRI) || |
3627 | isAllOnesOrAllOnesSplat(MI: *SelectLHS, MRI)) && |
3628 | (isNullOrNullSplat(MI: *SelectRHS, MRI) || |
3629 | isAllOnesOrAllOnesSplat(MI: *SelectRHS, MRI)); |
3630 | if (CanFoldNonConst) |
3631 | return true; |
3632 | |
3633 | return isConstantOrConstantVector(MI: *MRI.getVRegDef(Reg: OtherOperandReg), MRI, |
3634 | /*AllowFP*/ true, |
3635 | /*AllowOpaqueConstants*/ false); |
3636 | } |
3637 | |
3638 | /// \p SelectOperand is the operand in binary operator \p MI that is the select |
3639 | /// to fold. |
3640 | void CombinerHelper::applyFoldBinOpIntoSelect(MachineInstr &MI, |
3641 | const unsigned &SelectOperand) { |
3642 | Register Dst = MI.getOperand(i: 0).getReg(); |
3643 | Register LHS = MI.getOperand(i: 1).getReg(); |
3644 | Register RHS = MI.getOperand(i: 2).getReg(); |
3645 | MachineInstr *Select = MRI.getVRegDef(Reg: MI.getOperand(i: SelectOperand).getReg()); |
3646 | |
3647 | Register SelectCond = Select->getOperand(i: 1).getReg(); |
3648 | Register SelectTrue = Select->getOperand(i: 2).getReg(); |
3649 | Register SelectFalse = Select->getOperand(i: 3).getReg(); |
3650 | |
3651 | LLT Ty = MRI.getType(Reg: Dst); |
3652 | unsigned BinOpcode = MI.getOpcode(); |
3653 | |
3654 | Register FoldTrue, FoldFalse; |
3655 | |
3656 | // We have a select-of-constants followed by a binary operator with a |
3657 | // constant. Eliminate the binop by pulling the constant math into the select. |
3658 | // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO |
3659 | if (SelectOperand == 1) { |
3660 | // TODO: SelectionDAG verifies this actually constant folds before |
3661 | // committing to the combine. |
3662 | |
3663 | FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectTrue, RHS}).getReg(Idx: 0); |
3664 | FoldFalse = |
3665 | Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectFalse, RHS}).getReg(Idx: 0); |
3666 | } else { |
3667 | FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectTrue}).getReg(Idx: 0); |
3668 | FoldFalse = |
3669 | Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectFalse}).getReg(Idx: 0); |
3670 | } |
3671 | |
3672 | Builder.buildSelect(Res: Dst, Tst: SelectCond, Op0: FoldTrue, Op1: FoldFalse, Flags: MI.getFlags()); |
3673 | MI.eraseFromParent(); |
3674 | } |
3675 | |
3676 | std::optional<SmallVector<Register, 8>> |
3677 | CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const { |
3678 | assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!" ); |
3679 | // We want to detect if Root is part of a tree which represents a bunch |
3680 | // of loads being merged into a larger load. We'll try to recognize patterns |
3681 | // like, for example: |
3682 | // |
3683 | // Reg Reg |
3684 | // \ / |
3685 | // OR_1 Reg |
3686 | // \ / |
3687 | // OR_2 |
3688 | // \ Reg |
3689 | // .. / |
3690 | // Root |
3691 | // |
3692 | // Reg Reg Reg Reg |
3693 | // \ / \ / |
3694 | // OR_1 OR_2 |
3695 | // \ / |
3696 | // \ / |
3697 | // ... |
3698 | // Root |
3699 | // |
3700 | // Each "Reg" may have been produced by a load + some arithmetic. This |
3701 | // function will save each of them. |
3702 | SmallVector<Register, 8> RegsToVisit; |
3703 | SmallVector<const MachineInstr *, 7> Ors = {Root}; |
3704 | |
3705 | // In the "worst" case, we're dealing with a load for each byte. So, there |
3706 | // are at most #bytes - 1 ORs. |
3707 | const unsigned MaxIter = |
3708 | MRI.getType(Reg: Root->getOperand(i: 0).getReg()).getSizeInBytes() - 1; |
3709 | for (unsigned Iter = 0; Iter < MaxIter; ++Iter) { |
3710 | if (Ors.empty()) |
3711 | break; |
3712 | const MachineInstr *Curr = Ors.pop_back_val(); |
3713 | Register OrLHS = Curr->getOperand(i: 1).getReg(); |
3714 | Register OrRHS = Curr->getOperand(i: 2).getReg(); |
3715 | |
3716 | // In the combine, we want to elimate the entire tree. |
3717 | if (!MRI.hasOneNonDBGUse(RegNo: OrLHS) || !MRI.hasOneNonDBGUse(RegNo: OrRHS)) |
3718 | return std::nullopt; |
3719 | |
3720 | // If it's a G_OR, save it and continue to walk. If it's not, then it's |
3721 | // something that may be a load + arithmetic. |
3722 | if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrLHS, MRI)) |
3723 | Ors.push_back(Elt: Or); |
3724 | else |
3725 | RegsToVisit.push_back(Elt: OrLHS); |
3726 | if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrRHS, MRI)) |
3727 | Ors.push_back(Elt: Or); |
3728 | else |
3729 | RegsToVisit.push_back(Elt: OrRHS); |
3730 | } |
3731 | |
3732 | // We're going to try and merge each register into a wider power-of-2 type, |
3733 | // so we ought to have an even number of registers. |
3734 | if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0) |
3735 | return std::nullopt; |
3736 | return RegsToVisit; |
3737 | } |
3738 | |
3739 | /// Helper function for findLoadOffsetsForLoadOrCombine. |
3740 | /// |
3741 | /// Check if \p Reg is the result of loading a \p MemSizeInBits wide value, |
3742 | /// and then moving that value into a specific byte offset. |
3743 | /// |
3744 | /// e.g. x[i] << 24 |
3745 | /// |
3746 | /// \returns The load instruction and the byte offset it is moved into. |
3747 | static std::optional<std::pair<GZExtLoad *, int64_t>> |
3748 | matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits, |
3749 | const MachineRegisterInfo &MRI) { |
3750 | assert(MRI.hasOneNonDBGUse(Reg) && |
3751 | "Expected Reg to only have one non-debug use?" ); |
3752 | Register MaybeLoad; |
3753 | int64_t Shift; |
3754 | if (!mi_match(R: Reg, MRI, |
3755 | P: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: MaybeLoad), R: m_ICst(Cst&: Shift))))) { |
3756 | Shift = 0; |
3757 | MaybeLoad = Reg; |
3758 | } |
3759 | |
3760 | if (Shift % MemSizeInBits != 0) |
3761 | return std::nullopt; |
3762 | |
3763 | // TODO: Handle other types of loads. |
3764 | auto *Load = getOpcodeDef<GZExtLoad>(Reg: MaybeLoad, MRI); |
3765 | if (!Load) |
3766 | return std::nullopt; |
3767 | |
3768 | if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits) |
3769 | return std::nullopt; |
3770 | |
3771 | return std::make_pair(x&: Load, y: Shift / MemSizeInBits); |
3772 | } |
3773 | |
3774 | std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>> |
3775 | CombinerHelper::findLoadOffsetsForLoadOrCombine( |
3776 | SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, |
3777 | const SmallVector<Register, 8> &RegsToVisit, const unsigned MemSizeInBits) { |
3778 | |
3779 | // Each load found for the pattern. There should be one for each RegsToVisit. |
3780 | SmallSetVector<const MachineInstr *, 8> Loads; |
3781 | |
3782 | // The lowest index used in any load. (The lowest "i" for each x[i].) |
3783 | int64_t LowestIdx = INT64_MAX; |
3784 | |
3785 | // The load which uses the lowest index. |
3786 | GZExtLoad *LowestIdxLoad = nullptr; |
3787 | |
3788 | // Keeps track of the load indices we see. We shouldn't see any indices twice. |
3789 | SmallSet<int64_t, 8> SeenIdx; |
3790 | |
3791 | // Ensure each load is in the same MBB. |
3792 | // TODO: Support multiple MachineBasicBlocks. |
3793 | MachineBasicBlock *MBB = nullptr; |
3794 | const MachineMemOperand *MMO = nullptr; |
3795 | |
3796 | // Earliest instruction-order load in the pattern. |
3797 | GZExtLoad *EarliestLoad = nullptr; |
3798 | |
3799 | // Latest instruction-order load in the pattern. |
3800 | GZExtLoad *LatestLoad = nullptr; |
3801 | |
3802 | // Base pointer which every load should share. |
3803 | Register BasePtr; |
3804 | |
3805 | // We want to find a load for each register. Each load should have some |
3806 | // appropriate bit twiddling arithmetic. During this loop, we will also keep |
3807 | // track of the load which uses the lowest index. Later, we will check if we |
3808 | // can use its pointer in the final, combined load. |
3809 | for (auto Reg : RegsToVisit) { |
3810 | // Find the load, and find the position that it will end up in (e.g. a |
3811 | // shifted) value. |
3812 | auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI); |
3813 | if (!LoadAndPos) |
3814 | return std::nullopt; |
3815 | GZExtLoad *Load; |
3816 | int64_t DstPos; |
3817 | std::tie(args&: Load, args&: DstPos) = *LoadAndPos; |
3818 | |
3819 | // TODO: Handle multiple MachineBasicBlocks. Currently not handled because |
3820 | // it is difficult to check for stores/calls/etc between loads. |
3821 | MachineBasicBlock *LoadMBB = Load->getParent(); |
3822 | if (!MBB) |
3823 | MBB = LoadMBB; |
3824 | if (LoadMBB != MBB) |
3825 | return std::nullopt; |
3826 | |
3827 | // Make sure that the MachineMemOperands of every seen load are compatible. |
3828 | auto &LoadMMO = Load->getMMO(); |
3829 | if (!MMO) |
3830 | MMO = &LoadMMO; |
3831 | if (MMO->getAddrSpace() != LoadMMO.getAddrSpace()) |
3832 | return std::nullopt; |
3833 | |
3834 | // Find out what the base pointer and index for the load is. |
3835 | Register LoadPtr; |
3836 | int64_t Idx; |
3837 | if (!mi_match(R: Load->getOperand(i: 1).getReg(), MRI, |
3838 | P: m_GPtrAdd(L: m_Reg(R&: LoadPtr), R: m_ICst(Cst&: Idx)))) { |
3839 | LoadPtr = Load->getOperand(i: 1).getReg(); |
3840 | Idx = 0; |
3841 | } |
3842 | |
3843 | // Don't combine things like a[i], a[i] -> a bigger load. |
3844 | if (!SeenIdx.insert(V: Idx).second) |
3845 | return std::nullopt; |
3846 | |
3847 | // Every load must share the same base pointer; don't combine things like: |
3848 | // |
3849 | // a[i], b[i + 1] -> a bigger load. |
3850 | if (!BasePtr.isValid()) |
3851 | BasePtr = LoadPtr; |
3852 | if (BasePtr != LoadPtr) |
3853 | return std::nullopt; |
3854 | |
3855 | if (Idx < LowestIdx) { |
3856 | LowestIdx = Idx; |
3857 | LowestIdxLoad = Load; |
3858 | } |
3859 | |
3860 | // Keep track of the byte offset that this load ends up at. If we have seen |
3861 | // the byte offset, then stop here. We do not want to combine: |
3862 | // |
3863 | // a[i] << 16, a[i + k] << 16 -> a bigger load. |
3864 | if (!MemOffset2Idx.try_emplace(Key: DstPos, Args&: Idx).second) |
3865 | return std::nullopt; |
3866 | Loads.insert(X: Load); |
3867 | |
3868 | // Keep track of the position of the earliest/latest loads in the pattern. |
3869 | // We will check that there are no load fold barriers between them later |
3870 | // on. |
3871 | // |
3872 | // FIXME: Is there a better way to check for load fold barriers? |
3873 | if (!EarliestLoad || dominates(DefMI: *Load, UseMI: *EarliestLoad)) |
3874 | EarliestLoad = Load; |
3875 | if (!LatestLoad || dominates(DefMI: *LatestLoad, UseMI: *Load)) |
3876 | LatestLoad = Load; |
3877 | } |
3878 | |
3879 | // We found a load for each register. Let's check if each load satisfies the |
3880 | // pattern. |
3881 | assert(Loads.size() == RegsToVisit.size() && |
3882 | "Expected to find a load for each register?" ); |
3883 | assert(EarliestLoad != LatestLoad && EarliestLoad && |
3884 | LatestLoad && "Expected at least two loads?" ); |
3885 | |
3886 | // Check if there are any stores, calls, etc. between any of the loads. If |
3887 | // there are, then we can't safely perform the combine. |
3888 | // |
3889 | // MaxIter is chosen based off the (worst case) number of iterations it |
3890 | // typically takes to succeed in the LLVM test suite plus some padding. |
3891 | // |
3892 | // FIXME: Is there a better way to check for load fold barriers? |
3893 | const unsigned MaxIter = 20; |
3894 | unsigned Iter = 0; |
3895 | for (const auto &MI : instructionsWithoutDebug(It: EarliestLoad->getIterator(), |
3896 | End: LatestLoad->getIterator())) { |
3897 | if (Loads.count(key: &MI)) |
3898 | continue; |
3899 | if (MI.isLoadFoldBarrier()) |
3900 | return std::nullopt; |
3901 | if (Iter++ == MaxIter) |
3902 | return std::nullopt; |
3903 | } |
3904 | |
3905 | return std::make_tuple(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad); |
3906 | } |
3907 | |
3908 | bool CombinerHelper::matchLoadOrCombine( |
3909 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
3910 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
3911 | MachineFunction &MF = *MI.getMF(); |
3912 | // Assuming a little-endian target, transform: |
3913 | // s8 *a = ... |
3914 | // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24) |
3915 | // => |
3916 | // s32 val = *((i32)a) |
3917 | // |
3918 | // s8 *a = ... |
3919 | // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3] |
3920 | // => |
3921 | // s32 val = BSWAP(*((s32)a)) |
3922 | Register Dst = MI.getOperand(i: 0).getReg(); |
3923 | LLT Ty = MRI.getType(Reg: Dst); |
3924 | if (Ty.isVector()) |
3925 | return false; |
3926 | |
3927 | // We need to combine at least two loads into this type. Since the smallest |
3928 | // possible load is into a byte, we need at least a 16-bit wide type. |
3929 | const unsigned WideMemSizeInBits = Ty.getSizeInBits(); |
3930 | if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0) |
3931 | return false; |
3932 | |
3933 | // Match a collection of non-OR instructions in the pattern. |
3934 | auto RegsToVisit = findCandidatesForLoadOrCombine(Root: &MI); |
3935 | if (!RegsToVisit) |
3936 | return false; |
3937 | |
3938 | // We have a collection of non-OR instructions. Figure out how wide each of |
3939 | // the small loads should be based off of the number of potential loads we |
3940 | // found. |
3941 | const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size(); |
3942 | if (NarrowMemSizeInBits % 8 != 0) |
3943 | return false; |
3944 | |
3945 | // Check if each register feeding into each OR is a load from the same |
3946 | // base pointer + some arithmetic. |
3947 | // |
3948 | // e.g. a[0], a[1] << 8, a[2] << 16, etc. |
3949 | // |
3950 | // Also verify that each of these ends up putting a[i] into the same memory |
3951 | // offset as a load into a wide type would. |
3952 | SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx; |
3953 | GZExtLoad *LowestIdxLoad, *LatestLoad; |
3954 | int64_t LowestIdx; |
3955 | auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine( |
3956 | MemOffset2Idx, RegsToVisit: *RegsToVisit, MemSizeInBits: NarrowMemSizeInBits); |
3957 | if (!MaybeLoadInfo) |
3958 | return false; |
3959 | std::tie(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad) = *MaybeLoadInfo; |
3960 | |
3961 | // We have a bunch of loads being OR'd together. Using the addresses + offsets |
3962 | // we found before, check if this corresponds to a big or little endian byte |
3963 | // pattern. If it does, then we can represent it using a load + possibly a |
3964 | // BSWAP. |
3965 | bool IsBigEndianTarget = MF.getDataLayout().isBigEndian(); |
3966 | std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx); |
3967 | if (!IsBigEndian) |
3968 | return false; |
3969 | bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian; |
3970 | if (NeedsBSwap && !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_BSWAP, {Ty}})) |
3971 | return false; |
3972 | |
3973 | // Make sure that the load from the lowest index produces offset 0 in the |
3974 | // final value. |
3975 | // |
3976 | // This ensures that we won't combine something like this: |
3977 | // |
3978 | // load x[i] -> byte 2 |
3979 | // load x[i+1] -> byte 0 ---> wide_load x[i] |
3980 | // load x[i+2] -> byte 1 |
3981 | const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits; |
3982 | const unsigned ZeroByteOffset = |
3983 | *IsBigEndian |
3984 | ? bigEndianByteAt(ByteWidth: NumLoadsInTy, I: 0) |
3985 | : littleEndianByteAt(ByteWidth: NumLoadsInTy, I: 0); |
3986 | auto ZeroOffsetIdx = MemOffset2Idx.find(Val: ZeroByteOffset); |
3987 | if (ZeroOffsetIdx == MemOffset2Idx.end() || |
3988 | ZeroOffsetIdx->second != LowestIdx) |
3989 | return false; |
3990 | |
3991 | // We wil reuse the pointer from the load which ends up at byte offset 0. It |
3992 | // may not use index 0. |
3993 | Register Ptr = LowestIdxLoad->getPointerReg(); |
3994 | const MachineMemOperand &MMO = LowestIdxLoad->getMMO(); |
3995 | LegalityQuery::MemDesc MMDesc(MMO); |
3996 | MMDesc.MemoryTy = Ty; |
3997 | if (!isLegalOrBeforeLegalizer( |
3998 | Query: {TargetOpcode::G_LOAD, {Ty, MRI.getType(Reg: Ptr)}, {MMDesc}})) |
3999 | return false; |
4000 | auto PtrInfo = MMO.getPointerInfo(); |
4001 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: WideMemSizeInBits / 8); |
4002 | |
4003 | // Load must be allowed and fast on the target. |
4004 | LLVMContext &C = MF.getFunction().getContext(); |
4005 | auto &DL = MF.getDataLayout(); |
4006 | unsigned Fast = 0; |
4007 | if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty, MMO: *NewMMO, Fast: &Fast) || |
4008 | !Fast) |
4009 | return false; |
4010 | |
4011 | MatchInfo = [=](MachineIRBuilder &MIB) { |
4012 | MIB.setInstrAndDebugLoc(*LatestLoad); |
4013 | Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(VReg: Dst) : Dst; |
4014 | MIB.buildLoad(Res: LoadDst, Addr: Ptr, MMO&: *NewMMO); |
4015 | if (NeedsBSwap) |
4016 | MIB.buildBSwap(Dst, Src0: LoadDst); |
4017 | }; |
4018 | return true; |
4019 | } |
4020 | |
4021 | bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI, |
4022 | MachineInstr *&ExtMI) { |
4023 | auto &PHI = cast<GPhi>(Val&: MI); |
4024 | Register DstReg = PHI.getReg(Idx: 0); |
4025 | |
4026 | // TODO: Extending a vector may be expensive, don't do this until heuristics |
4027 | // are better. |
4028 | if (MRI.getType(Reg: DstReg).isVector()) |
4029 | return false; |
4030 | |
4031 | // Try to match a phi, whose only use is an extend. |
4032 | if (!MRI.hasOneNonDBGUse(RegNo: DstReg)) |
4033 | return false; |
4034 | ExtMI = &*MRI.use_instr_nodbg_begin(RegNo: DstReg); |
4035 | switch (ExtMI->getOpcode()) { |
4036 | case TargetOpcode::G_ANYEXT: |
4037 | return true; // G_ANYEXT is usually free. |
4038 | case TargetOpcode::G_ZEXT: |
4039 | case TargetOpcode::G_SEXT: |
4040 | break; |
4041 | default: |
4042 | return false; |
4043 | } |
4044 | |
4045 | // If the target is likely to fold this extend away, don't propagate. |
4046 | if (Builder.getTII().isExtendLikelyToBeFolded(ExtMI&: *ExtMI, MRI)) |
4047 | return false; |
4048 | |
4049 | // We don't want to propagate the extends unless there's a good chance that |
4050 | // they'll be optimized in some way. |
4051 | // Collect the unique incoming values. |
4052 | SmallPtrSet<MachineInstr *, 4> InSrcs; |
4053 | for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { |
4054 | auto *DefMI = getDefIgnoringCopies(Reg: PHI.getIncomingValue(I), MRI); |
4055 | switch (DefMI->getOpcode()) { |
4056 | case TargetOpcode::G_LOAD: |
4057 | case TargetOpcode::G_TRUNC: |
4058 | case TargetOpcode::G_SEXT: |
4059 | case TargetOpcode::G_ZEXT: |
4060 | case TargetOpcode::G_ANYEXT: |
4061 | case TargetOpcode::G_CONSTANT: |
4062 | InSrcs.insert(Ptr: DefMI); |
4063 | // Don't try to propagate if there are too many places to create new |
4064 | // extends, chances are it'll increase code size. |
4065 | if (InSrcs.size() > 2) |
4066 | return false; |
4067 | break; |
4068 | default: |
4069 | return false; |
4070 | } |
4071 | } |
4072 | return true; |
4073 | } |
4074 | |
4075 | void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI, |
4076 | MachineInstr *&ExtMI) { |
4077 | auto &PHI = cast<GPhi>(Val&: MI); |
4078 | Register DstReg = ExtMI->getOperand(i: 0).getReg(); |
4079 | LLT ExtTy = MRI.getType(Reg: DstReg); |
4080 | |
4081 | // Propagate the extension into the block of each incoming reg's block. |
4082 | // Use a SetVector here because PHIs can have duplicate edges, and we want |
4083 | // deterministic iteration order. |
4084 | SmallSetVector<MachineInstr *, 8> SrcMIs; |
4085 | SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap; |
4086 | for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { |
4087 | auto SrcReg = PHI.getIncomingValue(I); |
4088 | auto *SrcMI = MRI.getVRegDef(Reg: SrcReg); |
4089 | if (!SrcMIs.insert(X: SrcMI)) |
4090 | continue; |
4091 | |
4092 | // Build an extend after each src inst. |
4093 | auto *MBB = SrcMI->getParent(); |
4094 | MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator(); |
4095 | if (InsertPt != MBB->end() && InsertPt->isPHI()) |
4096 | InsertPt = MBB->getFirstNonPHI(); |
4097 | |
4098 | Builder.setInsertPt(MBB&: *SrcMI->getParent(), II: InsertPt); |
4099 | Builder.setDebugLoc(MI.getDebugLoc()); |
4100 | auto NewExt = Builder.buildExtOrTrunc(ExtOpc: ExtMI->getOpcode(), Res: ExtTy, Op: SrcReg); |
4101 | OldToNewSrcMap[SrcMI] = NewExt; |
4102 | } |
4103 | |
4104 | // Create a new phi with the extended inputs. |
4105 | Builder.setInstrAndDebugLoc(MI); |
4106 | auto NewPhi = Builder.buildInstrNoInsert(Opcode: TargetOpcode::G_PHI); |
4107 | NewPhi.addDef(RegNo: DstReg); |
4108 | for (const MachineOperand &MO : llvm::drop_begin(RangeOrContainer: MI.operands())) { |
4109 | if (!MO.isReg()) { |
4110 | NewPhi.addMBB(MBB: MO.getMBB()); |
4111 | continue; |
4112 | } |
4113 | auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(Reg: MO.getReg())]; |
4114 | NewPhi.addUse(RegNo: NewSrc->getOperand(i: 0).getReg()); |
4115 | } |
4116 | Builder.insertInstr(MIB: NewPhi); |
4117 | ExtMI->eraseFromParent(); |
4118 | } |
4119 | |
4120 | bool CombinerHelper::(MachineInstr &MI, |
4121 | Register &Reg) { |
4122 | assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); |
4123 | // If we have a constant index, look for a G_BUILD_VECTOR source |
4124 | // and find the source register that the index maps to. |
4125 | Register SrcVec = MI.getOperand(i: 1).getReg(); |
4126 | LLT SrcTy = MRI.getType(Reg: SrcVec); |
4127 | |
4128 | auto Cst = getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
4129 | if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements()) |
4130 | return false; |
4131 | |
4132 | unsigned VecIdx = Cst->Value.getZExtValue(); |
4133 | |
4134 | // Check if we have a build_vector or build_vector_trunc with an optional |
4135 | // trunc in front. |
4136 | MachineInstr *SrcVecMI = MRI.getVRegDef(Reg: SrcVec); |
4137 | if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) { |
4138 | SrcVecMI = MRI.getVRegDef(Reg: SrcVecMI->getOperand(i: 1).getReg()); |
4139 | } |
4140 | |
4141 | if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR && |
4142 | SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC) |
4143 | return false; |
4144 | |
4145 | EVT Ty(getMVTForLLT(Ty: SrcTy)); |
4146 | if (!MRI.hasOneNonDBGUse(RegNo: SrcVec) && |
4147 | !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty)) |
4148 | return false; |
4149 | |
4150 | Reg = SrcVecMI->getOperand(i: VecIdx + 1).getReg(); |
4151 | return true; |
4152 | } |
4153 | |
4154 | void CombinerHelper::(MachineInstr &MI, |
4155 | Register &Reg) { |
4156 | // Check the type of the register, since it may have come from a |
4157 | // G_BUILD_VECTOR_TRUNC. |
4158 | LLT ScalarTy = MRI.getType(Reg); |
4159 | Register DstReg = MI.getOperand(i: 0).getReg(); |
4160 | LLT DstTy = MRI.getType(Reg: DstReg); |
4161 | |
4162 | if (ScalarTy != DstTy) { |
4163 | assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits()); |
4164 | Builder.buildTrunc(Res: DstReg, Op: Reg); |
4165 | MI.eraseFromParent(); |
4166 | return; |
4167 | } |
4168 | replaceSingleDefInstWithReg(MI, Replacement: Reg); |
4169 | } |
4170 | |
4171 | bool CombinerHelper::( |
4172 | MachineInstr &MI, |
4173 | SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) { |
4174 | assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); |
4175 | // This combine tries to find build_vector's which have every source element |
4176 | // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like |
4177 | // the masked load scalarization is run late in the pipeline. There's already |
4178 | // a combine for a similar pattern starting from the extract, but that |
4179 | // doesn't attempt to do it if there are multiple uses of the build_vector, |
4180 | // which in this case is true. Starting the combine from the build_vector |
4181 | // feels more natural than trying to find sibling nodes of extracts. |
4182 | // E.g. |
4183 | // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4 |
4184 | // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0 |
4185 | // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1 |
4186 | // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2 |
4187 | // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3 |
4188 | // ==> |
4189 | // replace ext{1,2,3,4} with %s{1,2,3,4} |
4190 | |
4191 | Register DstReg = MI.getOperand(i: 0).getReg(); |
4192 | LLT DstTy = MRI.getType(Reg: DstReg); |
4193 | unsigned NumElts = DstTy.getNumElements(); |
4194 | |
4195 | SmallBitVector (NumElts); |
4196 | for (MachineInstr &II : MRI.use_nodbg_instructions(Reg: DstReg)) { |
4197 | if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT) |
4198 | return false; |
4199 | auto Cst = getIConstantVRegVal(VReg: II.getOperand(i: 2).getReg(), MRI); |
4200 | if (!Cst) |
4201 | return false; |
4202 | unsigned Idx = Cst->getZExtValue(); |
4203 | if (Idx >= NumElts) |
4204 | return false; // Out of range. |
4205 | ExtractedElts.set(Idx); |
4206 | SrcDstPairs.emplace_back( |
4207 | Args: std::make_pair(x: MI.getOperand(i: Idx + 1).getReg(), y: &II)); |
4208 | } |
4209 | // Match if every element was extracted. |
4210 | return ExtractedElts.all(); |
4211 | } |
4212 | |
4213 | void CombinerHelper::( |
4214 | MachineInstr &MI, |
4215 | SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) { |
4216 | assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); |
4217 | for (auto &Pair : SrcDstPairs) { |
4218 | auto *ExtMI = Pair.second; |
4219 | replaceRegWith(MRI, FromReg: ExtMI->getOperand(i: 0).getReg(), ToReg: Pair.first); |
4220 | ExtMI->eraseFromParent(); |
4221 | } |
4222 | MI.eraseFromParent(); |
4223 | } |
4224 | |
4225 | void CombinerHelper::applyBuildFn( |
4226 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4227 | applyBuildFnNoErase(MI, MatchInfo); |
4228 | MI.eraseFromParent(); |
4229 | } |
4230 | |
4231 | void CombinerHelper::applyBuildFnNoErase( |
4232 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4233 | MatchInfo(Builder); |
4234 | } |
4235 | |
4236 | bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI, |
4237 | BuildFnTy &MatchInfo) { |
4238 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
4239 | |
4240 | Register Dst = MI.getOperand(i: 0).getReg(); |
4241 | LLT Ty = MRI.getType(Reg: Dst); |
4242 | unsigned BitWidth = Ty.getScalarSizeInBits(); |
4243 | |
4244 | Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt; |
4245 | unsigned FshOpc = 0; |
4246 | |
4247 | // Match (or (shl ...), (lshr ...)). |
4248 | if (!mi_match(R: Dst, MRI, |
4249 | // m_GOr() handles the commuted version as well. |
4250 | P: m_GOr(L: m_GShl(L: m_Reg(R&: ShlSrc), R: m_Reg(R&: ShlAmt)), |
4251 | R: m_GLShr(L: m_Reg(R&: LShrSrc), R: m_Reg(R&: LShrAmt))))) |
4252 | return false; |
4253 | |
4254 | // Given constants C0 and C1 such that C0 + C1 is bit-width: |
4255 | // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1) |
4256 | int64_t CstShlAmt, CstLShrAmt; |
4257 | if (mi_match(R: ShlAmt, MRI, P: m_ICstOrSplat(Cst&: CstShlAmt)) && |
4258 | mi_match(R: LShrAmt, MRI, P: m_ICstOrSplat(Cst&: CstLShrAmt)) && |
4259 | CstShlAmt + CstLShrAmt == BitWidth) { |
4260 | FshOpc = TargetOpcode::G_FSHR; |
4261 | Amt = LShrAmt; |
4262 | |
4263 | } else if (mi_match(R: LShrAmt, MRI, |
4264 | P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) && |
4265 | ShlAmt == Amt) { |
4266 | // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt) |
4267 | FshOpc = TargetOpcode::G_FSHL; |
4268 | |
4269 | } else if (mi_match(R: ShlAmt, MRI, |
4270 | P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) && |
4271 | LShrAmt == Amt) { |
4272 | // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt) |
4273 | FshOpc = TargetOpcode::G_FSHR; |
4274 | |
4275 | } else { |
4276 | return false; |
4277 | } |
4278 | |
4279 | LLT AmtTy = MRI.getType(Reg: Amt); |
4280 | if (!isLegalOrBeforeLegalizer(Query: {FshOpc, {Ty, AmtTy}})) |
4281 | return false; |
4282 | |
4283 | MatchInfo = [=](MachineIRBuilder &B) { |
4284 | B.buildInstr(Opc: FshOpc, DstOps: {Dst}, SrcOps: {ShlSrc, LShrSrc, Amt}); |
4285 | }; |
4286 | return true; |
4287 | } |
4288 | |
4289 | /// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate. |
4290 | bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) { |
4291 | unsigned Opc = MI.getOpcode(); |
4292 | assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); |
4293 | Register X = MI.getOperand(i: 1).getReg(); |
4294 | Register Y = MI.getOperand(i: 2).getReg(); |
4295 | if (X != Y) |
4296 | return false; |
4297 | unsigned RotateOpc = |
4298 | Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR; |
4299 | return isLegalOrBeforeLegalizer(Query: {RotateOpc, {MRI.getType(Reg: X), MRI.getType(Reg: Y)}}); |
4300 | } |
4301 | |
4302 | void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) { |
4303 | unsigned Opc = MI.getOpcode(); |
4304 | assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); |
4305 | bool IsFSHL = Opc == TargetOpcode::G_FSHL; |
4306 | Observer.changingInstr(MI); |
4307 | MI.setDesc(Builder.getTII().get(Opcode: IsFSHL ? TargetOpcode::G_ROTL |
4308 | : TargetOpcode::G_ROTR)); |
4309 | MI.removeOperand(OpNo: 2); |
4310 | Observer.changedInstr(MI); |
4311 | } |
4312 | |
4313 | // Fold (rot x, c) -> (rot x, c % BitSize) |
4314 | bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) { |
4315 | assert(MI.getOpcode() == TargetOpcode::G_ROTL || |
4316 | MI.getOpcode() == TargetOpcode::G_ROTR); |
4317 | unsigned Bitsize = |
4318 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits(); |
4319 | Register AmtReg = MI.getOperand(i: 2).getReg(); |
4320 | bool OutOfRange = false; |
4321 | auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) { |
4322 | if (auto *CI = dyn_cast<ConstantInt>(Val: C)) |
4323 | OutOfRange |= CI->getValue().uge(RHS: Bitsize); |
4324 | return true; |
4325 | }; |
4326 | return matchUnaryPredicate(MRI, Reg: AmtReg, Match: MatchOutOfRange) && OutOfRange; |
4327 | } |
4328 | |
4329 | void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) { |
4330 | assert(MI.getOpcode() == TargetOpcode::G_ROTL || |
4331 | MI.getOpcode() == TargetOpcode::G_ROTR); |
4332 | unsigned Bitsize = |
4333 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits(); |
4334 | Register Amt = MI.getOperand(i: 2).getReg(); |
4335 | LLT AmtTy = MRI.getType(Reg: Amt); |
4336 | auto Bits = Builder.buildConstant(Res: AmtTy, Val: Bitsize); |
4337 | Amt = Builder.buildURem(Dst: AmtTy, Src0: MI.getOperand(i: 2).getReg(), Src1: Bits).getReg(Idx: 0); |
4338 | Observer.changingInstr(MI); |
4339 | MI.getOperand(i: 2).setReg(Amt); |
4340 | Observer.changedInstr(MI); |
4341 | } |
4342 | |
4343 | bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI, |
4344 | int64_t &MatchInfo) { |
4345 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
4346 | auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate()); |
4347 | |
4348 | // We want to avoid calling KnownBits on the LHS if possible, as this combine |
4349 | // has no filter and runs on every G_ICMP instruction. We can avoid calling |
4350 | // KnownBits on the LHS in two cases: |
4351 | // |
4352 | // - The RHS is unknown: Constants are always on RHS. If the RHS is unknown |
4353 | // we cannot do any transforms so we can safely bail out early. |
4354 | // - The RHS is zero: we don't need to know the LHS to do unsigned <0 and |
4355 | // >=0. |
4356 | auto KnownRHS = KB->getKnownBits(R: MI.getOperand(i: 3).getReg()); |
4357 | if (KnownRHS.isUnknown()) |
4358 | return false; |
4359 | |
4360 | std::optional<bool> KnownVal; |
4361 | if (KnownRHS.isZero()) { |
4362 | // ? uge 0 -> always true |
4363 | // ? ult 0 -> always false |
4364 | if (Pred == CmpInst::ICMP_UGE) |
4365 | KnownVal = true; |
4366 | else if (Pred == CmpInst::ICMP_ULT) |
4367 | KnownVal = false; |
4368 | } |
4369 | |
4370 | if (!KnownVal) { |
4371 | auto KnownLHS = KB->getKnownBits(R: MI.getOperand(i: 2).getReg()); |
4372 | switch (Pred) { |
4373 | default: |
4374 | llvm_unreachable("Unexpected G_ICMP predicate?" ); |
4375 | case CmpInst::ICMP_EQ: |
4376 | KnownVal = KnownBits::eq(LHS: KnownLHS, RHS: KnownRHS); |
4377 | break; |
4378 | case CmpInst::ICMP_NE: |
4379 | KnownVal = KnownBits::ne(LHS: KnownLHS, RHS: KnownRHS); |
4380 | break; |
4381 | case CmpInst::ICMP_SGE: |
4382 | KnownVal = KnownBits::sge(LHS: KnownLHS, RHS: KnownRHS); |
4383 | break; |
4384 | case CmpInst::ICMP_SGT: |
4385 | KnownVal = KnownBits::sgt(LHS: KnownLHS, RHS: KnownRHS); |
4386 | break; |
4387 | case CmpInst::ICMP_SLE: |
4388 | KnownVal = KnownBits::sle(LHS: KnownLHS, RHS: KnownRHS); |
4389 | break; |
4390 | case CmpInst::ICMP_SLT: |
4391 | KnownVal = KnownBits::slt(LHS: KnownLHS, RHS: KnownRHS); |
4392 | break; |
4393 | case CmpInst::ICMP_UGE: |
4394 | KnownVal = KnownBits::uge(LHS: KnownLHS, RHS: KnownRHS); |
4395 | break; |
4396 | case CmpInst::ICMP_UGT: |
4397 | KnownVal = KnownBits::ugt(LHS: KnownLHS, RHS: KnownRHS); |
4398 | break; |
4399 | case CmpInst::ICMP_ULE: |
4400 | KnownVal = KnownBits::ule(LHS: KnownLHS, RHS: KnownRHS); |
4401 | break; |
4402 | case CmpInst::ICMP_ULT: |
4403 | KnownVal = KnownBits::ult(LHS: KnownLHS, RHS: KnownRHS); |
4404 | break; |
4405 | } |
4406 | } |
4407 | |
4408 | if (!KnownVal) |
4409 | return false; |
4410 | MatchInfo = |
4411 | *KnownVal |
4412 | ? getICmpTrueVal(TLI: getTargetLowering(), |
4413 | /*IsVector = */ |
4414 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector(), |
4415 | /* IsFP = */ false) |
4416 | : 0; |
4417 | return true; |
4418 | } |
4419 | |
4420 | bool CombinerHelper::matchICmpToLHSKnownBits( |
4421 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4422 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
4423 | // Given: |
4424 | // |
4425 | // %x = G_WHATEVER (... x is known to be 0 or 1 ...) |
4426 | // %cmp = G_ICMP ne %x, 0 |
4427 | // |
4428 | // Or: |
4429 | // |
4430 | // %x = G_WHATEVER (... x is known to be 0 or 1 ...) |
4431 | // %cmp = G_ICMP eq %x, 1 |
4432 | // |
4433 | // We can replace %cmp with %x assuming true is 1 on the target. |
4434 | auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate()); |
4435 | if (!CmpInst::isEquality(pred: Pred)) |
4436 | return false; |
4437 | Register Dst = MI.getOperand(i: 0).getReg(); |
4438 | LLT DstTy = MRI.getType(Reg: Dst); |
4439 | if (getICmpTrueVal(TLI: getTargetLowering(), IsVector: DstTy.isVector(), |
4440 | /* IsFP = */ false) != 1) |
4441 | return false; |
4442 | int64_t OneOrZero = Pred == CmpInst::ICMP_EQ; |
4443 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICst(RequestedValue: OneOrZero))) |
4444 | return false; |
4445 | Register LHS = MI.getOperand(i: 2).getReg(); |
4446 | auto KnownLHS = KB->getKnownBits(R: LHS); |
4447 | if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1) |
4448 | return false; |
4449 | // Make sure replacing Dst with the LHS is a legal operation. |
4450 | LLT LHSTy = MRI.getType(Reg: LHS); |
4451 | unsigned LHSSize = LHSTy.getSizeInBits(); |
4452 | unsigned DstSize = DstTy.getSizeInBits(); |
4453 | unsigned Op = TargetOpcode::COPY; |
4454 | if (DstSize != LHSSize) |
4455 | Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT; |
4456 | if (!isLegalOrBeforeLegalizer(Query: {Op, {DstTy, LHSTy}})) |
4457 | return false; |
4458 | MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Opc: Op, DstOps: {Dst}, SrcOps: {LHS}); }; |
4459 | return true; |
4460 | } |
4461 | |
4462 | // Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0 |
4463 | bool CombinerHelper::matchAndOrDisjointMask( |
4464 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4465 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
4466 | |
4467 | // Ignore vector types to simplify matching the two constants. |
4468 | // TODO: do this for vectors and scalars via a demanded bits analysis. |
4469 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
4470 | if (Ty.isVector()) |
4471 | return false; |
4472 | |
4473 | Register Src; |
4474 | Register AndMaskReg; |
4475 | int64_t AndMaskBits; |
4476 | int64_t OrMaskBits; |
4477 | if (!mi_match(MI, MRI, |
4478 | P: m_GAnd(L: m_GOr(L: m_Reg(R&: Src), R: m_ICst(Cst&: OrMaskBits)), |
4479 | R: m_all_of(preds: m_ICst(Cst&: AndMaskBits), preds: m_Reg(R&: AndMaskReg))))) |
4480 | return false; |
4481 | |
4482 | // Check if OrMask could turn on any bits in Src. |
4483 | if (AndMaskBits & OrMaskBits) |
4484 | return false; |
4485 | |
4486 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4487 | Observer.changingInstr(MI); |
4488 | // Canonicalize the result to have the constant on the RHS. |
4489 | if (MI.getOperand(i: 1).getReg() == AndMaskReg) |
4490 | MI.getOperand(i: 2).setReg(AndMaskReg); |
4491 | MI.getOperand(i: 1).setReg(Src); |
4492 | Observer.changedInstr(MI); |
4493 | }; |
4494 | return true; |
4495 | } |
4496 | |
4497 | /// Form a G_SBFX from a G_SEXT_INREG fed by a right shift. |
4498 | bool CombinerHelper::( |
4499 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4500 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
4501 | Register Dst = MI.getOperand(i: 0).getReg(); |
4502 | Register Src = MI.getOperand(i: 1).getReg(); |
4503 | LLT Ty = MRI.getType(Reg: Src); |
4504 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4505 | if (!LI || !LI->isLegalOrCustom(Query: {TargetOpcode::G_SBFX, {Ty, ExtractTy}})) |
4506 | return false; |
4507 | int64_t Width = MI.getOperand(i: 2).getImm(); |
4508 | Register ShiftSrc; |
4509 | int64_t ShiftImm; |
4510 | if (!mi_match( |
4511 | R: Src, MRI, |
4512 | P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)), |
4513 | preds: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)))))) |
4514 | return false; |
4515 | if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits()) |
4516 | return false; |
4517 | |
4518 | MatchInfo = [=](MachineIRBuilder &B) { |
4519 | auto Cst1 = B.buildConstant(Res: ExtractTy, Val: ShiftImm); |
4520 | auto Cst2 = B.buildConstant(Res: ExtractTy, Val: Width); |
4521 | B.buildSbfx(Dst, Src: ShiftSrc, LSB: Cst1, Width: Cst2); |
4522 | }; |
4523 | return true; |
4524 | } |
4525 | |
4526 | /// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants. |
4527 | bool CombinerHelper::matchBitfieldExtractFromAnd(MachineInstr &MI, |
4528 | BuildFnTy &MatchInfo) { |
4529 | GAnd *And = cast<GAnd>(Val: &MI); |
4530 | Register Dst = And->getReg(Idx: 0); |
4531 | LLT Ty = MRI.getType(Reg: Dst); |
4532 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4533 | // Note that isLegalOrBeforeLegalizer is stricter and does not take custom |
4534 | // into account. |
4535 | if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}})) |
4536 | return false; |
4537 | |
4538 | int64_t AndImm, LSBImm; |
4539 | Register ShiftSrc; |
4540 | const unsigned Size = Ty.getScalarSizeInBits(); |
4541 | if (!mi_match(R: And->getReg(Idx: 0), MRI, |
4542 | P: m_GAnd(L: m_OneNonDBGUse(SP: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: LSBImm))), |
4543 | R: m_ICst(Cst&: AndImm)))) |
4544 | return false; |
4545 | |
4546 | // The mask is a mask of the low bits iff imm & (imm+1) == 0. |
4547 | auto MaybeMask = static_cast<uint64_t>(AndImm); |
4548 | if (MaybeMask & (MaybeMask + 1)) |
4549 | return false; |
4550 | |
4551 | // LSB must fit within the register. |
4552 | if (static_cast<uint64_t>(LSBImm) >= Size) |
4553 | return false; |
4554 | |
4555 | uint64_t Width = APInt(Size, AndImm).countr_one(); |
4556 | MatchInfo = [=](MachineIRBuilder &B) { |
4557 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4558 | auto LSBCst = B.buildConstant(Res: ExtractTy, Val: LSBImm); |
4559 | B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {ShiftSrc, LSBCst, WidthCst}); |
4560 | }; |
4561 | return true; |
4562 | } |
4563 | |
4564 | bool CombinerHelper::( |
4565 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4566 | const unsigned Opcode = MI.getOpcode(); |
4567 | assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR); |
4568 | |
4569 | const Register Dst = MI.getOperand(i: 0).getReg(); |
4570 | |
4571 | const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR |
4572 | ? TargetOpcode::G_SBFX |
4573 | : TargetOpcode::G_UBFX; |
4574 | |
4575 | // Check if the type we would use for the extract is legal |
4576 | LLT Ty = MRI.getType(Reg: Dst); |
4577 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4578 | if (!LI || !LI->isLegalOrCustom(Query: {ExtrOpcode, {Ty, ExtractTy}})) |
4579 | return false; |
4580 | |
4581 | Register ShlSrc; |
4582 | int64_t ShrAmt; |
4583 | int64_t ShlAmt; |
4584 | const unsigned Size = Ty.getScalarSizeInBits(); |
4585 | |
4586 | // Try to match shr (shl x, c1), c2 |
4587 | if (!mi_match(R: Dst, MRI, |
4588 | P: m_BinOp(Opcode, |
4589 | L: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: ShlSrc), R: m_ICst(Cst&: ShlAmt))), |
4590 | R: m_ICst(Cst&: ShrAmt)))) |
4591 | return false; |
4592 | |
4593 | // Make sure that the shift sizes can fit a bitfield extract |
4594 | if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size) |
4595 | return false; |
4596 | |
4597 | // Skip this combine if the G_SEXT_INREG combine could handle it |
4598 | if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt) |
4599 | return false; |
4600 | |
4601 | // Calculate start position and width of the extract |
4602 | const int64_t Pos = ShrAmt - ShlAmt; |
4603 | const int64_t Width = Size - ShrAmt; |
4604 | |
4605 | MatchInfo = [=](MachineIRBuilder &B) { |
4606 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4607 | auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos); |
4608 | B.buildInstr(Opc: ExtrOpcode, DstOps: {Dst}, SrcOps: {ShlSrc, PosCst, WidthCst}); |
4609 | }; |
4610 | return true; |
4611 | } |
4612 | |
4613 | bool CombinerHelper::matchBitfieldExtractFromShrAnd( |
4614 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4615 | const unsigned Opcode = MI.getOpcode(); |
4616 | assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR); |
4617 | |
4618 | const Register Dst = MI.getOperand(i: 0).getReg(); |
4619 | LLT Ty = MRI.getType(Reg: Dst); |
4620 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4621 | if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}})) |
4622 | return false; |
4623 | |
4624 | // Try to match shr (and x, c1), c2 |
4625 | Register AndSrc; |
4626 | int64_t ShrAmt; |
4627 | int64_t SMask; |
4628 | if (!mi_match(R: Dst, MRI, |
4629 | P: m_BinOp(Opcode, |
4630 | L: m_OneNonDBGUse(SP: m_GAnd(L: m_Reg(R&: AndSrc), R: m_ICst(Cst&: SMask))), |
4631 | R: m_ICst(Cst&: ShrAmt)))) |
4632 | return false; |
4633 | |
4634 | const unsigned Size = Ty.getScalarSizeInBits(); |
4635 | if (ShrAmt < 0 || ShrAmt >= Size) |
4636 | return false; |
4637 | |
4638 | // If the shift subsumes the mask, emit the 0 directly. |
4639 | if (0 == (SMask >> ShrAmt)) { |
4640 | MatchInfo = [=](MachineIRBuilder &B) { |
4641 | B.buildConstant(Res: Dst, Val: 0); |
4642 | }; |
4643 | return true; |
4644 | } |
4645 | |
4646 | // Check that ubfx can do the extraction, with no holes in the mask. |
4647 | uint64_t UMask = SMask; |
4648 | UMask |= maskTrailingOnes<uint64_t>(N: ShrAmt); |
4649 | UMask &= maskTrailingOnes<uint64_t>(N: Size); |
4650 | if (!isMask_64(Value: UMask)) |
4651 | return false; |
4652 | |
4653 | // Calculate start position and width of the extract. |
4654 | const int64_t Pos = ShrAmt; |
4655 | const int64_t Width = llvm::countr_one(Value: UMask) - ShrAmt; |
4656 | |
4657 | // It's preferable to keep the shift, rather than form G_SBFX. |
4658 | // TODO: remove the G_AND via demanded bits analysis. |
4659 | if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size) |
4660 | return false; |
4661 | |
4662 | MatchInfo = [=](MachineIRBuilder &B) { |
4663 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4664 | auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos); |
4665 | B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {AndSrc, PosCst, WidthCst}); |
4666 | }; |
4667 | return true; |
4668 | } |
4669 | |
4670 | bool CombinerHelper::reassociationCanBreakAddressingModePattern( |
4671 | MachineInstr &MI) { |
4672 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
4673 | |
4674 | Register Src1Reg = PtrAdd.getBaseReg(); |
4675 | auto *Src1Def = getOpcodeDef<GPtrAdd>(Reg: Src1Reg, MRI); |
4676 | if (!Src1Def) |
4677 | return false; |
4678 | |
4679 | Register Src2Reg = PtrAdd.getOffsetReg(); |
4680 | |
4681 | if (MRI.hasOneNonDBGUse(RegNo: Src1Reg)) |
4682 | return false; |
4683 | |
4684 | auto C1 = getIConstantVRegVal(VReg: Src1Def->getOffsetReg(), MRI); |
4685 | if (!C1) |
4686 | return false; |
4687 | auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI); |
4688 | if (!C2) |
4689 | return false; |
4690 | |
4691 | const APInt &C1APIntVal = *C1; |
4692 | const APInt &C2APIntVal = *C2; |
4693 | const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue(); |
4694 | |
4695 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: PtrAdd.getReg(Idx: 0))) { |
4696 | // This combine may end up running before ptrtoint/inttoptr combines |
4697 | // manage to eliminate redundant conversions, so try to look through them. |
4698 | MachineInstr *ConvUseMI = &UseMI; |
4699 | unsigned ConvUseOpc = ConvUseMI->getOpcode(); |
4700 | while (ConvUseOpc == TargetOpcode::G_INTTOPTR || |
4701 | ConvUseOpc == TargetOpcode::G_PTRTOINT) { |
4702 | Register DefReg = ConvUseMI->getOperand(i: 0).getReg(); |
4703 | if (!MRI.hasOneNonDBGUse(RegNo: DefReg)) |
4704 | break; |
4705 | ConvUseMI = &*MRI.use_instr_nodbg_begin(RegNo: DefReg); |
4706 | ConvUseOpc = ConvUseMI->getOpcode(); |
4707 | } |
4708 | auto *LdStMI = dyn_cast<GLoadStore>(Val: ConvUseMI); |
4709 | if (!LdStMI) |
4710 | continue; |
4711 | // Is x[offset2] already not a legal addressing mode? If so then |
4712 | // reassociating the constants breaks nothing (we test offset2 because |
4713 | // that's the one we hope to fold into the load or store). |
4714 | TargetLoweringBase::AddrMode AM; |
4715 | AM.HasBaseReg = true; |
4716 | AM.BaseOffs = C2APIntVal.getSExtValue(); |
4717 | unsigned AS = MRI.getType(Reg: LdStMI->getPointerReg()).getAddressSpace(); |
4718 | Type *AccessTy = getTypeForLLT(Ty: LdStMI->getMMO().getMemoryType(), |
4719 | C&: PtrAdd.getMF()->getFunction().getContext()); |
4720 | const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering(); |
4721 | if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM, |
4722 | Ty: AccessTy, AddrSpace: AS)) |
4723 | continue; |
4724 | |
4725 | // Would x[offset1+offset2] still be a legal addressing mode? |
4726 | AM.BaseOffs = CombinedValue; |
4727 | if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM, |
4728 | Ty: AccessTy, AddrSpace: AS)) |
4729 | return true; |
4730 | } |
4731 | |
4732 | return false; |
4733 | } |
4734 | |
4735 | bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI, |
4736 | MachineInstr *RHS, |
4737 | BuildFnTy &MatchInfo) { |
4738 | // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) |
4739 | Register Src1Reg = MI.getOperand(i: 1).getReg(); |
4740 | if (RHS->getOpcode() != TargetOpcode::G_ADD) |
4741 | return false; |
4742 | auto C2 = getIConstantVRegVal(VReg: RHS->getOperand(i: 2).getReg(), MRI); |
4743 | if (!C2) |
4744 | return false; |
4745 | |
4746 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4747 | LLT PtrTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
4748 | |
4749 | auto NewBase = |
4750 | Builder.buildPtrAdd(Res: PtrTy, Op0: Src1Reg, Op1: RHS->getOperand(i: 1).getReg()); |
4751 | Observer.changingInstr(MI); |
4752 | MI.getOperand(i: 1).setReg(NewBase.getReg(Idx: 0)); |
4753 | MI.getOperand(i: 2).setReg(RHS->getOperand(i: 2).getReg()); |
4754 | Observer.changedInstr(MI); |
4755 | }; |
4756 | return !reassociationCanBreakAddressingModePattern(MI); |
4757 | } |
4758 | |
4759 | bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI, |
4760 | MachineInstr *LHS, |
4761 | MachineInstr *RHS, |
4762 | BuildFnTy &MatchInfo) { |
4763 | // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C) |
4764 | // if and only if (G_PTR_ADD X, C) has one use. |
4765 | Register LHSBase; |
4766 | std::optional<ValueAndVReg> LHSCstOff; |
4767 | if (!mi_match(R: MI.getBaseReg(), MRI, |
4768 | P: m_OneNonDBGUse(SP: m_GPtrAdd(L: m_Reg(R&: LHSBase), R: m_GCst(ValReg&: LHSCstOff))))) |
4769 | return false; |
4770 | |
4771 | auto *LHSPtrAdd = cast<GPtrAdd>(Val: LHS); |
4772 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4773 | // When we change LHSPtrAdd's offset register we might cause it to use a reg |
4774 | // before its def. Sink the instruction so the outer PTR_ADD to ensure this |
4775 | // doesn't happen. |
4776 | LHSPtrAdd->moveBefore(MovePos: &MI); |
4777 | Register RHSReg = MI.getOffsetReg(); |
4778 | // set VReg will cause type mismatch if it comes from extend/trunc |
4779 | auto NewCst = B.buildConstant(Res: MRI.getType(Reg: RHSReg), Val: LHSCstOff->Value); |
4780 | Observer.changingInstr(MI); |
4781 | MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0)); |
4782 | Observer.changedInstr(MI); |
4783 | Observer.changingInstr(MI&: *LHSPtrAdd); |
4784 | LHSPtrAdd->getOperand(i: 2).setReg(RHSReg); |
4785 | Observer.changedInstr(MI&: *LHSPtrAdd); |
4786 | }; |
4787 | return !reassociationCanBreakAddressingModePattern(MI); |
4788 | } |
4789 | |
4790 | bool CombinerHelper::matchReassocFoldConstantsInSubTree(GPtrAdd &MI, |
4791 | MachineInstr *LHS, |
4792 | MachineInstr *RHS, |
4793 | BuildFnTy &MatchInfo) { |
4794 | // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) |
4795 | auto *LHSPtrAdd = dyn_cast<GPtrAdd>(Val: LHS); |
4796 | if (!LHSPtrAdd) |
4797 | return false; |
4798 | |
4799 | Register Src2Reg = MI.getOperand(i: 2).getReg(); |
4800 | Register LHSSrc1 = LHSPtrAdd->getBaseReg(); |
4801 | Register LHSSrc2 = LHSPtrAdd->getOffsetReg(); |
4802 | auto C1 = getIConstantVRegVal(VReg: LHSSrc2, MRI); |
4803 | if (!C1) |
4804 | return false; |
4805 | auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI); |
4806 | if (!C2) |
4807 | return false; |
4808 | |
4809 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4810 | auto NewCst = B.buildConstant(Res: MRI.getType(Reg: Src2Reg), Val: *C1 + *C2); |
4811 | Observer.changingInstr(MI); |
4812 | MI.getOperand(i: 1).setReg(LHSSrc1); |
4813 | MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0)); |
4814 | Observer.changedInstr(MI); |
4815 | }; |
4816 | return !reassociationCanBreakAddressingModePattern(MI); |
4817 | } |
4818 | |
4819 | bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI, |
4820 | BuildFnTy &MatchInfo) { |
4821 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
4822 | // We're trying to match a few pointer computation patterns here for |
4823 | // re-association opportunities. |
4824 | // 1) Isolating a constant operand to be on the RHS, e.g.: |
4825 | // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) |
4826 | // |
4827 | // 2) Folding two constants in each sub-tree as long as such folding |
4828 | // doesn't break a legal addressing mode. |
4829 | // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) |
4830 | // |
4831 | // 3) Move a constant from the LHS of an inner op to the RHS of the outer. |
4832 | // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C) |
4833 | // iif (G_PTR_ADD X, C) has one use. |
4834 | MachineInstr *LHS = MRI.getVRegDef(Reg: PtrAdd.getBaseReg()); |
4835 | MachineInstr *RHS = MRI.getVRegDef(Reg: PtrAdd.getOffsetReg()); |
4836 | |
4837 | // Try to match example 2. |
4838 | if (matchReassocFoldConstantsInSubTree(MI&: PtrAdd, LHS, RHS, MatchInfo)) |
4839 | return true; |
4840 | |
4841 | // Try to match example 3. |
4842 | if (matchReassocConstantInnerLHS(MI&: PtrAdd, LHS, RHS, MatchInfo)) |
4843 | return true; |
4844 | |
4845 | // Try to match example 1. |
4846 | if (matchReassocConstantInnerRHS(MI&: PtrAdd, RHS, MatchInfo)) |
4847 | return true; |
4848 | |
4849 | return false; |
4850 | } |
4851 | bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg, |
4852 | Register OpLHS, Register OpRHS, |
4853 | BuildFnTy &MatchInfo) { |
4854 | LLT OpRHSTy = MRI.getType(Reg: OpRHS); |
4855 | MachineInstr *OpLHSDef = MRI.getVRegDef(Reg: OpLHS); |
4856 | |
4857 | if (OpLHSDef->getOpcode() != Opc) |
4858 | return false; |
4859 | |
4860 | MachineInstr *OpRHSDef = MRI.getVRegDef(Reg: OpRHS); |
4861 | Register OpLHSLHS = OpLHSDef->getOperand(i: 1).getReg(); |
4862 | Register OpLHSRHS = OpLHSDef->getOperand(i: 2).getReg(); |
4863 | |
4864 | // If the inner op is (X op C), pull the constant out so it can be folded with |
4865 | // other constants in the expression tree. Folding is not guaranteed so we |
4866 | // might have (C1 op C2). In that case do not pull a constant out because it |
4867 | // won't help and can lead to infinite loops. |
4868 | if (isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSRHS), MRI) && |
4869 | !isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSLHS), MRI)) { |
4870 | if (isConstantOrConstantSplatVector(MI&: *OpRHSDef, MRI)) { |
4871 | // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2)) |
4872 | MatchInfo = [=](MachineIRBuilder &B) { |
4873 | auto NewCst = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSRHS, OpRHS}); |
4874 | B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {OpLHSLHS, NewCst}); |
4875 | }; |
4876 | return true; |
4877 | } |
4878 | if (getTargetLowering().isReassocProfitable(MRI, N0: OpLHS, N1: OpRHS)) { |
4879 | // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1) |
4880 | // iff (op x, c1) has one use |
4881 | MatchInfo = [=](MachineIRBuilder &B) { |
4882 | auto NewLHSLHS = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSLHS, OpRHS}); |
4883 | B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {NewLHSLHS, OpLHSRHS}); |
4884 | }; |
4885 | return true; |
4886 | } |
4887 | } |
4888 | |
4889 | return false; |
4890 | } |
4891 | |
4892 | bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI, |
4893 | BuildFnTy &MatchInfo) { |
4894 | // We don't check if the reassociation will break a legal addressing mode |
4895 | // here since pointer arithmetic is handled by G_PTR_ADD. |
4896 | unsigned Opc = MI.getOpcode(); |
4897 | Register DstReg = MI.getOperand(i: 0).getReg(); |
4898 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
4899 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
4900 | |
4901 | if (tryReassocBinOp(Opc, DstReg, OpLHS: LHSReg, OpRHS: RHSReg, MatchInfo)) |
4902 | return true; |
4903 | if (tryReassocBinOp(Opc, DstReg, OpLHS: RHSReg, OpRHS: LHSReg, MatchInfo)) |
4904 | return true; |
4905 | return false; |
4906 | } |
4907 | |
4908 | bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI, APInt &MatchInfo) { |
4909 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
4910 | Register SrcOp = MI.getOperand(i: 1).getReg(); |
4911 | |
4912 | if (auto MaybeCst = ConstantFoldCastOp(Opcode: MI.getOpcode(), DstTy, Op0: SrcOp, MRI)) { |
4913 | MatchInfo = *MaybeCst; |
4914 | return true; |
4915 | } |
4916 | |
4917 | return false; |
4918 | } |
4919 | |
4920 | bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI, APInt &MatchInfo) { |
4921 | Register Op1 = MI.getOperand(i: 1).getReg(); |
4922 | Register Op2 = MI.getOperand(i: 2).getReg(); |
4923 | auto MaybeCst = ConstantFoldBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI); |
4924 | if (!MaybeCst) |
4925 | return false; |
4926 | MatchInfo = *MaybeCst; |
4927 | return true; |
4928 | } |
4929 | |
4930 | bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI, ConstantFP* &MatchInfo) { |
4931 | Register Op1 = MI.getOperand(i: 1).getReg(); |
4932 | Register Op2 = MI.getOperand(i: 2).getReg(); |
4933 | auto MaybeCst = ConstantFoldFPBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI); |
4934 | if (!MaybeCst) |
4935 | return false; |
4936 | MatchInfo = |
4937 | ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: *MaybeCst); |
4938 | return true; |
4939 | } |
4940 | |
4941 | bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI, |
4942 | ConstantFP *&MatchInfo) { |
4943 | assert(MI.getOpcode() == TargetOpcode::G_FMA || |
4944 | MI.getOpcode() == TargetOpcode::G_FMAD); |
4945 | auto [_, Op1, Op2, Op3] = MI.getFirst4Regs(); |
4946 | |
4947 | const ConstantFP *Op3Cst = getConstantFPVRegVal(VReg: Op3, MRI); |
4948 | if (!Op3Cst) |
4949 | return false; |
4950 | |
4951 | const ConstantFP *Op2Cst = getConstantFPVRegVal(VReg: Op2, MRI); |
4952 | if (!Op2Cst) |
4953 | return false; |
4954 | |
4955 | const ConstantFP *Op1Cst = getConstantFPVRegVal(VReg: Op1, MRI); |
4956 | if (!Op1Cst) |
4957 | return false; |
4958 | |
4959 | APFloat Op1F = Op1Cst->getValueAPF(); |
4960 | Op1F.fusedMultiplyAdd(Multiplicand: Op2Cst->getValueAPF(), Addend: Op3Cst->getValueAPF(), |
4961 | RM: APFloat::rmNearestTiesToEven); |
4962 | MatchInfo = ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: Op1F); |
4963 | return true; |
4964 | } |
4965 | |
4966 | bool CombinerHelper::matchNarrowBinopFeedingAnd( |
4967 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4968 | // Look for a binop feeding into an AND with a mask: |
4969 | // |
4970 | // %add = G_ADD %lhs, %rhs |
4971 | // %and = G_AND %add, 000...11111111 |
4972 | // |
4973 | // Check if it's possible to perform the binop at a narrower width and zext |
4974 | // back to the original width like so: |
4975 | // |
4976 | // %narrow_lhs = G_TRUNC %lhs |
4977 | // %narrow_rhs = G_TRUNC %rhs |
4978 | // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs |
4979 | // %new_add = G_ZEXT %narrow_add |
4980 | // %and = G_AND %new_add, 000...11111111 |
4981 | // |
4982 | // This can allow later combines to eliminate the G_AND if it turns out |
4983 | // that the mask is irrelevant. |
4984 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
4985 | Register Dst = MI.getOperand(i: 0).getReg(); |
4986 | Register AndLHS = MI.getOperand(i: 1).getReg(); |
4987 | Register AndRHS = MI.getOperand(i: 2).getReg(); |
4988 | LLT WideTy = MRI.getType(Reg: Dst); |
4989 | |
4990 | // If the potential binop has more than one use, then it's possible that one |
4991 | // of those uses will need its full width. |
4992 | if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(RegNo: AndLHS)) |
4993 | return false; |
4994 | |
4995 | // Check if the LHS feeding the AND is impacted by the high bits that we're |
4996 | // masking out. |
4997 | // |
4998 | // e.g. for 64-bit x, y: |
4999 | // |
5000 | // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535 |
5001 | MachineInstr *LHSInst = getDefIgnoringCopies(Reg: AndLHS, MRI); |
5002 | if (!LHSInst) |
5003 | return false; |
5004 | unsigned LHSOpc = LHSInst->getOpcode(); |
5005 | switch (LHSOpc) { |
5006 | default: |
5007 | return false; |
5008 | case TargetOpcode::G_ADD: |
5009 | case TargetOpcode::G_SUB: |
5010 | case TargetOpcode::G_MUL: |
5011 | case TargetOpcode::G_AND: |
5012 | case TargetOpcode::G_OR: |
5013 | case TargetOpcode::G_XOR: |
5014 | break; |
5015 | } |
5016 | |
5017 | // Find the mask on the RHS. |
5018 | auto Cst = getIConstantVRegValWithLookThrough(VReg: AndRHS, MRI); |
5019 | if (!Cst) |
5020 | return false; |
5021 | auto Mask = Cst->Value; |
5022 | if (!Mask.isMask()) |
5023 | return false; |
5024 | |
5025 | // No point in combining if there's nothing to truncate. |
5026 | unsigned NarrowWidth = Mask.countr_one(); |
5027 | if (NarrowWidth == WideTy.getSizeInBits()) |
5028 | return false; |
5029 | LLT NarrowTy = LLT::scalar(SizeInBits: NarrowWidth); |
5030 | |
5031 | // Check if adding the zext + truncates could be harmful. |
5032 | auto &MF = *MI.getMF(); |
5033 | const auto &TLI = getTargetLowering(); |
5034 | LLVMContext &Ctx = MF.getFunction().getContext(); |
5035 | auto &DL = MF.getDataLayout(); |
5036 | if (!TLI.isTruncateFree(FromTy: WideTy, ToTy: NarrowTy, DL, Ctx) || |
5037 | !TLI.isZExtFree(FromTy: NarrowTy, ToTy: WideTy, DL, Ctx)) |
5038 | return false; |
5039 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) || |
5040 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {WideTy, NarrowTy}})) |
5041 | return false; |
5042 | Register BinOpLHS = LHSInst->getOperand(i: 1).getReg(); |
5043 | Register BinOpRHS = LHSInst->getOperand(i: 2).getReg(); |
5044 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5045 | auto NarrowLHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpLHS); |
5046 | auto NarrowRHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpRHS); |
5047 | auto NarrowBinOp = |
5048 | Builder.buildInstr(Opc: LHSOpc, DstOps: {NarrowTy}, SrcOps: {NarrowLHS, NarrowRHS}); |
5049 | auto Ext = Builder.buildZExt(Res: WideTy, Op: NarrowBinOp); |
5050 | Observer.changingInstr(MI); |
5051 | MI.getOperand(i: 1).setReg(Ext.getReg(Idx: 0)); |
5052 | Observer.changedInstr(MI); |
5053 | }; |
5054 | return true; |
5055 | } |
5056 | |
5057 | bool CombinerHelper::matchMulOBy2(MachineInstr &MI, BuildFnTy &MatchInfo) { |
5058 | unsigned Opc = MI.getOpcode(); |
5059 | assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO); |
5060 | |
5061 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 2))) |
5062 | return false; |
5063 | |
5064 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5065 | Observer.changingInstr(MI); |
5066 | unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO |
5067 | : TargetOpcode::G_SADDO; |
5068 | MI.setDesc(Builder.getTII().get(Opcode: NewOpc)); |
5069 | MI.getOperand(i: 3).setReg(MI.getOperand(i: 2).getReg()); |
5070 | Observer.changedInstr(MI); |
5071 | }; |
5072 | return true; |
5073 | } |
5074 | |
5075 | bool CombinerHelper::matchMulOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) { |
5076 | // (G_*MULO x, 0) -> 0 + no carry out |
5077 | assert(MI.getOpcode() == TargetOpcode::G_UMULO || |
5078 | MI.getOpcode() == TargetOpcode::G_SMULO); |
5079 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0))) |
5080 | return false; |
5081 | Register Dst = MI.getOperand(i: 0).getReg(); |
5082 | Register Carry = MI.getOperand(i: 1).getReg(); |
5083 | if (!isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Dst)) || |
5084 | !isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Carry))) |
5085 | return false; |
5086 | MatchInfo = [=](MachineIRBuilder &B) { |
5087 | B.buildConstant(Res: Dst, Val: 0); |
5088 | B.buildConstant(Res: Carry, Val: 0); |
5089 | }; |
5090 | return true; |
5091 | } |
5092 | |
5093 | bool CombinerHelper::matchAddEToAddO(MachineInstr &MI, BuildFnTy &MatchInfo) { |
5094 | // (G_*ADDE x, y, 0) -> (G_*ADDO x, y) |
5095 | // (G_*SUBE x, y, 0) -> (G_*SUBO x, y) |
5096 | assert(MI.getOpcode() == TargetOpcode::G_UADDE || |
5097 | MI.getOpcode() == TargetOpcode::G_SADDE || |
5098 | MI.getOpcode() == TargetOpcode::G_USUBE || |
5099 | MI.getOpcode() == TargetOpcode::G_SSUBE); |
5100 | if (!mi_match(R: MI.getOperand(i: 4).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0))) |
5101 | return false; |
5102 | MatchInfo = [&](MachineIRBuilder &B) { |
5103 | unsigned NewOpcode; |
5104 | switch (MI.getOpcode()) { |
5105 | case TargetOpcode::G_UADDE: |
5106 | NewOpcode = TargetOpcode::G_UADDO; |
5107 | break; |
5108 | case TargetOpcode::G_SADDE: |
5109 | NewOpcode = TargetOpcode::G_SADDO; |
5110 | break; |
5111 | case TargetOpcode::G_USUBE: |
5112 | NewOpcode = TargetOpcode::G_USUBO; |
5113 | break; |
5114 | case TargetOpcode::G_SSUBE: |
5115 | NewOpcode = TargetOpcode::G_SSUBO; |
5116 | break; |
5117 | } |
5118 | Observer.changingInstr(MI); |
5119 | MI.setDesc(B.getTII().get(Opcode: NewOpcode)); |
5120 | MI.removeOperand(OpNo: 4); |
5121 | Observer.changedInstr(MI); |
5122 | }; |
5123 | return true; |
5124 | } |
5125 | |
5126 | bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI, |
5127 | BuildFnTy &MatchInfo) { |
5128 | assert(MI.getOpcode() == TargetOpcode::G_SUB); |
5129 | Register Dst = MI.getOperand(i: 0).getReg(); |
5130 | // (x + y) - z -> x (if y == z) |
5131 | // (x + y) - z -> y (if x == z) |
5132 | Register X, Y, Z; |
5133 | if (mi_match(R: Dst, MRI, P: m_GSub(L: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: Y)), R: m_Reg(R&: Z)))) { |
5134 | Register ReplaceReg; |
5135 | int64_t CstX, CstY; |
5136 | if (Y == Z || (mi_match(R: Y, MRI, P: m_ICstOrSplat(Cst&: CstY)) && |
5137 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstY)))) |
5138 | ReplaceReg = X; |
5139 | else if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
5140 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
5141 | ReplaceReg = Y; |
5142 | if (ReplaceReg) { |
5143 | MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: ReplaceReg); }; |
5144 | return true; |
5145 | } |
5146 | } |
5147 | |
5148 | // x - (y + z) -> 0 - y (if x == z) |
5149 | // x - (y + z) -> 0 - z (if x == y) |
5150 | if (mi_match(R: Dst, MRI, P: m_GSub(L: m_Reg(R&: X), R: m_GAdd(L: m_Reg(R&: Y), R: m_Reg(R&: Z))))) { |
5151 | Register ReplaceReg; |
5152 | int64_t CstX; |
5153 | if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
5154 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
5155 | ReplaceReg = Y; |
5156 | else if (X == Y || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
5157 | mi_match(R: Y, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
5158 | ReplaceReg = Z; |
5159 | if (ReplaceReg) { |
5160 | MatchInfo = [=](MachineIRBuilder &B) { |
5161 | auto Zero = B.buildConstant(Res: MRI.getType(Reg: Dst), Val: 0); |
5162 | B.buildSub(Dst, Src0: Zero, Src1: ReplaceReg); |
5163 | }; |
5164 | return true; |
5165 | } |
5166 | } |
5167 | return false; |
5168 | } |
5169 | |
5170 | MachineInstr *CombinerHelper::buildUDivUsingMul(MachineInstr &MI) { |
5171 | assert(MI.getOpcode() == TargetOpcode::G_UDIV); |
5172 | auto &UDiv = cast<GenericMachineInstr>(Val&: MI); |
5173 | Register Dst = UDiv.getReg(Idx: 0); |
5174 | Register LHS = UDiv.getReg(Idx: 1); |
5175 | Register RHS = UDiv.getReg(Idx: 2); |
5176 | LLT Ty = MRI.getType(Reg: Dst); |
5177 | LLT ScalarTy = Ty.getScalarType(); |
5178 | const unsigned EltBits = ScalarTy.getScalarSizeInBits(); |
5179 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5180 | LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); |
5181 | |
5182 | auto &MIB = Builder; |
5183 | |
5184 | bool UseSRL = false; |
5185 | SmallVector<Register, 16> Shifts, Factors; |
5186 | auto *RHSDefInstr = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI)); |
5187 | bool IsSplat = getIConstantSplatVal(MI: *RHSDefInstr, MRI).has_value(); |
5188 | |
5189 | auto BuildExactUDIVPattern = [&](const Constant *C) { |
5190 | // Don't recompute inverses for each splat element. |
5191 | if (IsSplat && !Factors.empty()) { |
5192 | Shifts.push_back(Elt: Shifts[0]); |
5193 | Factors.push_back(Elt: Factors[0]); |
5194 | return true; |
5195 | } |
5196 | |
5197 | auto *CI = cast<ConstantInt>(Val: C); |
5198 | APInt Divisor = CI->getValue(); |
5199 | unsigned Shift = Divisor.countr_zero(); |
5200 | if (Shift) { |
5201 | Divisor.lshrInPlace(ShiftAmt: Shift); |
5202 | UseSRL = true; |
5203 | } |
5204 | |
5205 | // Calculate the multiplicative inverse modulo BW. |
5206 | APInt Factor = Divisor.multiplicativeInverse(); |
5207 | Shifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0)); |
5208 | Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0)); |
5209 | return true; |
5210 | }; |
5211 | |
5212 | if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) { |
5213 | // Collect all magic values from the build vector. |
5214 | if (!matchUnaryPredicate(MRI, Reg: RHS, Match: BuildExactUDIVPattern)) |
5215 | llvm_unreachable("Expected unary predicate match to succeed" ); |
5216 | |
5217 | Register Shift, Factor; |
5218 | if (Ty.isVector()) { |
5219 | Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0); |
5220 | Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0); |
5221 | } else { |
5222 | Shift = Shifts[0]; |
5223 | Factor = Factors[0]; |
5224 | } |
5225 | |
5226 | Register Res = LHS; |
5227 | |
5228 | if (UseSRL) |
5229 | Res = MIB.buildLShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0); |
5230 | |
5231 | return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor); |
5232 | } |
5233 | |
5234 | unsigned KnownLeadingZeros = |
5235 | KB ? KB->getKnownBits(R: LHS).countMinLeadingZeros() : 0; |
5236 | |
5237 | bool UseNPQ = false; |
5238 | SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors; |
5239 | auto BuildUDIVPattern = [&](const Constant *C) { |
5240 | auto *CI = cast<ConstantInt>(Val: C); |
5241 | const APInt &Divisor = CI->getValue(); |
5242 | |
5243 | bool SelNPQ = false; |
5244 | APInt Magic(Divisor.getBitWidth(), 0); |
5245 | unsigned PreShift = 0, PostShift = 0; |
5246 | |
5247 | // Magic algorithm doesn't work for division by 1. We need to emit a select |
5248 | // at the end. |
5249 | // TODO: Use undef values for divisor of 1. |
5250 | if (!Divisor.isOne()) { |
5251 | |
5252 | // UnsignedDivisionByConstantInfo doesn't work correctly if leading zeros |
5253 | // in the dividend exceeds the leading zeros for the divisor. |
5254 | UnsignedDivisionByConstantInfo magics = |
5255 | UnsignedDivisionByConstantInfo::get( |
5256 | D: Divisor, LeadingZeros: std::min(a: KnownLeadingZeros, b: Divisor.countl_zero())); |
5257 | |
5258 | Magic = std::move(magics.Magic); |
5259 | |
5260 | assert(magics.PreShift < Divisor.getBitWidth() && |
5261 | "We shouldn't generate an undefined shift!" ); |
5262 | assert(magics.PostShift < Divisor.getBitWidth() && |
5263 | "We shouldn't generate an undefined shift!" ); |
5264 | assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift" ); |
5265 | PreShift = magics.PreShift; |
5266 | PostShift = magics.PostShift; |
5267 | SelNPQ = magics.IsAdd; |
5268 | } |
5269 | |
5270 | PreShifts.push_back( |
5271 | Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PreShift).getReg(Idx: 0)); |
5272 | MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magic).getReg(Idx: 0)); |
5273 | NPQFactors.push_back( |
5274 | Elt: MIB.buildConstant(Res: ScalarTy, |
5275 | Val: SelNPQ ? APInt::getOneBitSet(numBits: EltBits, BitNo: EltBits - 1) |
5276 | : APInt::getZero(numBits: EltBits)) |
5277 | .getReg(Idx: 0)); |
5278 | PostShifts.push_back( |
5279 | Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PostShift).getReg(Idx: 0)); |
5280 | UseNPQ |= SelNPQ; |
5281 | return true; |
5282 | }; |
5283 | |
5284 | // Collect the shifts/magic values from each element. |
5285 | bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildUDIVPattern); |
5286 | (void)Matched; |
5287 | assert(Matched && "Expected unary predicate match to succeed" ); |
5288 | |
5289 | Register PreShift, PostShift, MagicFactor, NPQFactor; |
5290 | auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI); |
5291 | if (RHSDef) { |
5292 | PreShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PreShifts).getReg(Idx: 0); |
5293 | MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0); |
5294 | NPQFactor = MIB.buildBuildVector(Res: Ty, Ops: NPQFactors).getReg(Idx: 0); |
5295 | PostShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PostShifts).getReg(Idx: 0); |
5296 | } else { |
5297 | assert(MRI.getType(RHS).isScalar() && |
5298 | "Non-build_vector operation should have been a scalar" ); |
5299 | PreShift = PreShifts[0]; |
5300 | MagicFactor = MagicFactors[0]; |
5301 | PostShift = PostShifts[0]; |
5302 | } |
5303 | |
5304 | Register Q = LHS; |
5305 | Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PreShift).getReg(Idx: 0); |
5306 | |
5307 | // Multiply the numerator (operand 0) by the magic value. |
5308 | Q = MIB.buildUMulH(Dst: Ty, Src0: Q, Src1: MagicFactor).getReg(Idx: 0); |
5309 | |
5310 | if (UseNPQ) { |
5311 | Register NPQ = MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Q).getReg(Idx: 0); |
5312 | |
5313 | // For vectors we might have a mix of non-NPQ/NPQ paths, so use |
5314 | // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero. |
5315 | if (Ty.isVector()) |
5316 | NPQ = MIB.buildUMulH(Dst: Ty, Src0: NPQ, Src1: NPQFactor).getReg(Idx: 0); |
5317 | else |
5318 | NPQ = MIB.buildLShr(Dst: Ty, Src0: NPQ, Src1: MIB.buildConstant(Res: ShiftAmtTy, Val: 1)).getReg(Idx: 0); |
5319 | |
5320 | Q = MIB.buildAdd(Dst: Ty, Src0: NPQ, Src1: Q).getReg(Idx: 0); |
5321 | } |
5322 | |
5323 | Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PostShift).getReg(Idx: 0); |
5324 | auto One = MIB.buildConstant(Res: Ty, Val: 1); |
5325 | auto IsOne = MIB.buildICmp( |
5326 | Pred: CmpInst::Predicate::ICMP_EQ, |
5327 | Res: Ty.isScalar() ? LLT::scalar(SizeInBits: 1) : Ty.changeElementSize(NewEltSize: 1), Op0: RHS, Op1: One); |
5328 | return MIB.buildSelect(Res: Ty, Tst: IsOne, Op0: LHS, Op1: Q); |
5329 | } |
5330 | |
5331 | bool CombinerHelper::matchUDivByConst(MachineInstr &MI) { |
5332 | assert(MI.getOpcode() == TargetOpcode::G_UDIV); |
5333 | Register Dst = MI.getOperand(i: 0).getReg(); |
5334 | Register RHS = MI.getOperand(i: 2).getReg(); |
5335 | LLT DstTy = MRI.getType(Reg: Dst); |
5336 | |
5337 | auto &MF = *MI.getMF(); |
5338 | AttributeList Attr = MF.getFunction().getAttributes(); |
5339 | const auto &TLI = getTargetLowering(); |
5340 | LLVMContext &Ctx = MF.getFunction().getContext(); |
5341 | auto &DL = MF.getDataLayout(); |
5342 | if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, DL, Ctx), Attr)) |
5343 | return false; |
5344 | |
5345 | // Don't do this for minsize because the instruction sequence is usually |
5346 | // larger. |
5347 | if (MF.getFunction().hasMinSize()) |
5348 | return false; |
5349 | |
5350 | if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) { |
5351 | return matchUnaryPredicate( |
5352 | MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); }); |
5353 | } |
5354 | |
5355 | auto *RHSDef = MRI.getVRegDef(Reg: RHS); |
5356 | if (!isConstantOrConstantVector(MI&: *RHSDef, MRI)) |
5357 | return false; |
5358 | |
5359 | // Don't do this if the types are not going to be legal. |
5360 | if (LI) { |
5361 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}})) |
5362 | return false; |
5363 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMULH, {DstTy}})) |
5364 | return false; |
5365 | if (!isLegalOrBeforeLegalizer( |
5366 | Query: {TargetOpcode::G_ICMP, |
5367 | {DstTy.isVector() ? DstTy.changeElementSize(NewEltSize: 1) : LLT::scalar(SizeInBits: 1), |
5368 | DstTy}})) |
5369 | return false; |
5370 | } |
5371 | |
5372 | return matchUnaryPredicate( |
5373 | MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); }); |
5374 | } |
5375 | |
5376 | void CombinerHelper::applyUDivByConst(MachineInstr &MI) { |
5377 | auto *NewMI = buildUDivUsingMul(MI); |
5378 | replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg()); |
5379 | } |
5380 | |
5381 | bool CombinerHelper::matchSDivByConst(MachineInstr &MI) { |
5382 | assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV" ); |
5383 | Register Dst = MI.getOperand(i: 0).getReg(); |
5384 | Register RHS = MI.getOperand(i: 2).getReg(); |
5385 | LLT DstTy = MRI.getType(Reg: Dst); |
5386 | |
5387 | auto &MF = *MI.getMF(); |
5388 | AttributeList Attr = MF.getFunction().getAttributes(); |
5389 | const auto &TLI = getTargetLowering(); |
5390 | LLVMContext &Ctx = MF.getFunction().getContext(); |
5391 | auto &DL = MF.getDataLayout(); |
5392 | if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, DL, Ctx), Attr)) |
5393 | return false; |
5394 | |
5395 | // Don't do this for minsize because the instruction sequence is usually |
5396 | // larger. |
5397 | if (MF.getFunction().hasMinSize()) |
5398 | return false; |
5399 | |
5400 | // If the sdiv has an 'exact' flag we can use a simpler lowering. |
5401 | if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) { |
5402 | return matchUnaryPredicate( |
5403 | MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); }); |
5404 | } |
5405 | |
5406 | // Don't support the general case for now. |
5407 | return false; |
5408 | } |
5409 | |
5410 | void CombinerHelper::applySDivByConst(MachineInstr &MI) { |
5411 | auto *NewMI = buildSDivUsingMul(MI); |
5412 | replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg()); |
5413 | } |
5414 | |
5415 | MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) { |
5416 | assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV" ); |
5417 | auto &SDiv = cast<GenericMachineInstr>(Val&: MI); |
5418 | Register Dst = SDiv.getReg(Idx: 0); |
5419 | Register LHS = SDiv.getReg(Idx: 1); |
5420 | Register RHS = SDiv.getReg(Idx: 2); |
5421 | LLT Ty = MRI.getType(Reg: Dst); |
5422 | LLT ScalarTy = Ty.getScalarType(); |
5423 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5424 | LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); |
5425 | auto &MIB = Builder; |
5426 | |
5427 | bool UseSRA = false; |
5428 | SmallVector<Register, 16> Shifts, Factors; |
5429 | |
5430 | auto *RHSDef = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI)); |
5431 | bool IsSplat = getIConstantSplatVal(MI: *RHSDef, MRI).has_value(); |
5432 | |
5433 | auto BuildSDIVPattern = [&](const Constant *C) { |
5434 | // Don't recompute inverses for each splat element. |
5435 | if (IsSplat && !Factors.empty()) { |
5436 | Shifts.push_back(Elt: Shifts[0]); |
5437 | Factors.push_back(Elt: Factors[0]); |
5438 | return true; |
5439 | } |
5440 | |
5441 | auto *CI = cast<ConstantInt>(Val: C); |
5442 | APInt Divisor = CI->getValue(); |
5443 | unsigned Shift = Divisor.countr_zero(); |
5444 | if (Shift) { |
5445 | Divisor.ashrInPlace(ShiftAmt: Shift); |
5446 | UseSRA = true; |
5447 | } |
5448 | |
5449 | // Calculate the multiplicative inverse modulo BW. |
5450 | // 2^W requires W + 1 bits, so we have to extend and then truncate. |
5451 | APInt Factor = Divisor.multiplicativeInverse(); |
5452 | Shifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0)); |
5453 | Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0)); |
5454 | return true; |
5455 | }; |
5456 | |
5457 | // Collect all magic values from the build vector. |
5458 | bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildSDIVPattern); |
5459 | (void)Matched; |
5460 | assert(Matched && "Expected unary predicate match to succeed" ); |
5461 | |
5462 | Register Shift, Factor; |
5463 | if (Ty.isVector()) { |
5464 | Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0); |
5465 | Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0); |
5466 | } else { |
5467 | Shift = Shifts[0]; |
5468 | Factor = Factors[0]; |
5469 | } |
5470 | |
5471 | Register Res = LHS; |
5472 | |
5473 | if (UseSRA) |
5474 | Res = MIB.buildAShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0); |
5475 | |
5476 | return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor); |
5477 | } |
5478 | |
5479 | bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) { |
5480 | assert((MI.getOpcode() == TargetOpcode::G_SDIV || |
5481 | MI.getOpcode() == TargetOpcode::G_UDIV) && |
5482 | "Expected SDIV or UDIV" ); |
5483 | auto &Div = cast<GenericMachineInstr>(Val&: MI); |
5484 | Register RHS = Div.getReg(Idx: 2); |
5485 | auto MatchPow2 = [&](const Constant *C) { |
5486 | auto *CI = dyn_cast<ConstantInt>(Val: C); |
5487 | return CI && (CI->getValue().isPowerOf2() || |
5488 | (IsSigned && CI->getValue().isNegatedPowerOf2())); |
5489 | }; |
5490 | return matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2, /*AllowUndefs=*/false); |
5491 | } |
5492 | |
5493 | void CombinerHelper::applySDivByPow2(MachineInstr &MI) { |
5494 | assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV" ); |
5495 | auto &SDiv = cast<GenericMachineInstr>(Val&: MI); |
5496 | Register Dst = SDiv.getReg(Idx: 0); |
5497 | Register LHS = SDiv.getReg(Idx: 1); |
5498 | Register RHS = SDiv.getReg(Idx: 2); |
5499 | LLT Ty = MRI.getType(Reg: Dst); |
5500 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5501 | LLT CCVT = |
5502 | Ty.isVector() ? LLT::vector(EC: Ty.getElementCount(), ScalarSizeInBits: 1) : LLT::scalar(SizeInBits: 1); |
5503 | |
5504 | // Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2, |
5505 | // to the following version: |
5506 | // |
5507 | // %c1 = G_CTTZ %rhs |
5508 | // %inexact = G_SUB $bitwidth, %c1 |
5509 | // %sign = %G_ASHR %lhs, $(bitwidth - 1) |
5510 | // %lshr = G_LSHR %sign, %inexact |
5511 | // %add = G_ADD %lhs, %lshr |
5512 | // %ashr = G_ASHR %add, %c1 |
5513 | // %ashr = G_SELECT, %isoneorallones, %lhs, %ashr |
5514 | // %zero = G_CONSTANT $0 |
5515 | // %neg = G_NEG %ashr |
5516 | // %isneg = G_ICMP SLT %rhs, %zero |
5517 | // %res = G_SELECT %isneg, %neg, %ashr |
5518 | |
5519 | unsigned BitWidth = Ty.getScalarSizeInBits(); |
5520 | auto Zero = Builder.buildConstant(Res: Ty, Val: 0); |
5521 | |
5522 | auto Bits = Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth); |
5523 | auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS); |
5524 | auto Inexact = Builder.buildSub(Dst: ShiftAmtTy, Src0: Bits, Src1: C1); |
5525 | // Splat the sign bit into the register |
5526 | auto Sign = Builder.buildAShr( |
5527 | Dst: Ty, Src0: LHS, Src1: Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth - 1)); |
5528 | |
5529 | // Add (LHS < 0) ? abs2 - 1 : 0; |
5530 | auto LSrl = Builder.buildLShr(Dst: Ty, Src0: Sign, Src1: Inexact); |
5531 | auto Add = Builder.buildAdd(Dst: Ty, Src0: LHS, Src1: LSrl); |
5532 | auto AShr = Builder.buildAShr(Dst: Ty, Src0: Add, Src1: C1); |
5533 | |
5534 | // Special case: (sdiv X, 1) -> X |
5535 | // Special Case: (sdiv X, -1) -> 0-X |
5536 | auto One = Builder.buildConstant(Res: Ty, Val: 1); |
5537 | auto MinusOne = Builder.buildConstant(Res: Ty, Val: -1); |
5538 | auto IsOne = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: One); |
5539 | auto IsMinusOne = |
5540 | Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: MinusOne); |
5541 | auto IsOneOrMinusOne = Builder.buildOr(Dst: CCVT, Src0: IsOne, Src1: IsMinusOne); |
5542 | AShr = Builder.buildSelect(Res: Ty, Tst: IsOneOrMinusOne, Op0: LHS, Op1: AShr); |
5543 | |
5544 | // If divided by a positive value, we're done. Otherwise, the result must be |
5545 | // negated. |
5546 | auto Neg = Builder.buildNeg(Dst: Ty, Src0: AShr); |
5547 | auto IsNeg = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_SLT, Res: CCVT, Op0: RHS, Op1: Zero); |
5548 | Builder.buildSelect(Res: MI.getOperand(i: 0).getReg(), Tst: IsNeg, Op0: Neg, Op1: AShr); |
5549 | MI.eraseFromParent(); |
5550 | } |
5551 | |
5552 | void CombinerHelper::applyUDivByPow2(MachineInstr &MI) { |
5553 | assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV" ); |
5554 | auto &UDiv = cast<GenericMachineInstr>(Val&: MI); |
5555 | Register Dst = UDiv.getReg(Idx: 0); |
5556 | Register LHS = UDiv.getReg(Idx: 1); |
5557 | Register RHS = UDiv.getReg(Idx: 2); |
5558 | LLT Ty = MRI.getType(Reg: Dst); |
5559 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5560 | |
5561 | auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS); |
5562 | Builder.buildLShr(Dst: MI.getOperand(i: 0).getReg(), Src0: LHS, Src1: C1); |
5563 | MI.eraseFromParent(); |
5564 | } |
5565 | |
5566 | bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) { |
5567 | assert(MI.getOpcode() == TargetOpcode::G_UMULH); |
5568 | Register RHS = MI.getOperand(i: 2).getReg(); |
5569 | Register Dst = MI.getOperand(i: 0).getReg(); |
5570 | LLT Ty = MRI.getType(Reg: Dst); |
5571 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5572 | auto MatchPow2ExceptOne = [&](const Constant *C) { |
5573 | if (auto *CI = dyn_cast<ConstantInt>(Val: C)) |
5574 | return CI->getValue().isPowerOf2() && !CI->getValue().isOne(); |
5575 | return false; |
5576 | }; |
5577 | if (!matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2ExceptOne, AllowUndefs: false)) |
5578 | return false; |
5579 | return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}}); |
5580 | } |
5581 | |
5582 | void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) { |
5583 | Register LHS = MI.getOperand(i: 1).getReg(); |
5584 | Register RHS = MI.getOperand(i: 2).getReg(); |
5585 | Register Dst = MI.getOperand(i: 0).getReg(); |
5586 | LLT Ty = MRI.getType(Reg: Dst); |
5587 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5588 | unsigned NumEltBits = Ty.getScalarSizeInBits(); |
5589 | |
5590 | auto LogBase2 = buildLogBase2(V: RHS, MIB&: Builder); |
5591 | auto ShiftAmt = |
5592 | Builder.buildSub(Dst: Ty, Src0: Builder.buildConstant(Res: Ty, Val: NumEltBits), Src1: LogBase2); |
5593 | auto Trunc = Builder.buildZExtOrTrunc(Res: ShiftAmtTy, Op: ShiftAmt); |
5594 | Builder.buildLShr(Dst, Src0: LHS, Src1: Trunc); |
5595 | MI.eraseFromParent(); |
5596 | } |
5597 | |
5598 | bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI, |
5599 | BuildFnTy &MatchInfo) { |
5600 | unsigned Opc = MI.getOpcode(); |
5601 | assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB || |
5602 | Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || |
5603 | Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA); |
5604 | |
5605 | Register Dst = MI.getOperand(i: 0).getReg(); |
5606 | Register X = MI.getOperand(i: 1).getReg(); |
5607 | Register Y = MI.getOperand(i: 2).getReg(); |
5608 | LLT Type = MRI.getType(Reg: Dst); |
5609 | |
5610 | // fold (fadd x, fneg(y)) -> (fsub x, y) |
5611 | // fold (fadd fneg(y), x) -> (fsub x, y) |
5612 | // G_ADD is commutative so both cases are checked by m_GFAdd |
5613 | if (mi_match(R: Dst, MRI, P: m_GFAdd(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) && |
5614 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FSUB, {Type}})) { |
5615 | Opc = TargetOpcode::G_FSUB; |
5616 | } |
5617 | /// fold (fsub x, fneg(y)) -> (fadd x, y) |
5618 | else if (mi_match(R: Dst, MRI, P: m_GFSub(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) && |
5619 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FADD, {Type}})) { |
5620 | Opc = TargetOpcode::G_FADD; |
5621 | } |
5622 | // fold (fmul fneg(x), fneg(y)) -> (fmul x, y) |
5623 | // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y) |
5624 | // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z) |
5625 | // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z) |
5626 | else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || |
5627 | Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) && |
5628 | mi_match(R: X, MRI, P: m_GFNeg(Src: m_Reg(R&: X))) && |
5629 | mi_match(R: Y, MRI, P: m_GFNeg(Src: m_Reg(R&: Y)))) { |
5630 | // no opcode change |
5631 | } else |
5632 | return false; |
5633 | |
5634 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5635 | Observer.changingInstr(MI); |
5636 | MI.setDesc(B.getTII().get(Opcode: Opc)); |
5637 | MI.getOperand(i: 1).setReg(X); |
5638 | MI.getOperand(i: 2).setReg(Y); |
5639 | Observer.changedInstr(MI); |
5640 | }; |
5641 | return true; |
5642 | } |
5643 | |
5644 | bool CombinerHelper::matchFsubToFneg(MachineInstr &MI, Register &MatchInfo) { |
5645 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
5646 | |
5647 | Register LHS = MI.getOperand(i: 1).getReg(); |
5648 | MatchInfo = MI.getOperand(i: 2).getReg(); |
5649 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5650 | |
5651 | const auto LHSCst = Ty.isVector() |
5652 | ? getFConstantSplat(VReg: LHS, MRI, /* allowUndef */ AllowUndef: true) |
5653 | : getFConstantVRegValWithLookThrough(VReg: LHS, MRI); |
5654 | if (!LHSCst) |
5655 | return false; |
5656 | |
5657 | // -0.0 is always allowed |
5658 | if (LHSCst->Value.isNegZero()) |
5659 | return true; |
5660 | |
5661 | // +0.0 is only allowed if nsz is set. |
5662 | if (LHSCst->Value.isPosZero()) |
5663 | return MI.getFlag(Flag: MachineInstr::FmNsz); |
5664 | |
5665 | return false; |
5666 | } |
5667 | |
5668 | void CombinerHelper::applyFsubToFneg(MachineInstr &MI, Register &MatchInfo) { |
5669 | Register Dst = MI.getOperand(i: 0).getReg(); |
5670 | Builder.buildFNeg( |
5671 | Dst, Src0: Builder.buildFCanonicalize(Dst: MRI.getType(Reg: Dst), Src0: MatchInfo).getReg(Idx: 0)); |
5672 | eraseInst(MI); |
5673 | } |
5674 | |
5675 | /// Checks if \p MI is TargetOpcode::G_FMUL and contractable either |
5676 | /// due to global flags or MachineInstr flags. |
5677 | static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) { |
5678 | if (MI.getOpcode() != TargetOpcode::G_FMUL) |
5679 | return false; |
5680 | return AllowFusionGlobally || MI.getFlag(Flag: MachineInstr::MIFlag::FmContract); |
5681 | } |
5682 | |
5683 | static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1, |
5684 | const MachineRegisterInfo &MRI) { |
5685 | return std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI0.getOperand(i: 0).getReg()), |
5686 | last: MRI.use_instr_nodbg_end()) > |
5687 | std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI1.getOperand(i: 0).getReg()), |
5688 | last: MRI.use_instr_nodbg_end()); |
5689 | } |
5690 | |
5691 | bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI, |
5692 | bool &AllowFusionGlobally, |
5693 | bool &HasFMAD, bool &Aggressive, |
5694 | bool CanReassociate) { |
5695 | |
5696 | auto *MF = MI.getMF(); |
5697 | const auto &TLI = *MF->getSubtarget().getTargetLowering(); |
5698 | const TargetOptions &Options = MF->getTarget().Options; |
5699 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5700 | |
5701 | if (CanReassociate && |
5702 | !(Options.UnsafeFPMath || MI.getFlag(Flag: MachineInstr::MIFlag::FmReassoc))) |
5703 | return false; |
5704 | |
5705 | // Floating-point multiply-add with intermediate rounding. |
5706 | HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, Ty: DstType)); |
5707 | // Floating-point multiply-add without intermediate rounding. |
5708 | bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(MF: *MF, DstType) && |
5709 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FMA, {DstType}}); |
5710 | // No valid opcode, do not combine. |
5711 | if (!HasFMAD && !HasFMA) |
5712 | return false; |
5713 | |
5714 | AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast || |
5715 | Options.UnsafeFPMath || HasFMAD; |
5716 | // If the addition is not contractable, do not combine. |
5717 | if (!AllowFusionGlobally && !MI.getFlag(Flag: MachineInstr::MIFlag::FmContract)) |
5718 | return false; |
5719 | |
5720 | Aggressive = TLI.enableAggressiveFMAFusion(Ty: DstType); |
5721 | return true; |
5722 | } |
5723 | |
5724 | bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA( |
5725 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5726 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5727 | |
5728 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5729 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5730 | return false; |
5731 | |
5732 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5733 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5734 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5735 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5736 | unsigned PreferredFusedOpcode = |
5737 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5738 | |
5739 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5740 | // prefer to fold the multiply with fewer uses. |
5741 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5742 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5743 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5744 | std::swap(a&: LHS, b&: RHS); |
5745 | } |
5746 | |
5747 | // fold (fadd (fmul x, y), z) -> (fma x, y, z) |
5748 | if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5749 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg))) { |
5750 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5751 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5752 | SrcOps: {LHS.MI->getOperand(i: 1).getReg(), |
5753 | LHS.MI->getOperand(i: 2).getReg(), RHS.Reg}); |
5754 | }; |
5755 | return true; |
5756 | } |
5757 | |
5758 | // fold (fadd x, (fmul y, z)) -> (fma y, z, x) |
5759 | if (isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
5760 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg))) { |
5761 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5762 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5763 | SrcOps: {RHS.MI->getOperand(i: 1).getReg(), |
5764 | RHS.MI->getOperand(i: 2).getReg(), LHS.Reg}); |
5765 | }; |
5766 | return true; |
5767 | } |
5768 | |
5769 | return false; |
5770 | } |
5771 | |
5772 | bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA( |
5773 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5774 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5775 | |
5776 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5777 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5778 | return false; |
5779 | |
5780 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
5781 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5782 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5783 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5784 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5785 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5786 | |
5787 | unsigned PreferredFusedOpcode = |
5788 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5789 | |
5790 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5791 | // prefer to fold the multiply with fewer uses. |
5792 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5793 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5794 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5795 | std::swap(a&: LHS, b&: RHS); |
5796 | } |
5797 | |
5798 | // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) |
5799 | MachineInstr *FpExtSrc; |
5800 | if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) && |
5801 | isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) && |
5802 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5803 | SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) { |
5804 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5805 | auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg()); |
5806 | auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg()); |
5807 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5808 | SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), RHS.Reg}); |
5809 | }; |
5810 | return true; |
5811 | } |
5812 | |
5813 | // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z) |
5814 | // Note: Commutes FADD operands. |
5815 | if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) && |
5816 | isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) && |
5817 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5818 | SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) { |
5819 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5820 | auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg()); |
5821 | auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg()); |
5822 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5823 | SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), LHS.Reg}); |
5824 | }; |
5825 | return true; |
5826 | } |
5827 | |
5828 | return false; |
5829 | } |
5830 | |
5831 | bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA( |
5832 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5833 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5834 | |
5835 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5836 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, CanReassociate: true)) |
5837 | return false; |
5838 | |
5839 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5840 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5841 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5842 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5843 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5844 | |
5845 | unsigned PreferredFusedOpcode = |
5846 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5847 | |
5848 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5849 | // prefer to fold the multiply with fewer uses. |
5850 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5851 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5852 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5853 | std::swap(a&: LHS, b&: RHS); |
5854 | } |
5855 | |
5856 | MachineInstr *FMA = nullptr; |
5857 | Register Z; |
5858 | // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z)) |
5859 | if (LHS.MI->getOpcode() == PreferredFusedOpcode && |
5860 | (MRI.getVRegDef(Reg: LHS.MI->getOperand(i: 3).getReg())->getOpcode() == |
5861 | TargetOpcode::G_FMUL) && |
5862 | MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 0).getReg()) && |
5863 | MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 3).getReg())) { |
5864 | FMA = LHS.MI; |
5865 | Z = RHS.Reg; |
5866 | } |
5867 | // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z)) |
5868 | else if (RHS.MI->getOpcode() == PreferredFusedOpcode && |
5869 | (MRI.getVRegDef(Reg: RHS.MI->getOperand(i: 3).getReg())->getOpcode() == |
5870 | TargetOpcode::G_FMUL) && |
5871 | MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 0).getReg()) && |
5872 | MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 3).getReg())) { |
5873 | Z = LHS.Reg; |
5874 | FMA = RHS.MI; |
5875 | } |
5876 | |
5877 | if (FMA) { |
5878 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMA->getOperand(i: 3).getReg()); |
5879 | Register X = FMA->getOperand(i: 1).getReg(); |
5880 | Register Y = FMA->getOperand(i: 2).getReg(); |
5881 | Register U = FMulMI->getOperand(i: 1).getReg(); |
5882 | Register V = FMulMI->getOperand(i: 2).getReg(); |
5883 | |
5884 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5885 | Register InnerFMA = MRI.createGenericVirtualRegister(Ty: DstTy); |
5886 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {InnerFMA}, SrcOps: {U, V, Z}); |
5887 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5888 | SrcOps: {X, Y, InnerFMA}); |
5889 | }; |
5890 | return true; |
5891 | } |
5892 | |
5893 | return false; |
5894 | } |
5895 | |
5896 | bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive( |
5897 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5898 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5899 | |
5900 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5901 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5902 | return false; |
5903 | |
5904 | if (!Aggressive) |
5905 | return false; |
5906 | |
5907 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
5908 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5909 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5910 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5911 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5912 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5913 | |
5914 | unsigned PreferredFusedOpcode = |
5915 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5916 | |
5917 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5918 | // prefer to fold the multiply with fewer uses. |
5919 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5920 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5921 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5922 | std::swap(a&: LHS, b&: RHS); |
5923 | } |
5924 | |
5925 | // Builds: (fma x, y, (fma (fpext u), (fpext v), z)) |
5926 | auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X, |
5927 | Register Y, MachineIRBuilder &B) { |
5928 | Register FpExtU = B.buildFPExt(Res: DstType, Op: U).getReg(Idx: 0); |
5929 | Register FpExtV = B.buildFPExt(Res: DstType, Op: V).getReg(Idx: 0); |
5930 | Register InnerFMA = |
5931 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {DstType}, SrcOps: {FpExtU, FpExtV, Z}) |
5932 | .getReg(Idx: 0); |
5933 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5934 | SrcOps: {X, Y, InnerFMA}); |
5935 | }; |
5936 | |
5937 | MachineInstr *FMulMI, *FMAMI; |
5938 | // fold (fadd (fma x, y, (fpext (fmul u, v))), z) |
5939 | // -> (fma x, y, (fma (fpext u), (fpext v), z)) |
5940 | if (LHS.MI->getOpcode() == PreferredFusedOpcode && |
5941 | mi_match(R: LHS.MI->getOperand(i: 3).getReg(), MRI, |
5942 | P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
5943 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5944 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5945 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
5946 | MatchInfo = [=](MachineIRBuilder &B) { |
5947 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
5948 | FMulMI->getOperand(i: 2).getReg(), RHS.Reg, |
5949 | LHS.MI->getOperand(i: 1).getReg(), |
5950 | LHS.MI->getOperand(i: 2).getReg(), B); |
5951 | }; |
5952 | return true; |
5953 | } |
5954 | |
5955 | // fold (fadd (fpext (fma x, y, (fmul u, v))), z) |
5956 | // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) |
5957 | // FIXME: This turns two single-precision and one double-precision |
5958 | // operation into two double-precision operations, which might not be |
5959 | // interesting for all targets, especially GPUs. |
5960 | if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) && |
5961 | FMAMI->getOpcode() == PreferredFusedOpcode) { |
5962 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg()); |
5963 | if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5964 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5965 | SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) { |
5966 | MatchInfo = [=](MachineIRBuilder &B) { |
5967 | Register X = FMAMI->getOperand(i: 1).getReg(); |
5968 | Register Y = FMAMI->getOperand(i: 2).getReg(); |
5969 | X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0); |
5970 | Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0); |
5971 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
5972 | FMulMI->getOperand(i: 2).getReg(), RHS.Reg, X, Y, B); |
5973 | }; |
5974 | |
5975 | return true; |
5976 | } |
5977 | } |
5978 | |
5979 | // fold (fadd z, (fma x, y, (fpext (fmul u, v))) |
5980 | // -> (fma x, y, (fma (fpext u), (fpext v), z)) |
5981 | if (RHS.MI->getOpcode() == PreferredFusedOpcode && |
5982 | mi_match(R: RHS.MI->getOperand(i: 3).getReg(), MRI, |
5983 | P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
5984 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5985 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5986 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
5987 | MatchInfo = [=](MachineIRBuilder &B) { |
5988 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
5989 | FMulMI->getOperand(i: 2).getReg(), LHS.Reg, |
5990 | RHS.MI->getOperand(i: 1).getReg(), |
5991 | RHS.MI->getOperand(i: 2).getReg(), B); |
5992 | }; |
5993 | return true; |
5994 | } |
5995 | |
5996 | // fold (fadd z, (fpext (fma x, y, (fmul u, v))) |
5997 | // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) |
5998 | // FIXME: This turns two single-precision and one double-precision |
5999 | // operation into two double-precision operations, which might not be |
6000 | // interesting for all targets, especially GPUs. |
6001 | if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) && |
6002 | FMAMI->getOpcode() == PreferredFusedOpcode) { |
6003 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg()); |
6004 | if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6005 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
6006 | SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) { |
6007 | MatchInfo = [=](MachineIRBuilder &B) { |
6008 | Register X = FMAMI->getOperand(i: 1).getReg(); |
6009 | Register Y = FMAMI->getOperand(i: 2).getReg(); |
6010 | X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0); |
6011 | Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0); |
6012 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
6013 | FMulMI->getOperand(i: 2).getReg(), LHS.Reg, X, Y, B); |
6014 | }; |
6015 | return true; |
6016 | } |
6017 | } |
6018 | |
6019 | return false; |
6020 | } |
6021 | |
6022 | bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA( |
6023 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
6024 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
6025 | |
6026 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
6027 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
6028 | return false; |
6029 | |
6030 | Register Op1 = MI.getOperand(i: 1).getReg(); |
6031 | Register Op2 = MI.getOperand(i: 2).getReg(); |
6032 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
6033 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
6034 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6035 | |
6036 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
6037 | // prefer to fold the multiply with fewer uses. |
6038 | int FirstMulHasFewerUses = true; |
6039 | if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
6040 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
6041 | hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
6042 | FirstMulHasFewerUses = false; |
6043 | |
6044 | unsigned PreferredFusedOpcode = |
6045 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
6046 | |
6047 | // fold (fsub (fmul x, y), z) -> (fma x, y, -z) |
6048 | if (FirstMulHasFewerUses && |
6049 | (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
6050 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg)))) { |
6051 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6052 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHS.Reg).getReg(Idx: 0); |
6053 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6054 | SrcOps: {LHS.MI->getOperand(i: 1).getReg(), |
6055 | LHS.MI->getOperand(i: 2).getReg(), NegZ}); |
6056 | }; |
6057 | return true; |
6058 | } |
6059 | // fold (fsub x, (fmul y, z)) -> (fma -y, z, x) |
6060 | else if ((isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
6061 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg)))) { |
6062 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6063 | Register NegY = |
6064 | B.buildFNeg(Dst: DstTy, Src0: RHS.MI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
6065 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6066 | SrcOps: {NegY, RHS.MI->getOperand(i: 2).getReg(), LHS.Reg}); |
6067 | }; |
6068 | return true; |
6069 | } |
6070 | |
6071 | return false; |
6072 | } |
6073 | |
6074 | bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA( |
6075 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
6076 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
6077 | |
6078 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
6079 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
6080 | return false; |
6081 | |
6082 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
6083 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
6084 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6085 | |
6086 | unsigned PreferredFusedOpcode = |
6087 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
6088 | |
6089 | MachineInstr *FMulMI; |
6090 | // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z)) |
6091 | if (mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) && |
6092 | (Aggressive || (MRI.hasOneNonDBGUse(RegNo: LHSReg) && |
6093 | MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) && |
6094 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) { |
6095 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6096 | Register NegX = |
6097 | B.buildFNeg(Dst: DstTy, Src0: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
6098 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0); |
6099 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6100 | SrcOps: {NegX, FMulMI->getOperand(i: 2).getReg(), NegZ}); |
6101 | }; |
6102 | return true; |
6103 | } |
6104 | |
6105 | // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x) |
6106 | if (mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) && |
6107 | (Aggressive || (MRI.hasOneNonDBGUse(RegNo: RHSReg) && |
6108 | MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) && |
6109 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) { |
6110 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6111 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6112 | SrcOps: {FMulMI->getOperand(i: 1).getReg(), |
6113 | FMulMI->getOperand(i: 2).getReg(), LHSReg}); |
6114 | }; |
6115 | return true; |
6116 | } |
6117 | |
6118 | return false; |
6119 | } |
6120 | |
6121 | bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA( |
6122 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
6123 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
6124 | |
6125 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
6126 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
6127 | return false; |
6128 | |
6129 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
6130 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
6131 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6132 | |
6133 | unsigned PreferredFusedOpcode = |
6134 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
6135 | |
6136 | MachineInstr *FMulMI; |
6137 | // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z)) |
6138 | if (mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
6139 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6140 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHSReg))) { |
6141 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6142 | Register FpExtX = |
6143 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
6144 | Register FpExtY = |
6145 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0); |
6146 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0); |
6147 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6148 | SrcOps: {FpExtX, FpExtY, NegZ}); |
6149 | }; |
6150 | return true; |
6151 | } |
6152 | |
6153 | // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x) |
6154 | if (mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
6155 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6156 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHSReg))) { |
6157 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6158 | Register FpExtY = |
6159 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
6160 | Register NegY = B.buildFNeg(Dst: DstTy, Src0: FpExtY).getReg(Idx: 0); |
6161 | Register FpExtZ = |
6162 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0); |
6163 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6164 | SrcOps: {NegY, FpExtZ, LHSReg}); |
6165 | }; |
6166 | return true; |
6167 | } |
6168 | |
6169 | return false; |
6170 | } |
6171 | |
6172 | bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA( |
6173 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
6174 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
6175 | |
6176 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
6177 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
6178 | return false; |
6179 | |
6180 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
6181 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6182 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
6183 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
6184 | |
6185 | unsigned PreferredFusedOpcode = |
6186 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
6187 | |
6188 | auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z, |
6189 | MachineIRBuilder &B) { |
6190 | Register FpExtX = B.buildFPExt(Res: DstTy, Op: X).getReg(Idx: 0); |
6191 | Register FpExtY = B.buildFPExt(Res: DstTy, Op: Y).getReg(Idx: 0); |
6192 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {Dst}, SrcOps: {FpExtX, FpExtY, Z}); |
6193 | }; |
6194 | |
6195 | MachineInstr *FMulMI; |
6196 | // fold (fsub (fpext (fneg (fmul x, y))), z) -> |
6197 | // (fneg (fma (fpext x), (fpext y), z)) |
6198 | // fold (fsub (fneg (fpext (fmul x, y))), z) -> |
6199 | // (fneg (fma (fpext x), (fpext y), z)) |
6200 | if ((mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) || |
6201 | mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) && |
6202 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6203 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy, |
6204 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
6205 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6206 | Register FMAReg = MRI.createGenericVirtualRegister(Ty: DstTy); |
6207 | buildMatchInfo(FMAReg, FMulMI->getOperand(i: 1).getReg(), |
6208 | FMulMI->getOperand(i: 2).getReg(), RHSReg, B); |
6209 | B.buildFNeg(Dst: MI.getOperand(i: 0).getReg(), Src0: FMAReg); |
6210 | }; |
6211 | return true; |
6212 | } |
6213 | |
6214 | // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x) |
6215 | // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x) |
6216 | if ((mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) || |
6217 | mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) && |
6218 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6219 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy, |
6220 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
6221 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6222 | buildMatchInfo(MI.getOperand(i: 0).getReg(), FMulMI->getOperand(i: 1).getReg(), |
6223 | FMulMI->getOperand(i: 2).getReg(), LHSReg, B); |
6224 | }; |
6225 | return true; |
6226 | } |
6227 | |
6228 | return false; |
6229 | } |
6230 | |
6231 | bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI, |
6232 | unsigned &IdxToPropagate) { |
6233 | bool PropagateNaN; |
6234 | switch (MI.getOpcode()) { |
6235 | default: |
6236 | return false; |
6237 | case TargetOpcode::G_FMINNUM: |
6238 | case TargetOpcode::G_FMAXNUM: |
6239 | PropagateNaN = false; |
6240 | break; |
6241 | case TargetOpcode::G_FMINIMUM: |
6242 | case TargetOpcode::G_FMAXIMUM: |
6243 | PropagateNaN = true; |
6244 | break; |
6245 | } |
6246 | |
6247 | auto MatchNaN = [&](unsigned Idx) { |
6248 | Register MaybeNaNReg = MI.getOperand(i: Idx).getReg(); |
6249 | const ConstantFP *MaybeCst = getConstantFPVRegVal(VReg: MaybeNaNReg, MRI); |
6250 | if (!MaybeCst || !MaybeCst->getValueAPF().isNaN()) |
6251 | return false; |
6252 | IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1); |
6253 | return true; |
6254 | }; |
6255 | |
6256 | return MatchNaN(1) || MatchNaN(2); |
6257 | } |
6258 | |
6259 | bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) { |
6260 | assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD" ); |
6261 | Register LHS = MI.getOperand(i: 1).getReg(); |
6262 | Register RHS = MI.getOperand(i: 2).getReg(); |
6263 | |
6264 | // Helper lambda to check for opportunities for |
6265 | // A + (B - A) -> B |
6266 | // (B - A) + A -> B |
6267 | auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) { |
6268 | Register Reg; |
6269 | return mi_match(R: MaybeSub, MRI, P: m_GSub(L: m_Reg(R&: Src), R: m_Reg(R&: Reg))) && |
6270 | Reg == MaybeSameReg; |
6271 | }; |
6272 | return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); |
6273 | } |
6274 | |
6275 | bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI, |
6276 | Register &MatchInfo) { |
6277 | // This combine folds the following patterns: |
6278 | // |
6279 | // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k)) |
6280 | // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k))) |
6281 | // into |
6282 | // x |
6283 | // if |
6284 | // k == sizeof(VecEltTy)/2 |
6285 | // type(x) == type(dst) |
6286 | // |
6287 | // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef) |
6288 | // into |
6289 | // x |
6290 | // if |
6291 | // type(x) == type(dst) |
6292 | |
6293 | LLT DstVecTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6294 | LLT DstEltTy = DstVecTy.getElementType(); |
6295 | |
6296 | Register Lo, Hi; |
6297 | |
6298 | if (mi_match( |
6299 | MI, MRI, |
6300 | P: m_GBuildVector(L: m_GTrunc(Src: m_GBitcast(Src: m_Reg(R&: Lo))), R: m_GImplicitDef()))) { |
6301 | MatchInfo = Lo; |
6302 | return MRI.getType(Reg: MatchInfo) == DstVecTy; |
6303 | } |
6304 | |
6305 | std::optional<ValueAndVReg> ShiftAmount; |
6306 | const auto LoPattern = m_GBitcast(Src: m_Reg(R&: Lo)); |
6307 | const auto HiPattern = m_GLShr(L: m_GBitcast(Src: m_Reg(R&: Hi)), R: m_GCst(ValReg&: ShiftAmount)); |
6308 | if (mi_match( |
6309 | MI, MRI, |
6310 | P: m_any_of(preds: m_GBuildVectorTrunc(L: LoPattern, R: HiPattern), |
6311 | preds: m_GBuildVector(L: m_GTrunc(Src: LoPattern), R: m_GTrunc(Src: HiPattern))))) { |
6312 | if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) { |
6313 | MatchInfo = Lo; |
6314 | return MRI.getType(Reg: MatchInfo) == DstVecTy; |
6315 | } |
6316 | } |
6317 | |
6318 | return false; |
6319 | } |
6320 | |
6321 | bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI, |
6322 | Register &MatchInfo) { |
6323 | // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x |
6324 | // if type(x) == type(G_TRUNC) |
6325 | if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI, |
6326 | P: m_GBitcast(Src: m_GBuildVector(L: m_Reg(R&: MatchInfo), R: m_Reg())))) |
6327 | return false; |
6328 | |
6329 | return MRI.getType(Reg: MatchInfo) == MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6330 | } |
6331 | |
6332 | bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI, |
6333 | Register &MatchInfo) { |
6334 | // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with |
6335 | // y if K == size of vector element type |
6336 | std::optional<ValueAndVReg> ShiftAmt; |
6337 | if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI, |
6338 | P: m_GLShr(L: m_GBitcast(Src: m_GBuildVector(L: m_Reg(), R: m_Reg(R&: MatchInfo))), |
6339 | R: m_GCst(ValReg&: ShiftAmt)))) |
6340 | return false; |
6341 | |
6342 | LLT MatchTy = MRI.getType(Reg: MatchInfo); |
6343 | return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() && |
6344 | MatchTy == MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6345 | } |
6346 | |
6347 | unsigned CombinerHelper::getFPMinMaxOpcForSelect( |
6348 | CmpInst::Predicate Pred, LLT DstTy, |
6349 | SelectPatternNaNBehaviour VsNaNRetVal) const { |
6350 | assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE && |
6351 | "Expected a NaN behaviour?" ); |
6352 | // Choose an opcode based off of legality or the behaviour when one of the |
6353 | // LHS/RHS may be NaN. |
6354 | switch (Pred) { |
6355 | default: |
6356 | return 0; |
6357 | case CmpInst::FCMP_UGT: |
6358 | case CmpInst::FCMP_UGE: |
6359 | case CmpInst::FCMP_OGT: |
6360 | case CmpInst::FCMP_OGE: |
6361 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6362 | return TargetOpcode::G_FMAXNUM; |
6363 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) |
6364 | return TargetOpcode::G_FMAXIMUM; |
6365 | if (isLegal(Query: {TargetOpcode::G_FMAXNUM, {DstTy}})) |
6366 | return TargetOpcode::G_FMAXNUM; |
6367 | if (isLegal(Query: {TargetOpcode::G_FMAXIMUM, {DstTy}})) |
6368 | return TargetOpcode::G_FMAXIMUM; |
6369 | return 0; |
6370 | case CmpInst::FCMP_ULT: |
6371 | case CmpInst::FCMP_ULE: |
6372 | case CmpInst::FCMP_OLT: |
6373 | case CmpInst::FCMP_OLE: |
6374 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6375 | return TargetOpcode::G_FMINNUM; |
6376 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) |
6377 | return TargetOpcode::G_FMINIMUM; |
6378 | if (isLegal(Query: {TargetOpcode::G_FMINNUM, {DstTy}})) |
6379 | return TargetOpcode::G_FMINNUM; |
6380 | if (!isLegal(Query: {TargetOpcode::G_FMINIMUM, {DstTy}})) |
6381 | return 0; |
6382 | return TargetOpcode::G_FMINIMUM; |
6383 | } |
6384 | } |
6385 | |
6386 | CombinerHelper::SelectPatternNaNBehaviour |
6387 | CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS, |
6388 | bool IsOrderedComparison) const { |
6389 | bool LHSSafe = isKnownNeverNaN(Val: LHS, MRI); |
6390 | bool RHSSafe = isKnownNeverNaN(Val: RHS, MRI); |
6391 | // Completely unsafe. |
6392 | if (!LHSSafe && !RHSSafe) |
6393 | return SelectPatternNaNBehaviour::NOT_APPLICABLE; |
6394 | if (LHSSafe && RHSSafe) |
6395 | return SelectPatternNaNBehaviour::RETURNS_ANY; |
6396 | // An ordered comparison will return false when given a NaN, so it |
6397 | // returns the RHS. |
6398 | if (IsOrderedComparison) |
6399 | return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN |
6400 | : SelectPatternNaNBehaviour::RETURNS_OTHER; |
6401 | // An unordered comparison will return true when given a NaN, so it |
6402 | // returns the LHS. |
6403 | return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER |
6404 | : SelectPatternNaNBehaviour::RETURNS_NAN; |
6405 | } |
6406 | |
6407 | bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond, |
6408 | Register TrueVal, Register FalseVal, |
6409 | BuildFnTy &MatchInfo) { |
6410 | // Match: select (fcmp cond x, y) x, y |
6411 | // select (fcmp cond x, y) y, x |
6412 | // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition. |
6413 | LLT DstTy = MRI.getType(Reg: Dst); |
6414 | // Bail out early on pointers, since we'll never want to fold to a min/max. |
6415 | if (DstTy.isPointer()) |
6416 | return false; |
6417 | // Match a floating point compare with a less-than/greater-than predicate. |
6418 | // TODO: Allow multiple users of the compare if they are all selects. |
6419 | CmpInst::Predicate Pred; |
6420 | Register CmpLHS, CmpRHS; |
6421 | if (!mi_match(R: Cond, MRI, |
6422 | P: m_OneNonDBGUse( |
6423 | SP: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: CmpLHS), R: m_Reg(R&: CmpRHS)))) || |
6424 | CmpInst::isEquality(pred: Pred)) |
6425 | return false; |
6426 | SelectPatternNaNBehaviour ResWithKnownNaNInfo = |
6427 | computeRetValAgainstNaN(LHS: CmpLHS, RHS: CmpRHS, IsOrderedComparison: CmpInst::isOrdered(predicate: Pred)); |
6428 | if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE) |
6429 | return false; |
6430 | if (TrueVal == CmpRHS && FalseVal == CmpLHS) { |
6431 | std::swap(a&: CmpLHS, b&: CmpRHS); |
6432 | Pred = CmpInst::getSwappedPredicate(pred: Pred); |
6433 | if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN) |
6434 | ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER; |
6435 | else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6436 | ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN; |
6437 | } |
6438 | if (TrueVal != CmpLHS || FalseVal != CmpRHS) |
6439 | return false; |
6440 | // Decide what type of max/min this should be based off of the predicate. |
6441 | unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, VsNaNRetVal: ResWithKnownNaNInfo); |
6442 | if (!Opc || !isLegal(Query: {Opc, {DstTy}})) |
6443 | return false; |
6444 | // Comparisons between signed zero and zero may have different results... |
6445 | // unless we have fmaximum/fminimum. In that case, we know -0 < 0. |
6446 | if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) { |
6447 | // We don't know if a comparison between two 0s will give us a consistent |
6448 | // result. Be conservative and only proceed if at least one side is |
6449 | // non-zero. |
6450 | auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpLHS, MRI); |
6451 | if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) { |
6452 | KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpRHS, MRI); |
6453 | if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) |
6454 | return false; |
6455 | } |
6456 | } |
6457 | MatchInfo = [=](MachineIRBuilder &B) { |
6458 | B.buildInstr(Opc, DstOps: {Dst}, SrcOps: {CmpLHS, CmpRHS}); |
6459 | }; |
6460 | return true; |
6461 | } |
6462 | |
6463 | bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI, |
6464 | BuildFnTy &MatchInfo) { |
6465 | // TODO: Handle integer cases. |
6466 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
6467 | // Condition may be fed by a truncated compare. |
6468 | Register Cond = MI.getOperand(i: 1).getReg(); |
6469 | Register MaybeTrunc; |
6470 | if (mi_match(R: Cond, MRI, P: m_OneNonDBGUse(SP: m_GTrunc(Src: m_Reg(R&: MaybeTrunc))))) |
6471 | Cond = MaybeTrunc; |
6472 | Register Dst = MI.getOperand(i: 0).getReg(); |
6473 | Register TrueVal = MI.getOperand(i: 2).getReg(); |
6474 | Register FalseVal = MI.getOperand(i: 3).getReg(); |
6475 | return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo); |
6476 | } |
6477 | |
6478 | bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI, |
6479 | BuildFnTy &MatchInfo) { |
6480 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
6481 | // (X + Y) == X --> Y == 0 |
6482 | // (X + Y) != X --> Y != 0 |
6483 | // (X - Y) == X --> Y == 0 |
6484 | // (X - Y) != X --> Y != 0 |
6485 | // (X ^ Y) == X --> Y == 0 |
6486 | // (X ^ Y) != X --> Y != 0 |
6487 | Register Dst = MI.getOperand(i: 0).getReg(); |
6488 | CmpInst::Predicate Pred; |
6489 | Register X, Y, OpLHS, OpRHS; |
6490 | bool MatchedSub = mi_match( |
6491 | R: Dst, MRI, |
6492 | P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), R: m_GSub(L: m_Reg(R&: OpLHS), R: m_Reg(R&: Y)))); |
6493 | if (MatchedSub && X != OpLHS) |
6494 | return false; |
6495 | if (!MatchedSub) { |
6496 | if (!mi_match(R: Dst, MRI, |
6497 | P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), |
6498 | R: m_any_of(preds: m_GAdd(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)), |
6499 | preds: m_GXor(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)))))) |
6500 | return false; |
6501 | Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register(); |
6502 | } |
6503 | MatchInfo = [=](MachineIRBuilder &B) { |
6504 | auto Zero = B.buildConstant(Res: MRI.getType(Reg: Y), Val: 0); |
6505 | B.buildICmp(Pred, Res: Dst, Op0: Y, Op1: Zero); |
6506 | }; |
6507 | return CmpInst::isEquality(pred: Pred) && Y.isValid(); |
6508 | } |
6509 | |
6510 | bool CombinerHelper::matchShiftsTooBig(MachineInstr &MI) { |
6511 | Register ShiftReg = MI.getOperand(i: 2).getReg(); |
6512 | LLT ResTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6513 | auto IsShiftTooBig = [&](const Constant *C) { |
6514 | auto *CI = dyn_cast<ConstantInt>(Val: C); |
6515 | return CI && CI->uge(Num: ResTy.getScalarSizeInBits()); |
6516 | }; |
6517 | return matchUnaryPredicate(MRI, Reg: ShiftReg, Match: IsShiftTooBig); |
6518 | } |
6519 | |
6520 | bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) { |
6521 | unsigned LHSOpndIdx = 1; |
6522 | unsigned RHSOpndIdx = 2; |
6523 | switch (MI.getOpcode()) { |
6524 | case TargetOpcode::G_UADDO: |
6525 | case TargetOpcode::G_SADDO: |
6526 | case TargetOpcode::G_UMULO: |
6527 | case TargetOpcode::G_SMULO: |
6528 | LHSOpndIdx = 2; |
6529 | RHSOpndIdx = 3; |
6530 | break; |
6531 | default: |
6532 | break; |
6533 | } |
6534 | Register LHS = MI.getOperand(i: LHSOpndIdx).getReg(); |
6535 | Register RHS = MI.getOperand(i: RHSOpndIdx).getReg(); |
6536 | if (!getIConstantVRegVal(VReg: LHS, MRI)) { |
6537 | // Skip commuting if LHS is not a constant. But, LHS may be a |
6538 | // G_CONSTANT_FOLD_BARRIER. If so we commute as long as we don't already |
6539 | // have a constant on the RHS. |
6540 | if (MRI.getVRegDef(Reg: LHS)->getOpcode() != |
6541 | TargetOpcode::G_CONSTANT_FOLD_BARRIER) |
6542 | return false; |
6543 | } |
6544 | // Commute as long as RHS is not a constant or G_CONSTANT_FOLD_BARRIER. |
6545 | return MRI.getVRegDef(Reg: RHS)->getOpcode() != |
6546 | TargetOpcode::G_CONSTANT_FOLD_BARRIER && |
6547 | !getIConstantVRegVal(VReg: RHS, MRI); |
6548 | } |
6549 | |
6550 | bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) { |
6551 | Register LHS = MI.getOperand(i: 1).getReg(); |
6552 | Register RHS = MI.getOperand(i: 2).getReg(); |
6553 | std::optional<FPValueAndVReg> ValAndVReg; |
6554 | if (!mi_match(R: LHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg))) |
6555 | return false; |
6556 | return !mi_match(R: RHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg)); |
6557 | } |
6558 | |
6559 | void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) { |
6560 | Observer.changingInstr(MI); |
6561 | unsigned LHSOpndIdx = 1; |
6562 | unsigned RHSOpndIdx = 2; |
6563 | switch (MI.getOpcode()) { |
6564 | case TargetOpcode::G_UADDO: |
6565 | case TargetOpcode::G_SADDO: |
6566 | case TargetOpcode::G_UMULO: |
6567 | case TargetOpcode::G_SMULO: |
6568 | LHSOpndIdx = 2; |
6569 | RHSOpndIdx = 3; |
6570 | break; |
6571 | default: |
6572 | break; |
6573 | } |
6574 | Register LHSReg = MI.getOperand(i: LHSOpndIdx).getReg(); |
6575 | Register RHSReg = MI.getOperand(i: RHSOpndIdx).getReg(); |
6576 | MI.getOperand(i: LHSOpndIdx).setReg(RHSReg); |
6577 | MI.getOperand(i: RHSOpndIdx).setReg(LHSReg); |
6578 | Observer.changedInstr(MI); |
6579 | } |
6580 | |
6581 | bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) { |
6582 | LLT SrcTy = MRI.getType(Reg: Src); |
6583 | if (SrcTy.isFixedVector()) |
6584 | return isConstantSplatVector(Src, SplatValue: 1, AllowUndefs); |
6585 | if (SrcTy.isScalar()) { |
6586 | if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr) |
6587 | return true; |
6588 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6589 | return IConstant && IConstant->Value == 1; |
6590 | } |
6591 | return false; // scalable vector |
6592 | } |
6593 | |
6594 | bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) { |
6595 | LLT SrcTy = MRI.getType(Reg: Src); |
6596 | if (SrcTy.isFixedVector()) |
6597 | return isConstantSplatVector(Src, SplatValue: 0, AllowUndefs); |
6598 | if (SrcTy.isScalar()) { |
6599 | if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr) |
6600 | return true; |
6601 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6602 | return IConstant && IConstant->Value == 0; |
6603 | } |
6604 | return false; // scalable vector |
6605 | } |
6606 | |
6607 | // Ignores COPYs during conformance checks. |
6608 | // FIXME scalable vectors. |
6609 | bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue, |
6610 | bool AllowUndefs) { |
6611 | GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI); |
6612 | if (!BuildVector) |
6613 | return false; |
6614 | unsigned NumSources = BuildVector->getNumSources(); |
6615 | |
6616 | for (unsigned I = 0; I < NumSources; ++I) { |
6617 | GImplicitDef *ImplicitDef = |
6618 | getOpcodeDef<GImplicitDef>(Reg: BuildVector->getSourceReg(I), MRI); |
6619 | if (ImplicitDef && AllowUndefs) |
6620 | continue; |
6621 | if (ImplicitDef && !AllowUndefs) |
6622 | return false; |
6623 | std::optional<ValueAndVReg> IConstant = |
6624 | getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI); |
6625 | if (IConstant && IConstant->Value == SplatValue) |
6626 | continue; |
6627 | return false; |
6628 | } |
6629 | return true; |
6630 | } |
6631 | |
6632 | // Ignores COPYs during lookups. |
6633 | // FIXME scalable vectors |
6634 | std::optional<APInt> |
6635 | CombinerHelper::getConstantOrConstantSplatVector(Register Src) { |
6636 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6637 | if (IConstant) |
6638 | return IConstant->Value; |
6639 | |
6640 | GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI); |
6641 | if (!BuildVector) |
6642 | return std::nullopt; |
6643 | unsigned NumSources = BuildVector->getNumSources(); |
6644 | |
6645 | std::optional<APInt> Value = std::nullopt; |
6646 | for (unsigned I = 0; I < NumSources; ++I) { |
6647 | std::optional<ValueAndVReg> IConstant = |
6648 | getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI); |
6649 | if (!IConstant) |
6650 | return std::nullopt; |
6651 | if (!Value) |
6652 | Value = IConstant->Value; |
6653 | else if (*Value != IConstant->Value) |
6654 | return std::nullopt; |
6655 | } |
6656 | return Value; |
6657 | } |
6658 | |
6659 | // FIXME G_SPLAT_VECTOR |
6660 | bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const { |
6661 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6662 | if (IConstant) |
6663 | return true; |
6664 | |
6665 | GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI); |
6666 | if (!BuildVector) |
6667 | return false; |
6668 | |
6669 | unsigned NumSources = BuildVector->getNumSources(); |
6670 | for (unsigned I = 0; I < NumSources; ++I) { |
6671 | std::optional<ValueAndVReg> IConstant = |
6672 | getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI); |
6673 | if (!IConstant) |
6674 | return false; |
6675 | } |
6676 | return true; |
6677 | } |
6678 | |
6679 | // TODO: use knownbits to determine zeros |
6680 | bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select, |
6681 | BuildFnTy &MatchInfo) { |
6682 | uint32_t Flags = Select->getFlags(); |
6683 | Register Dest = Select->getReg(Idx: 0); |
6684 | Register Cond = Select->getCondReg(); |
6685 | Register True = Select->getTrueReg(); |
6686 | Register False = Select->getFalseReg(); |
6687 | LLT CondTy = MRI.getType(Reg: Select->getCondReg()); |
6688 | LLT TrueTy = MRI.getType(Reg: Select->getTrueReg()); |
6689 | |
6690 | // We only do this combine for scalar boolean conditions. |
6691 | if (CondTy != LLT::scalar(SizeInBits: 1)) |
6692 | return false; |
6693 | |
6694 | if (TrueTy.isPointer()) |
6695 | return false; |
6696 | |
6697 | // Both are scalars. |
6698 | std::optional<ValueAndVReg> TrueOpt = |
6699 | getIConstantVRegValWithLookThrough(VReg: True, MRI); |
6700 | std::optional<ValueAndVReg> FalseOpt = |
6701 | getIConstantVRegValWithLookThrough(VReg: False, MRI); |
6702 | |
6703 | if (!TrueOpt || !FalseOpt) |
6704 | return false; |
6705 | |
6706 | APInt TrueValue = TrueOpt->Value; |
6707 | APInt FalseValue = FalseOpt->Value; |
6708 | |
6709 | // select Cond, 1, 0 --> zext (Cond) |
6710 | if (TrueValue.isOne() && FalseValue.isZero()) { |
6711 | MatchInfo = [=](MachineIRBuilder &B) { |
6712 | B.setInstrAndDebugLoc(*Select); |
6713 | B.buildZExtOrTrunc(Res: Dest, Op: Cond); |
6714 | }; |
6715 | return true; |
6716 | } |
6717 | |
6718 | // select Cond, -1, 0 --> sext (Cond) |
6719 | if (TrueValue.isAllOnes() && FalseValue.isZero()) { |
6720 | MatchInfo = [=](MachineIRBuilder &B) { |
6721 | B.setInstrAndDebugLoc(*Select); |
6722 | B.buildSExtOrTrunc(Res: Dest, Op: Cond); |
6723 | }; |
6724 | return true; |
6725 | } |
6726 | |
6727 | // select Cond, 0, 1 --> zext (!Cond) |
6728 | if (TrueValue.isZero() && FalseValue.isOne()) { |
6729 | MatchInfo = [=](MachineIRBuilder &B) { |
6730 | B.setInstrAndDebugLoc(*Select); |
6731 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6732 | B.buildNot(Dst: Inner, Src0: Cond); |
6733 | B.buildZExtOrTrunc(Res: Dest, Op: Inner); |
6734 | }; |
6735 | return true; |
6736 | } |
6737 | |
6738 | // select Cond, 0, -1 --> sext (!Cond) |
6739 | if (TrueValue.isZero() && FalseValue.isAllOnes()) { |
6740 | MatchInfo = [=](MachineIRBuilder &B) { |
6741 | B.setInstrAndDebugLoc(*Select); |
6742 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6743 | B.buildNot(Dst: Inner, Src0: Cond); |
6744 | B.buildSExtOrTrunc(Res: Dest, Op: Inner); |
6745 | }; |
6746 | return true; |
6747 | } |
6748 | |
6749 | // select Cond, C1, C1-1 --> add (zext Cond), C1-1 |
6750 | if (TrueValue - 1 == FalseValue) { |
6751 | MatchInfo = [=](MachineIRBuilder &B) { |
6752 | B.setInstrAndDebugLoc(*Select); |
6753 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6754 | B.buildZExtOrTrunc(Res: Inner, Op: Cond); |
6755 | B.buildAdd(Dst: Dest, Src0: Inner, Src1: False); |
6756 | }; |
6757 | return true; |
6758 | } |
6759 | |
6760 | // select Cond, C1, C1+1 --> add (sext Cond), C1+1 |
6761 | if (TrueValue + 1 == FalseValue) { |
6762 | MatchInfo = [=](MachineIRBuilder &B) { |
6763 | B.setInstrAndDebugLoc(*Select); |
6764 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6765 | B.buildSExtOrTrunc(Res: Inner, Op: Cond); |
6766 | B.buildAdd(Dst: Dest, Src0: Inner, Src1: False); |
6767 | }; |
6768 | return true; |
6769 | } |
6770 | |
6771 | // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2) |
6772 | if (TrueValue.isPowerOf2() && FalseValue.isZero()) { |
6773 | MatchInfo = [=](MachineIRBuilder &B) { |
6774 | B.setInstrAndDebugLoc(*Select); |
6775 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6776 | B.buildZExtOrTrunc(Res: Inner, Op: Cond); |
6777 | // The shift amount must be scalar. |
6778 | LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy; |
6779 | auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: TrueValue.exactLogBase2()); |
6780 | B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags); |
6781 | }; |
6782 | return true; |
6783 | } |
6784 | // select Cond, -1, C --> or (sext Cond), C |
6785 | if (TrueValue.isAllOnes()) { |
6786 | MatchInfo = [=](MachineIRBuilder &B) { |
6787 | B.setInstrAndDebugLoc(*Select); |
6788 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6789 | B.buildSExtOrTrunc(Res: Inner, Op: Cond); |
6790 | B.buildOr(Dst: Dest, Src0: Inner, Src1: False, Flags); |
6791 | }; |
6792 | return true; |
6793 | } |
6794 | |
6795 | // select Cond, C, -1 --> or (sext (not Cond)), C |
6796 | if (FalseValue.isAllOnes()) { |
6797 | MatchInfo = [=](MachineIRBuilder &B) { |
6798 | B.setInstrAndDebugLoc(*Select); |
6799 | Register Not = MRI.createGenericVirtualRegister(Ty: CondTy); |
6800 | B.buildNot(Dst: Not, Src0: Cond); |
6801 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6802 | B.buildSExtOrTrunc(Res: Inner, Op: Not); |
6803 | B.buildOr(Dst: Dest, Src0: Inner, Src1: True, Flags); |
6804 | }; |
6805 | return true; |
6806 | } |
6807 | |
6808 | return false; |
6809 | } |
6810 | |
6811 | // TODO: use knownbits to determine zeros |
6812 | bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select, |
6813 | BuildFnTy &MatchInfo) { |
6814 | uint32_t Flags = Select->getFlags(); |
6815 | Register DstReg = Select->getReg(Idx: 0); |
6816 | Register Cond = Select->getCondReg(); |
6817 | Register True = Select->getTrueReg(); |
6818 | Register False = Select->getFalseReg(); |
6819 | LLT CondTy = MRI.getType(Reg: Select->getCondReg()); |
6820 | LLT TrueTy = MRI.getType(Reg: Select->getTrueReg()); |
6821 | |
6822 | // Boolean or fixed vector of booleans. |
6823 | if (CondTy.isScalableVector() || |
6824 | (CondTy.isFixedVector() && |
6825 | CondTy.getElementType().getScalarSizeInBits() != 1) || |
6826 | CondTy.getScalarSizeInBits() != 1) |
6827 | return false; |
6828 | |
6829 | if (CondTy != TrueTy) |
6830 | return false; |
6831 | |
6832 | // select Cond, Cond, F --> or Cond, F |
6833 | // select Cond, 1, F --> or Cond, F |
6834 | if ((Cond == True) || isOneOrOneSplat(Src: True, /* AllowUndefs */ true)) { |
6835 | MatchInfo = [=](MachineIRBuilder &B) { |
6836 | B.setInstrAndDebugLoc(*Select); |
6837 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6838 | B.buildZExtOrTrunc(Res: Ext, Op: Cond); |
6839 | auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False); |
6840 | B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeFalse, Flags); |
6841 | }; |
6842 | return true; |
6843 | } |
6844 | |
6845 | // select Cond, T, Cond --> and Cond, T |
6846 | // select Cond, T, 0 --> and Cond, T |
6847 | if ((Cond == False) || isZeroOrZeroSplat(Src: False, /* AllowUndefs */ true)) { |
6848 | MatchInfo = [=](MachineIRBuilder &B) { |
6849 | B.setInstrAndDebugLoc(*Select); |
6850 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6851 | B.buildZExtOrTrunc(Res: Ext, Op: Cond); |
6852 | auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True); |
6853 | B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeTrue); |
6854 | }; |
6855 | return true; |
6856 | } |
6857 | |
6858 | // select Cond, T, 1 --> or (not Cond), T |
6859 | if (isOneOrOneSplat(Src: False, /* AllowUndefs */ true)) { |
6860 | MatchInfo = [=](MachineIRBuilder &B) { |
6861 | B.setInstrAndDebugLoc(*Select); |
6862 | // First the not. |
6863 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6864 | B.buildNot(Dst: Inner, Src0: Cond); |
6865 | // Then an ext to match the destination register. |
6866 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6867 | B.buildZExtOrTrunc(Res: Ext, Op: Inner); |
6868 | auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True); |
6869 | B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeTrue, Flags); |
6870 | }; |
6871 | return true; |
6872 | } |
6873 | |
6874 | // select Cond, 0, F --> and (not Cond), F |
6875 | if (isZeroOrZeroSplat(Src: True, /* AllowUndefs */ true)) { |
6876 | MatchInfo = [=](MachineIRBuilder &B) { |
6877 | B.setInstrAndDebugLoc(*Select); |
6878 | // First the not. |
6879 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6880 | B.buildNot(Dst: Inner, Src0: Cond); |
6881 | // Then an ext to match the destination register. |
6882 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6883 | B.buildZExtOrTrunc(Res: Ext, Op: Inner); |
6884 | auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False); |
6885 | B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeFalse); |
6886 | }; |
6887 | return true; |
6888 | } |
6889 | |
6890 | return false; |
6891 | } |
6892 | |
6893 | bool CombinerHelper::matchSelectIMinMax(const MachineOperand &MO, |
6894 | BuildFnTy &MatchInfo) { |
6895 | GSelect *Select = cast<GSelect>(Val: MRI.getVRegDef(Reg: MO.getReg())); |
6896 | GICmp *Cmp = cast<GICmp>(Val: MRI.getVRegDef(Reg: Select->getCondReg())); |
6897 | |
6898 | Register DstReg = Select->getReg(Idx: 0); |
6899 | Register True = Select->getTrueReg(); |
6900 | Register False = Select->getFalseReg(); |
6901 | LLT DstTy = MRI.getType(Reg: DstReg); |
6902 | |
6903 | if (DstTy.isPointer()) |
6904 | return false; |
6905 | |
6906 | // We want to fold the icmp and replace the select. |
6907 | if (!MRI.hasOneNonDBGUse(RegNo: Cmp->getReg(Idx: 0))) |
6908 | return false; |
6909 | |
6910 | CmpInst::Predicate Pred = Cmp->getCond(); |
6911 | // We need a larger or smaller predicate for |
6912 | // canonicalization. |
6913 | if (CmpInst::isEquality(pred: Pred)) |
6914 | return false; |
6915 | |
6916 | Register CmpLHS = Cmp->getLHSReg(); |
6917 | Register CmpRHS = Cmp->getRHSReg(); |
6918 | |
6919 | // We can swap CmpLHS and CmpRHS for higher hitrate. |
6920 | if (True == CmpRHS && False == CmpLHS) { |
6921 | std::swap(a&: CmpLHS, b&: CmpRHS); |
6922 | Pred = CmpInst::getSwappedPredicate(pred: Pred); |
6923 | } |
6924 | |
6925 | // (icmp X, Y) ? X : Y -> integer minmax. |
6926 | // see matchSelectPattern in ValueTracking. |
6927 | // Legality between G_SELECT and integer minmax can differ. |
6928 | if (True != CmpLHS || False != CmpRHS) |
6929 | return false; |
6930 | |
6931 | switch (Pred) { |
6932 | case ICmpInst::ICMP_UGT: |
6933 | case ICmpInst::ICMP_UGE: { |
6934 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMAX, DstTy})) |
6935 | return false; |
6936 | MatchInfo = [=](MachineIRBuilder &B) { B.buildUMax(Dst: DstReg, Src0: True, Src1: False); }; |
6937 | return true; |
6938 | } |
6939 | case ICmpInst::ICMP_SGT: |
6940 | case ICmpInst::ICMP_SGE: { |
6941 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMAX, DstTy})) |
6942 | return false; |
6943 | MatchInfo = [=](MachineIRBuilder &B) { B.buildSMax(Dst: DstReg, Src0: True, Src1: False); }; |
6944 | return true; |
6945 | } |
6946 | case ICmpInst::ICMP_ULT: |
6947 | case ICmpInst::ICMP_ULE: { |
6948 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMIN, DstTy})) |
6949 | return false; |
6950 | MatchInfo = [=](MachineIRBuilder &B) { B.buildUMin(Dst: DstReg, Src0: True, Src1: False); }; |
6951 | return true; |
6952 | } |
6953 | case ICmpInst::ICMP_SLT: |
6954 | case ICmpInst::ICMP_SLE: { |
6955 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMIN, DstTy})) |
6956 | return false; |
6957 | MatchInfo = [=](MachineIRBuilder &B) { B.buildSMin(Dst: DstReg, Src0: True, Src1: False); }; |
6958 | return true; |
6959 | } |
6960 | default: |
6961 | return false; |
6962 | } |
6963 | } |
6964 | |
6965 | bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) { |
6966 | GSelect *Select = cast<GSelect>(Val: &MI); |
6967 | |
6968 | if (tryFoldSelectOfConstants(Select, MatchInfo)) |
6969 | return true; |
6970 | |
6971 | if (tryFoldBoolSelectToLogic(Select, MatchInfo)) |
6972 | return true; |
6973 | |
6974 | return false; |
6975 | } |
6976 | |
6977 | /// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2) |
6978 | /// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2) |
6979 | /// into a single comparison using range-based reasoning. |
6980 | /// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges. |
6981 | bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(GLogicalBinOp *Logic, |
6982 | BuildFnTy &MatchInfo) { |
6983 | assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor" ); |
6984 | bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND; |
6985 | Register DstReg = Logic->getReg(Idx: 0); |
6986 | Register LHS = Logic->getLHSReg(); |
6987 | Register RHS = Logic->getRHSReg(); |
6988 | unsigned Flags = Logic->getFlags(); |
6989 | |
6990 | // We need an G_ICMP on the LHS register. |
6991 | GICmp *Cmp1 = getOpcodeDef<GICmp>(Reg: LHS, MRI); |
6992 | if (!Cmp1) |
6993 | return false; |
6994 | |
6995 | // We need an G_ICMP on the RHS register. |
6996 | GICmp *Cmp2 = getOpcodeDef<GICmp>(Reg: RHS, MRI); |
6997 | if (!Cmp2) |
6998 | return false; |
6999 | |
7000 | // We want to fold the icmps. |
7001 | if (!MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) || |
7002 | !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0))) |
7003 | return false; |
7004 | |
7005 | APInt C1; |
7006 | APInt C2; |
7007 | std::optional<ValueAndVReg> MaybeC1 = |
7008 | getIConstantVRegValWithLookThrough(VReg: Cmp1->getRHSReg(), MRI); |
7009 | if (!MaybeC1) |
7010 | return false; |
7011 | C1 = MaybeC1->Value; |
7012 | |
7013 | std::optional<ValueAndVReg> MaybeC2 = |
7014 | getIConstantVRegValWithLookThrough(VReg: Cmp2->getRHSReg(), MRI); |
7015 | if (!MaybeC2) |
7016 | return false; |
7017 | C2 = MaybeC2->Value; |
7018 | |
7019 | Register R1 = Cmp1->getLHSReg(); |
7020 | Register R2 = Cmp2->getLHSReg(); |
7021 | CmpInst::Predicate Pred1 = Cmp1->getCond(); |
7022 | CmpInst::Predicate Pred2 = Cmp2->getCond(); |
7023 | LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0)); |
7024 | LLT CmpOperandTy = MRI.getType(Reg: R1); |
7025 | |
7026 | if (CmpOperandTy.isPointer()) |
7027 | return false; |
7028 | |
7029 | // We build ands, adds, and constants of type CmpOperandTy. |
7030 | // They must be legal to build. |
7031 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_AND, CmpOperandTy}) || |
7032 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, CmpOperandTy}) || |
7033 | !isConstantLegalOrBeforeLegalizer(Ty: CmpOperandTy)) |
7034 | return false; |
7035 | |
7036 | // Look through add of a constant offset on R1, R2, or both operands. This |
7037 | // allows us to interpret the R + C' < C'' range idiom into a proper range. |
7038 | std::optional<APInt> Offset1; |
7039 | std::optional<APInt> Offset2; |
7040 | if (R1 != R2) { |
7041 | if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R1, MRI)) { |
7042 | std::optional<ValueAndVReg> MaybeOffset1 = |
7043 | getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI); |
7044 | if (MaybeOffset1) { |
7045 | R1 = Add->getLHSReg(); |
7046 | Offset1 = MaybeOffset1->Value; |
7047 | } |
7048 | } |
7049 | if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R2, MRI)) { |
7050 | std::optional<ValueAndVReg> MaybeOffset2 = |
7051 | getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI); |
7052 | if (MaybeOffset2) { |
7053 | R2 = Add->getLHSReg(); |
7054 | Offset2 = MaybeOffset2->Value; |
7055 | } |
7056 | } |
7057 | } |
7058 | |
7059 | if (R1 != R2) |
7060 | return false; |
7061 | |
7062 | // We calculate the icmp ranges including maybe offsets. |
7063 | ConstantRange CR1 = ConstantRange::makeExactICmpRegion( |
7064 | Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred1) : Pred1, Other: C1); |
7065 | if (Offset1) |
7066 | CR1 = CR1.subtract(CI: *Offset1); |
7067 | |
7068 | ConstantRange CR2 = ConstantRange::makeExactICmpRegion( |
7069 | Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred2) : Pred2, Other: C2); |
7070 | if (Offset2) |
7071 | CR2 = CR2.subtract(CI: *Offset2); |
7072 | |
7073 | bool CreateMask = false; |
7074 | APInt LowerDiff; |
7075 | std::optional<ConstantRange> CR = CR1.exactUnionWith(CR: CR2); |
7076 | if (!CR) { |
7077 | // We need non-wrapping ranges. |
7078 | if (CR1.isWrappedSet() || CR2.isWrappedSet()) |
7079 | return false; |
7080 | |
7081 | // Check whether we have equal-size ranges that only differ by one bit. |
7082 | // In that case we can apply a mask to map one range onto the other. |
7083 | LowerDiff = CR1.getLower() ^ CR2.getLower(); |
7084 | APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1); |
7085 | APInt CR1Size = CR1.getUpper() - CR1.getLower(); |
7086 | if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff || |
7087 | CR1Size != CR2.getUpper() - CR2.getLower()) |
7088 | return false; |
7089 | |
7090 | CR = CR1.getLower().ult(RHS: CR2.getLower()) ? CR1 : CR2; |
7091 | CreateMask = true; |
7092 | } |
7093 | |
7094 | if (IsAnd) |
7095 | CR = CR->inverse(); |
7096 | |
7097 | CmpInst::Predicate NewPred; |
7098 | APInt NewC, Offset; |
7099 | CR->getEquivalentICmp(Pred&: NewPred, RHS&: NewC, Offset); |
7100 | |
7101 | // We take the result type of one of the original icmps, CmpTy, for |
7102 | // the to be build icmp. The operand type, CmpOperandTy, is used for |
7103 | // the other instructions and constants to be build. The types of |
7104 | // the parameters and output are the same for add and and. CmpTy |
7105 | // and the type of DstReg might differ. That is why we zext or trunc |
7106 | // the icmp into the destination register. |
7107 | |
7108 | MatchInfo = [=](MachineIRBuilder &B) { |
7109 | if (CreateMask && Offset != 0) { |
7110 | auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff); |
7111 | auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask. |
7112 | auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset); |
7113 | auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: And, Src1: OffsetC, Flags); |
7114 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
7115 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon); |
7116 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
7117 | } else if (CreateMask && Offset == 0) { |
7118 | auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff); |
7119 | auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask. |
7120 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
7121 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: And, Op1: NewCon); |
7122 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
7123 | } else if (!CreateMask && Offset != 0) { |
7124 | auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset); |
7125 | auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: R1, Src1: OffsetC, Flags); |
7126 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
7127 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon); |
7128 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
7129 | } else if (!CreateMask && Offset == 0) { |
7130 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
7131 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: R1, Op1: NewCon); |
7132 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
7133 | } else { |
7134 | llvm_unreachable("unexpected configuration of CreateMask and Offset" ); |
7135 | } |
7136 | }; |
7137 | return true; |
7138 | } |
7139 | |
7140 | bool CombinerHelper::tryFoldLogicOfFCmps(GLogicalBinOp *Logic, |
7141 | BuildFnTy &MatchInfo) { |
7142 | assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpecte xor" ); |
7143 | Register DestReg = Logic->getReg(Idx: 0); |
7144 | Register LHS = Logic->getLHSReg(); |
7145 | Register RHS = Logic->getRHSReg(); |
7146 | bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND; |
7147 | |
7148 | // We need a compare on the LHS register. |
7149 | GFCmp *Cmp1 = getOpcodeDef<GFCmp>(Reg: LHS, MRI); |
7150 | if (!Cmp1) |
7151 | return false; |
7152 | |
7153 | // We need a compare on the RHS register. |
7154 | GFCmp *Cmp2 = getOpcodeDef<GFCmp>(Reg: RHS, MRI); |
7155 | if (!Cmp2) |
7156 | return false; |
7157 | |
7158 | LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0)); |
7159 | LLT CmpOperandTy = MRI.getType(Reg: Cmp1->getLHSReg()); |
7160 | |
7161 | // We build one fcmp, want to fold the fcmps, replace the logic op, |
7162 | // and the fcmps must have the same shape. |
7163 | if (!isLegalOrBeforeLegalizer( |
7164 | Query: {TargetOpcode::G_FCMP, {CmpTy, CmpOperandTy}}) || |
7165 | !MRI.hasOneNonDBGUse(RegNo: Logic->getReg(Idx: 0)) || |
7166 | !MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) || |
7167 | !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)) || |
7168 | MRI.getType(Reg: Cmp1->getLHSReg()) != MRI.getType(Reg: Cmp2->getLHSReg())) |
7169 | return false; |
7170 | |
7171 | CmpInst::Predicate PredL = Cmp1->getCond(); |
7172 | CmpInst::Predicate PredR = Cmp2->getCond(); |
7173 | Register LHS0 = Cmp1->getLHSReg(); |
7174 | Register LHS1 = Cmp1->getRHSReg(); |
7175 | Register RHS0 = Cmp2->getLHSReg(); |
7176 | Register RHS1 = Cmp2->getRHSReg(); |
7177 | |
7178 | if (LHS0 == RHS1 && LHS1 == RHS0) { |
7179 | // Swap RHS operands to match LHS. |
7180 | PredR = CmpInst::getSwappedPredicate(pred: PredR); |
7181 | std::swap(a&: RHS0, b&: RHS1); |
7182 | } |
7183 | |
7184 | if (LHS0 == RHS0 && LHS1 == RHS1) { |
7185 | // We determine the new predicate. |
7186 | unsigned CmpCodeL = getFCmpCode(CC: PredL); |
7187 | unsigned CmpCodeR = getFCmpCode(CC: PredR); |
7188 | unsigned NewPred = IsAnd ? CmpCodeL & CmpCodeR : CmpCodeL | CmpCodeR; |
7189 | unsigned Flags = Cmp1->getFlags() | Cmp2->getFlags(); |
7190 | MatchInfo = [=](MachineIRBuilder &B) { |
7191 | // The fcmp predicates fill the lower part of the enum. |
7192 | FCmpInst::Predicate Pred = static_cast<FCmpInst::Predicate>(NewPred); |
7193 | if (Pred == FCmpInst::FCMP_FALSE && |
7194 | isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) { |
7195 | auto False = B.buildConstant(Res: CmpTy, Val: 0); |
7196 | B.buildZExtOrTrunc(Res: DestReg, Op: False); |
7197 | } else if (Pred == FCmpInst::FCMP_TRUE && |
7198 | isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) { |
7199 | auto True = |
7200 | B.buildConstant(Res: CmpTy, Val: getICmpTrueVal(TLI: getTargetLowering(), |
7201 | IsVector: CmpTy.isVector() /*isVector*/, |
7202 | IsFP: true /*isFP*/)); |
7203 | B.buildZExtOrTrunc(Res: DestReg, Op: True); |
7204 | } else { // We take the predicate without predicate optimizations. |
7205 | auto Cmp = B.buildFCmp(Pred, Res: CmpTy, Op0: LHS0, Op1: LHS1, Flags); |
7206 | B.buildZExtOrTrunc(Res: DestReg, Op: Cmp); |
7207 | } |
7208 | }; |
7209 | return true; |
7210 | } |
7211 | |
7212 | return false; |
7213 | } |
7214 | |
7215 | bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) { |
7216 | GAnd *And = cast<GAnd>(Val: &MI); |
7217 | |
7218 | if (tryFoldAndOrOrICmpsUsingRanges(Logic: And, MatchInfo)) |
7219 | return true; |
7220 | |
7221 | if (tryFoldLogicOfFCmps(Logic: And, MatchInfo)) |
7222 | return true; |
7223 | |
7224 | return false; |
7225 | } |
7226 | |
7227 | bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) { |
7228 | GOr *Or = cast<GOr>(Val: &MI); |
7229 | |
7230 | if (tryFoldAndOrOrICmpsUsingRanges(Logic: Or, MatchInfo)) |
7231 | return true; |
7232 | |
7233 | if (tryFoldLogicOfFCmps(Logic: Or, MatchInfo)) |
7234 | return true; |
7235 | |
7236 | return false; |
7237 | } |
7238 | |
7239 | bool CombinerHelper::matchAddOverflow(MachineInstr &MI, BuildFnTy &MatchInfo) { |
7240 | GAddCarryOut *Add = cast<GAddCarryOut>(Val: &MI); |
7241 | |
7242 | // Addo has no flags |
7243 | Register Dst = Add->getReg(Idx: 0); |
7244 | Register Carry = Add->getReg(Idx: 1); |
7245 | Register LHS = Add->getLHSReg(); |
7246 | Register RHS = Add->getRHSReg(); |
7247 | bool IsSigned = Add->isSigned(); |
7248 | LLT DstTy = MRI.getType(Reg: Dst); |
7249 | LLT CarryTy = MRI.getType(Reg: Carry); |
7250 | |
7251 | // Fold addo, if the carry is dead -> add, undef. |
7252 | if (MRI.use_nodbg_empty(RegNo: Carry) && |
7253 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}})) { |
7254 | MatchInfo = [=](MachineIRBuilder &B) { |
7255 | B.buildAdd(Dst, Src0: LHS, Src1: RHS); |
7256 | B.buildUndef(Res: Carry); |
7257 | }; |
7258 | return true; |
7259 | } |
7260 | |
7261 | // Canonicalize constant to RHS. |
7262 | if (isConstantOrConstantVectorI(Src: LHS) && !isConstantOrConstantVectorI(Src: RHS)) { |
7263 | if (IsSigned) { |
7264 | MatchInfo = [=](MachineIRBuilder &B) { |
7265 | B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS); |
7266 | }; |
7267 | return true; |
7268 | } |
7269 | // !IsSigned |
7270 | MatchInfo = [=](MachineIRBuilder &B) { |
7271 | B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS); |
7272 | }; |
7273 | return true; |
7274 | } |
7275 | |
7276 | std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(Src: LHS); |
7277 | std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(Src: RHS); |
7278 | |
7279 | // Fold addo(c1, c2) -> c3, carry. |
7280 | if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(Ty: DstTy) && |
7281 | isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) { |
7282 | bool Overflow; |
7283 | APInt Result = IsSigned ? MaybeLHS->sadd_ov(RHS: *MaybeRHS, Overflow) |
7284 | : MaybeLHS->uadd_ov(RHS: *MaybeRHS, Overflow); |
7285 | MatchInfo = [=](MachineIRBuilder &B) { |
7286 | B.buildConstant(Res: Dst, Val: Result); |
7287 | B.buildConstant(Res: Carry, Val: Overflow); |
7288 | }; |
7289 | return true; |
7290 | } |
7291 | |
7292 | // Fold (addo x, 0) -> x, no carry |
7293 | if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) { |
7294 | MatchInfo = [=](MachineIRBuilder &B) { |
7295 | B.buildCopy(Res: Dst, Op: LHS); |
7296 | B.buildConstant(Res: Carry, Val: 0); |
7297 | }; |
7298 | return true; |
7299 | } |
7300 | |
7301 | // Given 2 constant operands whose sum does not overflow: |
7302 | // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1 |
7303 | // saddo (X +nsw C0), C1 -> saddo X, C0 + C1 |
7304 | GAdd *AddLHS = getOpcodeDef<GAdd>(Reg: LHS, MRI); |
7305 | if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)) && |
7306 | ((IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoSWrap)) || |
7307 | (!IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoUWrap)))) { |
7308 | std::optional<APInt> MaybeAddRHS = |
7309 | getConstantOrConstantSplatVector(Src: AddLHS->getRHSReg()); |
7310 | if (MaybeAddRHS) { |
7311 | bool Overflow; |
7312 | APInt NewC = IsSigned ? MaybeAddRHS->sadd_ov(RHS: *MaybeRHS, Overflow) |
7313 | : MaybeAddRHS->uadd_ov(RHS: *MaybeRHS, Overflow); |
7314 | if (!Overflow && isConstantLegalOrBeforeLegalizer(Ty: DstTy)) { |
7315 | if (IsSigned) { |
7316 | MatchInfo = [=](MachineIRBuilder &B) { |
7317 | auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC); |
7318 | B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS); |
7319 | }; |
7320 | return true; |
7321 | } |
7322 | // !IsSigned |
7323 | MatchInfo = [=](MachineIRBuilder &B) { |
7324 | auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC); |
7325 | B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS); |
7326 | }; |
7327 | return true; |
7328 | } |
7329 | } |
7330 | }; |
7331 | |
7332 | // We try to combine addo to non-overflowing add. |
7333 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}}) || |
7334 | !isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) |
7335 | return false; |
7336 | |
7337 | // We try to combine uaddo to non-overflowing add. |
7338 | if (!IsSigned) { |
7339 | ConstantRange CRLHS = |
7340 | ConstantRange::fromKnownBits(Known: KB->getKnownBits(R: LHS), /*IsSigned=*/false); |
7341 | ConstantRange CRRHS = |
7342 | ConstantRange::fromKnownBits(Known: KB->getKnownBits(R: RHS), /*IsSigned=*/false); |
7343 | |
7344 | switch (CRLHS.unsignedAddMayOverflow(Other: CRRHS)) { |
7345 | case ConstantRange::OverflowResult::MayOverflow: |
7346 | return false; |
7347 | case ConstantRange::OverflowResult::NeverOverflows: { |
7348 | MatchInfo = [=](MachineIRBuilder &B) { |
7349 | B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap); |
7350 | B.buildConstant(Res: Carry, Val: 0); |
7351 | }; |
7352 | return true; |
7353 | } |
7354 | case ConstantRange::OverflowResult::AlwaysOverflowsLow: |
7355 | case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { |
7356 | MatchInfo = [=](MachineIRBuilder &B) { |
7357 | B.buildAdd(Dst, Src0: LHS, Src1: RHS); |
7358 | B.buildConstant(Res: Carry, Val: 1); |
7359 | }; |
7360 | return true; |
7361 | } |
7362 | } |
7363 | return false; |
7364 | } |
7365 | |
7366 | // We try to combine saddo to non-overflowing add. |
7367 | |
7368 | // If LHS and RHS each have at least two sign bits, then there is no signed |
7369 | // overflow. |
7370 | if (KB->computeNumSignBits(R: RHS) > 1 && KB->computeNumSignBits(R: LHS) > 1) { |
7371 | MatchInfo = [=](MachineIRBuilder &B) { |
7372 | B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap); |
7373 | B.buildConstant(Res: Carry, Val: 0); |
7374 | }; |
7375 | return true; |
7376 | } |
7377 | |
7378 | ConstantRange CRLHS = |
7379 | ConstantRange::fromKnownBits(Known: KB->getKnownBits(R: LHS), /*IsSigned=*/true); |
7380 | ConstantRange CRRHS = |
7381 | ConstantRange::fromKnownBits(Known: KB->getKnownBits(R: RHS), /*IsSigned=*/true); |
7382 | |
7383 | switch (CRLHS.signedAddMayOverflow(Other: CRRHS)) { |
7384 | case ConstantRange::OverflowResult::MayOverflow: |
7385 | return false; |
7386 | case ConstantRange::OverflowResult::NeverOverflows: { |
7387 | MatchInfo = [=](MachineIRBuilder &B) { |
7388 | B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap); |
7389 | B.buildConstant(Res: Carry, Val: 0); |
7390 | }; |
7391 | return true; |
7392 | } |
7393 | case ConstantRange::OverflowResult::AlwaysOverflowsLow: |
7394 | case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { |
7395 | MatchInfo = [=](MachineIRBuilder &B) { |
7396 | B.buildAdd(Dst, Src0: LHS, Src1: RHS); |
7397 | B.buildConstant(Res: Carry, Val: 1); |
7398 | }; |
7399 | return true; |
7400 | } |
7401 | } |
7402 | |
7403 | return false; |
7404 | } |
7405 | |
7406 | void CombinerHelper::applyBuildFnMO(const MachineOperand &MO, |
7407 | BuildFnTy &MatchInfo) { |
7408 | MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI); |
7409 | MatchInfo(Builder); |
7410 | Root->eraseFromParent(); |
7411 | } |
7412 | |
7413 | bool CombinerHelper::matchFPowIExpansion(MachineInstr &MI, int64_t Exponent) { |
7414 | bool OptForSize = MI.getMF()->getFunction().hasOptSize(); |
7415 | return getTargetLowering().isBeneficialToExpandPowI(Exponent, OptForSize); |
7416 | } |
7417 | |
7418 | void CombinerHelper::applyExpandFPowI(MachineInstr &MI, int64_t Exponent) { |
7419 | auto [Dst, Base] = MI.getFirst2Regs(); |
7420 | LLT Ty = MRI.getType(Reg: Dst); |
7421 | int64_t ExpVal = Exponent; |
7422 | |
7423 | if (ExpVal == 0) { |
7424 | Builder.buildFConstant(Res: Dst, Val: 1.0); |
7425 | MI.removeFromParent(); |
7426 | return; |
7427 | } |
7428 | |
7429 | if (ExpVal < 0) |
7430 | ExpVal = -ExpVal; |
7431 | |
7432 | // We use the simple binary decomposition method from SelectionDAG ExpandPowI |
7433 | // to generate the multiply sequence. There are more optimal ways to do this |
7434 | // (for example, powi(x,15) generates one more multiply than it should), but |
7435 | // this has the benefit of being both really simple and much better than a |
7436 | // libcall. |
7437 | std::optional<SrcOp> Res; |
7438 | SrcOp CurSquare = Base; |
7439 | while (ExpVal > 0) { |
7440 | if (ExpVal & 1) { |
7441 | if (!Res) |
7442 | Res = CurSquare; |
7443 | else |
7444 | Res = Builder.buildFMul(Dst: Ty, Src0: *Res, Src1: CurSquare); |
7445 | } |
7446 | |
7447 | CurSquare = Builder.buildFMul(Dst: Ty, Src0: CurSquare, Src1: CurSquare); |
7448 | ExpVal >>= 1; |
7449 | } |
7450 | |
7451 | // If the original exponent was negative, invert the result, producing |
7452 | // 1/(x*x*x). |
7453 | if (Exponent < 0) |
7454 | Res = Builder.buildFDiv(Dst: Ty, Src0: Builder.buildFConstant(Res: Ty, Val: 1.0), Src1: *Res, |
7455 | Flags: MI.getFlags()); |
7456 | |
7457 | Builder.buildCopy(Res: Dst, Op: *Res); |
7458 | MI.eraseFromParent(); |
7459 | } |
7460 | |
7461 | bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO, |
7462 | BuildFnTy &MatchInfo) { |
7463 | GSext *Sext = cast<GSext>(Val: getDefIgnoringCopies(Reg: MO.getReg(), MRI)); |
7464 | GTrunc *Trunc = cast<GTrunc>(Val: getDefIgnoringCopies(Reg: Sext->getSrcReg(), MRI)); |
7465 | |
7466 | Register Dst = Sext->getReg(Idx: 0); |
7467 | Register Src = Trunc->getSrcReg(); |
7468 | |
7469 | LLT DstTy = MRI.getType(Reg: Dst); |
7470 | LLT SrcTy = MRI.getType(Reg: Src); |
7471 | |
7472 | if (DstTy == SrcTy) { |
7473 | MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: Src); }; |
7474 | return true; |
7475 | } |
7476 | |
7477 | if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && |
7478 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { |
7479 | MatchInfo = [=](MachineIRBuilder &B) { |
7480 | B.buildTrunc(Res: Dst, Op: Src, Flags: MachineInstr::MIFlag::NoSWrap); |
7481 | }; |
7482 | return true; |
7483 | } |
7484 | |
7485 | if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && |
7486 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXT, {DstTy, SrcTy}})) { |
7487 | MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Res: Dst, Op: Src); }; |
7488 | return true; |
7489 | } |
7490 | |
7491 | return false; |
7492 | } |
7493 | |
7494 | bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO, |
7495 | BuildFnTy &MatchInfo) { |
7496 | GZext *Zext = cast<GZext>(Val: getDefIgnoringCopies(Reg: MO.getReg(), MRI)); |
7497 | GTrunc *Trunc = cast<GTrunc>(Val: getDefIgnoringCopies(Reg: Zext->getSrcReg(), MRI)); |
7498 | |
7499 | Register Dst = Zext->getReg(Idx: 0); |
7500 | Register Src = Trunc->getSrcReg(); |
7501 | |
7502 | LLT DstTy = MRI.getType(Reg: Dst); |
7503 | LLT SrcTy = MRI.getType(Reg: Src); |
7504 | |
7505 | if (DstTy == SrcTy) { |
7506 | MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: Src); }; |
7507 | return true; |
7508 | } |
7509 | |
7510 | if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && |
7511 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { |
7512 | MatchInfo = [=](MachineIRBuilder &B) { |
7513 | B.buildTrunc(Res: Dst, Op: Src, Flags: MachineInstr::MIFlag::NoUWrap); |
7514 | }; |
7515 | return true; |
7516 | } |
7517 | |
7518 | if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && |
7519 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) { |
7520 | MatchInfo = [=](MachineIRBuilder &B) { |
7521 | B.buildZExt(Res: Dst, Op: Src, Flags: MachineInstr::MIFlag::NonNeg); |
7522 | }; |
7523 | return true; |
7524 | } |
7525 | |
7526 | return false; |
7527 | } |
7528 | |
7529 | bool CombinerHelper::matchNonNegZext(const MachineOperand &MO, |
7530 | BuildFnTy &MatchInfo) { |
7531 | GZext *Zext = cast<GZext>(Val: MRI.getVRegDef(Reg: MO.getReg())); |
7532 | |
7533 | Register Dst = Zext->getReg(Idx: 0); |
7534 | Register Src = Zext->getSrcReg(); |
7535 | |
7536 | LLT DstTy = MRI.getType(Reg: Dst); |
7537 | LLT SrcTy = MRI.getType(Reg: Src); |
7538 | const auto &TLI = getTargetLowering(); |
7539 | |
7540 | // Convert zext nneg to sext if sext is the preferred form for the target. |
7541 | if (isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXT, {DstTy, SrcTy}}) && |
7542 | TLI.isSExtCheaperThanZExt(FromTy: getMVTForLLT(Ty: SrcTy), ToTy: getMVTForLLT(Ty: DstTy))) { |
7543 | MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Res: Dst, Op: Src); }; |
7544 | return true; |
7545 | } |
7546 | |
7547 | return false; |
7548 | } |
7549 | |