1 | //===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" |
9 | #include "llvm/ADT/APFloat.h" |
10 | #include "llvm/ADT/STLExtras.h" |
11 | #include "llvm/ADT/SetVector.h" |
12 | #include "llvm/ADT/SmallBitVector.h" |
13 | #include "llvm/Analysis/CmpInstAnalysis.h" |
14 | #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" |
15 | #include "llvm/CodeGen/GlobalISel/GISelValueTracking.h" |
16 | #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" |
17 | #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" |
18 | #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" |
19 | #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" |
20 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
21 | #include "llvm/CodeGen/GlobalISel/Utils.h" |
22 | #include "llvm/CodeGen/LowLevelTypeUtils.h" |
23 | #include "llvm/CodeGen/MachineBasicBlock.h" |
24 | #include "llvm/CodeGen/MachineDominators.h" |
25 | #include "llvm/CodeGen/MachineInstr.h" |
26 | #include "llvm/CodeGen/MachineMemOperand.h" |
27 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
28 | #include "llvm/CodeGen/Register.h" |
29 | #include "llvm/CodeGen/RegisterBankInfo.h" |
30 | #include "llvm/CodeGen/TargetInstrInfo.h" |
31 | #include "llvm/CodeGen/TargetLowering.h" |
32 | #include "llvm/CodeGen/TargetOpcodes.h" |
33 | #include "llvm/IR/ConstantRange.h" |
34 | #include "llvm/IR/DataLayout.h" |
35 | #include "llvm/IR/InstrTypes.h" |
36 | #include "llvm/Support/Casting.h" |
37 | #include "llvm/Support/DivisionByConstantInfo.h" |
38 | #include "llvm/Support/ErrorHandling.h" |
39 | #include "llvm/Support/MathExtras.h" |
40 | #include "llvm/Target/TargetMachine.h" |
41 | #include <cmath> |
42 | #include <optional> |
43 | #include <tuple> |
44 | |
45 | #define DEBUG_TYPE "gi-combiner" |
46 | |
47 | using namespace llvm; |
48 | using namespace MIPatternMatch; |
49 | |
50 | // Option to allow testing of the combiner while no targets know about indexed |
51 | // addressing. |
52 | static cl::opt<bool> |
53 | ForceLegalIndexing("force-legal-indexing" , cl::Hidden, cl::init(Val: false), |
54 | cl::desc("Force all indexed operations to be " |
55 | "legal for the GlobalISel combiner" )); |
56 | |
57 | CombinerHelper::CombinerHelper(GISelChangeObserver &Observer, |
58 | MachineIRBuilder &B, bool IsPreLegalize, |
59 | GISelValueTracking *VT, |
60 | MachineDominatorTree *MDT, |
61 | const LegalizerInfo *LI) |
62 | : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), VT(VT), |
63 | MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI), |
64 | RBI(Builder.getMF().getSubtarget().getRegBankInfo()), |
65 | TRI(Builder.getMF().getSubtarget().getRegisterInfo()) { |
66 | (void)this->VT; |
67 | } |
68 | |
69 | const TargetLowering &CombinerHelper::getTargetLowering() const { |
70 | return *Builder.getMF().getSubtarget().getTargetLowering(); |
71 | } |
72 | |
73 | const MachineFunction &CombinerHelper::getMachineFunction() const { |
74 | return Builder.getMF(); |
75 | } |
76 | |
77 | const DataLayout &CombinerHelper::getDataLayout() const { |
78 | return getMachineFunction().getDataLayout(); |
79 | } |
80 | |
81 | LLVMContext &CombinerHelper::getContext() const { return Builder.getContext(); } |
82 | |
83 | /// \returns The little endian in-memory byte position of byte \p I in a |
84 | /// \p ByteWidth bytes wide type. |
85 | /// |
86 | /// E.g. Given a 4-byte type x, x[0] -> byte 0 |
87 | static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) { |
88 | assert(I < ByteWidth && "I must be in [0, ByteWidth)" ); |
89 | return I; |
90 | } |
91 | |
92 | /// Determines the LogBase2 value for a non-null input value using the |
93 | /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V). |
94 | static Register buildLogBase2(Register V, MachineIRBuilder &MIB) { |
95 | auto &MRI = *MIB.getMRI(); |
96 | LLT Ty = MRI.getType(Reg: V); |
97 | auto Ctlz = MIB.buildCTLZ(Dst: Ty, Src0: V); |
98 | auto Base = MIB.buildConstant(Res: Ty, Val: Ty.getScalarSizeInBits() - 1); |
99 | return MIB.buildSub(Dst: Ty, Src0: Base, Src1: Ctlz).getReg(Idx: 0); |
100 | } |
101 | |
102 | /// \returns The big endian in-memory byte position of byte \p I in a |
103 | /// \p ByteWidth bytes wide type. |
104 | /// |
105 | /// E.g. Given a 4-byte type x, x[0] -> byte 3 |
106 | static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) { |
107 | assert(I < ByteWidth && "I must be in [0, ByteWidth)" ); |
108 | return ByteWidth - I - 1; |
109 | } |
110 | |
111 | /// Given a map from byte offsets in memory to indices in a load/store, |
112 | /// determine if that map corresponds to a little or big endian byte pattern. |
113 | /// |
114 | /// \param MemOffset2Idx maps memory offsets to address offsets. |
115 | /// \param LowestIdx is the lowest index in \p MemOffset2Idx. |
116 | /// |
117 | /// \returns true if the map corresponds to a big endian byte pattern, false if |
118 | /// it corresponds to a little endian byte pattern, and std::nullopt otherwise. |
119 | /// |
120 | /// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns |
121 | /// are as follows: |
122 | /// |
123 | /// AddrOffset Little endian Big endian |
124 | /// 0 0 3 |
125 | /// 1 1 2 |
126 | /// 2 2 1 |
127 | /// 3 3 0 |
128 | static std::optional<bool> |
129 | isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, |
130 | int64_t LowestIdx) { |
131 | // Need at least two byte positions to decide on endianness. |
132 | unsigned Width = MemOffset2Idx.size(); |
133 | if (Width < 2) |
134 | return std::nullopt; |
135 | bool BigEndian = true, LittleEndian = true; |
136 | for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) { |
137 | auto MemOffsetAndIdx = MemOffset2Idx.find(Val: MemOffset); |
138 | if (MemOffsetAndIdx == MemOffset2Idx.end()) |
139 | return std::nullopt; |
140 | const int64_t Idx = MemOffsetAndIdx->second - LowestIdx; |
141 | assert(Idx >= 0 && "Expected non-negative byte offset?" ); |
142 | LittleEndian &= Idx == littleEndianByteAt(ByteWidth: Width, I: MemOffset); |
143 | BigEndian &= Idx == bigEndianByteAt(ByteWidth: Width, I: MemOffset); |
144 | if (!BigEndian && !LittleEndian) |
145 | return std::nullopt; |
146 | } |
147 | |
148 | assert((BigEndian != LittleEndian) && |
149 | "Pattern cannot be both big and little endian!" ); |
150 | return BigEndian; |
151 | } |
152 | |
153 | bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; } |
154 | |
155 | bool CombinerHelper::isLegal(const LegalityQuery &Query) const { |
156 | assert(LI && "Must have LegalizerInfo to query isLegal!" ); |
157 | return LI->getAction(Query).Action == LegalizeActions::Legal; |
158 | } |
159 | |
160 | bool CombinerHelper::isLegalOrBeforeLegalizer( |
161 | const LegalityQuery &Query) const { |
162 | return isPreLegalize() || isLegal(Query); |
163 | } |
164 | |
165 | bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const { |
166 | if (!Ty.isVector()) |
167 | return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CONSTANT, {Ty}}); |
168 | // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs. |
169 | if (isPreLegalize()) |
170 | return true; |
171 | LLT EltTy = Ty.getElementType(); |
172 | return isLegal(Query: {TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) && |
173 | isLegal(Query: {TargetOpcode::G_CONSTANT, {EltTy}}); |
174 | } |
175 | |
176 | void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg, |
177 | Register ToReg) const { |
178 | Observer.changingAllUsesOfReg(MRI, Reg: FromReg); |
179 | |
180 | if (MRI.constrainRegAttrs(Reg: ToReg, ConstrainingReg: FromReg)) |
181 | MRI.replaceRegWith(FromReg, ToReg); |
182 | else |
183 | Builder.buildCopy(Res: FromReg, Op: ToReg); |
184 | |
185 | Observer.finishedChangingAllUsesOfReg(); |
186 | } |
187 | |
188 | void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI, |
189 | MachineOperand &FromRegOp, |
190 | Register ToReg) const { |
191 | assert(FromRegOp.getParent() && "Expected an operand in an MI" ); |
192 | Observer.changingInstr(MI&: *FromRegOp.getParent()); |
193 | |
194 | FromRegOp.setReg(ToReg); |
195 | |
196 | Observer.changedInstr(MI&: *FromRegOp.getParent()); |
197 | } |
198 | |
199 | void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI, |
200 | unsigned ToOpcode) const { |
201 | Observer.changingInstr(MI&: FromMI); |
202 | |
203 | FromMI.setDesc(Builder.getTII().get(Opcode: ToOpcode)); |
204 | |
205 | Observer.changedInstr(MI&: FromMI); |
206 | } |
207 | |
208 | const RegisterBank *CombinerHelper::getRegBank(Register Reg) const { |
209 | return RBI->getRegBank(Reg, MRI, TRI: *TRI); |
210 | } |
211 | |
212 | void CombinerHelper::setRegBank(Register Reg, |
213 | const RegisterBank *RegBank) const { |
214 | if (RegBank) |
215 | MRI.setRegBank(Reg, RegBank: *RegBank); |
216 | } |
217 | |
218 | bool CombinerHelper::tryCombineCopy(MachineInstr &MI) const { |
219 | if (matchCombineCopy(MI)) { |
220 | applyCombineCopy(MI); |
221 | return true; |
222 | } |
223 | return false; |
224 | } |
225 | bool CombinerHelper::matchCombineCopy(MachineInstr &MI) const { |
226 | if (MI.getOpcode() != TargetOpcode::COPY) |
227 | return false; |
228 | Register DstReg = MI.getOperand(i: 0).getReg(); |
229 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
230 | return canReplaceReg(DstReg, SrcReg, MRI); |
231 | } |
232 | void CombinerHelper::applyCombineCopy(MachineInstr &MI) const { |
233 | Register DstReg = MI.getOperand(i: 0).getReg(); |
234 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
235 | replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg); |
236 | MI.eraseFromParent(); |
237 | } |
238 | |
239 | bool CombinerHelper::matchFreezeOfSingleMaybePoisonOperand( |
240 | MachineInstr &MI, BuildFnTy &MatchInfo) const { |
241 | // Ported from InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating. |
242 | Register DstOp = MI.getOperand(i: 0).getReg(); |
243 | Register OrigOp = MI.getOperand(i: 1).getReg(); |
244 | |
245 | if (!MRI.hasOneNonDBGUse(RegNo: OrigOp)) |
246 | return false; |
247 | |
248 | MachineInstr *OrigDef = MRI.getUniqueVRegDef(Reg: OrigOp); |
249 | // Even if only a single operand of the PHI is not guaranteed non-poison, |
250 | // moving freeze() backwards across a PHI can cause optimization issues for |
251 | // other users of that operand. |
252 | // |
253 | // Moving freeze() from one of the output registers of a G_UNMERGE_VALUES to |
254 | // the source register is unprofitable because it makes the freeze() more |
255 | // strict than is necessary (it would affect the whole register instead of |
256 | // just the subreg being frozen). |
257 | if (OrigDef->isPHI() || isa<GUnmerge>(Val: OrigDef)) |
258 | return false; |
259 | |
260 | if (canCreateUndefOrPoison(Reg: OrigOp, MRI, |
261 | /*ConsiderFlagsAndMetadata=*/false)) |
262 | return false; |
263 | |
264 | std::optional<MachineOperand> MaybePoisonOperand; |
265 | for (MachineOperand &Operand : OrigDef->uses()) { |
266 | if (!Operand.isReg()) |
267 | return false; |
268 | |
269 | if (isGuaranteedNotToBeUndefOrPoison(Reg: Operand.getReg(), MRI)) |
270 | continue; |
271 | |
272 | if (!MaybePoisonOperand) |
273 | MaybePoisonOperand = Operand; |
274 | else { |
275 | // We have more than one maybe-poison operand. Moving the freeze is |
276 | // unsafe. |
277 | return false; |
278 | } |
279 | } |
280 | |
281 | // Eliminate freeze if all operands are guaranteed non-poison. |
282 | if (!MaybePoisonOperand) { |
283 | MatchInfo = [=](MachineIRBuilder &B) { |
284 | Observer.changingInstr(MI&: *OrigDef); |
285 | cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags(); |
286 | Observer.changedInstr(MI&: *OrigDef); |
287 | B.buildCopy(Res: DstOp, Op: OrigOp); |
288 | }; |
289 | return true; |
290 | } |
291 | |
292 | Register MaybePoisonOperandReg = MaybePoisonOperand->getReg(); |
293 | LLT MaybePoisonOperandRegTy = MRI.getType(Reg: MaybePoisonOperandReg); |
294 | |
295 | MatchInfo = [=](MachineIRBuilder &B) mutable { |
296 | Observer.changingInstr(MI&: *OrigDef); |
297 | cast<GenericMachineInstr>(Val: OrigDef)->dropPoisonGeneratingFlags(); |
298 | Observer.changedInstr(MI&: *OrigDef); |
299 | B.setInsertPt(MBB&: *OrigDef->getParent(), II: OrigDef->getIterator()); |
300 | auto Freeze = B.buildFreeze(Dst: MaybePoisonOperandRegTy, Src: MaybePoisonOperandReg); |
301 | replaceRegOpWith( |
302 | MRI, FromRegOp&: *OrigDef->findRegisterUseOperand(Reg: MaybePoisonOperandReg, TRI), |
303 | ToReg: Freeze.getReg(Idx: 0)); |
304 | replaceRegWith(MRI, FromReg: DstOp, ToReg: OrigOp); |
305 | }; |
306 | return true; |
307 | } |
308 | |
309 | bool CombinerHelper::matchCombineConcatVectors( |
310 | MachineInstr &MI, SmallVector<Register> &Ops) const { |
311 | assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS && |
312 | "Invalid instruction" ); |
313 | bool IsUndef = true; |
314 | MachineInstr *Undef = nullptr; |
315 | |
316 | // Walk over all the operands of concat vectors and check if they are |
317 | // build_vector themselves or undef. |
318 | // Then collect their operands in Ops. |
319 | for (const MachineOperand &MO : MI.uses()) { |
320 | Register Reg = MO.getReg(); |
321 | MachineInstr *Def = MRI.getVRegDef(Reg); |
322 | assert(Def && "Operand not defined" ); |
323 | if (!MRI.hasOneNonDBGUse(RegNo: Reg)) |
324 | return false; |
325 | switch (Def->getOpcode()) { |
326 | case TargetOpcode::G_BUILD_VECTOR: |
327 | IsUndef = false; |
328 | // Remember the operands of the build_vector to fold |
329 | // them into the yet-to-build flattened concat vectors. |
330 | for (const MachineOperand &BuildVecMO : Def->uses()) |
331 | Ops.push_back(Elt: BuildVecMO.getReg()); |
332 | break; |
333 | case TargetOpcode::G_IMPLICIT_DEF: { |
334 | LLT OpType = MRI.getType(Reg); |
335 | // Keep one undef value for all the undef operands. |
336 | if (!Undef) { |
337 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
338 | Undef = Builder.buildUndef(Res: OpType.getScalarType()); |
339 | } |
340 | assert(MRI.getType(Undef->getOperand(0).getReg()) == |
341 | OpType.getScalarType() && |
342 | "All undefs should have the same type" ); |
343 | // Break the undef vector in as many scalar elements as needed |
344 | // for the flattening. |
345 | for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements(); |
346 | EltIdx != EltEnd; ++EltIdx) |
347 | Ops.push_back(Elt: Undef->getOperand(i: 0).getReg()); |
348 | break; |
349 | } |
350 | default: |
351 | return false; |
352 | } |
353 | } |
354 | |
355 | // Check if the combine is illegal |
356 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
357 | if (!isLegalOrBeforeLegalizer( |
358 | Query: {TargetOpcode::G_BUILD_VECTOR, {DstTy, MRI.getType(Reg: Ops[0])}})) { |
359 | return false; |
360 | } |
361 | |
362 | if (IsUndef) |
363 | Ops.clear(); |
364 | |
365 | return true; |
366 | } |
367 | void CombinerHelper::applyCombineConcatVectors( |
368 | MachineInstr &MI, SmallVector<Register> &Ops) const { |
369 | // We determined that the concat_vectors can be flatten. |
370 | // Generate the flattened build_vector. |
371 | Register DstReg = MI.getOperand(i: 0).getReg(); |
372 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
373 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg); |
374 | |
375 | // Note: IsUndef is sort of redundant. We could have determine it by |
376 | // checking that at all Ops are undef. Alternatively, we could have |
377 | // generate a build_vector of undefs and rely on another combine to |
378 | // clean that up. For now, given we already gather this information |
379 | // in matchCombineConcatVectors, just save compile time and issue the |
380 | // right thing. |
381 | if (Ops.empty()) |
382 | Builder.buildUndef(Res: NewDstReg); |
383 | else |
384 | Builder.buildBuildVector(Res: NewDstReg, Ops); |
385 | replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg); |
386 | MI.eraseFromParent(); |
387 | } |
388 | |
389 | bool CombinerHelper::matchCombineShuffleToBuildVector(MachineInstr &MI) const { |
390 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && |
391 | "Invalid instruction" ); |
392 | auto &Shuffle = cast<GShuffleVector>(Val&: MI); |
393 | |
394 | Register SrcVec1 = Shuffle.getSrc1Reg(); |
395 | Register SrcVec2 = Shuffle.getSrc2Reg(); |
396 | |
397 | LLT SrcVec1Type = MRI.getType(Reg: SrcVec1); |
398 | LLT SrcVec2Type = MRI.getType(Reg: SrcVec2); |
399 | return SrcVec1Type.isVector() && SrcVec2Type.isVector(); |
400 | } |
401 | |
402 | void CombinerHelper::applyCombineShuffleToBuildVector(MachineInstr &MI) const { |
403 | auto &Shuffle = cast<GShuffleVector>(Val&: MI); |
404 | |
405 | Register SrcVec1 = Shuffle.getSrc1Reg(); |
406 | Register SrcVec2 = Shuffle.getSrc2Reg(); |
407 | LLT EltTy = MRI.getType(Reg: SrcVec1).getElementType(); |
408 | int Width = MRI.getType(Reg: SrcVec1).getNumElements(); |
409 | |
410 | auto Unmerge1 = Builder.buildUnmerge(Res: EltTy, Op: SrcVec1); |
411 | auto Unmerge2 = Builder.buildUnmerge(Res: EltTy, Op: SrcVec2); |
412 | |
413 | SmallVector<Register> ; |
414 | // Select only applicable elements from unmerged values. |
415 | for (int Val : Shuffle.getMask()) { |
416 | if (Val == -1) |
417 | Extracts.push_back(Elt: Builder.buildUndef(Res: EltTy).getReg(Idx: 0)); |
418 | else if (Val < Width) |
419 | Extracts.push_back(Elt: Unmerge1.getReg(Idx: Val)); |
420 | else |
421 | Extracts.push_back(Elt: Unmerge2.getReg(Idx: Val - Width)); |
422 | } |
423 | assert(Extracts.size() > 0 && "Expected at least one element in the shuffle" ); |
424 | if (Extracts.size() == 1) |
425 | Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: Extracts[0]); |
426 | else |
427 | Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: Extracts); |
428 | MI.eraseFromParent(); |
429 | } |
430 | |
431 | bool CombinerHelper::matchCombineShuffleConcat( |
432 | MachineInstr &MI, SmallVector<Register> &Ops) const { |
433 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
434 | auto ConcatMI1 = |
435 | dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg())); |
436 | auto ConcatMI2 = |
437 | dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg())); |
438 | if (!ConcatMI1 || !ConcatMI2) |
439 | return false; |
440 | |
441 | // Check that the sources of the Concat instructions have the same type |
442 | if (MRI.getType(Reg: ConcatMI1->getSourceReg(I: 0)) != |
443 | MRI.getType(Reg: ConcatMI2->getSourceReg(I: 0))) |
444 | return false; |
445 | |
446 | LLT ConcatSrcTy = MRI.getType(Reg: ConcatMI1->getReg(Idx: 1)); |
447 | LLT ShuffleSrcTy1 = MRI.getType(Reg: MI.getOperand(i: 1).getReg()); |
448 | unsigned ConcatSrcNumElt = ConcatSrcTy.getNumElements(); |
449 | for (unsigned i = 0; i < Mask.size(); i += ConcatSrcNumElt) { |
450 | // Check if the index takes a whole source register from G_CONCAT_VECTORS |
451 | // Assumes that all Sources of G_CONCAT_VECTORS are the same type |
452 | if (Mask[i] == -1) { |
453 | for (unsigned j = 1; j < ConcatSrcNumElt; j++) { |
454 | if (i + j >= Mask.size()) |
455 | return false; |
456 | if (Mask[i + j] != -1) |
457 | return false; |
458 | } |
459 | if (!isLegalOrBeforeLegalizer( |
460 | Query: {TargetOpcode::G_IMPLICIT_DEF, {ConcatSrcTy}})) |
461 | return false; |
462 | Ops.push_back(Elt: 0); |
463 | } else if (Mask[i] % ConcatSrcNumElt == 0) { |
464 | for (unsigned j = 1; j < ConcatSrcNumElt; j++) { |
465 | if (i + j >= Mask.size()) |
466 | return false; |
467 | if (Mask[i + j] != Mask[i] + static_cast<int>(j)) |
468 | return false; |
469 | } |
470 | // Retrieve the source register from its respective G_CONCAT_VECTORS |
471 | // instruction |
472 | if (Mask[i] < ShuffleSrcTy1.getNumElements()) { |
473 | Ops.push_back(Elt: ConcatMI1->getSourceReg(I: Mask[i] / ConcatSrcNumElt)); |
474 | } else { |
475 | Ops.push_back(Elt: ConcatMI2->getSourceReg(I: Mask[i] / ConcatSrcNumElt - |
476 | ConcatMI1->getNumSources())); |
477 | } |
478 | } else { |
479 | return false; |
480 | } |
481 | } |
482 | |
483 | if (!isLegalOrBeforeLegalizer( |
484 | Query: {TargetOpcode::G_CONCAT_VECTORS, |
485 | {MRI.getType(Reg: MI.getOperand(i: 0).getReg()), ConcatSrcTy}})) |
486 | return false; |
487 | |
488 | return !Ops.empty(); |
489 | } |
490 | |
491 | void CombinerHelper::applyCombineShuffleConcat( |
492 | MachineInstr &MI, SmallVector<Register> &Ops) const { |
493 | LLT SrcTy; |
494 | for (Register &Reg : Ops) { |
495 | if (Reg != 0) |
496 | SrcTy = MRI.getType(Reg); |
497 | } |
498 | assert(SrcTy.isValid() && "Unexpected full undef vector in concat combine" ); |
499 | |
500 | Register UndefReg = 0; |
501 | |
502 | for (Register &Reg : Ops) { |
503 | if (Reg == 0) { |
504 | if (UndefReg == 0) |
505 | UndefReg = Builder.buildUndef(Res: SrcTy).getReg(Idx: 0); |
506 | Reg = UndefReg; |
507 | } |
508 | } |
509 | |
510 | if (Ops.size() > 1) |
511 | Builder.buildConcatVectors(Res: MI.getOperand(i: 0).getReg(), Ops); |
512 | else |
513 | Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: Ops[0]); |
514 | MI.eraseFromParent(); |
515 | } |
516 | |
517 | bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) const { |
518 | SmallVector<Register, 4> Ops; |
519 | if (matchCombineShuffleVector(MI, Ops)) { |
520 | applyCombineShuffleVector(MI, Ops); |
521 | return true; |
522 | } |
523 | return false; |
524 | } |
525 | |
526 | bool CombinerHelper::matchCombineShuffleVector( |
527 | MachineInstr &MI, SmallVectorImpl<Register> &Ops) const { |
528 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && |
529 | "Invalid instruction kind" ); |
530 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
531 | Register Src1 = MI.getOperand(i: 1).getReg(); |
532 | LLT SrcType = MRI.getType(Reg: Src1); |
533 | // As bizarre as it may look, shuffle vector can actually produce |
534 | // scalar! This is because at the IR level a <1 x ty> shuffle |
535 | // vector is perfectly valid. |
536 | unsigned DstNumElts = DstType.isVector() ? DstType.getNumElements() : 1; |
537 | unsigned SrcNumElts = SrcType.isVector() ? SrcType.getNumElements() : 1; |
538 | |
539 | // If the resulting vector is smaller than the size of the source |
540 | // vectors being concatenated, we won't be able to replace the |
541 | // shuffle vector into a concat_vectors. |
542 | // |
543 | // Note: We may still be able to produce a concat_vectors fed by |
544 | // extract_vector_elt and so on. It is less clear that would |
545 | // be better though, so don't bother for now. |
546 | // |
547 | // If the destination is a scalar, the size of the sources doesn't |
548 | // matter. we will lower the shuffle to a plain copy. This will |
549 | // work only if the source and destination have the same size. But |
550 | // that's covered by the next condition. |
551 | // |
552 | // TODO: If the size between the source and destination don't match |
553 | // we could still emit an extract vector element in that case. |
554 | if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1) |
555 | return false; |
556 | |
557 | // Check that the shuffle mask can be broken evenly between the |
558 | // different sources. |
559 | if (DstNumElts % SrcNumElts != 0) |
560 | return false; |
561 | |
562 | // Mask length is a multiple of the source vector length. |
563 | // Check if the shuffle is some kind of concatenation of the input |
564 | // vectors. |
565 | unsigned NumConcat = DstNumElts / SrcNumElts; |
566 | SmallVector<int, 8> ConcatSrcs(NumConcat, -1); |
567 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
568 | for (unsigned i = 0; i != DstNumElts; ++i) { |
569 | int Idx = Mask[i]; |
570 | // Undef value. |
571 | if (Idx < 0) |
572 | continue; |
573 | // Ensure the indices in each SrcType sized piece are sequential and that |
574 | // the same source is used for the whole piece. |
575 | if ((Idx % SrcNumElts != (i % SrcNumElts)) || |
576 | (ConcatSrcs[i / SrcNumElts] >= 0 && |
577 | ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts))) |
578 | return false; |
579 | // Remember which source this index came from. |
580 | ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts; |
581 | } |
582 | |
583 | // The shuffle is concatenating multiple vectors together. |
584 | // Collect the different operands for that. |
585 | Register UndefReg; |
586 | Register Src2 = MI.getOperand(i: 2).getReg(); |
587 | for (auto Src : ConcatSrcs) { |
588 | if (Src < 0) { |
589 | if (!UndefReg) { |
590 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
591 | UndefReg = Builder.buildUndef(Res: SrcType).getReg(Idx: 0); |
592 | } |
593 | Ops.push_back(Elt: UndefReg); |
594 | } else if (Src == 0) |
595 | Ops.push_back(Elt: Src1); |
596 | else |
597 | Ops.push_back(Elt: Src2); |
598 | } |
599 | return true; |
600 | } |
601 | |
602 | void CombinerHelper::applyCombineShuffleVector( |
603 | MachineInstr &MI, const ArrayRef<Register> Ops) const { |
604 | Register DstReg = MI.getOperand(i: 0).getReg(); |
605 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
606 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg); |
607 | |
608 | if (Ops.size() == 1) |
609 | Builder.buildCopy(Res: NewDstReg, Op: Ops[0]); |
610 | else |
611 | Builder.buildMergeLikeInstr(Res: NewDstReg, Ops); |
612 | |
613 | replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg); |
614 | MI.eraseFromParent(); |
615 | } |
616 | |
617 | bool CombinerHelper::(MachineInstr &MI) const { |
618 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && |
619 | "Invalid instruction kind" ); |
620 | |
621 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
622 | return Mask.size() == 1; |
623 | } |
624 | |
625 | void CombinerHelper::(MachineInstr &MI) const { |
626 | Register DstReg = MI.getOperand(i: 0).getReg(); |
627 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
628 | |
629 | int I = MI.getOperand(i: 3).getShuffleMask()[0]; |
630 | Register Src1 = MI.getOperand(i: 1).getReg(); |
631 | LLT Src1Ty = MRI.getType(Reg: Src1); |
632 | int Src1NumElts = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1; |
633 | Register SrcReg; |
634 | if (I >= Src1NumElts) { |
635 | SrcReg = MI.getOperand(i: 2).getReg(); |
636 | I -= Src1NumElts; |
637 | } else if (I >= 0) |
638 | SrcReg = Src1; |
639 | |
640 | if (I < 0) |
641 | Builder.buildUndef(Res: DstReg); |
642 | else if (!MRI.getType(Reg: SrcReg).isVector()) |
643 | Builder.buildCopy(Res: DstReg, Op: SrcReg); |
644 | else |
645 | Builder.buildExtractVectorElementConstant(Res: DstReg, Val: SrcReg, Idx: I); |
646 | |
647 | MI.eraseFromParent(); |
648 | } |
649 | |
650 | namespace { |
651 | |
652 | /// Select a preference between two uses. CurrentUse is the current preference |
653 | /// while *ForCandidate is attributes of the candidate under consideration. |
654 | PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI, |
655 | PreferredTuple &CurrentUse, |
656 | const LLT TyForCandidate, |
657 | unsigned OpcodeForCandidate, |
658 | MachineInstr *MIForCandidate) { |
659 | if (!CurrentUse.Ty.isValid()) { |
660 | if (CurrentUse.ExtendOpcode == OpcodeForCandidate || |
661 | CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT) |
662 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
663 | return CurrentUse; |
664 | } |
665 | |
666 | // We permit the extend to hoist through basic blocks but this is only |
667 | // sensible if the target has extending loads. If you end up lowering back |
668 | // into a load and extend during the legalizer then the end result is |
669 | // hoisting the extend up to the load. |
670 | |
671 | // Prefer defined extensions to undefined extensions as these are more |
672 | // likely to reduce the number of instructions. |
673 | if (OpcodeForCandidate == TargetOpcode::G_ANYEXT && |
674 | CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT) |
675 | return CurrentUse; |
676 | else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT && |
677 | OpcodeForCandidate != TargetOpcode::G_ANYEXT) |
678 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
679 | |
680 | // Prefer sign extensions to zero extensions as sign-extensions tend to be |
681 | // more expensive. Don't do this if the load is already a zero-extend load |
682 | // though, otherwise we'll rewrite a zero-extend load into a sign-extend |
683 | // later. |
684 | if (!isa<GZExtLoad>(Val: LoadMI) && CurrentUse.Ty == TyForCandidate) { |
685 | if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT && |
686 | OpcodeForCandidate == TargetOpcode::G_ZEXT) |
687 | return CurrentUse; |
688 | else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT && |
689 | OpcodeForCandidate == TargetOpcode::G_SEXT) |
690 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
691 | } |
692 | |
693 | // This is potentially target specific. We've chosen the largest type |
694 | // because G_TRUNC is usually free. One potential catch with this is that |
695 | // some targets have a reduced number of larger registers than smaller |
696 | // registers and this choice potentially increases the live-range for the |
697 | // larger value. |
698 | if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) { |
699 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
700 | } |
701 | return CurrentUse; |
702 | } |
703 | |
704 | /// Find a suitable place to insert some instructions and insert them. This |
705 | /// function accounts for special cases like inserting before a PHI node. |
706 | /// The current strategy for inserting before PHI's is to duplicate the |
707 | /// instructions for each predecessor. However, while that's ok for G_TRUNC |
708 | /// on most targets since it generally requires no code, other targets/cases may |
709 | /// want to try harder to find a dominating block. |
710 | static void InsertInsnsWithoutSideEffectsBeforeUse( |
711 | MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO, |
712 | std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator, |
713 | MachineOperand &UseMO)> |
714 | Inserter) { |
715 | MachineInstr &UseMI = *UseMO.getParent(); |
716 | |
717 | MachineBasicBlock *InsertBB = UseMI.getParent(); |
718 | |
719 | // If the use is a PHI then we want the predecessor block instead. |
720 | if (UseMI.isPHI()) { |
721 | MachineOperand *PredBB = std::next(x: &UseMO); |
722 | InsertBB = PredBB->getMBB(); |
723 | } |
724 | |
725 | // If the block is the same block as the def then we want to insert just after |
726 | // the def instead of at the start of the block. |
727 | if (InsertBB == DefMI.getParent()) { |
728 | MachineBasicBlock::iterator InsertPt = &DefMI; |
729 | Inserter(InsertBB, std::next(x: InsertPt), UseMO); |
730 | return; |
731 | } |
732 | |
733 | // Otherwise we want the start of the BB |
734 | Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO); |
735 | } |
736 | } // end anonymous namespace |
737 | |
738 | bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) const { |
739 | PreferredTuple Preferred; |
740 | if (matchCombineExtendingLoads(MI, MatchInfo&: Preferred)) { |
741 | applyCombineExtendingLoads(MI, MatchInfo&: Preferred); |
742 | return true; |
743 | } |
744 | return false; |
745 | } |
746 | |
747 | static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) { |
748 | unsigned CandidateLoadOpc; |
749 | switch (ExtOpc) { |
750 | case TargetOpcode::G_ANYEXT: |
751 | CandidateLoadOpc = TargetOpcode::G_LOAD; |
752 | break; |
753 | case TargetOpcode::G_SEXT: |
754 | CandidateLoadOpc = TargetOpcode::G_SEXTLOAD; |
755 | break; |
756 | case TargetOpcode::G_ZEXT: |
757 | CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD; |
758 | break; |
759 | default: |
760 | llvm_unreachable("Unexpected extend opc" ); |
761 | } |
762 | return CandidateLoadOpc; |
763 | } |
764 | |
765 | bool CombinerHelper::matchCombineExtendingLoads( |
766 | MachineInstr &MI, PreferredTuple &Preferred) const { |
767 | // We match the loads and follow the uses to the extend instead of matching |
768 | // the extends and following the def to the load. This is because the load |
769 | // must remain in the same position for correctness (unless we also add code |
770 | // to find a safe place to sink it) whereas the extend is freely movable. |
771 | // It also prevents us from duplicating the load for the volatile case or just |
772 | // for performance. |
773 | GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: &MI); |
774 | if (!LoadMI) |
775 | return false; |
776 | |
777 | Register LoadReg = LoadMI->getDstReg(); |
778 | |
779 | LLT LoadValueTy = MRI.getType(Reg: LoadReg); |
780 | if (!LoadValueTy.isScalar()) |
781 | return false; |
782 | |
783 | // Most architectures are going to legalize <s8 loads into at least a 1 byte |
784 | // load, and the MMOs can only describe memory accesses in multiples of bytes. |
785 | // If we try to perform extload combining on those, we can end up with |
786 | // %a(s8) = extload %ptr (load 1 byte from %ptr) |
787 | // ... which is an illegal extload instruction. |
788 | if (LoadValueTy.getSizeInBits() < 8) |
789 | return false; |
790 | |
791 | // For non power-of-2 types, they will very likely be legalized into multiple |
792 | // loads. Don't bother trying to match them into extending loads. |
793 | if (!llvm::has_single_bit<uint32_t>(Value: LoadValueTy.getSizeInBits())) |
794 | return false; |
795 | |
796 | // Find the preferred type aside from the any-extends (unless it's the only |
797 | // one) and non-extending ops. We'll emit an extending load to that type and |
798 | // and emit a variant of (extend (trunc X)) for the others according to the |
799 | // relative type sizes. At the same time, pick an extend to use based on the |
800 | // extend involved in the chosen type. |
801 | unsigned PreferredOpcode = |
802 | isa<GLoad>(Val: &MI) |
803 | ? TargetOpcode::G_ANYEXT |
804 | : isa<GSExtLoad>(Val: &MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT; |
805 | Preferred = {.Ty: LLT(), .ExtendOpcode: PreferredOpcode, .MI: nullptr}; |
806 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: LoadReg)) { |
807 | if (UseMI.getOpcode() == TargetOpcode::G_SEXT || |
808 | UseMI.getOpcode() == TargetOpcode::G_ZEXT || |
809 | (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) { |
810 | const auto &MMO = LoadMI->getMMO(); |
811 | // Don't do anything for atomics. |
812 | if (MMO.isAtomic()) |
813 | continue; |
814 | // Check for legality. |
815 | if (!isPreLegalize()) { |
816 | LegalityQuery::MemDesc MMDesc(MMO); |
817 | unsigned CandidateLoadOpc = getExtLoadOpcForExtend(ExtOpc: UseMI.getOpcode()); |
818 | LLT UseTy = MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()); |
819 | LLT SrcTy = MRI.getType(Reg: LoadMI->getPointerReg()); |
820 | if (LI->getAction(Query: {CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}}) |
821 | .Action != LegalizeActions::Legal) |
822 | continue; |
823 | } |
824 | Preferred = ChoosePreferredUse(LoadMI&: MI, CurrentUse&: Preferred, |
825 | TyForCandidate: MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()), |
826 | OpcodeForCandidate: UseMI.getOpcode(), MIForCandidate: &UseMI); |
827 | } |
828 | } |
829 | |
830 | // There were no extends |
831 | if (!Preferred.MI) |
832 | return false; |
833 | // It should be impossible to chose an extend without selecting a different |
834 | // type since by definition the result of an extend is larger. |
835 | assert(Preferred.Ty != LoadValueTy && "Extending to same type?" ); |
836 | |
837 | LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI); |
838 | return true; |
839 | } |
840 | |
841 | void CombinerHelper::applyCombineExtendingLoads( |
842 | MachineInstr &MI, PreferredTuple &Preferred) const { |
843 | // Rewrite the load to the chosen extending load. |
844 | Register ChosenDstReg = Preferred.MI->getOperand(i: 0).getReg(); |
845 | |
846 | // Inserter to insert a truncate back to the original type at a given point |
847 | // with some basic CSE to limit truncate duplication to one per BB. |
848 | DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns; |
849 | auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB, |
850 | MachineBasicBlock::iterator InsertBefore, |
851 | MachineOperand &UseMO) { |
852 | MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(Val: InsertIntoBB); |
853 | if (PreviouslyEmitted) { |
854 | Observer.changingInstr(MI&: *UseMO.getParent()); |
855 | UseMO.setReg(PreviouslyEmitted->getOperand(i: 0).getReg()); |
856 | Observer.changedInstr(MI&: *UseMO.getParent()); |
857 | return; |
858 | } |
859 | |
860 | Builder.setInsertPt(MBB&: *InsertIntoBB, II: InsertBefore); |
861 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: MI.getOperand(i: 0).getReg()); |
862 | MachineInstr *NewMI = Builder.buildTrunc(Res: NewDstReg, Op: ChosenDstReg); |
863 | EmittedInsns[InsertIntoBB] = NewMI; |
864 | replaceRegOpWith(MRI, FromRegOp&: UseMO, ToReg: NewDstReg); |
865 | }; |
866 | |
867 | Observer.changingInstr(MI); |
868 | unsigned LoadOpc = getExtLoadOpcForExtend(ExtOpc: Preferred.ExtendOpcode); |
869 | MI.setDesc(Builder.getTII().get(Opcode: LoadOpc)); |
870 | |
871 | // Rewrite all the uses to fix up the types. |
872 | auto &LoadValue = MI.getOperand(i: 0); |
873 | SmallVector<MachineOperand *, 4> Uses( |
874 | llvm::make_pointer_range(Range: MRI.use_operands(Reg: LoadValue.getReg()))); |
875 | |
876 | for (auto *UseMO : Uses) { |
877 | MachineInstr *UseMI = UseMO->getParent(); |
878 | |
879 | // If the extend is compatible with the preferred extend then we should fix |
880 | // up the type and extend so that it uses the preferred use. |
881 | if (UseMI->getOpcode() == Preferred.ExtendOpcode || |
882 | UseMI->getOpcode() == TargetOpcode::G_ANYEXT) { |
883 | Register UseDstReg = UseMI->getOperand(i: 0).getReg(); |
884 | MachineOperand &UseSrcMO = UseMI->getOperand(i: 1); |
885 | const LLT UseDstTy = MRI.getType(Reg: UseDstReg); |
886 | if (UseDstReg != ChosenDstReg) { |
887 | if (Preferred.Ty == UseDstTy) { |
888 | // If the use has the same type as the preferred use, then merge |
889 | // the vregs and erase the extend. For example: |
890 | // %1:_(s8) = G_LOAD ... |
891 | // %2:_(s32) = G_SEXT %1(s8) |
892 | // %3:_(s32) = G_ANYEXT %1(s8) |
893 | // ... = ... %3(s32) |
894 | // rewrites to: |
895 | // %2:_(s32) = G_SEXTLOAD ... |
896 | // ... = ... %2(s32) |
897 | replaceRegWith(MRI, FromReg: UseDstReg, ToReg: ChosenDstReg); |
898 | Observer.erasingInstr(MI&: *UseMO->getParent()); |
899 | UseMO->getParent()->eraseFromParent(); |
900 | } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) { |
901 | // If the preferred size is smaller, then keep the extend but extend |
902 | // from the result of the extending load. For example: |
903 | // %1:_(s8) = G_LOAD ... |
904 | // %2:_(s32) = G_SEXT %1(s8) |
905 | // %3:_(s64) = G_ANYEXT %1(s8) |
906 | // ... = ... %3(s64) |
907 | /// rewrites to: |
908 | // %2:_(s32) = G_SEXTLOAD ... |
909 | // %3:_(s64) = G_ANYEXT %2:_(s32) |
910 | // ... = ... %3(s64) |
911 | replaceRegOpWith(MRI, FromRegOp&: UseSrcMO, ToReg: ChosenDstReg); |
912 | } else { |
913 | // If the preferred size is large, then insert a truncate. For |
914 | // example: |
915 | // %1:_(s8) = G_LOAD ... |
916 | // %2:_(s64) = G_SEXT %1(s8) |
917 | // %3:_(s32) = G_ZEXT %1(s8) |
918 | // ... = ... %3(s32) |
919 | /// rewrites to: |
920 | // %2:_(s64) = G_SEXTLOAD ... |
921 | // %4:_(s8) = G_TRUNC %2:_(s32) |
922 | // %3:_(s64) = G_ZEXT %2:_(s8) |
923 | // ... = ... %3(s64) |
924 | InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, |
925 | Inserter: InsertTruncAt); |
926 | } |
927 | continue; |
928 | } |
929 | // The use is (one of) the uses of the preferred use we chose earlier. |
930 | // We're going to update the load to def this value later so just erase |
931 | // the old extend. |
932 | Observer.erasingInstr(MI&: *UseMO->getParent()); |
933 | UseMO->getParent()->eraseFromParent(); |
934 | continue; |
935 | } |
936 | |
937 | // The use isn't an extend. Truncate back to the type we originally loaded. |
938 | // This is free on many targets. |
939 | InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, Inserter: InsertTruncAt); |
940 | } |
941 | |
942 | MI.getOperand(i: 0).setReg(ChosenDstReg); |
943 | Observer.changedInstr(MI); |
944 | } |
945 | |
946 | bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI, |
947 | BuildFnTy &MatchInfo) const { |
948 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
949 | |
950 | // If we have the following code: |
951 | // %mask = G_CONSTANT 255 |
952 | // %ld = G_LOAD %ptr, (load s16) |
953 | // %and = G_AND %ld, %mask |
954 | // |
955 | // Try to fold it into |
956 | // %ld = G_ZEXTLOAD %ptr, (load s8) |
957 | |
958 | Register Dst = MI.getOperand(i: 0).getReg(); |
959 | if (MRI.getType(Reg: Dst).isVector()) |
960 | return false; |
961 | |
962 | auto MaybeMask = |
963 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
964 | if (!MaybeMask) |
965 | return false; |
966 | |
967 | APInt MaskVal = MaybeMask->Value; |
968 | |
969 | if (!MaskVal.isMask()) |
970 | return false; |
971 | |
972 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
973 | // Don't use getOpcodeDef() here since intermediate instructions may have |
974 | // multiple users. |
975 | GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: MRI.getVRegDef(Reg: SrcReg)); |
976 | if (!LoadMI || !MRI.hasOneNonDBGUse(RegNo: LoadMI->getDstReg())) |
977 | return false; |
978 | |
979 | Register LoadReg = LoadMI->getDstReg(); |
980 | LLT RegTy = MRI.getType(Reg: LoadReg); |
981 | Register PtrReg = LoadMI->getPointerReg(); |
982 | unsigned RegSize = RegTy.getSizeInBits(); |
983 | LocationSize LoadSizeBits = LoadMI->getMemSizeInBits(); |
984 | unsigned MaskSizeBits = MaskVal.countr_one(); |
985 | |
986 | // The mask may not be larger than the in-memory type, as it might cover sign |
987 | // extended bits |
988 | if (MaskSizeBits > LoadSizeBits.getValue()) |
989 | return false; |
990 | |
991 | // If the mask covers the whole destination register, there's nothing to |
992 | // extend |
993 | if (MaskSizeBits >= RegSize) |
994 | return false; |
995 | |
996 | // Most targets cannot deal with loads of size < 8 and need to re-legalize to |
997 | // at least byte loads. Avoid creating such loads here |
998 | if (MaskSizeBits < 8 || !isPowerOf2_32(Value: MaskSizeBits)) |
999 | return false; |
1000 | |
1001 | const MachineMemOperand &MMO = LoadMI->getMMO(); |
1002 | LegalityQuery::MemDesc MemDesc(MMO); |
1003 | |
1004 | // Don't modify the memory access size if this is atomic/volatile, but we can |
1005 | // still adjust the opcode to indicate the high bit behavior. |
1006 | if (LoadMI->isSimple()) |
1007 | MemDesc.MemoryTy = LLT::scalar(SizeInBits: MaskSizeBits); |
1008 | else if (LoadSizeBits.getValue() > MaskSizeBits || |
1009 | LoadSizeBits.getValue() == RegSize) |
1010 | return false; |
1011 | |
1012 | // TODO: Could check if it's legal with the reduced or original memory size. |
1013 | if (!isLegalOrBeforeLegalizer( |
1014 | Query: {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(Reg: PtrReg)}, {MemDesc}})) |
1015 | return false; |
1016 | |
1017 | MatchInfo = [=](MachineIRBuilder &B) { |
1018 | B.setInstrAndDebugLoc(*LoadMI); |
1019 | auto &MF = B.getMF(); |
1020 | auto PtrInfo = MMO.getPointerInfo(); |
1021 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: MemDesc.MemoryTy); |
1022 | B.buildLoadInstr(Opcode: TargetOpcode::G_ZEXTLOAD, Res: Dst, Addr: PtrReg, MMO&: *NewMMO); |
1023 | LoadMI->eraseFromParent(); |
1024 | }; |
1025 | return true; |
1026 | } |
1027 | |
1028 | bool CombinerHelper::isPredecessor(const MachineInstr &DefMI, |
1029 | const MachineInstr &UseMI) const { |
1030 | assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && |
1031 | "shouldn't consider debug uses" ); |
1032 | assert(DefMI.getParent() == UseMI.getParent()); |
1033 | if (&DefMI == &UseMI) |
1034 | return true; |
1035 | const MachineBasicBlock &MBB = *DefMI.getParent(); |
1036 | auto DefOrUse = find_if(Range: MBB, P: [&DefMI, &UseMI](const MachineInstr &MI) { |
1037 | return &MI == &DefMI || &MI == &UseMI; |
1038 | }); |
1039 | if (DefOrUse == MBB.end()) |
1040 | llvm_unreachable("Block must contain both DefMI and UseMI!" ); |
1041 | return &*DefOrUse == &DefMI; |
1042 | } |
1043 | |
1044 | bool CombinerHelper::dominates(const MachineInstr &DefMI, |
1045 | const MachineInstr &UseMI) const { |
1046 | assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && |
1047 | "shouldn't consider debug uses" ); |
1048 | if (MDT) |
1049 | return MDT->dominates(A: &DefMI, B: &UseMI); |
1050 | else if (DefMI.getParent() != UseMI.getParent()) |
1051 | return false; |
1052 | |
1053 | return isPredecessor(DefMI, UseMI); |
1054 | } |
1055 | |
1056 | bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) const { |
1057 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
1058 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
1059 | Register LoadUser = SrcReg; |
1060 | |
1061 | if (MRI.getType(Reg: SrcReg).isVector()) |
1062 | return false; |
1063 | |
1064 | Register TruncSrc; |
1065 | if (mi_match(R: SrcReg, MRI, P: m_GTrunc(Src: m_Reg(R&: TruncSrc)))) |
1066 | LoadUser = TruncSrc; |
1067 | |
1068 | uint64_t SizeInBits = MI.getOperand(i: 2).getImm(); |
1069 | // If the source is a G_SEXTLOAD from the same bit width, then we don't |
1070 | // need any extend at all, just a truncate. |
1071 | if (auto *LoadMI = getOpcodeDef<GSExtLoad>(Reg: LoadUser, MRI)) { |
1072 | // If truncating more than the original extended value, abort. |
1073 | auto LoadSizeBits = LoadMI->getMemSizeInBits(); |
1074 | if (TruncSrc && |
1075 | MRI.getType(Reg: TruncSrc).getSizeInBits() < LoadSizeBits.getValue()) |
1076 | return false; |
1077 | if (LoadSizeBits == SizeInBits) |
1078 | return true; |
1079 | } |
1080 | return false; |
1081 | } |
1082 | |
1083 | void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) const { |
1084 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
1085 | Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: MI.getOperand(i: 1).getReg()); |
1086 | MI.eraseFromParent(); |
1087 | } |
1088 | |
1089 | bool CombinerHelper::matchSextInRegOfLoad( |
1090 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const { |
1091 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
1092 | |
1093 | Register DstReg = MI.getOperand(i: 0).getReg(); |
1094 | LLT RegTy = MRI.getType(Reg: DstReg); |
1095 | |
1096 | // Only supports scalars for now. |
1097 | if (RegTy.isVector()) |
1098 | return false; |
1099 | |
1100 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
1101 | auto *LoadDef = getOpcodeDef<GLoad>(Reg: SrcReg, MRI); |
1102 | if (!LoadDef || !MRI.hasOneNonDBGUse(RegNo: SrcReg)) |
1103 | return false; |
1104 | |
1105 | uint64_t MemBits = LoadDef->getMemSizeInBits().getValue(); |
1106 | |
1107 | // If the sign extend extends from a narrower width than the load's width, |
1108 | // then we can narrow the load width when we combine to a G_SEXTLOAD. |
1109 | // Avoid widening the load at all. |
1110 | unsigned NewSizeBits = std::min(a: (uint64_t)MI.getOperand(i: 2).getImm(), b: MemBits); |
1111 | |
1112 | // Don't generate G_SEXTLOADs with a < 1 byte width. |
1113 | if (NewSizeBits < 8) |
1114 | return false; |
1115 | // Don't bother creating a non-power-2 sextload, it will likely be broken up |
1116 | // anyway for most targets. |
1117 | if (!isPowerOf2_32(Value: NewSizeBits)) |
1118 | return false; |
1119 | |
1120 | const MachineMemOperand &MMO = LoadDef->getMMO(); |
1121 | LegalityQuery::MemDesc MMDesc(MMO); |
1122 | |
1123 | // Don't modify the memory access size if this is atomic/volatile, but we can |
1124 | // still adjust the opcode to indicate the high bit behavior. |
1125 | if (LoadDef->isSimple()) |
1126 | MMDesc.MemoryTy = LLT::scalar(SizeInBits: NewSizeBits); |
1127 | else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits()) |
1128 | return false; |
1129 | |
1130 | // TODO: Could check if it's legal with the reduced or original memory size. |
1131 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXTLOAD, |
1132 | {MRI.getType(Reg: LoadDef->getDstReg()), |
1133 | MRI.getType(Reg: LoadDef->getPointerReg())}, |
1134 | {MMDesc}})) |
1135 | return false; |
1136 | |
1137 | MatchInfo = std::make_tuple(args: LoadDef->getDstReg(), args&: NewSizeBits); |
1138 | return true; |
1139 | } |
1140 | |
1141 | void CombinerHelper::applySextInRegOfLoad( |
1142 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const { |
1143 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
1144 | Register LoadReg; |
1145 | unsigned ScalarSizeBits; |
1146 | std::tie(args&: LoadReg, args&: ScalarSizeBits) = MatchInfo; |
1147 | GLoad *LoadDef = cast<GLoad>(Val: MRI.getVRegDef(Reg: LoadReg)); |
1148 | |
1149 | // If we have the following: |
1150 | // %ld = G_LOAD %ptr, (load 2) |
1151 | // %ext = G_SEXT_INREG %ld, 8 |
1152 | // ==> |
1153 | // %ld = G_SEXTLOAD %ptr (load 1) |
1154 | |
1155 | auto &MMO = LoadDef->getMMO(); |
1156 | Builder.setInstrAndDebugLoc(*LoadDef); |
1157 | auto &MF = Builder.getMF(); |
1158 | auto PtrInfo = MMO.getPointerInfo(); |
1159 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: ScalarSizeBits / 8); |
1160 | Builder.buildLoadInstr(Opcode: TargetOpcode::G_SEXTLOAD, Res: MI.getOperand(i: 0).getReg(), |
1161 | Addr: LoadDef->getPointerReg(), MMO&: *NewMMO); |
1162 | MI.eraseFromParent(); |
1163 | |
1164 | // Not all loads can be deleted, so make sure the old one is removed. |
1165 | LoadDef->eraseFromParent(); |
1166 | } |
1167 | |
1168 | /// Return true if 'MI' is a load or a store that may be fold it's address |
1169 | /// operand into the load / store addressing mode. |
1170 | static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI, |
1171 | MachineRegisterInfo &MRI) { |
1172 | TargetLowering::AddrMode AM; |
1173 | auto *MF = MI->getMF(); |
1174 | auto *Addr = getOpcodeDef<GPtrAdd>(Reg: MI->getPointerReg(), MRI); |
1175 | if (!Addr) |
1176 | return false; |
1177 | |
1178 | AM.HasBaseReg = true; |
1179 | if (auto CstOff = getIConstantVRegVal(VReg: Addr->getOffsetReg(), MRI)) |
1180 | AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm] |
1181 | else |
1182 | AM.Scale = 1; // [reg +/- reg] |
1183 | |
1184 | return TLI.isLegalAddressingMode( |
1185 | DL: MF->getDataLayout(), AM, |
1186 | Ty: getTypeForLLT(Ty: MI->getMMO().getMemoryType(), |
1187 | C&: MF->getFunction().getContext()), |
1188 | AddrSpace: MI->getMMO().getAddrSpace()); |
1189 | } |
1190 | |
1191 | static unsigned getIndexedOpc(unsigned LdStOpc) { |
1192 | switch (LdStOpc) { |
1193 | case TargetOpcode::G_LOAD: |
1194 | return TargetOpcode::G_INDEXED_LOAD; |
1195 | case TargetOpcode::G_STORE: |
1196 | return TargetOpcode::G_INDEXED_STORE; |
1197 | case TargetOpcode::G_ZEXTLOAD: |
1198 | return TargetOpcode::G_INDEXED_ZEXTLOAD; |
1199 | case TargetOpcode::G_SEXTLOAD: |
1200 | return TargetOpcode::G_INDEXED_SEXTLOAD; |
1201 | default: |
1202 | llvm_unreachable("Unexpected opcode" ); |
1203 | } |
1204 | } |
1205 | |
1206 | bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const { |
1207 | // Check for legality. |
1208 | LLT PtrTy = MRI.getType(Reg: LdSt.getPointerReg()); |
1209 | LLT Ty = MRI.getType(Reg: LdSt.getReg(Idx: 0)); |
1210 | LLT MemTy = LdSt.getMMO().getMemoryType(); |
1211 | SmallVector<LegalityQuery::MemDesc, 2> MemDescrs( |
1212 | {{MemTy, MemTy.getSizeInBits().getKnownMinValue(), |
1213 | AtomicOrdering::NotAtomic}}); |
1214 | unsigned IndexedOpc = getIndexedOpc(LdStOpc: LdSt.getOpcode()); |
1215 | SmallVector<LLT> OpTys; |
1216 | if (IndexedOpc == TargetOpcode::G_INDEXED_STORE) |
1217 | OpTys = {PtrTy, Ty, Ty}; |
1218 | else |
1219 | OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD |
1220 | |
1221 | LegalityQuery Q(IndexedOpc, OpTys, MemDescrs); |
1222 | return isLegal(Query: Q); |
1223 | } |
1224 | |
1225 | static cl::opt<unsigned> PostIndexUseThreshold( |
1226 | "post-index-use-threshold" , cl::Hidden, cl::init(Val: 32), |
1227 | cl::desc("Number of uses of a base pointer to check before it is no longer " |
1228 | "considered for post-indexing." )); |
1229 | |
1230 | bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr, |
1231 | Register &Base, Register &Offset, |
1232 | bool &RematOffset) const { |
1233 | // We're looking for the following pattern, for either load or store: |
1234 | // %baseptr:_(p0) = ... |
1235 | // G_STORE %val(s64), %baseptr(p0) |
1236 | // %offset:_(s64) = G_CONSTANT i64 -256 |
1237 | // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64) |
1238 | const auto &TLI = getTargetLowering(); |
1239 | |
1240 | Register Ptr = LdSt.getPointerReg(); |
1241 | // If the store is the only use, don't bother. |
1242 | if (MRI.hasOneNonDBGUse(RegNo: Ptr)) |
1243 | return false; |
1244 | |
1245 | if (!isIndexedLoadStoreLegal(LdSt)) |
1246 | return false; |
1247 | |
1248 | if (getOpcodeDef(Opcode: TargetOpcode::G_FRAME_INDEX, Reg: Ptr, MRI)) |
1249 | return false; |
1250 | |
1251 | MachineInstr *StoredValDef = getDefIgnoringCopies(Reg: LdSt.getReg(Idx: 0), MRI); |
1252 | auto *PtrDef = MRI.getVRegDef(Reg: Ptr); |
1253 | |
1254 | unsigned NumUsesChecked = 0; |
1255 | for (auto &Use : MRI.use_nodbg_instructions(Reg: Ptr)) { |
1256 | if (++NumUsesChecked > PostIndexUseThreshold) |
1257 | return false; // Try to avoid exploding compile time. |
1258 | |
1259 | auto *PtrAdd = dyn_cast<GPtrAdd>(Val: &Use); |
1260 | // The use itself might be dead. This can happen during combines if DCE |
1261 | // hasn't had a chance to run yet. Don't allow it to form an indexed op. |
1262 | if (!PtrAdd || MRI.use_nodbg_empty(RegNo: PtrAdd->getReg(Idx: 0))) |
1263 | continue; |
1264 | |
1265 | // Check the user of this isn't the store, otherwise we'd be generate a |
1266 | // indexed store defining its own use. |
1267 | if (StoredValDef == &Use) |
1268 | continue; |
1269 | |
1270 | Offset = PtrAdd->getOffsetReg(); |
1271 | if (!ForceLegalIndexing && |
1272 | !TLI.isIndexingLegal(MI&: LdSt, Base: PtrAdd->getBaseReg(), Offset, |
1273 | /*IsPre*/ false, MRI)) |
1274 | continue; |
1275 | |
1276 | // Make sure the offset calculation is before the potentially indexed op. |
1277 | MachineInstr *OffsetDef = MRI.getVRegDef(Reg: Offset); |
1278 | RematOffset = false; |
1279 | if (!dominates(DefMI: *OffsetDef, UseMI: LdSt)) { |
1280 | // If the offset however is just a G_CONSTANT, we can always just |
1281 | // rematerialize it where we need it. |
1282 | if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT) |
1283 | continue; |
1284 | RematOffset = true; |
1285 | } |
1286 | |
1287 | for (auto &BasePtrUse : MRI.use_nodbg_instructions(Reg: PtrAdd->getBaseReg())) { |
1288 | if (&BasePtrUse == PtrDef) |
1289 | continue; |
1290 | |
1291 | // If the user is a later load/store that can be post-indexed, then don't |
1292 | // combine this one. |
1293 | auto *BasePtrLdSt = dyn_cast<GLoadStore>(Val: &BasePtrUse); |
1294 | if (BasePtrLdSt && BasePtrLdSt != &LdSt && |
1295 | dominates(DefMI: LdSt, UseMI: *BasePtrLdSt) && |
1296 | isIndexedLoadStoreLegal(LdSt&: *BasePtrLdSt)) |
1297 | return false; |
1298 | |
1299 | // Now we're looking for the key G_PTR_ADD instruction, which contains |
1300 | // the offset add that we want to fold. |
1301 | if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(Val: &BasePtrUse)) { |
1302 | Register PtrAddDefReg = BasePtrUseDef->getReg(Idx: 0); |
1303 | for (auto &BaseUseUse : MRI.use_nodbg_instructions(Reg: PtrAddDefReg)) { |
1304 | // If the use is in a different block, then we may produce worse code |
1305 | // due to the extra register pressure. |
1306 | if (BaseUseUse.getParent() != LdSt.getParent()) |
1307 | return false; |
1308 | |
1309 | if (auto *UseUseLdSt = dyn_cast<GLoadStore>(Val: &BaseUseUse)) |
1310 | if (canFoldInAddressingMode(MI: UseUseLdSt, TLI, MRI)) |
1311 | return false; |
1312 | } |
1313 | if (!dominates(DefMI: LdSt, UseMI: BasePtrUse)) |
1314 | return false; // All use must be dominated by the load/store. |
1315 | } |
1316 | } |
1317 | |
1318 | Addr = PtrAdd->getReg(Idx: 0); |
1319 | Base = PtrAdd->getBaseReg(); |
1320 | return true; |
1321 | } |
1322 | |
1323 | return false; |
1324 | } |
1325 | |
1326 | bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr, |
1327 | Register &Base, |
1328 | Register &Offset) const { |
1329 | auto &MF = *LdSt.getParent()->getParent(); |
1330 | const auto &TLI = *MF.getSubtarget().getTargetLowering(); |
1331 | |
1332 | Addr = LdSt.getPointerReg(); |
1333 | if (!mi_match(R: Addr, MRI, P: m_GPtrAdd(L: m_Reg(R&: Base), R: m_Reg(R&: Offset))) || |
1334 | MRI.hasOneNonDBGUse(RegNo: Addr)) |
1335 | return false; |
1336 | |
1337 | if (!ForceLegalIndexing && |
1338 | !TLI.isIndexingLegal(MI&: LdSt, Base, Offset, /*IsPre*/ true, MRI)) |
1339 | return false; |
1340 | |
1341 | if (!isIndexedLoadStoreLegal(LdSt)) |
1342 | return false; |
1343 | |
1344 | MachineInstr *BaseDef = getDefIgnoringCopies(Reg: Base, MRI); |
1345 | if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX) |
1346 | return false; |
1347 | |
1348 | if (auto *St = dyn_cast<GStore>(Val: &LdSt)) { |
1349 | // Would require a copy. |
1350 | if (Base == St->getValueReg()) |
1351 | return false; |
1352 | |
1353 | // We're expecting one use of Addr in MI, but it could also be the |
1354 | // value stored, which isn't actually dominated by the instruction. |
1355 | if (St->getValueReg() == Addr) |
1356 | return false; |
1357 | } |
1358 | |
1359 | // Avoid increasing cross-block register pressure. |
1360 | for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) |
1361 | if (AddrUse.getParent() != LdSt.getParent()) |
1362 | return false; |
1363 | |
1364 | // FIXME: check whether all uses of the base pointer are constant PtrAdds. |
1365 | // That might allow us to end base's liveness here by adjusting the constant. |
1366 | bool RealUse = false; |
1367 | for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) { |
1368 | if (!dominates(DefMI: LdSt, UseMI: AddrUse)) |
1369 | return false; // All use must be dominated by the load/store. |
1370 | |
1371 | // If Ptr may be folded in addressing mode of other use, then it's |
1372 | // not profitable to do this transformation. |
1373 | if (auto *UseLdSt = dyn_cast<GLoadStore>(Val: &AddrUse)) { |
1374 | if (!canFoldInAddressingMode(MI: UseLdSt, TLI, MRI)) |
1375 | RealUse = true; |
1376 | } else { |
1377 | RealUse = true; |
1378 | } |
1379 | } |
1380 | return RealUse; |
1381 | } |
1382 | |
1383 | bool CombinerHelper::( |
1384 | MachineInstr &MI, BuildFnTy &MatchInfo) const { |
1385 | assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); |
1386 | |
1387 | // Check if there is a load that defines the vector being extracted from. |
1388 | auto *LoadMI = getOpcodeDef<GLoad>(Reg: MI.getOperand(i: 1).getReg(), MRI); |
1389 | if (!LoadMI) |
1390 | return false; |
1391 | |
1392 | Register Vector = MI.getOperand(i: 1).getReg(); |
1393 | LLT VecEltTy = MRI.getType(Reg: Vector).getElementType(); |
1394 | |
1395 | assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy); |
1396 | |
1397 | // Checking whether we should reduce the load width. |
1398 | if (!MRI.hasOneNonDBGUse(RegNo: Vector)) |
1399 | return false; |
1400 | |
1401 | // Check if the defining load is simple. |
1402 | if (!LoadMI->isSimple()) |
1403 | return false; |
1404 | |
1405 | // If the vector element type is not a multiple of a byte then we are unable |
1406 | // to correctly compute an address to load only the extracted element as a |
1407 | // scalar. |
1408 | if (!VecEltTy.isByteSized()) |
1409 | return false; |
1410 | |
1411 | // Check for load fold barriers between the extraction and the load. |
1412 | if (MI.getParent() != LoadMI->getParent()) |
1413 | return false; |
1414 | const unsigned MaxIter = 20; |
1415 | unsigned Iter = 0; |
1416 | for (auto II = LoadMI->getIterator(), IE = MI.getIterator(); II != IE; ++II) { |
1417 | if (II->isLoadFoldBarrier()) |
1418 | return false; |
1419 | if (Iter++ == MaxIter) |
1420 | return false; |
1421 | } |
1422 | |
1423 | // Check if the new load that we are going to create is legal |
1424 | // if we are in the post-legalization phase. |
1425 | MachineMemOperand MMO = LoadMI->getMMO(); |
1426 | Align Alignment = MMO.getAlign(); |
1427 | MachinePointerInfo PtrInfo; |
1428 | uint64_t Offset; |
1429 | |
1430 | // Finding the appropriate PtrInfo if offset is a known constant. |
1431 | // This is required to create the memory operand for the narrowed load. |
1432 | // This machine memory operand object helps us infer about legality |
1433 | // before we proceed to combine the instruction. |
1434 | if (auto CVal = getIConstantVRegVal(VReg: Vector, MRI)) { |
1435 | int Elt = CVal->getZExtValue(); |
1436 | // FIXME: should be (ABI size)*Elt. |
1437 | Offset = VecEltTy.getSizeInBits() * Elt / 8; |
1438 | PtrInfo = MMO.getPointerInfo().getWithOffset(O: Offset); |
1439 | } else { |
1440 | // Discard the pointer info except the address space because the memory |
1441 | // operand can't represent this new access since the offset is variable. |
1442 | Offset = VecEltTy.getSizeInBits() / 8; |
1443 | PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace()); |
1444 | } |
1445 | |
1446 | Alignment = commonAlignment(A: Alignment, Offset); |
1447 | |
1448 | Register VecPtr = LoadMI->getPointerReg(); |
1449 | LLT PtrTy = MRI.getType(Reg: VecPtr); |
1450 | |
1451 | MachineFunction &MF = *MI.getMF(); |
1452 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: VecEltTy); |
1453 | |
1454 | LegalityQuery::MemDesc MMDesc(*NewMMO); |
1455 | |
1456 | if (!isLegalOrBeforeLegalizer( |
1457 | Query: {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}})) |
1458 | return false; |
1459 | |
1460 | // Load must be allowed and fast on the target. |
1461 | LLVMContext &C = MF.getFunction().getContext(); |
1462 | auto &DL = MF.getDataLayout(); |
1463 | unsigned Fast = 0; |
1464 | if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty: VecEltTy, MMO: *NewMMO, |
1465 | Fast: &Fast) || |
1466 | !Fast) |
1467 | return false; |
1468 | |
1469 | Register Result = MI.getOperand(i: 0).getReg(); |
1470 | Register Index = MI.getOperand(i: 2).getReg(); |
1471 | |
1472 | MatchInfo = [=](MachineIRBuilder &B) { |
1473 | GISelObserverWrapper DummyObserver; |
1474 | LegalizerHelper Helper(B.getMF(), DummyObserver, B); |
1475 | //// Get pointer to the vector element. |
1476 | Register finalPtr = Helper.getVectorElementPointer( |
1477 | VecPtr: LoadMI->getPointerReg(), VecTy: MRI.getType(Reg: LoadMI->getOperand(i: 0).getReg()), |
1478 | Index); |
1479 | // New G_LOAD instruction. |
1480 | B.buildLoad(Res: Result, Addr: finalPtr, PtrInfo, Alignment); |
1481 | // Remove original GLOAD instruction. |
1482 | LoadMI->eraseFromParent(); |
1483 | }; |
1484 | |
1485 | return true; |
1486 | } |
1487 | |
1488 | bool CombinerHelper::matchCombineIndexedLoadStore( |
1489 | MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const { |
1490 | auto &LdSt = cast<GLoadStore>(Val&: MI); |
1491 | |
1492 | if (LdSt.isAtomic()) |
1493 | return false; |
1494 | |
1495 | MatchInfo.IsPre = findPreIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base, |
1496 | Offset&: MatchInfo.Offset); |
1497 | if (!MatchInfo.IsPre && |
1498 | !findPostIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base, |
1499 | Offset&: MatchInfo.Offset, RematOffset&: MatchInfo.RematOffset)) |
1500 | return false; |
1501 | |
1502 | return true; |
1503 | } |
1504 | |
1505 | void CombinerHelper::applyCombineIndexedLoadStore( |
1506 | MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const { |
1507 | MachineInstr &AddrDef = *MRI.getUniqueVRegDef(Reg: MatchInfo.Addr); |
1508 | unsigned Opcode = MI.getOpcode(); |
1509 | bool IsStore = Opcode == TargetOpcode::G_STORE; |
1510 | unsigned NewOpcode = getIndexedOpc(LdStOpc: Opcode); |
1511 | |
1512 | // If the offset constant didn't happen to dominate the load/store, we can |
1513 | // just clone it as needed. |
1514 | if (MatchInfo.RematOffset) { |
1515 | auto *OldCst = MRI.getVRegDef(Reg: MatchInfo.Offset); |
1516 | auto NewCst = Builder.buildConstant(Res: MRI.getType(Reg: MatchInfo.Offset), |
1517 | Val: *OldCst->getOperand(i: 1).getCImm()); |
1518 | MatchInfo.Offset = NewCst.getReg(Idx: 0); |
1519 | } |
1520 | |
1521 | auto MIB = Builder.buildInstr(Opcode: NewOpcode); |
1522 | if (IsStore) { |
1523 | MIB.addDef(RegNo: MatchInfo.Addr); |
1524 | MIB.addUse(RegNo: MI.getOperand(i: 0).getReg()); |
1525 | } else { |
1526 | MIB.addDef(RegNo: MI.getOperand(i: 0).getReg()); |
1527 | MIB.addDef(RegNo: MatchInfo.Addr); |
1528 | } |
1529 | |
1530 | MIB.addUse(RegNo: MatchInfo.Base); |
1531 | MIB.addUse(RegNo: MatchInfo.Offset); |
1532 | MIB.addImm(Val: MatchInfo.IsPre); |
1533 | MIB->cloneMemRefs(MF&: *MI.getMF(), MI); |
1534 | MI.eraseFromParent(); |
1535 | AddrDef.eraseFromParent(); |
1536 | |
1537 | LLVM_DEBUG(dbgs() << " Combinined to indexed operation" ); |
1538 | } |
1539 | |
1540 | bool CombinerHelper::matchCombineDivRem(MachineInstr &MI, |
1541 | MachineInstr *&OtherMI) const { |
1542 | unsigned Opcode = MI.getOpcode(); |
1543 | bool IsDiv, IsSigned; |
1544 | |
1545 | switch (Opcode) { |
1546 | default: |
1547 | llvm_unreachable("Unexpected opcode!" ); |
1548 | case TargetOpcode::G_SDIV: |
1549 | case TargetOpcode::G_UDIV: { |
1550 | IsDiv = true; |
1551 | IsSigned = Opcode == TargetOpcode::G_SDIV; |
1552 | break; |
1553 | } |
1554 | case TargetOpcode::G_SREM: |
1555 | case TargetOpcode::G_UREM: { |
1556 | IsDiv = false; |
1557 | IsSigned = Opcode == TargetOpcode::G_SREM; |
1558 | break; |
1559 | } |
1560 | } |
1561 | |
1562 | Register Src1 = MI.getOperand(i: 1).getReg(); |
1563 | unsigned DivOpcode, RemOpcode, DivremOpcode; |
1564 | if (IsSigned) { |
1565 | DivOpcode = TargetOpcode::G_SDIV; |
1566 | RemOpcode = TargetOpcode::G_SREM; |
1567 | DivremOpcode = TargetOpcode::G_SDIVREM; |
1568 | } else { |
1569 | DivOpcode = TargetOpcode::G_UDIV; |
1570 | RemOpcode = TargetOpcode::G_UREM; |
1571 | DivremOpcode = TargetOpcode::G_UDIVREM; |
1572 | } |
1573 | |
1574 | if (!isLegalOrBeforeLegalizer(Query: {DivremOpcode, {MRI.getType(Reg: Src1)}})) |
1575 | return false; |
1576 | |
1577 | // Combine: |
1578 | // %div:_ = G_[SU]DIV %src1:_, %src2:_ |
1579 | // %rem:_ = G_[SU]REM %src1:_, %src2:_ |
1580 | // into: |
1581 | // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ |
1582 | |
1583 | // Combine: |
1584 | // %rem:_ = G_[SU]REM %src1:_, %src2:_ |
1585 | // %div:_ = G_[SU]DIV %src1:_, %src2:_ |
1586 | // into: |
1587 | // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ |
1588 | |
1589 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: Src1)) { |
1590 | if (MI.getParent() == UseMI.getParent() && |
1591 | ((IsDiv && UseMI.getOpcode() == RemOpcode) || |
1592 | (!IsDiv && UseMI.getOpcode() == DivOpcode)) && |
1593 | matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: UseMI.getOperand(i: 2)) && |
1594 | matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: UseMI.getOperand(i: 1))) { |
1595 | OtherMI = &UseMI; |
1596 | return true; |
1597 | } |
1598 | } |
1599 | |
1600 | return false; |
1601 | } |
1602 | |
1603 | void CombinerHelper::applyCombineDivRem(MachineInstr &MI, |
1604 | MachineInstr *&OtherMI) const { |
1605 | unsigned Opcode = MI.getOpcode(); |
1606 | assert(OtherMI && "OtherMI shouldn't be empty." ); |
1607 | |
1608 | Register DestDivReg, DestRemReg; |
1609 | if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) { |
1610 | DestDivReg = MI.getOperand(i: 0).getReg(); |
1611 | DestRemReg = OtherMI->getOperand(i: 0).getReg(); |
1612 | } else { |
1613 | DestDivReg = OtherMI->getOperand(i: 0).getReg(); |
1614 | DestRemReg = MI.getOperand(i: 0).getReg(); |
1615 | } |
1616 | |
1617 | bool IsSigned = |
1618 | Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM; |
1619 | |
1620 | // Check which instruction is first in the block so we don't break def-use |
1621 | // deps by "moving" the instruction incorrectly. Also keep track of which |
1622 | // instruction is first so we pick it's operands, avoiding use-before-def |
1623 | // bugs. |
1624 | MachineInstr *FirstInst = dominates(DefMI: MI, UseMI: *OtherMI) ? &MI : OtherMI; |
1625 | Builder.setInstrAndDebugLoc(*FirstInst); |
1626 | |
1627 | Builder.buildInstr(Opc: IsSigned ? TargetOpcode::G_SDIVREM |
1628 | : TargetOpcode::G_UDIVREM, |
1629 | DstOps: {DestDivReg, DestRemReg}, |
1630 | SrcOps: { FirstInst->getOperand(i: 1), FirstInst->getOperand(i: 2) }); |
1631 | MI.eraseFromParent(); |
1632 | OtherMI->eraseFromParent(); |
1633 | } |
1634 | |
1635 | bool CombinerHelper::matchOptBrCondByInvertingCond( |
1636 | MachineInstr &MI, MachineInstr *&BrCond) const { |
1637 | assert(MI.getOpcode() == TargetOpcode::G_BR); |
1638 | |
1639 | // Try to match the following: |
1640 | // bb1: |
1641 | // G_BRCOND %c1, %bb2 |
1642 | // G_BR %bb3 |
1643 | // bb2: |
1644 | // ... |
1645 | // bb3: |
1646 | |
1647 | // The above pattern does not have a fall through to the successor bb2, always |
1648 | // resulting in a branch no matter which path is taken. Here we try to find |
1649 | // and replace that pattern with conditional branch to bb3 and otherwise |
1650 | // fallthrough to bb2. This is generally better for branch predictors. |
1651 | |
1652 | MachineBasicBlock *MBB = MI.getParent(); |
1653 | MachineBasicBlock::iterator BrIt(MI); |
1654 | if (BrIt == MBB->begin()) |
1655 | return false; |
1656 | assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator" ); |
1657 | |
1658 | BrCond = &*std::prev(x: BrIt); |
1659 | if (BrCond->getOpcode() != TargetOpcode::G_BRCOND) |
1660 | return false; |
1661 | |
1662 | // Check that the next block is the conditional branch target. Also make sure |
1663 | // that it isn't the same as the G_BR's target (otherwise, this will loop.) |
1664 | MachineBasicBlock *BrCondTarget = BrCond->getOperand(i: 1).getMBB(); |
1665 | return BrCondTarget != MI.getOperand(i: 0).getMBB() && |
1666 | MBB->isLayoutSuccessor(MBB: BrCondTarget); |
1667 | } |
1668 | |
1669 | void CombinerHelper::applyOptBrCondByInvertingCond( |
1670 | MachineInstr &MI, MachineInstr *&BrCond) const { |
1671 | MachineBasicBlock *BrTarget = MI.getOperand(i: 0).getMBB(); |
1672 | Builder.setInstrAndDebugLoc(*BrCond); |
1673 | LLT Ty = MRI.getType(Reg: BrCond->getOperand(i: 0).getReg()); |
1674 | // FIXME: Does int/fp matter for this? If so, we might need to restrict |
1675 | // this to i1 only since we might not know for sure what kind of |
1676 | // compare generated the condition value. |
1677 | auto True = Builder.buildConstant( |
1678 | Res: Ty, Val: getICmpTrueVal(TLI: getTargetLowering(), IsVector: false, IsFP: false)); |
1679 | auto Xor = Builder.buildXor(Dst: Ty, Src0: BrCond->getOperand(i: 0), Src1: True); |
1680 | |
1681 | auto *FallthroughBB = BrCond->getOperand(i: 1).getMBB(); |
1682 | Observer.changingInstr(MI); |
1683 | MI.getOperand(i: 0).setMBB(FallthroughBB); |
1684 | Observer.changedInstr(MI); |
1685 | |
1686 | // Change the conditional branch to use the inverted condition and |
1687 | // new target block. |
1688 | Observer.changingInstr(MI&: *BrCond); |
1689 | BrCond->getOperand(i: 0).setReg(Xor.getReg(Idx: 0)); |
1690 | BrCond->getOperand(i: 1).setMBB(BrTarget); |
1691 | Observer.changedInstr(MI&: *BrCond); |
1692 | } |
1693 | |
1694 | bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) const { |
1695 | MachineIRBuilder HelperBuilder(MI); |
1696 | GISelObserverWrapper DummyObserver; |
1697 | LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); |
1698 | return Helper.lowerMemcpyInline(MI) == |
1699 | LegalizerHelper::LegalizeResult::Legalized; |
1700 | } |
1701 | |
1702 | bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI, |
1703 | unsigned MaxLen) const { |
1704 | MachineIRBuilder HelperBuilder(MI); |
1705 | GISelObserverWrapper DummyObserver; |
1706 | LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); |
1707 | return Helper.lowerMemCpyFamily(MI, MaxLen) == |
1708 | LegalizerHelper::LegalizeResult::Legalized; |
1709 | } |
1710 | |
1711 | static APFloat constantFoldFpUnary(const MachineInstr &MI, |
1712 | const MachineRegisterInfo &MRI, |
1713 | const APFloat &Val) { |
1714 | APFloat Result(Val); |
1715 | switch (MI.getOpcode()) { |
1716 | default: |
1717 | llvm_unreachable("Unexpected opcode!" ); |
1718 | case TargetOpcode::G_FNEG: { |
1719 | Result.changeSign(); |
1720 | return Result; |
1721 | } |
1722 | case TargetOpcode::G_FABS: { |
1723 | Result.clearSign(); |
1724 | return Result; |
1725 | } |
1726 | case TargetOpcode::G_FPTRUNC: { |
1727 | bool Unused; |
1728 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
1729 | Result.convert(ToSemantics: getFltSemanticForLLT(Ty: DstTy), RM: APFloat::rmNearestTiesToEven, |
1730 | losesInfo: &Unused); |
1731 | return Result; |
1732 | } |
1733 | case TargetOpcode::G_FSQRT: { |
1734 | bool Unused; |
1735 | Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, |
1736 | losesInfo: &Unused); |
1737 | Result = APFloat(sqrt(x: Result.convertToDouble())); |
1738 | break; |
1739 | } |
1740 | case TargetOpcode::G_FLOG2: { |
1741 | bool Unused; |
1742 | Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, |
1743 | losesInfo: &Unused); |
1744 | Result = APFloat(log2(x: Result.convertToDouble())); |
1745 | break; |
1746 | } |
1747 | } |
1748 | // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise, |
1749 | // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and |
1750 | // `G_FLOG2` reach here. |
1751 | bool Unused; |
1752 | Result.convert(ToSemantics: Val.getSemantics(), RM: APFloat::rmNearestTiesToEven, losesInfo: &Unused); |
1753 | return Result; |
1754 | } |
1755 | |
1756 | void CombinerHelper::applyCombineConstantFoldFpUnary( |
1757 | MachineInstr &MI, const ConstantFP *Cst) const { |
1758 | APFloat Folded = constantFoldFpUnary(MI, MRI, Val: Cst->getValue()); |
1759 | const ConstantFP *NewCst = ConstantFP::get(Context&: Builder.getContext(), V: Folded); |
1760 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: *NewCst); |
1761 | MI.eraseFromParent(); |
1762 | } |
1763 | |
1764 | bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI, |
1765 | PtrAddChain &MatchInfo) const { |
1766 | // We're trying to match the following pattern: |
1767 | // %t1 = G_PTR_ADD %base, G_CONSTANT imm1 |
1768 | // %root = G_PTR_ADD %t1, G_CONSTANT imm2 |
1769 | // --> |
1770 | // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2) |
1771 | |
1772 | if (MI.getOpcode() != TargetOpcode::G_PTR_ADD) |
1773 | return false; |
1774 | |
1775 | Register Add2 = MI.getOperand(i: 1).getReg(); |
1776 | Register Imm1 = MI.getOperand(i: 2).getReg(); |
1777 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI); |
1778 | if (!MaybeImmVal) |
1779 | return false; |
1780 | |
1781 | MachineInstr *Add2Def = MRI.getVRegDef(Reg: Add2); |
1782 | if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD) |
1783 | return false; |
1784 | |
1785 | Register Base = Add2Def->getOperand(i: 1).getReg(); |
1786 | Register Imm2 = Add2Def->getOperand(i: 2).getReg(); |
1787 | auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI); |
1788 | if (!MaybeImm2Val) |
1789 | return false; |
1790 | |
1791 | // Check if the new combined immediate forms an illegal addressing mode. |
1792 | // Do not combine if it was legal before but would get illegal. |
1793 | // To do so, we need to find a load/store user of the pointer to get |
1794 | // the access type. |
1795 | Type *AccessTy = nullptr; |
1796 | auto &MF = *MI.getMF(); |
1797 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: MI.getOperand(i: 0).getReg())) { |
1798 | if (auto *LdSt = dyn_cast<GLoadStore>(Val: &UseMI)) { |
1799 | AccessTy = getTypeForLLT(Ty: MRI.getType(Reg: LdSt->getReg(Idx: 0)), |
1800 | C&: MF.getFunction().getContext()); |
1801 | break; |
1802 | } |
1803 | } |
1804 | TargetLoweringBase::AddrMode AMNew; |
1805 | APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value; |
1806 | AMNew.BaseOffs = CombinedImm.getSExtValue(); |
1807 | if (AccessTy) { |
1808 | AMNew.HasBaseReg = true; |
1809 | TargetLoweringBase::AddrMode AMOld; |
1810 | AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue(); |
1811 | AMOld.HasBaseReg = true; |
1812 | unsigned AS = MRI.getType(Reg: Add2).getAddressSpace(); |
1813 | const auto &TLI = *MF.getSubtarget().getTargetLowering(); |
1814 | if (TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMOld, Ty: AccessTy, AddrSpace: AS) && |
1815 | !TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMNew, Ty: AccessTy, AddrSpace: AS)) |
1816 | return false; |
1817 | } |
1818 | |
1819 | // Pass the combined immediate to the apply function. |
1820 | MatchInfo.Imm = AMNew.BaseOffs; |
1821 | MatchInfo.Base = Base; |
1822 | MatchInfo.Bank = getRegBank(Reg: Imm2); |
1823 | return true; |
1824 | } |
1825 | |
1826 | void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI, |
1827 | PtrAddChain &MatchInfo) const { |
1828 | assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD" ); |
1829 | MachineIRBuilder MIB(MI); |
1830 | LLT OffsetTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
1831 | auto NewOffset = MIB.buildConstant(Res: OffsetTy, Val: MatchInfo.Imm); |
1832 | setRegBank(Reg: NewOffset.getReg(Idx: 0), RegBank: MatchInfo.Bank); |
1833 | Observer.changingInstr(MI); |
1834 | MI.getOperand(i: 1).setReg(MatchInfo.Base); |
1835 | MI.getOperand(i: 2).setReg(NewOffset.getReg(Idx: 0)); |
1836 | Observer.changedInstr(MI); |
1837 | } |
1838 | |
1839 | bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI, |
1840 | RegisterImmPair &MatchInfo) const { |
1841 | // We're trying to match the following pattern with any of |
1842 | // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions: |
1843 | // %t1 = SHIFT %base, G_CONSTANT imm1 |
1844 | // %root = SHIFT %t1, G_CONSTANT imm2 |
1845 | // --> |
1846 | // %root = SHIFT %base, G_CONSTANT (imm1 + imm2) |
1847 | |
1848 | unsigned Opcode = MI.getOpcode(); |
1849 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
1850 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || |
1851 | Opcode == TargetOpcode::G_USHLSAT) && |
1852 | "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT" ); |
1853 | |
1854 | Register Shl2 = MI.getOperand(i: 1).getReg(); |
1855 | Register Imm1 = MI.getOperand(i: 2).getReg(); |
1856 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI); |
1857 | if (!MaybeImmVal) |
1858 | return false; |
1859 | |
1860 | MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Reg: Shl2); |
1861 | if (Shl2Def->getOpcode() != Opcode) |
1862 | return false; |
1863 | |
1864 | Register Base = Shl2Def->getOperand(i: 1).getReg(); |
1865 | Register Imm2 = Shl2Def->getOperand(i: 2).getReg(); |
1866 | auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI); |
1867 | if (!MaybeImm2Val) |
1868 | return false; |
1869 | |
1870 | // Pass the combined immediate to the apply function. |
1871 | MatchInfo.Imm = |
1872 | (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue(); |
1873 | MatchInfo.Reg = Base; |
1874 | |
1875 | // There is no simple replacement for a saturating unsigned left shift that |
1876 | // exceeds the scalar size. |
1877 | if (Opcode == TargetOpcode::G_USHLSAT && |
1878 | MatchInfo.Imm >= MRI.getType(Reg: Shl2).getScalarSizeInBits()) |
1879 | return false; |
1880 | |
1881 | return true; |
1882 | } |
1883 | |
1884 | void CombinerHelper::applyShiftImmedChain(MachineInstr &MI, |
1885 | RegisterImmPair &MatchInfo) const { |
1886 | unsigned Opcode = MI.getOpcode(); |
1887 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
1888 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || |
1889 | Opcode == TargetOpcode::G_USHLSAT) && |
1890 | "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT" ); |
1891 | |
1892 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 1).getReg()); |
1893 | unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits(); |
1894 | auto Imm = MatchInfo.Imm; |
1895 | |
1896 | if (Imm >= ScalarSizeInBits) { |
1897 | // Any logical shift that exceeds scalar size will produce zero. |
1898 | if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) { |
1899 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: 0); |
1900 | MI.eraseFromParent(); |
1901 | return; |
1902 | } |
1903 | // Arithmetic shift and saturating signed left shift have no effect beyond |
1904 | // scalar size. |
1905 | Imm = ScalarSizeInBits - 1; |
1906 | } |
1907 | |
1908 | LLT ImmTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
1909 | Register NewImm = Builder.buildConstant(Res: ImmTy, Val: Imm).getReg(Idx: 0); |
1910 | Observer.changingInstr(MI); |
1911 | MI.getOperand(i: 1).setReg(MatchInfo.Reg); |
1912 | MI.getOperand(i: 2).setReg(NewImm); |
1913 | Observer.changedInstr(MI); |
1914 | } |
1915 | |
1916 | bool CombinerHelper::matchShiftOfShiftedLogic( |
1917 | MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const { |
1918 | // We're trying to match the following pattern with any of |
1919 | // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination |
1920 | // with any of G_AND/G_OR/G_XOR logic instructions. |
1921 | // %t1 = SHIFT %X, G_CONSTANT C0 |
1922 | // %t2 = LOGIC %t1, %Y |
1923 | // %root = SHIFT %t2, G_CONSTANT C1 |
1924 | // --> |
1925 | // %t3 = SHIFT %X, G_CONSTANT (C0+C1) |
1926 | // %t4 = SHIFT %Y, G_CONSTANT C1 |
1927 | // %root = LOGIC %t3, %t4 |
1928 | unsigned ShiftOpcode = MI.getOpcode(); |
1929 | assert((ShiftOpcode == TargetOpcode::G_SHL || |
1930 | ShiftOpcode == TargetOpcode::G_ASHR || |
1931 | ShiftOpcode == TargetOpcode::G_LSHR || |
1932 | ShiftOpcode == TargetOpcode::G_USHLSAT || |
1933 | ShiftOpcode == TargetOpcode::G_SSHLSAT) && |
1934 | "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT" ); |
1935 | |
1936 | // Match a one-use bitwise logic op. |
1937 | Register LogicDest = MI.getOperand(i: 1).getReg(); |
1938 | if (!MRI.hasOneNonDBGUse(RegNo: LogicDest)) |
1939 | return false; |
1940 | |
1941 | MachineInstr *LogicMI = MRI.getUniqueVRegDef(Reg: LogicDest); |
1942 | unsigned LogicOpcode = LogicMI->getOpcode(); |
1943 | if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR && |
1944 | LogicOpcode != TargetOpcode::G_XOR) |
1945 | return false; |
1946 | |
1947 | // Find a matching one-use shift by constant. |
1948 | const Register C1 = MI.getOperand(i: 2).getReg(); |
1949 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: C1, MRI); |
1950 | if (!MaybeImmVal || MaybeImmVal->Value == 0) |
1951 | return false; |
1952 | |
1953 | const uint64_t C1Val = MaybeImmVal->Value.getZExtValue(); |
1954 | |
1955 | auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) { |
1956 | // Shift should match previous one and should be a one-use. |
1957 | if (MI->getOpcode() != ShiftOpcode || |
1958 | !MRI.hasOneNonDBGUse(RegNo: MI->getOperand(i: 0).getReg())) |
1959 | return false; |
1960 | |
1961 | // Must be a constant. |
1962 | auto MaybeImmVal = |
1963 | getIConstantVRegValWithLookThrough(VReg: MI->getOperand(i: 2).getReg(), MRI); |
1964 | if (!MaybeImmVal) |
1965 | return false; |
1966 | |
1967 | ShiftVal = MaybeImmVal->Value.getSExtValue(); |
1968 | return true; |
1969 | }; |
1970 | |
1971 | // Logic ops are commutative, so check each operand for a match. |
1972 | Register LogicMIReg1 = LogicMI->getOperand(i: 1).getReg(); |
1973 | MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(Reg: LogicMIReg1); |
1974 | Register LogicMIReg2 = LogicMI->getOperand(i: 2).getReg(); |
1975 | MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(Reg: LogicMIReg2); |
1976 | uint64_t C0Val; |
1977 | |
1978 | if (matchFirstShift(LogicMIOp1, C0Val)) { |
1979 | MatchInfo.LogicNonShiftReg = LogicMIReg2; |
1980 | MatchInfo.Shift2 = LogicMIOp1; |
1981 | } else if (matchFirstShift(LogicMIOp2, C0Val)) { |
1982 | MatchInfo.LogicNonShiftReg = LogicMIReg1; |
1983 | MatchInfo.Shift2 = LogicMIOp2; |
1984 | } else |
1985 | return false; |
1986 | |
1987 | MatchInfo.ValSum = C0Val + C1Val; |
1988 | |
1989 | // The fold is not valid if the sum of the shift values exceeds bitwidth. |
1990 | if (MatchInfo.ValSum >= MRI.getType(Reg: LogicDest).getScalarSizeInBits()) |
1991 | return false; |
1992 | |
1993 | MatchInfo.Logic = LogicMI; |
1994 | return true; |
1995 | } |
1996 | |
1997 | void CombinerHelper::applyShiftOfShiftedLogic( |
1998 | MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const { |
1999 | unsigned Opcode = MI.getOpcode(); |
2000 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
2001 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT || |
2002 | Opcode == TargetOpcode::G_SSHLSAT) && |
2003 | "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT" ); |
2004 | |
2005 | LLT ShlType = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
2006 | LLT DestType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2007 | |
2008 | Register Const = Builder.buildConstant(Res: ShlType, Val: MatchInfo.ValSum).getReg(Idx: 0); |
2009 | |
2010 | Register Shift1Base = MatchInfo.Shift2->getOperand(i: 1).getReg(); |
2011 | Register Shift1 = |
2012 | Builder.buildInstr(Opc: Opcode, DstOps: {DestType}, SrcOps: {Shift1Base, Const}).getReg(Idx: 0); |
2013 | |
2014 | // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same |
2015 | // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when |
2016 | // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we |
2017 | // remove old shift1. And it will cause crash later. So erase it earlier to |
2018 | // avoid the crash. |
2019 | MatchInfo.Shift2->eraseFromParent(); |
2020 | |
2021 | Register Shift2Const = MI.getOperand(i: 2).getReg(); |
2022 | Register Shift2 = Builder |
2023 | .buildInstr(Opc: Opcode, DstOps: {DestType}, |
2024 | SrcOps: {MatchInfo.LogicNonShiftReg, Shift2Const}) |
2025 | .getReg(Idx: 0); |
2026 | |
2027 | Register Dest = MI.getOperand(i: 0).getReg(); |
2028 | Builder.buildInstr(Opc: MatchInfo.Logic->getOpcode(), DstOps: {Dest}, SrcOps: {Shift1, Shift2}); |
2029 | |
2030 | // This was one use so it's safe to remove it. |
2031 | MatchInfo.Logic->eraseFromParent(); |
2032 | |
2033 | MI.eraseFromParent(); |
2034 | } |
2035 | |
2036 | bool CombinerHelper::matchCommuteShift(MachineInstr &MI, |
2037 | BuildFnTy &MatchInfo) const { |
2038 | assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL" ); |
2039 | // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2) |
2040 | // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2) |
2041 | auto &Shl = cast<GenericMachineInstr>(Val&: MI); |
2042 | Register DstReg = Shl.getReg(Idx: 0); |
2043 | Register SrcReg = Shl.getReg(Idx: 1); |
2044 | Register ShiftReg = Shl.getReg(Idx: 2); |
2045 | Register X, C1; |
2046 | |
2047 | if (!getTargetLowering().isDesirableToCommuteWithShift(MI, IsAfterLegal: !isPreLegalize())) |
2048 | return false; |
2049 | |
2050 | if (!mi_match(R: SrcReg, MRI, |
2051 | P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: C1)), |
2052 | preds: m_GOr(L: m_Reg(R&: X), R: m_Reg(R&: C1)))))) |
2053 | return false; |
2054 | |
2055 | APInt C1Val, C2Val; |
2056 | if (!mi_match(R: C1, MRI, P: m_ICstOrSplat(Cst&: C1Val)) || |
2057 | !mi_match(R: ShiftReg, MRI, P: m_ICstOrSplat(Cst&: C2Val))) |
2058 | return false; |
2059 | |
2060 | auto *SrcDef = MRI.getVRegDef(Reg: SrcReg); |
2061 | assert((SrcDef->getOpcode() == TargetOpcode::G_ADD || |
2062 | SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op" ); |
2063 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2064 | MatchInfo = [=](MachineIRBuilder &B) { |
2065 | auto S1 = B.buildShl(Dst: SrcTy, Src0: X, Src1: ShiftReg); |
2066 | auto S2 = B.buildShl(Dst: SrcTy, Src0: C1, Src1: ShiftReg); |
2067 | B.buildInstr(Opc: SrcDef->getOpcode(), DstOps: {DstReg}, SrcOps: {S1, S2}); |
2068 | }; |
2069 | return true; |
2070 | } |
2071 | |
2072 | bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI, |
2073 | unsigned &ShiftVal) const { |
2074 | assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL" ); |
2075 | auto MaybeImmVal = |
2076 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
2077 | if (!MaybeImmVal) |
2078 | return false; |
2079 | |
2080 | ShiftVal = MaybeImmVal->Value.exactLogBase2(); |
2081 | return (static_cast<int32_t>(ShiftVal) != -1); |
2082 | } |
2083 | |
2084 | void CombinerHelper::applyCombineMulToShl(MachineInstr &MI, |
2085 | unsigned &ShiftVal) const { |
2086 | assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL" ); |
2087 | MachineIRBuilder MIB(MI); |
2088 | LLT ShiftTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2089 | auto ShiftCst = MIB.buildConstant(Res: ShiftTy, Val: ShiftVal); |
2090 | Observer.changingInstr(MI); |
2091 | MI.setDesc(MIB.getTII().get(Opcode: TargetOpcode::G_SHL)); |
2092 | MI.getOperand(i: 2).setReg(ShiftCst.getReg(Idx: 0)); |
2093 | if (ShiftVal == ShiftTy.getScalarSizeInBits() - 1) |
2094 | MI.clearFlag(Flag: MachineInstr::MIFlag::NoSWrap); |
2095 | Observer.changedInstr(MI); |
2096 | } |
2097 | |
2098 | bool CombinerHelper::matchCombineSubToAdd(MachineInstr &MI, |
2099 | BuildFnTy &MatchInfo) const { |
2100 | GSub &Sub = cast<GSub>(Val&: MI); |
2101 | |
2102 | LLT Ty = MRI.getType(Reg: Sub.getReg(Idx: 0)); |
2103 | |
2104 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {Ty}})) |
2105 | return false; |
2106 | |
2107 | if (!isConstantLegalOrBeforeLegalizer(Ty)) |
2108 | return false; |
2109 | |
2110 | APInt Imm = getIConstantFromReg(VReg: Sub.getRHSReg(), MRI); |
2111 | |
2112 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
2113 | auto NegCst = B.buildConstant(Res: Ty, Val: -Imm); |
2114 | Observer.changingInstr(MI); |
2115 | MI.setDesc(B.getTII().get(Opcode: TargetOpcode::G_ADD)); |
2116 | MI.getOperand(i: 2).setReg(NegCst.getReg(Idx: 0)); |
2117 | MI.clearFlag(Flag: MachineInstr::MIFlag::NoUWrap); |
2118 | if (Imm.isMinSignedValue()) |
2119 | MI.clearFlags(flags: MachineInstr::MIFlag::NoSWrap); |
2120 | Observer.changedInstr(MI); |
2121 | }; |
2122 | return true; |
2123 | } |
2124 | |
2125 | // shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source |
2126 | bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI, |
2127 | RegisterImmPair &MatchData) const { |
2128 | assert(MI.getOpcode() == TargetOpcode::G_SHL && VT); |
2129 | if (!getTargetLowering().isDesirableToPullExtFromShl(MI)) |
2130 | return false; |
2131 | |
2132 | Register LHS = MI.getOperand(i: 1).getReg(); |
2133 | |
2134 | Register ExtSrc; |
2135 | if (!mi_match(R: LHS, MRI, P: m_GAnyExt(Src: m_Reg(R&: ExtSrc))) && |
2136 | !mi_match(R: LHS, MRI, P: m_GZExt(Src: m_Reg(R&: ExtSrc))) && |
2137 | !mi_match(R: LHS, MRI, P: m_GSExt(Src: m_Reg(R&: ExtSrc)))) |
2138 | return false; |
2139 | |
2140 | Register RHS = MI.getOperand(i: 2).getReg(); |
2141 | MachineInstr *MIShiftAmt = MRI.getVRegDef(Reg: RHS); |
2142 | auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(MI&: *MIShiftAmt, MRI); |
2143 | if (!MaybeShiftAmtVal) |
2144 | return false; |
2145 | |
2146 | if (LI) { |
2147 | LLT SrcTy = MRI.getType(Reg: ExtSrc); |
2148 | |
2149 | // We only really care about the legality with the shifted value. We can |
2150 | // pick any type the constant shift amount, so ask the target what to |
2151 | // use. Otherwise we would have to guess and hope it is reported as legal. |
2152 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: SrcTy); |
2153 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}})) |
2154 | return false; |
2155 | } |
2156 | |
2157 | int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue(); |
2158 | MatchData.Reg = ExtSrc; |
2159 | MatchData.Imm = ShiftAmt; |
2160 | |
2161 | unsigned MinLeadingZeros = VT->getKnownZeroes(R: ExtSrc).countl_one(); |
2162 | unsigned SrcTySize = MRI.getType(Reg: ExtSrc).getScalarSizeInBits(); |
2163 | return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize; |
2164 | } |
2165 | |
2166 | void CombinerHelper::applyCombineShlOfExtend( |
2167 | MachineInstr &MI, const RegisterImmPair &MatchData) const { |
2168 | Register ExtSrcReg = MatchData.Reg; |
2169 | int64_t ShiftAmtVal = MatchData.Imm; |
2170 | |
2171 | LLT ExtSrcTy = MRI.getType(Reg: ExtSrcReg); |
2172 | auto ShiftAmt = Builder.buildConstant(Res: ExtSrcTy, Val: ShiftAmtVal); |
2173 | auto NarrowShift = |
2174 | Builder.buildShl(Dst: ExtSrcTy, Src0: ExtSrcReg, Src1: ShiftAmt, Flags: MI.getFlags()); |
2175 | Builder.buildZExt(Res: MI.getOperand(i: 0), Op: NarrowShift); |
2176 | MI.eraseFromParent(); |
2177 | } |
2178 | |
2179 | bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI, |
2180 | Register &MatchInfo) const { |
2181 | GMerge &Merge = cast<GMerge>(Val&: MI); |
2182 | SmallVector<Register, 16> MergedValues; |
2183 | for (unsigned I = 0; I < Merge.getNumSources(); ++I) |
2184 | MergedValues.emplace_back(Args: Merge.getSourceReg(I)); |
2185 | |
2186 | auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: MergedValues[0], MRI); |
2187 | if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources()) |
2188 | return false; |
2189 | |
2190 | for (unsigned I = 0; I < MergedValues.size(); ++I) |
2191 | if (MergedValues[I] != Unmerge->getReg(Idx: I)) |
2192 | return false; |
2193 | |
2194 | MatchInfo = Unmerge->getSourceReg(); |
2195 | return true; |
2196 | } |
2197 | |
2198 | static Register peekThroughBitcast(Register Reg, |
2199 | const MachineRegisterInfo &MRI) { |
2200 | while (mi_match(R: Reg, MRI, P: m_GBitcast(Src: m_Reg(R&: Reg)))) |
2201 | ; |
2202 | |
2203 | return Reg; |
2204 | } |
2205 | |
2206 | bool CombinerHelper::matchCombineUnmergeMergeToPlainValues( |
2207 | MachineInstr &MI, SmallVectorImpl<Register> &Operands) const { |
2208 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2209 | "Expected an unmerge" ); |
2210 | auto &Unmerge = cast<GUnmerge>(Val&: MI); |
2211 | Register SrcReg = peekThroughBitcast(Reg: Unmerge.getSourceReg(), MRI); |
2212 | |
2213 | auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(Reg: SrcReg, MRI); |
2214 | if (!SrcInstr) |
2215 | return false; |
2216 | |
2217 | // Check the source type of the merge. |
2218 | LLT SrcMergeTy = MRI.getType(Reg: SrcInstr->getSourceReg(I: 0)); |
2219 | LLT Dst0Ty = MRI.getType(Reg: Unmerge.getReg(Idx: 0)); |
2220 | bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits(); |
2221 | if (SrcMergeTy != Dst0Ty && !SameSize) |
2222 | return false; |
2223 | // They are the same now (modulo a bitcast). |
2224 | // We can collect all the src registers. |
2225 | for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx) |
2226 | Operands.push_back(Elt: SrcInstr->getSourceReg(I: Idx)); |
2227 | return true; |
2228 | } |
2229 | |
2230 | void CombinerHelper::applyCombineUnmergeMergeToPlainValues( |
2231 | MachineInstr &MI, SmallVectorImpl<Register> &Operands) const { |
2232 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2233 | "Expected an unmerge" ); |
2234 | assert((MI.getNumOperands() - 1 == Operands.size()) && |
2235 | "Not enough operands to replace all defs" ); |
2236 | unsigned NumElems = MI.getNumOperands() - 1; |
2237 | |
2238 | LLT SrcTy = MRI.getType(Reg: Operands[0]); |
2239 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2240 | bool CanReuseInputDirectly = DstTy == SrcTy; |
2241 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2242 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2243 | Register SrcReg = Operands[Idx]; |
2244 | |
2245 | // This combine may run after RegBankSelect, so we need to be aware of |
2246 | // register banks. |
2247 | const auto &DstCB = MRI.getRegClassOrRegBank(Reg: DstReg); |
2248 | if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(Reg: SrcReg)) { |
2249 | SrcReg = Builder.buildCopy(Res: MRI.getType(Reg: SrcReg), Op: SrcReg).getReg(Idx: 0); |
2250 | MRI.setRegClassOrRegBank(Reg: SrcReg, RCOrRB: DstCB); |
2251 | } |
2252 | |
2253 | if (CanReuseInputDirectly) |
2254 | replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg); |
2255 | else |
2256 | Builder.buildCast(Dst: DstReg, Src: SrcReg); |
2257 | } |
2258 | MI.eraseFromParent(); |
2259 | } |
2260 | |
2261 | bool CombinerHelper::matchCombineUnmergeConstant( |
2262 | MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const { |
2263 | unsigned SrcIdx = MI.getNumOperands() - 1; |
2264 | Register SrcReg = MI.getOperand(i: SrcIdx).getReg(); |
2265 | MachineInstr *SrcInstr = MRI.getVRegDef(Reg: SrcReg); |
2266 | if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT && |
2267 | SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT) |
2268 | return false; |
2269 | // Break down the big constant in smaller ones. |
2270 | const MachineOperand &CstVal = SrcInstr->getOperand(i: 1); |
2271 | APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT |
2272 | ? CstVal.getCImm()->getValue() |
2273 | : CstVal.getFPImm()->getValueAPF().bitcastToAPInt(); |
2274 | |
2275 | LLT Dst0Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2276 | unsigned ShiftAmt = Dst0Ty.getSizeInBits(); |
2277 | // Unmerge a constant. |
2278 | for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) { |
2279 | Csts.emplace_back(Args: Val.trunc(width: ShiftAmt)); |
2280 | Val = Val.lshr(shiftAmt: ShiftAmt); |
2281 | } |
2282 | |
2283 | return true; |
2284 | } |
2285 | |
2286 | void CombinerHelper::applyCombineUnmergeConstant( |
2287 | MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const { |
2288 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2289 | "Expected an unmerge" ); |
2290 | assert((MI.getNumOperands() - 1 == Csts.size()) && |
2291 | "Not enough operands to replace all defs" ); |
2292 | unsigned NumElems = MI.getNumOperands() - 1; |
2293 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2294 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2295 | Builder.buildConstant(Res: DstReg, Val: Csts[Idx]); |
2296 | } |
2297 | |
2298 | MI.eraseFromParent(); |
2299 | } |
2300 | |
2301 | bool CombinerHelper::matchCombineUnmergeUndef( |
2302 | MachineInstr &MI, |
2303 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
2304 | unsigned SrcIdx = MI.getNumOperands() - 1; |
2305 | Register SrcReg = MI.getOperand(i: SrcIdx).getReg(); |
2306 | MatchInfo = [&MI](MachineIRBuilder &B) { |
2307 | unsigned NumElems = MI.getNumOperands() - 1; |
2308 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2309 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2310 | B.buildUndef(Res: DstReg); |
2311 | } |
2312 | }; |
2313 | return isa<GImplicitDef>(Val: MRI.getVRegDef(Reg: SrcReg)); |
2314 | } |
2315 | |
2316 | bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc( |
2317 | MachineInstr &MI) const { |
2318 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2319 | "Expected an unmerge" ); |
2320 | if (MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector() || |
2321 | MRI.getType(Reg: MI.getOperand(i: MI.getNumDefs()).getReg()).isVector()) |
2322 | return false; |
2323 | // Check that all the lanes are dead except the first one. |
2324 | for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { |
2325 | if (!MRI.use_nodbg_empty(RegNo: MI.getOperand(i: Idx).getReg())) |
2326 | return false; |
2327 | } |
2328 | return true; |
2329 | } |
2330 | |
2331 | void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc( |
2332 | MachineInstr &MI) const { |
2333 | Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg(); |
2334 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2335 | Builder.buildTrunc(Res: Dst0Reg, Op: SrcReg); |
2336 | MI.eraseFromParent(); |
2337 | } |
2338 | |
2339 | bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) const { |
2340 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2341 | "Expected an unmerge" ); |
2342 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2343 | LLT Dst0Ty = MRI.getType(Reg: Dst0Reg); |
2344 | // G_ZEXT on vector applies to each lane, so it will |
2345 | // affect all destinations. Therefore we won't be able |
2346 | // to simplify the unmerge to just the first definition. |
2347 | if (Dst0Ty.isVector()) |
2348 | return false; |
2349 | Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg(); |
2350 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2351 | if (SrcTy.isVector()) |
2352 | return false; |
2353 | |
2354 | Register ZExtSrcReg; |
2355 | if (!mi_match(R: SrcReg, MRI, P: m_GZExt(Src: m_Reg(R&: ZExtSrcReg)))) |
2356 | return false; |
2357 | |
2358 | // Finally we can replace the first definition with |
2359 | // a zext of the source if the definition is big enough to hold |
2360 | // all of ZExtSrc bits. |
2361 | LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg); |
2362 | return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits(); |
2363 | } |
2364 | |
2365 | void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) const { |
2366 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2367 | "Expected an unmerge" ); |
2368 | |
2369 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2370 | |
2371 | MachineInstr *ZExtInstr = |
2372 | MRI.getVRegDef(Reg: MI.getOperand(i: MI.getNumDefs()).getReg()); |
2373 | assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT && |
2374 | "Expecting a G_ZEXT" ); |
2375 | |
2376 | Register ZExtSrcReg = ZExtInstr->getOperand(i: 1).getReg(); |
2377 | LLT Dst0Ty = MRI.getType(Reg: Dst0Reg); |
2378 | LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg); |
2379 | |
2380 | if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) { |
2381 | Builder.buildZExt(Res: Dst0Reg, Op: ZExtSrcReg); |
2382 | } else { |
2383 | assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() && |
2384 | "ZExt src doesn't fit in destination" ); |
2385 | replaceRegWith(MRI, FromReg: Dst0Reg, ToReg: ZExtSrcReg); |
2386 | } |
2387 | |
2388 | Register ZeroReg; |
2389 | for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { |
2390 | if (!ZeroReg) |
2391 | ZeroReg = Builder.buildConstant(Res: Dst0Ty, Val: 0).getReg(Idx: 0); |
2392 | replaceRegWith(MRI, FromReg: MI.getOperand(i: Idx).getReg(), ToReg: ZeroReg); |
2393 | } |
2394 | MI.eraseFromParent(); |
2395 | } |
2396 | |
2397 | bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI, |
2398 | unsigned TargetShiftSize, |
2399 | unsigned &ShiftVal) const { |
2400 | assert((MI.getOpcode() == TargetOpcode::G_SHL || |
2401 | MI.getOpcode() == TargetOpcode::G_LSHR || |
2402 | MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift" ); |
2403 | |
2404 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2405 | if (Ty.isVector()) // TODO: |
2406 | return false; |
2407 | |
2408 | // Don't narrow further than the requested size. |
2409 | unsigned Size = Ty.getSizeInBits(); |
2410 | if (Size <= TargetShiftSize) |
2411 | return false; |
2412 | |
2413 | auto MaybeImmVal = |
2414 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
2415 | if (!MaybeImmVal) |
2416 | return false; |
2417 | |
2418 | ShiftVal = MaybeImmVal->Value.getSExtValue(); |
2419 | return ShiftVal >= Size / 2 && ShiftVal < Size; |
2420 | } |
2421 | |
2422 | void CombinerHelper::applyCombineShiftToUnmerge( |
2423 | MachineInstr &MI, const unsigned &ShiftVal) const { |
2424 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2425 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2426 | LLT Ty = MRI.getType(Reg: SrcReg); |
2427 | unsigned Size = Ty.getSizeInBits(); |
2428 | unsigned HalfSize = Size / 2; |
2429 | assert(ShiftVal >= HalfSize); |
2430 | |
2431 | LLT HalfTy = LLT::scalar(SizeInBits: HalfSize); |
2432 | |
2433 | auto Unmerge = Builder.buildUnmerge(Res: HalfTy, Op: SrcReg); |
2434 | unsigned NarrowShiftAmt = ShiftVal - HalfSize; |
2435 | |
2436 | if (MI.getOpcode() == TargetOpcode::G_LSHR) { |
2437 | Register Narrowed = Unmerge.getReg(Idx: 1); |
2438 | |
2439 | // dst = G_LSHR s64:x, C for C >= 32 |
2440 | // => |
2441 | // lo, hi = G_UNMERGE_VALUES x |
2442 | // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0 |
2443 | |
2444 | if (NarrowShiftAmt != 0) { |
2445 | Narrowed = Builder.buildLShr(Dst: HalfTy, Src0: Narrowed, |
2446 | Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0); |
2447 | } |
2448 | |
2449 | auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0); |
2450 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Narrowed, Zero}); |
2451 | } else if (MI.getOpcode() == TargetOpcode::G_SHL) { |
2452 | Register Narrowed = Unmerge.getReg(Idx: 0); |
2453 | // dst = G_SHL s64:x, C for C >= 32 |
2454 | // => |
2455 | // lo, hi = G_UNMERGE_VALUES x |
2456 | // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32) |
2457 | if (NarrowShiftAmt != 0) { |
2458 | Narrowed = Builder.buildShl(Dst: HalfTy, Src0: Narrowed, |
2459 | Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0); |
2460 | } |
2461 | |
2462 | auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0); |
2463 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Zero, Narrowed}); |
2464 | } else { |
2465 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
2466 | auto Hi = Builder.buildAShr( |
2467 | Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1), |
2468 | Src1: Builder.buildConstant(Res: HalfTy, Val: HalfSize - 1)); |
2469 | |
2470 | if (ShiftVal == HalfSize) { |
2471 | // (G_ASHR i64:x, 32) -> |
2472 | // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31) |
2473 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Unmerge.getReg(Idx: 1), Hi}); |
2474 | } else if (ShiftVal == Size - 1) { |
2475 | // Don't need a second shift. |
2476 | // (G_ASHR i64:x, 63) -> |
2477 | // %narrowed = (G_ASHR hi_32(x), 31) |
2478 | // G_MERGE_VALUES %narrowed, %narrowed |
2479 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Hi, Hi}); |
2480 | } else { |
2481 | auto Lo = Builder.buildAShr( |
2482 | Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1), |
2483 | Src1: Builder.buildConstant(Res: HalfTy, Val: ShiftVal - HalfSize)); |
2484 | |
2485 | // (G_ASHR i64:x, C) ->, for C >= 32 |
2486 | // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31) |
2487 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Lo, Hi}); |
2488 | } |
2489 | } |
2490 | |
2491 | MI.eraseFromParent(); |
2492 | } |
2493 | |
2494 | bool CombinerHelper::tryCombineShiftToUnmerge( |
2495 | MachineInstr &MI, unsigned TargetShiftAmount) const { |
2496 | unsigned ShiftAmt; |
2497 | if (matchCombineShiftToUnmerge(MI, TargetShiftSize: TargetShiftAmount, ShiftVal&: ShiftAmt)) { |
2498 | applyCombineShiftToUnmerge(MI, ShiftVal: ShiftAmt); |
2499 | return true; |
2500 | } |
2501 | |
2502 | return false; |
2503 | } |
2504 | |
2505 | bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI, |
2506 | Register &Reg) const { |
2507 | assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR" ); |
2508 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2509 | LLT DstTy = MRI.getType(Reg: DstReg); |
2510 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2511 | return mi_match(R: SrcReg, MRI, |
2512 | P: m_GPtrToInt(Src: m_all_of(preds: m_SpecificType(Ty: DstTy), preds: m_Reg(R&: Reg)))); |
2513 | } |
2514 | |
2515 | void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI, |
2516 | Register &Reg) const { |
2517 | assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR" ); |
2518 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2519 | Builder.buildCopy(Res: DstReg, Op: Reg); |
2520 | MI.eraseFromParent(); |
2521 | } |
2522 | |
2523 | void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI, |
2524 | Register &Reg) const { |
2525 | assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT" ); |
2526 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2527 | Builder.buildZExtOrTrunc(Res: DstReg, Op: Reg); |
2528 | MI.eraseFromParent(); |
2529 | } |
2530 | |
2531 | bool CombinerHelper::matchCombineAddP2IToPtrAdd( |
2532 | MachineInstr &MI, std::pair<Register, bool> &PtrReg) const { |
2533 | assert(MI.getOpcode() == TargetOpcode::G_ADD); |
2534 | Register LHS = MI.getOperand(i: 1).getReg(); |
2535 | Register RHS = MI.getOperand(i: 2).getReg(); |
2536 | LLT IntTy = MRI.getType(Reg: LHS); |
2537 | |
2538 | // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the |
2539 | // instruction. |
2540 | PtrReg.second = false; |
2541 | for (Register SrcReg : {LHS, RHS}) { |
2542 | if (mi_match(R: SrcReg, MRI, P: m_GPtrToInt(Src: m_Reg(R&: PtrReg.first)))) { |
2543 | // Don't handle cases where the integer is implicitly converted to the |
2544 | // pointer width. |
2545 | LLT PtrTy = MRI.getType(Reg: PtrReg.first); |
2546 | if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits()) |
2547 | return true; |
2548 | } |
2549 | |
2550 | PtrReg.second = true; |
2551 | } |
2552 | |
2553 | return false; |
2554 | } |
2555 | |
2556 | void CombinerHelper::applyCombineAddP2IToPtrAdd( |
2557 | MachineInstr &MI, std::pair<Register, bool> &PtrReg) const { |
2558 | Register Dst = MI.getOperand(i: 0).getReg(); |
2559 | Register LHS = MI.getOperand(i: 1).getReg(); |
2560 | Register RHS = MI.getOperand(i: 2).getReg(); |
2561 | |
2562 | const bool DoCommute = PtrReg.second; |
2563 | if (DoCommute) |
2564 | std::swap(a&: LHS, b&: RHS); |
2565 | LHS = PtrReg.first; |
2566 | |
2567 | LLT PtrTy = MRI.getType(Reg: LHS); |
2568 | |
2569 | auto PtrAdd = Builder.buildPtrAdd(Res: PtrTy, Op0: LHS, Op1: RHS); |
2570 | Builder.buildPtrToInt(Dst, Src: PtrAdd); |
2571 | MI.eraseFromParent(); |
2572 | } |
2573 | |
2574 | bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI, |
2575 | APInt &NewCst) const { |
2576 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
2577 | Register LHS = PtrAdd.getBaseReg(); |
2578 | Register RHS = PtrAdd.getOffsetReg(); |
2579 | MachineRegisterInfo &MRI = Builder.getMF().getRegInfo(); |
2580 | |
2581 | if (auto RHSCst = getIConstantVRegVal(VReg: RHS, MRI)) { |
2582 | APInt Cst; |
2583 | if (mi_match(R: LHS, MRI, P: m_GIntToPtr(Src: m_ICst(Cst)))) { |
2584 | auto DstTy = MRI.getType(Reg: PtrAdd.getReg(Idx: 0)); |
2585 | // G_INTTOPTR uses zero-extension |
2586 | NewCst = Cst.zextOrTrunc(width: DstTy.getSizeInBits()); |
2587 | NewCst += RHSCst->sextOrTrunc(width: DstTy.getSizeInBits()); |
2588 | return true; |
2589 | } |
2590 | } |
2591 | |
2592 | return false; |
2593 | } |
2594 | |
2595 | void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI, |
2596 | APInt &NewCst) const { |
2597 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
2598 | Register Dst = PtrAdd.getReg(Idx: 0); |
2599 | |
2600 | Builder.buildConstant(Res: Dst, Val: NewCst); |
2601 | PtrAdd.eraseFromParent(); |
2602 | } |
2603 | |
2604 | bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI, |
2605 | Register &Reg) const { |
2606 | assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT" ); |
2607 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2608 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2609 | Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI); |
2610 | if (OriginalSrcReg.isValid()) |
2611 | SrcReg = OriginalSrcReg; |
2612 | LLT DstTy = MRI.getType(Reg: DstReg); |
2613 | return mi_match(R: SrcReg, MRI, |
2614 | P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))) && |
2615 | canReplaceReg(DstReg, SrcReg: Reg, MRI); |
2616 | } |
2617 | |
2618 | bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI, |
2619 | Register &Reg) const { |
2620 | assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT" ); |
2621 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2622 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2623 | LLT DstTy = MRI.getType(Reg: DstReg); |
2624 | if (mi_match(R: SrcReg, MRI, |
2625 | P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))) && |
2626 | canReplaceReg(DstReg, SrcReg: Reg, MRI)) { |
2627 | unsigned DstSize = DstTy.getScalarSizeInBits(); |
2628 | unsigned SrcSize = MRI.getType(Reg: SrcReg).getScalarSizeInBits(); |
2629 | return VT->getKnownBits(R: Reg).countMinLeadingZeros() >= DstSize - SrcSize; |
2630 | } |
2631 | return false; |
2632 | } |
2633 | |
2634 | static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) { |
2635 | const unsigned ShiftSize = ShiftTy.getScalarSizeInBits(); |
2636 | const unsigned TruncSize = TruncTy.getScalarSizeInBits(); |
2637 | |
2638 | // ShiftTy > 32 > TruncTy -> 32 |
2639 | if (ShiftSize > 32 && TruncSize < 32) |
2640 | return ShiftTy.changeElementSize(NewEltSize: 32); |
2641 | |
2642 | // TODO: We could also reduce to 16 bits, but that's more target-dependent. |
2643 | // Some targets like it, some don't, some only like it under certain |
2644 | // conditions/processor versions, etc. |
2645 | // A TL hook might be needed for this. |
2646 | |
2647 | // Don't combine |
2648 | return ShiftTy; |
2649 | } |
2650 | |
2651 | bool CombinerHelper::matchCombineTruncOfShift( |
2652 | MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const { |
2653 | assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC" ); |
2654 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2655 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2656 | |
2657 | if (!MRI.hasOneNonDBGUse(RegNo: SrcReg)) |
2658 | return false; |
2659 | |
2660 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2661 | LLT DstTy = MRI.getType(Reg: DstReg); |
2662 | |
2663 | MachineInstr *SrcMI = getDefIgnoringCopies(Reg: SrcReg, MRI); |
2664 | const auto &TL = getTargetLowering(); |
2665 | |
2666 | LLT NewShiftTy; |
2667 | switch (SrcMI->getOpcode()) { |
2668 | default: |
2669 | return false; |
2670 | case TargetOpcode::G_SHL: { |
2671 | NewShiftTy = DstTy; |
2672 | |
2673 | // Make sure new shift amount is legal. |
2674 | KnownBits Known = VT->getKnownBits(R: SrcMI->getOperand(i: 2).getReg()); |
2675 | if (Known.getMaxValue().uge(RHS: NewShiftTy.getScalarSizeInBits())) |
2676 | return false; |
2677 | break; |
2678 | } |
2679 | case TargetOpcode::G_LSHR: |
2680 | case TargetOpcode::G_ASHR: { |
2681 | // For right shifts, we conservatively do not do the transform if the TRUNC |
2682 | // has any STORE users. The reason is that if we change the type of the |
2683 | // shift, we may break the truncstore combine. |
2684 | // |
2685 | // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)). |
2686 | for (auto &User : MRI.use_instructions(Reg: DstReg)) |
2687 | if (User.getOpcode() == TargetOpcode::G_STORE) |
2688 | return false; |
2689 | |
2690 | NewShiftTy = getMidVTForTruncRightShiftCombine(ShiftTy: SrcTy, TruncTy: DstTy); |
2691 | if (NewShiftTy == SrcTy) |
2692 | return false; |
2693 | |
2694 | // Make sure we won't lose information by truncating the high bits. |
2695 | KnownBits Known = VT->getKnownBits(R: SrcMI->getOperand(i: 2).getReg()); |
2696 | if (Known.getMaxValue().ugt(RHS: NewShiftTy.getScalarSizeInBits() - |
2697 | DstTy.getScalarSizeInBits())) |
2698 | return false; |
2699 | break; |
2700 | } |
2701 | } |
2702 | |
2703 | if (!isLegalOrBeforeLegalizer( |
2704 | Query: {SrcMI->getOpcode(), |
2705 | {NewShiftTy, TL.getPreferredShiftAmountTy(ShiftValueTy: NewShiftTy)}})) |
2706 | return false; |
2707 | |
2708 | MatchInfo = std::make_pair(x&: SrcMI, y&: NewShiftTy); |
2709 | return true; |
2710 | } |
2711 | |
2712 | void CombinerHelper::applyCombineTruncOfShift( |
2713 | MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const { |
2714 | MachineInstr *ShiftMI = MatchInfo.first; |
2715 | LLT NewShiftTy = MatchInfo.second; |
2716 | |
2717 | Register Dst = MI.getOperand(i: 0).getReg(); |
2718 | LLT DstTy = MRI.getType(Reg: Dst); |
2719 | |
2720 | Register ShiftAmt = ShiftMI->getOperand(i: 2).getReg(); |
2721 | Register ShiftSrc = ShiftMI->getOperand(i: 1).getReg(); |
2722 | ShiftSrc = Builder.buildTrunc(Res: NewShiftTy, Op: ShiftSrc).getReg(Idx: 0); |
2723 | |
2724 | Register NewShift = |
2725 | Builder |
2726 | .buildInstr(Opc: ShiftMI->getOpcode(), DstOps: {NewShiftTy}, SrcOps: {ShiftSrc, ShiftAmt}) |
2727 | .getReg(Idx: 0); |
2728 | |
2729 | if (NewShiftTy == DstTy) |
2730 | replaceRegWith(MRI, FromReg: Dst, ToReg: NewShift); |
2731 | else |
2732 | Builder.buildTrunc(Res: Dst, Op: NewShift); |
2733 | |
2734 | eraseInst(MI); |
2735 | } |
2736 | |
2737 | bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) const { |
2738 | return any_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) { |
2739 | return MO.isReg() && |
2740 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
2741 | }); |
2742 | } |
2743 | |
2744 | bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) const { |
2745 | return all_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) { |
2746 | return !MO.isReg() || |
2747 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
2748 | }); |
2749 | } |
2750 | |
2751 | bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) const { |
2752 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR); |
2753 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
2754 | return all_of(Range&: Mask, P: [](int Elt) { return Elt < 0; }); |
2755 | } |
2756 | |
2757 | bool CombinerHelper::matchUndefStore(MachineInstr &MI) const { |
2758 | assert(MI.getOpcode() == TargetOpcode::G_STORE); |
2759 | return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 0).getReg(), |
2760 | MRI); |
2761 | } |
2762 | |
2763 | bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) const { |
2764 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
2765 | return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 1).getReg(), |
2766 | MRI); |
2767 | } |
2768 | |
2769 | bool CombinerHelper::( |
2770 | MachineInstr &MI) const { |
2771 | assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT || |
2772 | MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) && |
2773 | "Expected an insert/extract element op" ); |
2774 | LLT VecTy = MRI.getType(Reg: MI.getOperand(i: 1).getReg()); |
2775 | if (VecTy.isScalableVector()) |
2776 | return false; |
2777 | |
2778 | unsigned IdxIdx = |
2779 | MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3; |
2780 | auto Idx = getIConstantVRegVal(VReg: MI.getOperand(i: IdxIdx).getReg(), MRI); |
2781 | if (!Idx) |
2782 | return false; |
2783 | return Idx->getZExtValue() >= VecTy.getNumElements(); |
2784 | } |
2785 | |
2786 | bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI, |
2787 | unsigned &OpIdx) const { |
2788 | GSelect &SelMI = cast<GSelect>(Val&: MI); |
2789 | auto Cst = |
2790 | isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: SelMI.getCondReg()), MRI); |
2791 | if (!Cst) |
2792 | return false; |
2793 | OpIdx = Cst->isZero() ? 3 : 2; |
2794 | return true; |
2795 | } |
2796 | |
2797 | void CombinerHelper::eraseInst(MachineInstr &MI) const { MI.eraseFromParent(); } |
2798 | |
2799 | bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1, |
2800 | const MachineOperand &MOP2) const { |
2801 | if (!MOP1.isReg() || !MOP2.isReg()) |
2802 | return false; |
2803 | auto InstAndDef1 = getDefSrcRegIgnoringCopies(Reg: MOP1.getReg(), MRI); |
2804 | if (!InstAndDef1) |
2805 | return false; |
2806 | auto InstAndDef2 = getDefSrcRegIgnoringCopies(Reg: MOP2.getReg(), MRI); |
2807 | if (!InstAndDef2) |
2808 | return false; |
2809 | MachineInstr *I1 = InstAndDef1->MI; |
2810 | MachineInstr *I2 = InstAndDef2->MI; |
2811 | |
2812 | // Handle a case like this: |
2813 | // |
2814 | // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>) |
2815 | // |
2816 | // Even though %0 and %1 are produced by the same instruction they are not |
2817 | // the same values. |
2818 | if (I1 == I2) |
2819 | return MOP1.getReg() == MOP2.getReg(); |
2820 | |
2821 | // If we have an instruction which loads or stores, we can't guarantee that |
2822 | // it is identical. |
2823 | // |
2824 | // For example, we may have |
2825 | // |
2826 | // %x1 = G_LOAD %addr (load N from @somewhere) |
2827 | // ... |
2828 | // call @foo |
2829 | // ... |
2830 | // %x2 = G_LOAD %addr (load N from @somewhere) |
2831 | // ... |
2832 | // %or = G_OR %x1, %x2 |
2833 | // |
2834 | // It's possible that @foo will modify whatever lives at the address we're |
2835 | // loading from. To be safe, let's just assume that all loads and stores |
2836 | // are different (unless we have something which is guaranteed to not |
2837 | // change.) |
2838 | if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad()) |
2839 | return false; |
2840 | |
2841 | // If both instructions are loads or stores, they are equal only if both |
2842 | // are dereferenceable invariant loads with the same number of bits. |
2843 | if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) { |
2844 | GLoadStore *LS1 = dyn_cast<GLoadStore>(Val: I1); |
2845 | GLoadStore *LS2 = dyn_cast<GLoadStore>(Val: I2); |
2846 | if (!LS1 || !LS2) |
2847 | return false; |
2848 | |
2849 | if (!I2->isDereferenceableInvariantLoad() || |
2850 | (LS1->getMemSizeInBits() != LS2->getMemSizeInBits())) |
2851 | return false; |
2852 | } |
2853 | |
2854 | // Check for physical registers on the instructions first to avoid cases |
2855 | // like this: |
2856 | // |
2857 | // %a = COPY $physreg |
2858 | // ... |
2859 | // SOMETHING implicit-def $physreg |
2860 | // ... |
2861 | // %b = COPY $physreg |
2862 | // |
2863 | // These copies are not equivalent. |
2864 | if (any_of(Range: I1->uses(), P: [](const MachineOperand &MO) { |
2865 | return MO.isReg() && MO.getReg().isPhysical(); |
2866 | })) { |
2867 | // Check if we have a case like this: |
2868 | // |
2869 | // %a = COPY $physreg |
2870 | // %b = COPY %a |
2871 | // |
2872 | // In this case, I1 and I2 will both be equal to %a = COPY $physreg. |
2873 | // From that, we know that they must have the same value, since they must |
2874 | // have come from the same COPY. |
2875 | return I1->isIdenticalTo(Other: *I2); |
2876 | } |
2877 | |
2878 | // We don't have any physical registers, so we don't necessarily need the |
2879 | // same vreg defs. |
2880 | // |
2881 | // On the off-chance that there's some target instruction feeding into the |
2882 | // instruction, let's use produceSameValue instead of isIdenticalTo. |
2883 | if (Builder.getTII().produceSameValue(MI0: *I1, MI1: *I2, MRI: &MRI)) { |
2884 | // Handle instructions with multiple defs that produce same values. Values |
2885 | // are same for operands with same index. |
2886 | // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) |
2887 | // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) |
2888 | // I1 and I2 are different instructions but produce same values, |
2889 | // %1 and %6 are same, %1 and %7 are not the same value. |
2890 | return I1->findRegisterDefOperandIdx(Reg: InstAndDef1->Reg, /*TRI=*/nullptr) == |
2891 | I2->findRegisterDefOperandIdx(Reg: InstAndDef2->Reg, /*TRI=*/nullptr); |
2892 | } |
2893 | return false; |
2894 | } |
2895 | |
2896 | bool CombinerHelper::matchConstantOp(const MachineOperand &MOP, |
2897 | int64_t C) const { |
2898 | if (!MOP.isReg()) |
2899 | return false; |
2900 | auto *MI = MRI.getVRegDef(Reg: MOP.getReg()); |
2901 | auto MaybeCst = isConstantOrConstantSplatVector(MI&: *MI, MRI); |
2902 | return MaybeCst && MaybeCst->getBitWidth() <= 64 && |
2903 | MaybeCst->getSExtValue() == C; |
2904 | } |
2905 | |
2906 | bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP, |
2907 | double C) const { |
2908 | if (!MOP.isReg()) |
2909 | return false; |
2910 | std::optional<FPValueAndVReg> MaybeCst; |
2911 | if (!mi_match(R: MOP.getReg(), MRI, P: m_GFCstOrSplat(FPValReg&: MaybeCst))) |
2912 | return false; |
2913 | |
2914 | return MaybeCst->Value.isExactlyValue(V: C); |
2915 | } |
2916 | |
2917 | void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI, |
2918 | unsigned OpIdx) const { |
2919 | assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?" ); |
2920 | Register OldReg = MI.getOperand(i: 0).getReg(); |
2921 | Register Replacement = MI.getOperand(i: OpIdx).getReg(); |
2922 | assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?" ); |
2923 | replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement); |
2924 | MI.eraseFromParent(); |
2925 | } |
2926 | |
2927 | void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI, |
2928 | Register Replacement) const { |
2929 | assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?" ); |
2930 | Register OldReg = MI.getOperand(i: 0).getReg(); |
2931 | assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?" ); |
2932 | replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement); |
2933 | MI.eraseFromParent(); |
2934 | } |
2935 | |
2936 | bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI, |
2937 | unsigned ConstIdx) const { |
2938 | Register ConstReg = MI.getOperand(i: ConstIdx).getReg(); |
2939 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2940 | |
2941 | // Get the shift amount |
2942 | auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI); |
2943 | if (!VRegAndVal) |
2944 | return false; |
2945 | |
2946 | // Return true of shift amount >= Bitwidth |
2947 | return (VRegAndVal->Value.uge(RHS: DstTy.getSizeInBits())); |
2948 | } |
2949 | |
2950 | void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) const { |
2951 | assert((MI.getOpcode() == TargetOpcode::G_FSHL || |
2952 | MI.getOpcode() == TargetOpcode::G_FSHR) && |
2953 | "This is not a funnel shift operation" ); |
2954 | |
2955 | Register ConstReg = MI.getOperand(i: 3).getReg(); |
2956 | LLT ConstTy = MRI.getType(Reg: ConstReg); |
2957 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2958 | |
2959 | auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI); |
2960 | assert((VRegAndVal) && "Value is not a constant" ); |
2961 | |
2962 | // Calculate the new Shift Amount = Old Shift Amount % BitWidth |
2963 | APInt NewConst = VRegAndVal->Value.urem( |
2964 | RHS: APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits())); |
2965 | |
2966 | auto NewConstInstr = Builder.buildConstant(Res: ConstTy, Val: NewConst.getZExtValue()); |
2967 | Builder.buildInstr( |
2968 | Opc: MI.getOpcode(), DstOps: {MI.getOperand(i: 0)}, |
2969 | SrcOps: {MI.getOperand(i: 1), MI.getOperand(i: 2), NewConstInstr.getReg(Idx: 0)}); |
2970 | |
2971 | MI.eraseFromParent(); |
2972 | } |
2973 | |
2974 | bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) const { |
2975 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
2976 | // Match (cond ? x : x) |
2977 | return matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: MI.getOperand(i: 3)) && |
2978 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 2).getReg(), |
2979 | MRI); |
2980 | } |
2981 | |
2982 | bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) const { |
2983 | return matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: MI.getOperand(i: 2)) && |
2984 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 1).getReg(), |
2985 | MRI); |
2986 | } |
2987 | |
2988 | bool CombinerHelper::matchOperandIsZero(MachineInstr &MI, |
2989 | unsigned OpIdx) const { |
2990 | return matchConstantOp(MOP: MI.getOperand(i: OpIdx), C: 0) && |
2991 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: OpIdx).getReg(), |
2992 | MRI); |
2993 | } |
2994 | |
2995 | bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI, |
2996 | unsigned OpIdx) const { |
2997 | MachineOperand &MO = MI.getOperand(i: OpIdx); |
2998 | return MO.isReg() && |
2999 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
3000 | } |
3001 | |
3002 | bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI, |
3003 | unsigned OpIdx) const { |
3004 | MachineOperand &MO = MI.getOperand(i: OpIdx); |
3005 | return isKnownToBeAPowerOfTwo(Val: MO.getReg(), MRI, ValueTracking: VT); |
3006 | } |
3007 | |
3008 | void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, |
3009 | double C) const { |
3010 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
3011 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: C); |
3012 | MI.eraseFromParent(); |
3013 | } |
3014 | |
3015 | void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, |
3016 | int64_t C) const { |
3017 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
3018 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C); |
3019 | MI.eraseFromParent(); |
3020 | } |
3021 | |
3022 | void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) const { |
3023 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
3024 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C); |
3025 | MI.eraseFromParent(); |
3026 | } |
3027 | |
3028 | void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, |
3029 | ConstantFP *CFP) const { |
3030 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
3031 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: CFP->getValueAPF()); |
3032 | MI.eraseFromParent(); |
3033 | } |
3034 | |
3035 | void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) const { |
3036 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
3037 | Builder.buildUndef(Res: MI.getOperand(i: 0)); |
3038 | MI.eraseFromParent(); |
3039 | } |
3040 | |
3041 | bool CombinerHelper::matchSimplifyAddToSub( |
3042 | MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const { |
3043 | Register LHS = MI.getOperand(i: 1).getReg(); |
3044 | Register RHS = MI.getOperand(i: 2).getReg(); |
3045 | Register &NewLHS = std::get<0>(t&: MatchInfo); |
3046 | Register &NewRHS = std::get<1>(t&: MatchInfo); |
3047 | |
3048 | // Helper lambda to check for opportunities for |
3049 | // ((0-A) + B) -> B - A |
3050 | // (A + (0-B)) -> A - B |
3051 | auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) { |
3052 | if (!mi_match(R: MaybeSub, MRI, P: m_Neg(Src: m_Reg(R&: NewRHS)))) |
3053 | return false; |
3054 | NewLHS = MaybeNewLHS; |
3055 | return true; |
3056 | }; |
3057 | |
3058 | return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); |
3059 | } |
3060 | |
3061 | bool CombinerHelper::matchCombineInsertVecElts( |
3062 | MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const { |
3063 | assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT && |
3064 | "Invalid opcode" ); |
3065 | Register DstReg = MI.getOperand(i: 0).getReg(); |
3066 | LLT DstTy = MRI.getType(Reg: DstReg); |
3067 | assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?" ); |
3068 | |
3069 | if (DstTy.isScalableVector()) |
3070 | return false; |
3071 | |
3072 | unsigned NumElts = DstTy.getNumElements(); |
3073 | // If this MI is part of a sequence of insert_vec_elts, then |
3074 | // don't do the combine in the middle of the sequence. |
3075 | if (MRI.hasOneUse(RegNo: DstReg) && MRI.use_instr_begin(RegNo: DstReg)->getOpcode() == |
3076 | TargetOpcode::G_INSERT_VECTOR_ELT) |
3077 | return false; |
3078 | MachineInstr *CurrInst = &MI; |
3079 | MachineInstr *TmpInst; |
3080 | int64_t IntImm; |
3081 | Register TmpReg; |
3082 | MatchInfo.resize(N: NumElts); |
3083 | while (mi_match( |
3084 | R: CurrInst->getOperand(i: 0).getReg(), MRI, |
3085 | P: m_GInsertVecElt(Src0: m_MInstr(MI&: TmpInst), Src1: m_Reg(R&: TmpReg), Src2: m_ICst(Cst&: IntImm)))) { |
3086 | if (IntImm >= NumElts || IntImm < 0) |
3087 | return false; |
3088 | if (!MatchInfo[IntImm]) |
3089 | MatchInfo[IntImm] = TmpReg; |
3090 | CurrInst = TmpInst; |
3091 | } |
3092 | // Variable index. |
3093 | if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT) |
3094 | return false; |
3095 | if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) { |
3096 | for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) { |
3097 | if (!MatchInfo[I - 1].isValid()) |
3098 | MatchInfo[I - 1] = TmpInst->getOperand(i: I).getReg(); |
3099 | } |
3100 | return true; |
3101 | } |
3102 | // If we didn't end in a G_IMPLICIT_DEF and the source is not fully |
3103 | // overwritten, bail out. |
3104 | return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF || |
3105 | all_of(Range&: MatchInfo, P: [](Register Reg) { return !!Reg; }); |
3106 | } |
3107 | |
3108 | void CombinerHelper::applyCombineInsertVecElts( |
3109 | MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const { |
3110 | Register UndefReg; |
3111 | auto GetUndef = [&]() { |
3112 | if (UndefReg) |
3113 | return UndefReg; |
3114 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
3115 | UndefReg = Builder.buildUndef(Res: DstTy.getScalarType()).getReg(Idx: 0); |
3116 | return UndefReg; |
3117 | }; |
3118 | for (Register &Reg : MatchInfo) { |
3119 | if (!Reg) |
3120 | Reg = GetUndef(); |
3121 | } |
3122 | Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: MatchInfo); |
3123 | MI.eraseFromParent(); |
3124 | } |
3125 | |
3126 | void CombinerHelper::applySimplifyAddToSub( |
3127 | MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const { |
3128 | Register SubLHS, SubRHS; |
3129 | std::tie(args&: SubLHS, args&: SubRHS) = MatchInfo; |
3130 | Builder.buildSub(Dst: MI.getOperand(i: 0).getReg(), Src0: SubLHS, Src1: SubRHS); |
3131 | MI.eraseFromParent(); |
3132 | } |
3133 | |
3134 | bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands( |
3135 | MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const { |
3136 | // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ... |
3137 | // |
3138 | // Creates the new hand + logic instruction (but does not insert them.) |
3139 | // |
3140 | // On success, MatchInfo is populated with the new instructions. These are |
3141 | // inserted in applyHoistLogicOpWithSameOpcodeHands. |
3142 | unsigned LogicOpcode = MI.getOpcode(); |
3143 | assert(LogicOpcode == TargetOpcode::G_AND || |
3144 | LogicOpcode == TargetOpcode::G_OR || |
3145 | LogicOpcode == TargetOpcode::G_XOR); |
3146 | MachineIRBuilder MIB(MI); |
3147 | Register Dst = MI.getOperand(i: 0).getReg(); |
3148 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
3149 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
3150 | |
3151 | // Don't recompute anything. |
3152 | if (!MRI.hasOneNonDBGUse(RegNo: LHSReg) || !MRI.hasOneNonDBGUse(RegNo: RHSReg)) |
3153 | return false; |
3154 | |
3155 | // Make sure we have (hand x, ...), (hand y, ...) |
3156 | MachineInstr *LeftHandInst = getDefIgnoringCopies(Reg: LHSReg, MRI); |
3157 | MachineInstr *RightHandInst = getDefIgnoringCopies(Reg: RHSReg, MRI); |
3158 | if (!LeftHandInst || !RightHandInst) |
3159 | return false; |
3160 | unsigned HandOpcode = LeftHandInst->getOpcode(); |
3161 | if (HandOpcode != RightHandInst->getOpcode()) |
3162 | return false; |
3163 | if (LeftHandInst->getNumOperands() < 2 || |
3164 | !LeftHandInst->getOperand(i: 1).isReg() || |
3165 | RightHandInst->getNumOperands() < 2 || |
3166 | !RightHandInst->getOperand(i: 1).isReg()) |
3167 | return false; |
3168 | |
3169 | // Make sure the types match up, and if we're doing this post-legalization, |
3170 | // we end up with legal types. |
3171 | Register X = LeftHandInst->getOperand(i: 1).getReg(); |
3172 | Register Y = RightHandInst->getOperand(i: 1).getReg(); |
3173 | LLT XTy = MRI.getType(Reg: X); |
3174 | LLT YTy = MRI.getType(Reg: Y); |
3175 | if (!XTy.isValid() || XTy != YTy) |
3176 | return false; |
3177 | |
3178 | // Optional extra source register. |
3179 | Register ExtraHandOpSrcReg; |
3180 | switch (HandOpcode) { |
3181 | default: |
3182 | return false; |
3183 | case TargetOpcode::G_ANYEXT: |
3184 | case TargetOpcode::G_SEXT: |
3185 | case TargetOpcode::G_ZEXT: { |
3186 | // Match: logic (ext X), (ext Y) --> ext (logic X, Y) |
3187 | break; |
3188 | } |
3189 | case TargetOpcode::G_TRUNC: { |
3190 | // Match: logic (trunc X), (trunc Y) -> trunc (logic X, Y) |
3191 | const MachineFunction *MF = MI.getMF(); |
3192 | LLVMContext &Ctx = MF->getFunction().getContext(); |
3193 | |
3194 | LLT DstTy = MRI.getType(Reg: Dst); |
3195 | const TargetLowering &TLI = getTargetLowering(); |
3196 | |
3197 | // Be extra careful sinking truncate. If it's free, there's no benefit in |
3198 | // widening a binop. |
3199 | if (TLI.isZExtFree(FromTy: DstTy, ToTy: XTy, Ctx) && TLI.isTruncateFree(FromTy: XTy, ToTy: DstTy, Ctx)) |
3200 | return false; |
3201 | break; |
3202 | } |
3203 | case TargetOpcode::G_AND: |
3204 | case TargetOpcode::G_ASHR: |
3205 | case TargetOpcode::G_LSHR: |
3206 | case TargetOpcode::G_SHL: { |
3207 | // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z |
3208 | MachineOperand &ZOp = LeftHandInst->getOperand(i: 2); |
3209 | if (!matchEqualDefs(MOP1: ZOp, MOP2: RightHandInst->getOperand(i: 2))) |
3210 | return false; |
3211 | ExtraHandOpSrcReg = ZOp.getReg(); |
3212 | break; |
3213 | } |
3214 | } |
3215 | |
3216 | if (!isLegalOrBeforeLegalizer(Query: {LogicOpcode, {XTy, YTy}})) |
3217 | return false; |
3218 | |
3219 | // Record the steps to build the new instructions. |
3220 | // |
3221 | // Steps to build (logic x, y) |
3222 | auto NewLogicDst = MRI.createGenericVirtualRegister(Ty: XTy); |
3223 | OperandBuildSteps LogicBuildSteps = { |
3224 | [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: NewLogicDst); }, |
3225 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: X); }, |
3226 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: Y); }}; |
3227 | InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps); |
3228 | |
3229 | // Steps to build hand (logic x, y), ...z |
3230 | OperandBuildSteps HandBuildSteps = { |
3231 | [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: Dst); }, |
3232 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: NewLogicDst); }}; |
3233 | if (ExtraHandOpSrcReg.isValid()) |
3234 | HandBuildSteps.push_back( |
3235 | Elt: [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: ExtraHandOpSrcReg); }); |
3236 | InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps); |
3237 | |
3238 | MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps}); |
3239 | return true; |
3240 | } |
3241 | |
3242 | void CombinerHelper::applyBuildInstructionSteps( |
3243 | MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const { |
3244 | assert(MatchInfo.InstrsToBuild.size() && |
3245 | "Expected at least one instr to build?" ); |
3246 | for (auto &InstrToBuild : MatchInfo.InstrsToBuild) { |
3247 | assert(InstrToBuild.Opcode && "Expected a valid opcode?" ); |
3248 | assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?" ); |
3249 | MachineInstrBuilder Instr = Builder.buildInstr(Opcode: InstrToBuild.Opcode); |
3250 | for (auto &OperandFn : InstrToBuild.OperandFns) |
3251 | OperandFn(Instr); |
3252 | } |
3253 | MI.eraseFromParent(); |
3254 | } |
3255 | |
3256 | bool CombinerHelper::matchAshrShlToSextInreg( |
3257 | MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const { |
3258 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
3259 | int64_t ShlCst, AshrCst; |
3260 | Register Src; |
3261 | if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI, |
3262 | P: m_GAShr(L: m_GShl(L: m_Reg(R&: Src), R: m_ICstOrSplat(Cst&: ShlCst)), |
3263 | R: m_ICstOrSplat(Cst&: AshrCst)))) |
3264 | return false; |
3265 | if (ShlCst != AshrCst) |
3266 | return false; |
3267 | if (!isLegalOrBeforeLegalizer( |
3268 | Query: {TargetOpcode::G_SEXT_INREG, {MRI.getType(Reg: Src)}})) |
3269 | return false; |
3270 | MatchInfo = std::make_tuple(args&: Src, args&: ShlCst); |
3271 | return true; |
3272 | } |
3273 | |
3274 | void CombinerHelper::applyAshShlToSextInreg( |
3275 | MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const { |
3276 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
3277 | Register Src; |
3278 | int64_t ShiftAmt; |
3279 | std::tie(args&: Src, args&: ShiftAmt) = MatchInfo; |
3280 | unsigned Size = MRI.getType(Reg: Src).getScalarSizeInBits(); |
3281 | Builder.buildSExtInReg(Res: MI.getOperand(i: 0).getReg(), Op: Src, ImmOp: Size - ShiftAmt); |
3282 | MI.eraseFromParent(); |
3283 | } |
3284 | |
3285 | /// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0 |
3286 | bool CombinerHelper::matchOverlappingAnd( |
3287 | MachineInstr &MI, |
3288 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
3289 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
3290 | |
3291 | Register Dst = MI.getOperand(i: 0).getReg(); |
3292 | LLT Ty = MRI.getType(Reg: Dst); |
3293 | |
3294 | Register R; |
3295 | int64_t C1; |
3296 | int64_t C2; |
3297 | if (!mi_match( |
3298 | R: Dst, MRI, |
3299 | P: m_GAnd(L: m_GAnd(L: m_Reg(R), R: m_ICst(Cst&: C1)), R: m_ICst(Cst&: C2)))) |
3300 | return false; |
3301 | |
3302 | MatchInfo = [=](MachineIRBuilder &B) { |
3303 | if (C1 & C2) { |
3304 | B.buildAnd(Dst, Src0: R, Src1: B.buildConstant(Res: Ty, Val: C1 & C2)); |
3305 | return; |
3306 | } |
3307 | auto Zero = B.buildConstant(Res: Ty, Val: 0); |
3308 | replaceRegWith(MRI, FromReg: Dst, ToReg: Zero->getOperand(i: 0).getReg()); |
3309 | }; |
3310 | return true; |
3311 | } |
3312 | |
3313 | bool CombinerHelper::matchRedundantAnd(MachineInstr &MI, |
3314 | Register &Replacement) const { |
3315 | // Given |
3316 | // |
3317 | // %y:_(sN) = G_SOMETHING |
3318 | // %x:_(sN) = G_SOMETHING |
3319 | // %res:_(sN) = G_AND %x, %y |
3320 | // |
3321 | // Eliminate the G_AND when it is known that x & y == x or x & y == y. |
3322 | // |
3323 | // Patterns like this can appear as a result of legalization. E.g. |
3324 | // |
3325 | // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y |
3326 | // %one:_(s32) = G_CONSTANT i32 1 |
3327 | // %and:_(s32) = G_AND %cmp, %one |
3328 | // |
3329 | // In this case, G_ICMP only produces a single bit, so x & 1 == x. |
3330 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
3331 | if (!VT) |
3332 | return false; |
3333 | |
3334 | Register AndDst = MI.getOperand(i: 0).getReg(); |
3335 | Register LHS = MI.getOperand(i: 1).getReg(); |
3336 | Register RHS = MI.getOperand(i: 2).getReg(); |
3337 | |
3338 | // Check the RHS (maybe a constant) first, and if we have no KnownBits there, |
3339 | // we can't do anything. If we do, then it depends on whether we have |
3340 | // KnownBits on the LHS. |
3341 | KnownBits RHSBits = VT->getKnownBits(R: RHS); |
3342 | if (RHSBits.isUnknown()) |
3343 | return false; |
3344 | |
3345 | KnownBits LHSBits = VT->getKnownBits(R: LHS); |
3346 | |
3347 | // Check that x & Mask == x. |
3348 | // x & 1 == x, always |
3349 | // x & 0 == x, only if x is also 0 |
3350 | // Meaning Mask has no effect if every bit is either one in Mask or zero in x. |
3351 | // |
3352 | // Check if we can replace AndDst with the LHS of the G_AND |
3353 | if (canReplaceReg(DstReg: AndDst, SrcReg: LHS, MRI) && |
3354 | (LHSBits.Zero | RHSBits.One).isAllOnes()) { |
3355 | Replacement = LHS; |
3356 | return true; |
3357 | } |
3358 | |
3359 | // Check if we can replace AndDst with the RHS of the G_AND |
3360 | if (canReplaceReg(DstReg: AndDst, SrcReg: RHS, MRI) && |
3361 | (LHSBits.One | RHSBits.Zero).isAllOnes()) { |
3362 | Replacement = RHS; |
3363 | return true; |
3364 | } |
3365 | |
3366 | return false; |
3367 | } |
3368 | |
3369 | bool CombinerHelper::matchRedundantOr(MachineInstr &MI, |
3370 | Register &Replacement) const { |
3371 | // Given |
3372 | // |
3373 | // %y:_(sN) = G_SOMETHING |
3374 | // %x:_(sN) = G_SOMETHING |
3375 | // %res:_(sN) = G_OR %x, %y |
3376 | // |
3377 | // Eliminate the G_OR when it is known that x | y == x or x | y == y. |
3378 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
3379 | if (!VT) |
3380 | return false; |
3381 | |
3382 | Register OrDst = MI.getOperand(i: 0).getReg(); |
3383 | Register LHS = MI.getOperand(i: 1).getReg(); |
3384 | Register RHS = MI.getOperand(i: 2).getReg(); |
3385 | |
3386 | KnownBits LHSBits = VT->getKnownBits(R: LHS); |
3387 | KnownBits RHSBits = VT->getKnownBits(R: RHS); |
3388 | |
3389 | // Check that x | Mask == x. |
3390 | // x | 0 == x, always |
3391 | // x | 1 == x, only if x is also 1 |
3392 | // Meaning Mask has no effect if every bit is either zero in Mask or one in x. |
3393 | // |
3394 | // Check if we can replace OrDst with the LHS of the G_OR |
3395 | if (canReplaceReg(DstReg: OrDst, SrcReg: LHS, MRI) && |
3396 | (LHSBits.One | RHSBits.Zero).isAllOnes()) { |
3397 | Replacement = LHS; |
3398 | return true; |
3399 | } |
3400 | |
3401 | // Check if we can replace OrDst with the RHS of the G_OR |
3402 | if (canReplaceReg(DstReg: OrDst, SrcReg: RHS, MRI) && |
3403 | (LHSBits.Zero | RHSBits.One).isAllOnes()) { |
3404 | Replacement = RHS; |
3405 | return true; |
3406 | } |
3407 | |
3408 | return false; |
3409 | } |
3410 | |
3411 | bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) const { |
3412 | // If the input is already sign extended, just drop the extension. |
3413 | Register Src = MI.getOperand(i: 1).getReg(); |
3414 | unsigned ExtBits = MI.getOperand(i: 2).getImm(); |
3415 | unsigned TypeSize = MRI.getType(Reg: Src).getScalarSizeInBits(); |
3416 | return VT->computeNumSignBits(R: Src) >= (TypeSize - ExtBits + 1); |
3417 | } |
3418 | |
3419 | static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits, |
3420 | int64_t Cst, bool IsVector, bool IsFP) { |
3421 | // For i1, Cst will always be -1 regardless of boolean contents. |
3422 | return (ScalarSizeBits == 1 && Cst == -1) || |
3423 | isConstTrueVal(TLI, Val: Cst, IsVector, IsFP); |
3424 | } |
3425 | |
3426 | // This combine tries to reduce the number of scalarised G_TRUNC instructions by |
3427 | // using vector truncates instead |
3428 | // |
3429 | // EXAMPLE: |
3430 | // %a(i32), %b(i32) = G_UNMERGE_VALUES %src(<2 x i32>) |
3431 | // %T_a(i16) = G_TRUNC %a(i32) |
3432 | // %T_b(i16) = G_TRUNC %b(i32) |
3433 | // %Undef(i16) = G_IMPLICIT_DEF(i16) |
3434 | // %dst(v4i16) = G_BUILD_VECTORS %T_a(i16), %T_b(i16), %Undef(i16), %Undef(i16) |
3435 | // |
3436 | // ===> |
3437 | // %Undef(<2 x i32>) = G_IMPLICIT_DEF(<2 x i32>) |
3438 | // %Mid(<4 x s32>) = G_CONCAT_VECTORS %src(<2 x i32>), %Undef(<2 x i32>) |
3439 | // %dst(<4 x s16>) = G_TRUNC %Mid(<4 x s32>) |
3440 | // |
3441 | // Only matches sources made up of G_TRUNCs followed by G_IMPLICIT_DEFs |
3442 | bool CombinerHelper::matchUseVectorTruncate(MachineInstr &MI, |
3443 | Register &MatchInfo) const { |
3444 | auto BuildMI = cast<GBuildVector>(Val: &MI); |
3445 | unsigned NumOperands = BuildMI->getNumSources(); |
3446 | LLT DstTy = MRI.getType(Reg: BuildMI->getReg(Idx: 0)); |
3447 | |
3448 | // Check the G_BUILD_VECTOR sources |
3449 | unsigned I; |
3450 | MachineInstr *UnmergeMI = nullptr; |
3451 | |
3452 | // Check all source TRUNCs come from the same UNMERGE instruction |
3453 | for (I = 0; I < NumOperands; ++I) { |
3454 | auto SrcMI = MRI.getVRegDef(Reg: BuildMI->getSourceReg(I)); |
3455 | auto SrcMIOpc = SrcMI->getOpcode(); |
3456 | |
3457 | // Check if the G_TRUNC instructions all come from the same MI |
3458 | if (SrcMIOpc == TargetOpcode::G_TRUNC) { |
3459 | if (!UnmergeMI) { |
3460 | UnmergeMI = MRI.getVRegDef(Reg: SrcMI->getOperand(i: 1).getReg()); |
3461 | if (UnmergeMI->getOpcode() != TargetOpcode::G_UNMERGE_VALUES) |
3462 | return false; |
3463 | } else { |
3464 | auto UnmergeSrcMI = MRI.getVRegDef(Reg: SrcMI->getOperand(i: 1).getReg()); |
3465 | if (UnmergeMI != UnmergeSrcMI) |
3466 | return false; |
3467 | } |
3468 | } else { |
3469 | break; |
3470 | } |
3471 | } |
3472 | if (I < 2) |
3473 | return false; |
3474 | |
3475 | // Check the remaining source elements are only G_IMPLICIT_DEF |
3476 | for (; I < NumOperands; ++I) { |
3477 | auto SrcMI = MRI.getVRegDef(Reg: BuildMI->getSourceReg(I)); |
3478 | auto SrcMIOpc = SrcMI->getOpcode(); |
3479 | |
3480 | if (SrcMIOpc != TargetOpcode::G_IMPLICIT_DEF) |
3481 | return false; |
3482 | } |
3483 | |
3484 | // Check the size of unmerge source |
3485 | MatchInfo = cast<GUnmerge>(Val: UnmergeMI)->getSourceReg(); |
3486 | LLT UnmergeSrcTy = MRI.getType(Reg: MatchInfo); |
3487 | if (!DstTy.getElementCount().isKnownMultipleOf(RHS: UnmergeSrcTy.getNumElements())) |
3488 | return false; |
3489 | |
3490 | // Check the unmerge source and destination element types match |
3491 | LLT UnmergeSrcEltTy = UnmergeSrcTy.getElementType(); |
3492 | Register UnmergeDstReg = UnmergeMI->getOperand(i: 0).getReg(); |
3493 | LLT UnmergeDstEltTy = MRI.getType(Reg: UnmergeDstReg); |
3494 | if (UnmergeSrcEltTy != UnmergeDstEltTy) |
3495 | return false; |
3496 | |
3497 | // Only generate legal instructions post-legalizer |
3498 | if (!IsPreLegalize) { |
3499 | LLT MidTy = DstTy.changeElementType(NewEltTy: UnmergeSrcTy.getScalarType()); |
3500 | |
3501 | if (DstTy.getElementCount() != UnmergeSrcTy.getElementCount() && |
3502 | !isLegal(Query: {TargetOpcode::G_CONCAT_VECTORS, {MidTy, UnmergeSrcTy}})) |
3503 | return false; |
3504 | |
3505 | if (!isLegal(Query: {TargetOpcode::G_TRUNC, {DstTy, MidTy}})) |
3506 | return false; |
3507 | } |
3508 | |
3509 | return true; |
3510 | } |
3511 | |
3512 | void CombinerHelper::applyUseVectorTruncate(MachineInstr &MI, |
3513 | Register &MatchInfo) const { |
3514 | Register MidReg; |
3515 | auto BuildMI = cast<GBuildVector>(Val: &MI); |
3516 | Register DstReg = BuildMI->getReg(Idx: 0); |
3517 | LLT DstTy = MRI.getType(Reg: DstReg); |
3518 | LLT UnmergeSrcTy = MRI.getType(Reg: MatchInfo); |
3519 | unsigned DstTyNumElt = DstTy.getNumElements(); |
3520 | unsigned UnmergeSrcTyNumElt = UnmergeSrcTy.getNumElements(); |
3521 | |
3522 | // No need to pad vector if only G_TRUNC is needed |
3523 | if (DstTyNumElt / UnmergeSrcTyNumElt == 1) { |
3524 | MidReg = MatchInfo; |
3525 | } else { |
3526 | Register UndefReg = Builder.buildUndef(Res: UnmergeSrcTy).getReg(Idx: 0); |
3527 | SmallVector<Register> ConcatRegs = {MatchInfo}; |
3528 | for (unsigned I = 1; I < DstTyNumElt / UnmergeSrcTyNumElt; ++I) |
3529 | ConcatRegs.push_back(Elt: UndefReg); |
3530 | |
3531 | auto MidTy = DstTy.changeElementType(NewEltTy: UnmergeSrcTy.getScalarType()); |
3532 | MidReg = Builder.buildConcatVectors(Res: MidTy, Ops: ConcatRegs).getReg(Idx: 0); |
3533 | } |
3534 | |
3535 | Builder.buildTrunc(Res: DstReg, Op: MidReg); |
3536 | MI.eraseFromParent(); |
3537 | } |
3538 | |
3539 | bool CombinerHelper::matchNotCmp( |
3540 | MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const { |
3541 | assert(MI.getOpcode() == TargetOpcode::G_XOR); |
3542 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
3543 | const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering(); |
3544 | Register XorSrc; |
3545 | Register CstReg; |
3546 | // We match xor(src, true) here. |
3547 | if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI, |
3548 | P: m_GXor(L: m_Reg(R&: XorSrc), R: m_Reg(R&: CstReg)))) |
3549 | return false; |
3550 | |
3551 | if (!MRI.hasOneNonDBGUse(RegNo: XorSrc)) |
3552 | return false; |
3553 | |
3554 | // Check that XorSrc is the root of a tree of comparisons combined with ANDs |
3555 | // and ORs. The suffix of RegsToNegate starting from index I is used a work |
3556 | // list of tree nodes to visit. |
3557 | RegsToNegate.push_back(Elt: XorSrc); |
3558 | // Remember whether the comparisons are all integer or all floating point. |
3559 | bool IsInt = false; |
3560 | bool IsFP = false; |
3561 | for (unsigned I = 0; I < RegsToNegate.size(); ++I) { |
3562 | Register Reg = RegsToNegate[I]; |
3563 | if (!MRI.hasOneNonDBGUse(RegNo: Reg)) |
3564 | return false; |
3565 | MachineInstr *Def = MRI.getVRegDef(Reg); |
3566 | switch (Def->getOpcode()) { |
3567 | default: |
3568 | // Don't match if the tree contains anything other than ANDs, ORs and |
3569 | // comparisons. |
3570 | return false; |
3571 | case TargetOpcode::G_ICMP: |
3572 | if (IsFP) |
3573 | return false; |
3574 | IsInt = true; |
3575 | // When we apply the combine we will invert the predicate. |
3576 | break; |
3577 | case TargetOpcode::G_FCMP: |
3578 | if (IsInt) |
3579 | return false; |
3580 | IsFP = true; |
3581 | // When we apply the combine we will invert the predicate. |
3582 | break; |
3583 | case TargetOpcode::G_AND: |
3584 | case TargetOpcode::G_OR: |
3585 | // Implement De Morgan's laws: |
3586 | // ~(x & y) -> ~x | ~y |
3587 | // ~(x | y) -> ~x & ~y |
3588 | // When we apply the combine we will change the opcode and recursively |
3589 | // negate the operands. |
3590 | RegsToNegate.push_back(Elt: Def->getOperand(i: 1).getReg()); |
3591 | RegsToNegate.push_back(Elt: Def->getOperand(i: 2).getReg()); |
3592 | break; |
3593 | } |
3594 | } |
3595 | |
3596 | // Now we know whether the comparisons are integer or floating point, check |
3597 | // the constant in the xor. |
3598 | int64_t Cst; |
3599 | if (Ty.isVector()) { |
3600 | MachineInstr *CstDef = MRI.getVRegDef(Reg: CstReg); |
3601 | auto MaybeCst = getIConstantSplatSExtVal(MI: *CstDef, MRI); |
3602 | if (!MaybeCst) |
3603 | return false; |
3604 | if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getScalarSizeInBits(), Cst: *MaybeCst, IsVector: true, IsFP)) |
3605 | return false; |
3606 | } else { |
3607 | if (!mi_match(R: CstReg, MRI, P: m_ICst(Cst))) |
3608 | return false; |
3609 | if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getSizeInBits(), Cst, IsVector: false, IsFP)) |
3610 | return false; |
3611 | } |
3612 | |
3613 | return true; |
3614 | } |
3615 | |
3616 | void CombinerHelper::applyNotCmp( |
3617 | MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const { |
3618 | for (Register Reg : RegsToNegate) { |
3619 | MachineInstr *Def = MRI.getVRegDef(Reg); |
3620 | Observer.changingInstr(MI&: *Def); |
3621 | // For each comparison, invert the opcode. For each AND and OR, change the |
3622 | // opcode. |
3623 | switch (Def->getOpcode()) { |
3624 | default: |
3625 | llvm_unreachable("Unexpected opcode" ); |
3626 | case TargetOpcode::G_ICMP: |
3627 | case TargetOpcode::G_FCMP: { |
3628 | MachineOperand &PredOp = Def->getOperand(i: 1); |
3629 | CmpInst::Predicate NewP = CmpInst::getInversePredicate( |
3630 | pred: (CmpInst::Predicate)PredOp.getPredicate()); |
3631 | PredOp.setPredicate(NewP); |
3632 | break; |
3633 | } |
3634 | case TargetOpcode::G_AND: |
3635 | Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_OR)); |
3636 | break; |
3637 | case TargetOpcode::G_OR: |
3638 | Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND)); |
3639 | break; |
3640 | } |
3641 | Observer.changedInstr(MI&: *Def); |
3642 | } |
3643 | |
3644 | replaceRegWith(MRI, FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 1).getReg()); |
3645 | MI.eraseFromParent(); |
3646 | } |
3647 | |
3648 | bool CombinerHelper::matchXorOfAndWithSameReg( |
3649 | MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const { |
3650 | // Match (xor (and x, y), y) (or any of its commuted cases) |
3651 | assert(MI.getOpcode() == TargetOpcode::G_XOR); |
3652 | Register &X = MatchInfo.first; |
3653 | Register &Y = MatchInfo.second; |
3654 | Register AndReg = MI.getOperand(i: 1).getReg(); |
3655 | Register SharedReg = MI.getOperand(i: 2).getReg(); |
3656 | |
3657 | // Find a G_AND on either side of the G_XOR. |
3658 | // Look for one of |
3659 | // |
3660 | // (xor (and x, y), SharedReg) |
3661 | // (xor SharedReg, (and x, y)) |
3662 | if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) { |
3663 | std::swap(a&: AndReg, b&: SharedReg); |
3664 | if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) |
3665 | return false; |
3666 | } |
3667 | |
3668 | // Only do this if we'll eliminate the G_AND. |
3669 | if (!MRI.hasOneNonDBGUse(RegNo: AndReg)) |
3670 | return false; |
3671 | |
3672 | // We can combine if SharedReg is the same as either the LHS or RHS of the |
3673 | // G_AND. |
3674 | if (Y != SharedReg) |
3675 | std::swap(a&: X, b&: Y); |
3676 | return Y == SharedReg; |
3677 | } |
3678 | |
3679 | void CombinerHelper::applyXorOfAndWithSameReg( |
3680 | MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const { |
3681 | // Fold (xor (and x, y), y) -> (and (not x), y) |
3682 | Register X, Y; |
3683 | std::tie(args&: X, args&: Y) = MatchInfo; |
3684 | auto Not = Builder.buildNot(Dst: MRI.getType(Reg: X), Src0: X); |
3685 | Observer.changingInstr(MI); |
3686 | MI.setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND)); |
3687 | MI.getOperand(i: 1).setReg(Not->getOperand(i: 0).getReg()); |
3688 | MI.getOperand(i: 2).setReg(Y); |
3689 | Observer.changedInstr(MI); |
3690 | } |
3691 | |
3692 | bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) const { |
3693 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
3694 | Register DstReg = PtrAdd.getReg(Idx: 0); |
3695 | LLT Ty = MRI.getType(Reg: DstReg); |
3696 | const DataLayout &DL = Builder.getMF().getDataLayout(); |
3697 | |
3698 | if (DL.isNonIntegralAddressSpace(AddrSpace: Ty.getScalarType().getAddressSpace())) |
3699 | return false; |
3700 | |
3701 | if (Ty.isPointer()) { |
3702 | auto ConstVal = getIConstantVRegVal(VReg: PtrAdd.getBaseReg(), MRI); |
3703 | return ConstVal && *ConstVal == 0; |
3704 | } |
3705 | |
3706 | assert(Ty.isVector() && "Expecting a vector type" ); |
3707 | const MachineInstr *VecMI = MRI.getVRegDef(Reg: PtrAdd.getBaseReg()); |
3708 | return isBuildVectorAllZeros(MI: *VecMI, MRI); |
3709 | } |
3710 | |
3711 | void CombinerHelper::applyPtrAddZero(MachineInstr &MI) const { |
3712 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
3713 | Builder.buildIntToPtr(Dst: PtrAdd.getReg(Idx: 0), Src: PtrAdd.getOffsetReg()); |
3714 | PtrAdd.eraseFromParent(); |
3715 | } |
3716 | |
3717 | /// The second source operand is known to be a power of 2. |
3718 | void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) const { |
3719 | Register DstReg = MI.getOperand(i: 0).getReg(); |
3720 | Register Src0 = MI.getOperand(i: 1).getReg(); |
3721 | Register Pow2Src1 = MI.getOperand(i: 2).getReg(); |
3722 | LLT Ty = MRI.getType(Reg: DstReg); |
3723 | |
3724 | // Fold (urem x, pow2) -> (and x, pow2-1) |
3725 | auto NegOne = Builder.buildConstant(Res: Ty, Val: -1); |
3726 | auto Add = Builder.buildAdd(Dst: Ty, Src0: Pow2Src1, Src1: NegOne); |
3727 | Builder.buildAnd(Dst: DstReg, Src0, Src1: Add); |
3728 | MI.eraseFromParent(); |
3729 | } |
3730 | |
3731 | bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI, |
3732 | unsigned &SelectOpNo) const { |
3733 | Register LHS = MI.getOperand(i: 1).getReg(); |
3734 | Register RHS = MI.getOperand(i: 2).getReg(); |
3735 | |
3736 | Register OtherOperandReg = RHS; |
3737 | SelectOpNo = 1; |
3738 | MachineInstr *Select = MRI.getVRegDef(Reg: LHS); |
3739 | |
3740 | // Don't do this unless the old select is going away. We want to eliminate the |
3741 | // binary operator, not replace a binop with a select. |
3742 | if (Select->getOpcode() != TargetOpcode::G_SELECT || |
3743 | !MRI.hasOneNonDBGUse(RegNo: LHS)) { |
3744 | OtherOperandReg = LHS; |
3745 | SelectOpNo = 2; |
3746 | Select = MRI.getVRegDef(Reg: RHS); |
3747 | if (Select->getOpcode() != TargetOpcode::G_SELECT || |
3748 | !MRI.hasOneNonDBGUse(RegNo: RHS)) |
3749 | return false; |
3750 | } |
3751 | |
3752 | MachineInstr *SelectLHS = MRI.getVRegDef(Reg: Select->getOperand(i: 2).getReg()); |
3753 | MachineInstr *SelectRHS = MRI.getVRegDef(Reg: Select->getOperand(i: 3).getReg()); |
3754 | |
3755 | if (!isConstantOrConstantVector(MI: *SelectLHS, MRI, |
3756 | /*AllowFP*/ true, |
3757 | /*AllowOpaqueConstants*/ false)) |
3758 | return false; |
3759 | if (!isConstantOrConstantVector(MI: *SelectRHS, MRI, |
3760 | /*AllowFP*/ true, |
3761 | /*AllowOpaqueConstants*/ false)) |
3762 | return false; |
3763 | |
3764 | unsigned BinOpcode = MI.getOpcode(); |
3765 | |
3766 | // We know that one of the operands is a select of constants. Now verify that |
3767 | // the other binary operator operand is either a constant, or we can handle a |
3768 | // variable. |
3769 | bool CanFoldNonConst = |
3770 | (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) && |
3771 | (isNullOrNullSplat(MI: *SelectLHS, MRI) || |
3772 | isAllOnesOrAllOnesSplat(MI: *SelectLHS, MRI)) && |
3773 | (isNullOrNullSplat(MI: *SelectRHS, MRI) || |
3774 | isAllOnesOrAllOnesSplat(MI: *SelectRHS, MRI)); |
3775 | if (CanFoldNonConst) |
3776 | return true; |
3777 | |
3778 | return isConstantOrConstantVector(MI: *MRI.getVRegDef(Reg: OtherOperandReg), MRI, |
3779 | /*AllowFP*/ true, |
3780 | /*AllowOpaqueConstants*/ false); |
3781 | } |
3782 | |
3783 | /// \p SelectOperand is the operand in binary operator \p MI that is the select |
3784 | /// to fold. |
3785 | void CombinerHelper::applyFoldBinOpIntoSelect( |
3786 | MachineInstr &MI, const unsigned &SelectOperand) const { |
3787 | Register Dst = MI.getOperand(i: 0).getReg(); |
3788 | Register LHS = MI.getOperand(i: 1).getReg(); |
3789 | Register RHS = MI.getOperand(i: 2).getReg(); |
3790 | MachineInstr *Select = MRI.getVRegDef(Reg: MI.getOperand(i: SelectOperand).getReg()); |
3791 | |
3792 | Register SelectCond = Select->getOperand(i: 1).getReg(); |
3793 | Register SelectTrue = Select->getOperand(i: 2).getReg(); |
3794 | Register SelectFalse = Select->getOperand(i: 3).getReg(); |
3795 | |
3796 | LLT Ty = MRI.getType(Reg: Dst); |
3797 | unsigned BinOpcode = MI.getOpcode(); |
3798 | |
3799 | Register FoldTrue, FoldFalse; |
3800 | |
3801 | // We have a select-of-constants followed by a binary operator with a |
3802 | // constant. Eliminate the binop by pulling the constant math into the select. |
3803 | // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO |
3804 | if (SelectOperand == 1) { |
3805 | // TODO: SelectionDAG verifies this actually constant folds before |
3806 | // committing to the combine. |
3807 | |
3808 | FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectTrue, RHS}).getReg(Idx: 0); |
3809 | FoldFalse = |
3810 | Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectFalse, RHS}).getReg(Idx: 0); |
3811 | } else { |
3812 | FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectTrue}).getReg(Idx: 0); |
3813 | FoldFalse = |
3814 | Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectFalse}).getReg(Idx: 0); |
3815 | } |
3816 | |
3817 | Builder.buildSelect(Res: Dst, Tst: SelectCond, Op0: FoldTrue, Op1: FoldFalse, Flags: MI.getFlags()); |
3818 | MI.eraseFromParent(); |
3819 | } |
3820 | |
3821 | std::optional<SmallVector<Register, 8>> |
3822 | CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const { |
3823 | assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!" ); |
3824 | // We want to detect if Root is part of a tree which represents a bunch |
3825 | // of loads being merged into a larger load. We'll try to recognize patterns |
3826 | // like, for example: |
3827 | // |
3828 | // Reg Reg |
3829 | // \ / |
3830 | // OR_1 Reg |
3831 | // \ / |
3832 | // OR_2 |
3833 | // \ Reg |
3834 | // .. / |
3835 | // Root |
3836 | // |
3837 | // Reg Reg Reg Reg |
3838 | // \ / \ / |
3839 | // OR_1 OR_2 |
3840 | // \ / |
3841 | // \ / |
3842 | // ... |
3843 | // Root |
3844 | // |
3845 | // Each "Reg" may have been produced by a load + some arithmetic. This |
3846 | // function will save each of them. |
3847 | SmallVector<Register, 8> RegsToVisit; |
3848 | SmallVector<const MachineInstr *, 7> Ors = {Root}; |
3849 | |
3850 | // In the "worst" case, we're dealing with a load for each byte. So, there |
3851 | // are at most #bytes - 1 ORs. |
3852 | const unsigned MaxIter = |
3853 | MRI.getType(Reg: Root->getOperand(i: 0).getReg()).getSizeInBytes() - 1; |
3854 | for (unsigned Iter = 0; Iter < MaxIter; ++Iter) { |
3855 | if (Ors.empty()) |
3856 | break; |
3857 | const MachineInstr *Curr = Ors.pop_back_val(); |
3858 | Register OrLHS = Curr->getOperand(i: 1).getReg(); |
3859 | Register OrRHS = Curr->getOperand(i: 2).getReg(); |
3860 | |
3861 | // In the combine, we want to elimate the entire tree. |
3862 | if (!MRI.hasOneNonDBGUse(RegNo: OrLHS) || !MRI.hasOneNonDBGUse(RegNo: OrRHS)) |
3863 | return std::nullopt; |
3864 | |
3865 | // If it's a G_OR, save it and continue to walk. If it's not, then it's |
3866 | // something that may be a load + arithmetic. |
3867 | if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrLHS, MRI)) |
3868 | Ors.push_back(Elt: Or); |
3869 | else |
3870 | RegsToVisit.push_back(Elt: OrLHS); |
3871 | if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrRHS, MRI)) |
3872 | Ors.push_back(Elt: Or); |
3873 | else |
3874 | RegsToVisit.push_back(Elt: OrRHS); |
3875 | } |
3876 | |
3877 | // We're going to try and merge each register into a wider power-of-2 type, |
3878 | // so we ought to have an even number of registers. |
3879 | if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0) |
3880 | return std::nullopt; |
3881 | return RegsToVisit; |
3882 | } |
3883 | |
3884 | /// Helper function for findLoadOffsetsForLoadOrCombine. |
3885 | /// |
3886 | /// Check if \p Reg is the result of loading a \p MemSizeInBits wide value, |
3887 | /// and then moving that value into a specific byte offset. |
3888 | /// |
3889 | /// e.g. x[i] << 24 |
3890 | /// |
3891 | /// \returns The load instruction and the byte offset it is moved into. |
3892 | static std::optional<std::pair<GZExtLoad *, int64_t>> |
3893 | matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits, |
3894 | const MachineRegisterInfo &MRI) { |
3895 | assert(MRI.hasOneNonDBGUse(Reg) && |
3896 | "Expected Reg to only have one non-debug use?" ); |
3897 | Register MaybeLoad; |
3898 | int64_t Shift; |
3899 | if (!mi_match(R: Reg, MRI, |
3900 | P: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: MaybeLoad), R: m_ICst(Cst&: Shift))))) { |
3901 | Shift = 0; |
3902 | MaybeLoad = Reg; |
3903 | } |
3904 | |
3905 | if (Shift % MemSizeInBits != 0) |
3906 | return std::nullopt; |
3907 | |
3908 | // TODO: Handle other types of loads. |
3909 | auto *Load = getOpcodeDef<GZExtLoad>(Reg: MaybeLoad, MRI); |
3910 | if (!Load) |
3911 | return std::nullopt; |
3912 | |
3913 | if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits) |
3914 | return std::nullopt; |
3915 | |
3916 | return std::make_pair(x&: Load, y: Shift / MemSizeInBits); |
3917 | } |
3918 | |
3919 | std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>> |
3920 | CombinerHelper::findLoadOffsetsForLoadOrCombine( |
3921 | SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, |
3922 | const SmallVector<Register, 8> &RegsToVisit, |
3923 | const unsigned MemSizeInBits) const { |
3924 | |
3925 | // Each load found for the pattern. There should be one for each RegsToVisit. |
3926 | SmallSetVector<const MachineInstr *, 8> Loads; |
3927 | |
3928 | // The lowest index used in any load. (The lowest "i" for each x[i].) |
3929 | int64_t LowestIdx = INT64_MAX; |
3930 | |
3931 | // The load which uses the lowest index. |
3932 | GZExtLoad *LowestIdxLoad = nullptr; |
3933 | |
3934 | // Keeps track of the load indices we see. We shouldn't see any indices twice. |
3935 | SmallSet<int64_t, 8> SeenIdx; |
3936 | |
3937 | // Ensure each load is in the same MBB. |
3938 | // TODO: Support multiple MachineBasicBlocks. |
3939 | MachineBasicBlock *MBB = nullptr; |
3940 | const MachineMemOperand *MMO = nullptr; |
3941 | |
3942 | // Earliest instruction-order load in the pattern. |
3943 | GZExtLoad *EarliestLoad = nullptr; |
3944 | |
3945 | // Latest instruction-order load in the pattern. |
3946 | GZExtLoad *LatestLoad = nullptr; |
3947 | |
3948 | // Base pointer which every load should share. |
3949 | Register BasePtr; |
3950 | |
3951 | // We want to find a load for each register. Each load should have some |
3952 | // appropriate bit twiddling arithmetic. During this loop, we will also keep |
3953 | // track of the load which uses the lowest index. Later, we will check if we |
3954 | // can use its pointer in the final, combined load. |
3955 | for (auto Reg : RegsToVisit) { |
3956 | // Find the load, and find the position that it will end up in (e.g. a |
3957 | // shifted) value. |
3958 | auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI); |
3959 | if (!LoadAndPos) |
3960 | return std::nullopt; |
3961 | GZExtLoad *Load; |
3962 | int64_t DstPos; |
3963 | std::tie(args&: Load, args&: DstPos) = *LoadAndPos; |
3964 | |
3965 | // TODO: Handle multiple MachineBasicBlocks. Currently not handled because |
3966 | // it is difficult to check for stores/calls/etc between loads. |
3967 | MachineBasicBlock *LoadMBB = Load->getParent(); |
3968 | if (!MBB) |
3969 | MBB = LoadMBB; |
3970 | if (LoadMBB != MBB) |
3971 | return std::nullopt; |
3972 | |
3973 | // Make sure that the MachineMemOperands of every seen load are compatible. |
3974 | auto &LoadMMO = Load->getMMO(); |
3975 | if (!MMO) |
3976 | MMO = &LoadMMO; |
3977 | if (MMO->getAddrSpace() != LoadMMO.getAddrSpace()) |
3978 | return std::nullopt; |
3979 | |
3980 | // Find out what the base pointer and index for the load is. |
3981 | Register LoadPtr; |
3982 | int64_t Idx; |
3983 | if (!mi_match(R: Load->getOperand(i: 1).getReg(), MRI, |
3984 | P: m_GPtrAdd(L: m_Reg(R&: LoadPtr), R: m_ICst(Cst&: Idx)))) { |
3985 | LoadPtr = Load->getOperand(i: 1).getReg(); |
3986 | Idx = 0; |
3987 | } |
3988 | |
3989 | // Don't combine things like a[i], a[i] -> a bigger load. |
3990 | if (!SeenIdx.insert(V: Idx).second) |
3991 | return std::nullopt; |
3992 | |
3993 | // Every load must share the same base pointer; don't combine things like: |
3994 | // |
3995 | // a[i], b[i + 1] -> a bigger load. |
3996 | if (!BasePtr.isValid()) |
3997 | BasePtr = LoadPtr; |
3998 | if (BasePtr != LoadPtr) |
3999 | return std::nullopt; |
4000 | |
4001 | if (Idx < LowestIdx) { |
4002 | LowestIdx = Idx; |
4003 | LowestIdxLoad = Load; |
4004 | } |
4005 | |
4006 | // Keep track of the byte offset that this load ends up at. If we have seen |
4007 | // the byte offset, then stop here. We do not want to combine: |
4008 | // |
4009 | // a[i] << 16, a[i + k] << 16 -> a bigger load. |
4010 | if (!MemOffset2Idx.try_emplace(Key: DstPos, Args&: Idx).second) |
4011 | return std::nullopt; |
4012 | Loads.insert(X: Load); |
4013 | |
4014 | // Keep track of the position of the earliest/latest loads in the pattern. |
4015 | // We will check that there are no load fold barriers between them later |
4016 | // on. |
4017 | // |
4018 | // FIXME: Is there a better way to check for load fold barriers? |
4019 | if (!EarliestLoad || dominates(DefMI: *Load, UseMI: *EarliestLoad)) |
4020 | EarliestLoad = Load; |
4021 | if (!LatestLoad || dominates(DefMI: *LatestLoad, UseMI: *Load)) |
4022 | LatestLoad = Load; |
4023 | } |
4024 | |
4025 | // We found a load for each register. Let's check if each load satisfies the |
4026 | // pattern. |
4027 | assert(Loads.size() == RegsToVisit.size() && |
4028 | "Expected to find a load for each register?" ); |
4029 | assert(EarliestLoad != LatestLoad && EarliestLoad && |
4030 | LatestLoad && "Expected at least two loads?" ); |
4031 | |
4032 | // Check if there are any stores, calls, etc. between any of the loads. If |
4033 | // there are, then we can't safely perform the combine. |
4034 | // |
4035 | // MaxIter is chosen based off the (worst case) number of iterations it |
4036 | // typically takes to succeed in the LLVM test suite plus some padding. |
4037 | // |
4038 | // FIXME: Is there a better way to check for load fold barriers? |
4039 | const unsigned MaxIter = 20; |
4040 | unsigned Iter = 0; |
4041 | for (const auto &MI : instructionsWithoutDebug(It: EarliestLoad->getIterator(), |
4042 | End: LatestLoad->getIterator())) { |
4043 | if (Loads.count(key: &MI)) |
4044 | continue; |
4045 | if (MI.isLoadFoldBarrier()) |
4046 | return std::nullopt; |
4047 | if (Iter++ == MaxIter) |
4048 | return std::nullopt; |
4049 | } |
4050 | |
4051 | return std::make_tuple(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad); |
4052 | } |
4053 | |
4054 | bool CombinerHelper::matchLoadOrCombine( |
4055 | MachineInstr &MI, |
4056 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
4057 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
4058 | MachineFunction &MF = *MI.getMF(); |
4059 | // Assuming a little-endian target, transform: |
4060 | // s8 *a = ... |
4061 | // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24) |
4062 | // => |
4063 | // s32 val = *((i32)a) |
4064 | // |
4065 | // s8 *a = ... |
4066 | // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3] |
4067 | // => |
4068 | // s32 val = BSWAP(*((s32)a)) |
4069 | Register Dst = MI.getOperand(i: 0).getReg(); |
4070 | LLT Ty = MRI.getType(Reg: Dst); |
4071 | if (Ty.isVector()) |
4072 | return false; |
4073 | |
4074 | // We need to combine at least two loads into this type. Since the smallest |
4075 | // possible load is into a byte, we need at least a 16-bit wide type. |
4076 | const unsigned WideMemSizeInBits = Ty.getSizeInBits(); |
4077 | if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0) |
4078 | return false; |
4079 | |
4080 | // Match a collection of non-OR instructions in the pattern. |
4081 | auto RegsToVisit = findCandidatesForLoadOrCombine(Root: &MI); |
4082 | if (!RegsToVisit) |
4083 | return false; |
4084 | |
4085 | // We have a collection of non-OR instructions. Figure out how wide each of |
4086 | // the small loads should be based off of the number of potential loads we |
4087 | // found. |
4088 | const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size(); |
4089 | if (NarrowMemSizeInBits % 8 != 0) |
4090 | return false; |
4091 | |
4092 | // Check if each register feeding into each OR is a load from the same |
4093 | // base pointer + some arithmetic. |
4094 | // |
4095 | // e.g. a[0], a[1] << 8, a[2] << 16, etc. |
4096 | // |
4097 | // Also verify that each of these ends up putting a[i] into the same memory |
4098 | // offset as a load into a wide type would. |
4099 | SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx; |
4100 | GZExtLoad *LowestIdxLoad, *LatestLoad; |
4101 | int64_t LowestIdx; |
4102 | auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine( |
4103 | MemOffset2Idx, RegsToVisit: *RegsToVisit, MemSizeInBits: NarrowMemSizeInBits); |
4104 | if (!MaybeLoadInfo) |
4105 | return false; |
4106 | std::tie(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad) = *MaybeLoadInfo; |
4107 | |
4108 | // We have a bunch of loads being OR'd together. Using the addresses + offsets |
4109 | // we found before, check if this corresponds to a big or little endian byte |
4110 | // pattern. If it does, then we can represent it using a load + possibly a |
4111 | // BSWAP. |
4112 | bool IsBigEndianTarget = MF.getDataLayout().isBigEndian(); |
4113 | std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx); |
4114 | if (!IsBigEndian) |
4115 | return false; |
4116 | bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian; |
4117 | if (NeedsBSwap && !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_BSWAP, {Ty}})) |
4118 | return false; |
4119 | |
4120 | // Make sure that the load from the lowest index produces offset 0 in the |
4121 | // final value. |
4122 | // |
4123 | // This ensures that we won't combine something like this: |
4124 | // |
4125 | // load x[i] -> byte 2 |
4126 | // load x[i+1] -> byte 0 ---> wide_load x[i] |
4127 | // load x[i+2] -> byte 1 |
4128 | const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits; |
4129 | const unsigned ZeroByteOffset = |
4130 | *IsBigEndian |
4131 | ? bigEndianByteAt(ByteWidth: NumLoadsInTy, I: 0) |
4132 | : littleEndianByteAt(ByteWidth: NumLoadsInTy, I: 0); |
4133 | auto ZeroOffsetIdx = MemOffset2Idx.find(Val: ZeroByteOffset); |
4134 | if (ZeroOffsetIdx == MemOffset2Idx.end() || |
4135 | ZeroOffsetIdx->second != LowestIdx) |
4136 | return false; |
4137 | |
4138 | // We wil reuse the pointer from the load which ends up at byte offset 0. It |
4139 | // may not use index 0. |
4140 | Register Ptr = LowestIdxLoad->getPointerReg(); |
4141 | const MachineMemOperand &MMO = LowestIdxLoad->getMMO(); |
4142 | LegalityQuery::MemDesc MMDesc(MMO); |
4143 | MMDesc.MemoryTy = Ty; |
4144 | if (!isLegalOrBeforeLegalizer( |
4145 | Query: {TargetOpcode::G_LOAD, {Ty, MRI.getType(Reg: Ptr)}, {MMDesc}})) |
4146 | return false; |
4147 | auto PtrInfo = MMO.getPointerInfo(); |
4148 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: WideMemSizeInBits / 8); |
4149 | |
4150 | // Load must be allowed and fast on the target. |
4151 | LLVMContext &C = MF.getFunction().getContext(); |
4152 | auto &DL = MF.getDataLayout(); |
4153 | unsigned Fast = 0; |
4154 | if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty, MMO: *NewMMO, Fast: &Fast) || |
4155 | !Fast) |
4156 | return false; |
4157 | |
4158 | MatchInfo = [=](MachineIRBuilder &MIB) { |
4159 | MIB.setInstrAndDebugLoc(*LatestLoad); |
4160 | Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(VReg: Dst) : Dst; |
4161 | MIB.buildLoad(Res: LoadDst, Addr: Ptr, MMO&: *NewMMO); |
4162 | if (NeedsBSwap) |
4163 | MIB.buildBSwap(Dst, Src0: LoadDst); |
4164 | }; |
4165 | return true; |
4166 | } |
4167 | |
4168 | bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI, |
4169 | MachineInstr *&ExtMI) const { |
4170 | auto &PHI = cast<GPhi>(Val&: MI); |
4171 | Register DstReg = PHI.getReg(Idx: 0); |
4172 | |
4173 | // TODO: Extending a vector may be expensive, don't do this until heuristics |
4174 | // are better. |
4175 | if (MRI.getType(Reg: DstReg).isVector()) |
4176 | return false; |
4177 | |
4178 | // Try to match a phi, whose only use is an extend. |
4179 | if (!MRI.hasOneNonDBGUse(RegNo: DstReg)) |
4180 | return false; |
4181 | ExtMI = &*MRI.use_instr_nodbg_begin(RegNo: DstReg); |
4182 | switch (ExtMI->getOpcode()) { |
4183 | case TargetOpcode::G_ANYEXT: |
4184 | return true; // G_ANYEXT is usually free. |
4185 | case TargetOpcode::G_ZEXT: |
4186 | case TargetOpcode::G_SEXT: |
4187 | break; |
4188 | default: |
4189 | return false; |
4190 | } |
4191 | |
4192 | // If the target is likely to fold this extend away, don't propagate. |
4193 | if (Builder.getTII().isExtendLikelyToBeFolded(ExtMI&: *ExtMI, MRI)) |
4194 | return false; |
4195 | |
4196 | // We don't want to propagate the extends unless there's a good chance that |
4197 | // they'll be optimized in some way. |
4198 | // Collect the unique incoming values. |
4199 | SmallPtrSet<MachineInstr *, 4> InSrcs; |
4200 | for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { |
4201 | auto *DefMI = getDefIgnoringCopies(Reg: PHI.getIncomingValue(I), MRI); |
4202 | switch (DefMI->getOpcode()) { |
4203 | case TargetOpcode::G_LOAD: |
4204 | case TargetOpcode::G_TRUNC: |
4205 | case TargetOpcode::G_SEXT: |
4206 | case TargetOpcode::G_ZEXT: |
4207 | case TargetOpcode::G_ANYEXT: |
4208 | case TargetOpcode::G_CONSTANT: |
4209 | InSrcs.insert(Ptr: DefMI); |
4210 | // Don't try to propagate if there are too many places to create new |
4211 | // extends, chances are it'll increase code size. |
4212 | if (InSrcs.size() > 2) |
4213 | return false; |
4214 | break; |
4215 | default: |
4216 | return false; |
4217 | } |
4218 | } |
4219 | return true; |
4220 | } |
4221 | |
4222 | void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI, |
4223 | MachineInstr *&ExtMI) const { |
4224 | auto &PHI = cast<GPhi>(Val&: MI); |
4225 | Register DstReg = ExtMI->getOperand(i: 0).getReg(); |
4226 | LLT ExtTy = MRI.getType(Reg: DstReg); |
4227 | |
4228 | // Propagate the extension into the block of each incoming reg's block. |
4229 | // Use a SetVector here because PHIs can have duplicate edges, and we want |
4230 | // deterministic iteration order. |
4231 | SmallSetVector<MachineInstr *, 8> SrcMIs; |
4232 | SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap; |
4233 | for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { |
4234 | auto SrcReg = PHI.getIncomingValue(I); |
4235 | auto *SrcMI = MRI.getVRegDef(Reg: SrcReg); |
4236 | if (!SrcMIs.insert(X: SrcMI)) |
4237 | continue; |
4238 | |
4239 | // Build an extend after each src inst. |
4240 | auto *MBB = SrcMI->getParent(); |
4241 | MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator(); |
4242 | if (InsertPt != MBB->end() && InsertPt->isPHI()) |
4243 | InsertPt = MBB->getFirstNonPHI(); |
4244 | |
4245 | Builder.setInsertPt(MBB&: *SrcMI->getParent(), II: InsertPt); |
4246 | Builder.setDebugLoc(MI.getDebugLoc()); |
4247 | auto NewExt = Builder.buildExtOrTrunc(ExtOpc: ExtMI->getOpcode(), Res: ExtTy, Op: SrcReg); |
4248 | OldToNewSrcMap[SrcMI] = NewExt; |
4249 | } |
4250 | |
4251 | // Create a new phi with the extended inputs. |
4252 | Builder.setInstrAndDebugLoc(MI); |
4253 | auto NewPhi = Builder.buildInstrNoInsert(Opcode: TargetOpcode::G_PHI); |
4254 | NewPhi.addDef(RegNo: DstReg); |
4255 | for (const MachineOperand &MO : llvm::drop_begin(RangeOrContainer: MI.operands())) { |
4256 | if (!MO.isReg()) { |
4257 | NewPhi.addMBB(MBB: MO.getMBB()); |
4258 | continue; |
4259 | } |
4260 | auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(Reg: MO.getReg())]; |
4261 | NewPhi.addUse(RegNo: NewSrc->getOperand(i: 0).getReg()); |
4262 | } |
4263 | Builder.insertInstr(MIB: NewPhi); |
4264 | ExtMI->eraseFromParent(); |
4265 | } |
4266 | |
4267 | bool CombinerHelper::(MachineInstr &MI, |
4268 | Register &Reg) const { |
4269 | assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); |
4270 | // If we have a constant index, look for a G_BUILD_VECTOR source |
4271 | // and find the source register that the index maps to. |
4272 | Register SrcVec = MI.getOperand(i: 1).getReg(); |
4273 | LLT SrcTy = MRI.getType(Reg: SrcVec); |
4274 | if (SrcTy.isScalableVector()) |
4275 | return false; |
4276 | |
4277 | auto Cst = getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
4278 | if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements()) |
4279 | return false; |
4280 | |
4281 | unsigned VecIdx = Cst->Value.getZExtValue(); |
4282 | |
4283 | // Check if we have a build_vector or build_vector_trunc with an optional |
4284 | // trunc in front. |
4285 | MachineInstr *SrcVecMI = MRI.getVRegDef(Reg: SrcVec); |
4286 | if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) { |
4287 | SrcVecMI = MRI.getVRegDef(Reg: SrcVecMI->getOperand(i: 1).getReg()); |
4288 | } |
4289 | |
4290 | if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR && |
4291 | SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC) |
4292 | return false; |
4293 | |
4294 | EVT Ty(getMVTForLLT(Ty: SrcTy)); |
4295 | if (!MRI.hasOneNonDBGUse(RegNo: SrcVec) && |
4296 | !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty)) |
4297 | return false; |
4298 | |
4299 | Reg = SrcVecMI->getOperand(i: VecIdx + 1).getReg(); |
4300 | return true; |
4301 | } |
4302 | |
4303 | void CombinerHelper::(MachineInstr &MI, |
4304 | Register &Reg) const { |
4305 | // Check the type of the register, since it may have come from a |
4306 | // G_BUILD_VECTOR_TRUNC. |
4307 | LLT ScalarTy = MRI.getType(Reg); |
4308 | Register DstReg = MI.getOperand(i: 0).getReg(); |
4309 | LLT DstTy = MRI.getType(Reg: DstReg); |
4310 | |
4311 | if (ScalarTy != DstTy) { |
4312 | assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits()); |
4313 | Builder.buildTrunc(Res: DstReg, Op: Reg); |
4314 | MI.eraseFromParent(); |
4315 | return; |
4316 | } |
4317 | replaceSingleDefInstWithReg(MI, Replacement: Reg); |
4318 | } |
4319 | |
4320 | bool CombinerHelper::( |
4321 | MachineInstr &MI, |
4322 | SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const { |
4323 | assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); |
4324 | // This combine tries to find build_vector's which have every source element |
4325 | // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like |
4326 | // the masked load scalarization is run late in the pipeline. There's already |
4327 | // a combine for a similar pattern starting from the extract, but that |
4328 | // doesn't attempt to do it if there are multiple uses of the build_vector, |
4329 | // which in this case is true. Starting the combine from the build_vector |
4330 | // feels more natural than trying to find sibling nodes of extracts. |
4331 | // E.g. |
4332 | // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4 |
4333 | // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0 |
4334 | // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1 |
4335 | // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2 |
4336 | // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3 |
4337 | // ==> |
4338 | // replace ext{1,2,3,4} with %s{1,2,3,4} |
4339 | |
4340 | Register DstReg = MI.getOperand(i: 0).getReg(); |
4341 | LLT DstTy = MRI.getType(Reg: DstReg); |
4342 | unsigned NumElts = DstTy.getNumElements(); |
4343 | |
4344 | SmallBitVector (NumElts); |
4345 | for (MachineInstr &II : MRI.use_nodbg_instructions(Reg: DstReg)) { |
4346 | if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT) |
4347 | return false; |
4348 | auto Cst = getIConstantVRegVal(VReg: II.getOperand(i: 2).getReg(), MRI); |
4349 | if (!Cst) |
4350 | return false; |
4351 | unsigned Idx = Cst->getZExtValue(); |
4352 | if (Idx >= NumElts) |
4353 | return false; // Out of range. |
4354 | ExtractedElts.set(Idx); |
4355 | SrcDstPairs.emplace_back( |
4356 | Args: std::make_pair(x: MI.getOperand(i: Idx + 1).getReg(), y: &II)); |
4357 | } |
4358 | // Match if every element was extracted. |
4359 | return ExtractedElts.all(); |
4360 | } |
4361 | |
4362 | void CombinerHelper::( |
4363 | MachineInstr &MI, |
4364 | SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const { |
4365 | assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); |
4366 | for (auto &Pair : SrcDstPairs) { |
4367 | auto *ExtMI = Pair.second; |
4368 | replaceRegWith(MRI, FromReg: ExtMI->getOperand(i: 0).getReg(), ToReg: Pair.first); |
4369 | ExtMI->eraseFromParent(); |
4370 | } |
4371 | MI.eraseFromParent(); |
4372 | } |
4373 | |
4374 | void CombinerHelper::applyBuildFn( |
4375 | MachineInstr &MI, |
4376 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
4377 | applyBuildFnNoErase(MI, MatchInfo); |
4378 | MI.eraseFromParent(); |
4379 | } |
4380 | |
4381 | void CombinerHelper::applyBuildFnNoErase( |
4382 | MachineInstr &MI, |
4383 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
4384 | MatchInfo(Builder); |
4385 | } |
4386 | |
4387 | bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI, |
4388 | BuildFnTy &MatchInfo) const { |
4389 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
4390 | |
4391 | Register Dst = MI.getOperand(i: 0).getReg(); |
4392 | LLT Ty = MRI.getType(Reg: Dst); |
4393 | unsigned BitWidth = Ty.getScalarSizeInBits(); |
4394 | |
4395 | Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt; |
4396 | unsigned FshOpc = 0; |
4397 | |
4398 | // Match (or (shl ...), (lshr ...)). |
4399 | if (!mi_match(R: Dst, MRI, |
4400 | // m_GOr() handles the commuted version as well. |
4401 | P: m_GOr(L: m_GShl(L: m_Reg(R&: ShlSrc), R: m_Reg(R&: ShlAmt)), |
4402 | R: m_GLShr(L: m_Reg(R&: LShrSrc), R: m_Reg(R&: LShrAmt))))) |
4403 | return false; |
4404 | |
4405 | // Given constants C0 and C1 such that C0 + C1 is bit-width: |
4406 | // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1) |
4407 | int64_t CstShlAmt, CstLShrAmt; |
4408 | if (mi_match(R: ShlAmt, MRI, P: m_ICstOrSplat(Cst&: CstShlAmt)) && |
4409 | mi_match(R: LShrAmt, MRI, P: m_ICstOrSplat(Cst&: CstLShrAmt)) && |
4410 | CstShlAmt + CstLShrAmt == BitWidth) { |
4411 | FshOpc = TargetOpcode::G_FSHR; |
4412 | Amt = LShrAmt; |
4413 | |
4414 | } else if (mi_match(R: LShrAmt, MRI, |
4415 | P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) && |
4416 | ShlAmt == Amt) { |
4417 | // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt) |
4418 | FshOpc = TargetOpcode::G_FSHL; |
4419 | |
4420 | } else if (mi_match(R: ShlAmt, MRI, |
4421 | P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) && |
4422 | LShrAmt == Amt) { |
4423 | // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt) |
4424 | FshOpc = TargetOpcode::G_FSHR; |
4425 | |
4426 | } else { |
4427 | return false; |
4428 | } |
4429 | |
4430 | LLT AmtTy = MRI.getType(Reg: Amt); |
4431 | if (!isLegalOrBeforeLegalizer(Query: {FshOpc, {Ty, AmtTy}})) |
4432 | return false; |
4433 | |
4434 | MatchInfo = [=](MachineIRBuilder &B) { |
4435 | B.buildInstr(Opc: FshOpc, DstOps: {Dst}, SrcOps: {ShlSrc, LShrSrc, Amt}); |
4436 | }; |
4437 | return true; |
4438 | } |
4439 | |
4440 | /// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate. |
4441 | bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) const { |
4442 | unsigned Opc = MI.getOpcode(); |
4443 | assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); |
4444 | Register X = MI.getOperand(i: 1).getReg(); |
4445 | Register Y = MI.getOperand(i: 2).getReg(); |
4446 | if (X != Y) |
4447 | return false; |
4448 | unsigned RotateOpc = |
4449 | Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR; |
4450 | return isLegalOrBeforeLegalizer(Query: {RotateOpc, {MRI.getType(Reg: X), MRI.getType(Reg: Y)}}); |
4451 | } |
4452 | |
4453 | void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) const { |
4454 | unsigned Opc = MI.getOpcode(); |
4455 | assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); |
4456 | bool IsFSHL = Opc == TargetOpcode::G_FSHL; |
4457 | Observer.changingInstr(MI); |
4458 | MI.setDesc(Builder.getTII().get(Opcode: IsFSHL ? TargetOpcode::G_ROTL |
4459 | : TargetOpcode::G_ROTR)); |
4460 | MI.removeOperand(OpNo: 2); |
4461 | Observer.changedInstr(MI); |
4462 | } |
4463 | |
4464 | // Fold (rot x, c) -> (rot x, c % BitSize) |
4465 | bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) const { |
4466 | assert(MI.getOpcode() == TargetOpcode::G_ROTL || |
4467 | MI.getOpcode() == TargetOpcode::G_ROTR); |
4468 | unsigned Bitsize = |
4469 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits(); |
4470 | Register AmtReg = MI.getOperand(i: 2).getReg(); |
4471 | bool OutOfRange = false; |
4472 | auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) { |
4473 | if (auto *CI = dyn_cast<ConstantInt>(Val: C)) |
4474 | OutOfRange |= CI->getValue().uge(RHS: Bitsize); |
4475 | return true; |
4476 | }; |
4477 | return matchUnaryPredicate(MRI, Reg: AmtReg, Match: MatchOutOfRange) && OutOfRange; |
4478 | } |
4479 | |
4480 | void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) const { |
4481 | assert(MI.getOpcode() == TargetOpcode::G_ROTL || |
4482 | MI.getOpcode() == TargetOpcode::G_ROTR); |
4483 | unsigned Bitsize = |
4484 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits(); |
4485 | Register Amt = MI.getOperand(i: 2).getReg(); |
4486 | LLT AmtTy = MRI.getType(Reg: Amt); |
4487 | auto Bits = Builder.buildConstant(Res: AmtTy, Val: Bitsize); |
4488 | Amt = Builder.buildURem(Dst: AmtTy, Src0: MI.getOperand(i: 2).getReg(), Src1: Bits).getReg(Idx: 0); |
4489 | Observer.changingInstr(MI); |
4490 | MI.getOperand(i: 2).setReg(Amt); |
4491 | Observer.changedInstr(MI); |
4492 | } |
4493 | |
4494 | bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI, |
4495 | int64_t &MatchInfo) const { |
4496 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
4497 | auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate()); |
4498 | |
4499 | // We want to avoid calling KnownBits on the LHS if possible, as this combine |
4500 | // has no filter and runs on every G_ICMP instruction. We can avoid calling |
4501 | // KnownBits on the LHS in two cases: |
4502 | // |
4503 | // - The RHS is unknown: Constants are always on RHS. If the RHS is unknown |
4504 | // we cannot do any transforms so we can safely bail out early. |
4505 | // - The RHS is zero: we don't need to know the LHS to do unsigned <0 and |
4506 | // >=0. |
4507 | auto KnownRHS = VT->getKnownBits(R: MI.getOperand(i: 3).getReg()); |
4508 | if (KnownRHS.isUnknown()) |
4509 | return false; |
4510 | |
4511 | std::optional<bool> KnownVal; |
4512 | if (KnownRHS.isZero()) { |
4513 | // ? uge 0 -> always true |
4514 | // ? ult 0 -> always false |
4515 | if (Pred == CmpInst::ICMP_UGE) |
4516 | KnownVal = true; |
4517 | else if (Pred == CmpInst::ICMP_ULT) |
4518 | KnownVal = false; |
4519 | } |
4520 | |
4521 | if (!KnownVal) { |
4522 | auto KnownLHS = VT->getKnownBits(R: MI.getOperand(i: 2).getReg()); |
4523 | KnownVal = ICmpInst::compare(LHS: KnownLHS, RHS: KnownRHS, Pred); |
4524 | } |
4525 | |
4526 | if (!KnownVal) |
4527 | return false; |
4528 | MatchInfo = |
4529 | *KnownVal |
4530 | ? getICmpTrueVal(TLI: getTargetLowering(), |
4531 | /*IsVector = */ |
4532 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector(), |
4533 | /* IsFP = */ false) |
4534 | : 0; |
4535 | return true; |
4536 | } |
4537 | |
4538 | bool CombinerHelper::matchICmpToLHSKnownBits( |
4539 | MachineInstr &MI, |
4540 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
4541 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
4542 | // Given: |
4543 | // |
4544 | // %x = G_WHATEVER (... x is known to be 0 or 1 ...) |
4545 | // %cmp = G_ICMP ne %x, 0 |
4546 | // |
4547 | // Or: |
4548 | // |
4549 | // %x = G_WHATEVER (... x is known to be 0 or 1 ...) |
4550 | // %cmp = G_ICMP eq %x, 1 |
4551 | // |
4552 | // We can replace %cmp with %x assuming true is 1 on the target. |
4553 | auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate()); |
4554 | if (!CmpInst::isEquality(pred: Pred)) |
4555 | return false; |
4556 | Register Dst = MI.getOperand(i: 0).getReg(); |
4557 | LLT DstTy = MRI.getType(Reg: Dst); |
4558 | if (getICmpTrueVal(TLI: getTargetLowering(), IsVector: DstTy.isVector(), |
4559 | /* IsFP = */ false) != 1) |
4560 | return false; |
4561 | int64_t OneOrZero = Pred == CmpInst::ICMP_EQ; |
4562 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICst(RequestedValue: OneOrZero))) |
4563 | return false; |
4564 | Register LHS = MI.getOperand(i: 2).getReg(); |
4565 | auto KnownLHS = VT->getKnownBits(R: LHS); |
4566 | if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1) |
4567 | return false; |
4568 | // Make sure replacing Dst with the LHS is a legal operation. |
4569 | LLT LHSTy = MRI.getType(Reg: LHS); |
4570 | unsigned LHSSize = LHSTy.getSizeInBits(); |
4571 | unsigned DstSize = DstTy.getSizeInBits(); |
4572 | unsigned Op = TargetOpcode::COPY; |
4573 | if (DstSize != LHSSize) |
4574 | Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT; |
4575 | if (!isLegalOrBeforeLegalizer(Query: {Op, {DstTy, LHSTy}})) |
4576 | return false; |
4577 | MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Opc: Op, DstOps: {Dst}, SrcOps: {LHS}); }; |
4578 | return true; |
4579 | } |
4580 | |
4581 | // Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0 |
4582 | bool CombinerHelper::matchAndOrDisjointMask( |
4583 | MachineInstr &MI, |
4584 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
4585 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
4586 | |
4587 | // Ignore vector types to simplify matching the two constants. |
4588 | // TODO: do this for vectors and scalars via a demanded bits analysis. |
4589 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
4590 | if (Ty.isVector()) |
4591 | return false; |
4592 | |
4593 | Register Src; |
4594 | Register AndMaskReg; |
4595 | int64_t AndMaskBits; |
4596 | int64_t OrMaskBits; |
4597 | if (!mi_match(MI, MRI, |
4598 | P: m_GAnd(L: m_GOr(L: m_Reg(R&: Src), R: m_ICst(Cst&: OrMaskBits)), |
4599 | R: m_all_of(preds: m_ICst(Cst&: AndMaskBits), preds: m_Reg(R&: AndMaskReg))))) |
4600 | return false; |
4601 | |
4602 | // Check if OrMask could turn on any bits in Src. |
4603 | if (AndMaskBits & OrMaskBits) |
4604 | return false; |
4605 | |
4606 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4607 | Observer.changingInstr(MI); |
4608 | // Canonicalize the result to have the constant on the RHS. |
4609 | if (MI.getOperand(i: 1).getReg() == AndMaskReg) |
4610 | MI.getOperand(i: 2).setReg(AndMaskReg); |
4611 | MI.getOperand(i: 1).setReg(Src); |
4612 | Observer.changedInstr(MI); |
4613 | }; |
4614 | return true; |
4615 | } |
4616 | |
4617 | /// Form a G_SBFX from a G_SEXT_INREG fed by a right shift. |
4618 | bool CombinerHelper::( |
4619 | MachineInstr &MI, |
4620 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
4621 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
4622 | Register Dst = MI.getOperand(i: 0).getReg(); |
4623 | Register Src = MI.getOperand(i: 1).getReg(); |
4624 | LLT Ty = MRI.getType(Reg: Src); |
4625 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4626 | if (!LI || !LI->isLegalOrCustom(Query: {TargetOpcode::G_SBFX, {Ty, ExtractTy}})) |
4627 | return false; |
4628 | int64_t Width = MI.getOperand(i: 2).getImm(); |
4629 | Register ShiftSrc; |
4630 | int64_t ShiftImm; |
4631 | if (!mi_match( |
4632 | R: Src, MRI, |
4633 | P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)), |
4634 | preds: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)))))) |
4635 | return false; |
4636 | if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits()) |
4637 | return false; |
4638 | |
4639 | MatchInfo = [=](MachineIRBuilder &B) { |
4640 | auto Cst1 = B.buildConstant(Res: ExtractTy, Val: ShiftImm); |
4641 | auto Cst2 = B.buildConstant(Res: ExtractTy, Val: Width); |
4642 | B.buildSbfx(Dst, Src: ShiftSrc, LSB: Cst1, Width: Cst2); |
4643 | }; |
4644 | return true; |
4645 | } |
4646 | |
4647 | /// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants. |
4648 | bool CombinerHelper::matchBitfieldExtractFromAnd(MachineInstr &MI, |
4649 | BuildFnTy &MatchInfo) const { |
4650 | GAnd *And = cast<GAnd>(Val: &MI); |
4651 | Register Dst = And->getReg(Idx: 0); |
4652 | LLT Ty = MRI.getType(Reg: Dst); |
4653 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4654 | // Note that isLegalOrBeforeLegalizer is stricter and does not take custom |
4655 | // into account. |
4656 | if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}})) |
4657 | return false; |
4658 | |
4659 | int64_t AndImm, LSBImm; |
4660 | Register ShiftSrc; |
4661 | const unsigned Size = Ty.getScalarSizeInBits(); |
4662 | if (!mi_match(R: And->getReg(Idx: 0), MRI, |
4663 | P: m_GAnd(L: m_OneNonDBGUse(SP: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: LSBImm))), |
4664 | R: m_ICst(Cst&: AndImm)))) |
4665 | return false; |
4666 | |
4667 | // The mask is a mask of the low bits iff imm & (imm+1) == 0. |
4668 | auto MaybeMask = static_cast<uint64_t>(AndImm); |
4669 | if (MaybeMask & (MaybeMask + 1)) |
4670 | return false; |
4671 | |
4672 | // LSB must fit within the register. |
4673 | if (static_cast<uint64_t>(LSBImm) >= Size) |
4674 | return false; |
4675 | |
4676 | uint64_t Width = APInt(Size, AndImm).countr_one(); |
4677 | MatchInfo = [=](MachineIRBuilder &B) { |
4678 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4679 | auto LSBCst = B.buildConstant(Res: ExtractTy, Val: LSBImm); |
4680 | B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {ShiftSrc, LSBCst, WidthCst}); |
4681 | }; |
4682 | return true; |
4683 | } |
4684 | |
4685 | bool CombinerHelper::( |
4686 | MachineInstr &MI, |
4687 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
4688 | const unsigned Opcode = MI.getOpcode(); |
4689 | assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR); |
4690 | |
4691 | const Register Dst = MI.getOperand(i: 0).getReg(); |
4692 | |
4693 | const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR |
4694 | ? TargetOpcode::G_SBFX |
4695 | : TargetOpcode::G_UBFX; |
4696 | |
4697 | // Check if the type we would use for the extract is legal |
4698 | LLT Ty = MRI.getType(Reg: Dst); |
4699 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4700 | if (!LI || !LI->isLegalOrCustom(Query: {ExtrOpcode, {Ty, ExtractTy}})) |
4701 | return false; |
4702 | |
4703 | Register ShlSrc; |
4704 | int64_t ShrAmt; |
4705 | int64_t ShlAmt; |
4706 | const unsigned Size = Ty.getScalarSizeInBits(); |
4707 | |
4708 | // Try to match shr (shl x, c1), c2 |
4709 | if (!mi_match(R: Dst, MRI, |
4710 | P: m_BinOp(Opcode, |
4711 | L: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: ShlSrc), R: m_ICst(Cst&: ShlAmt))), |
4712 | R: m_ICst(Cst&: ShrAmt)))) |
4713 | return false; |
4714 | |
4715 | // Make sure that the shift sizes can fit a bitfield extract |
4716 | if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size) |
4717 | return false; |
4718 | |
4719 | // Skip this combine if the G_SEXT_INREG combine could handle it |
4720 | if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt) |
4721 | return false; |
4722 | |
4723 | // Calculate start position and width of the extract |
4724 | const int64_t Pos = ShrAmt - ShlAmt; |
4725 | const int64_t Width = Size - ShrAmt; |
4726 | |
4727 | MatchInfo = [=](MachineIRBuilder &B) { |
4728 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4729 | auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos); |
4730 | B.buildInstr(Opc: ExtrOpcode, DstOps: {Dst}, SrcOps: {ShlSrc, PosCst, WidthCst}); |
4731 | }; |
4732 | return true; |
4733 | } |
4734 | |
4735 | bool CombinerHelper::matchBitfieldExtractFromShrAnd( |
4736 | MachineInstr &MI, |
4737 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
4738 | const unsigned Opcode = MI.getOpcode(); |
4739 | assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR); |
4740 | |
4741 | const Register Dst = MI.getOperand(i: 0).getReg(); |
4742 | LLT Ty = MRI.getType(Reg: Dst); |
4743 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4744 | if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}})) |
4745 | return false; |
4746 | |
4747 | // Try to match shr (and x, c1), c2 |
4748 | Register AndSrc; |
4749 | int64_t ShrAmt; |
4750 | int64_t SMask; |
4751 | if (!mi_match(R: Dst, MRI, |
4752 | P: m_BinOp(Opcode, |
4753 | L: m_OneNonDBGUse(SP: m_GAnd(L: m_Reg(R&: AndSrc), R: m_ICst(Cst&: SMask))), |
4754 | R: m_ICst(Cst&: ShrAmt)))) |
4755 | return false; |
4756 | |
4757 | const unsigned Size = Ty.getScalarSizeInBits(); |
4758 | if (ShrAmt < 0 || ShrAmt >= Size) |
4759 | return false; |
4760 | |
4761 | // If the shift subsumes the mask, emit the 0 directly. |
4762 | if (0 == (SMask >> ShrAmt)) { |
4763 | MatchInfo = [=](MachineIRBuilder &B) { |
4764 | B.buildConstant(Res: Dst, Val: 0); |
4765 | }; |
4766 | return true; |
4767 | } |
4768 | |
4769 | // Check that ubfx can do the extraction, with no holes in the mask. |
4770 | uint64_t UMask = SMask; |
4771 | UMask |= maskTrailingOnes<uint64_t>(N: ShrAmt); |
4772 | UMask &= maskTrailingOnes<uint64_t>(N: Size); |
4773 | if (!isMask_64(Value: UMask)) |
4774 | return false; |
4775 | |
4776 | // Calculate start position and width of the extract. |
4777 | const int64_t Pos = ShrAmt; |
4778 | const int64_t Width = llvm::countr_one(Value: UMask) - ShrAmt; |
4779 | |
4780 | // It's preferable to keep the shift, rather than form G_SBFX. |
4781 | // TODO: remove the G_AND via demanded bits analysis. |
4782 | if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size) |
4783 | return false; |
4784 | |
4785 | MatchInfo = [=](MachineIRBuilder &B) { |
4786 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4787 | auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos); |
4788 | B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {AndSrc, PosCst, WidthCst}); |
4789 | }; |
4790 | return true; |
4791 | } |
4792 | |
4793 | bool CombinerHelper::reassociationCanBreakAddressingModePattern( |
4794 | MachineInstr &MI) const { |
4795 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
4796 | |
4797 | Register Src1Reg = PtrAdd.getBaseReg(); |
4798 | auto *Src1Def = getOpcodeDef<GPtrAdd>(Reg: Src1Reg, MRI); |
4799 | if (!Src1Def) |
4800 | return false; |
4801 | |
4802 | Register Src2Reg = PtrAdd.getOffsetReg(); |
4803 | |
4804 | if (MRI.hasOneNonDBGUse(RegNo: Src1Reg)) |
4805 | return false; |
4806 | |
4807 | auto C1 = getIConstantVRegVal(VReg: Src1Def->getOffsetReg(), MRI); |
4808 | if (!C1) |
4809 | return false; |
4810 | auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI); |
4811 | if (!C2) |
4812 | return false; |
4813 | |
4814 | const APInt &C1APIntVal = *C1; |
4815 | const APInt &C2APIntVal = *C2; |
4816 | const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue(); |
4817 | |
4818 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: PtrAdd.getReg(Idx: 0))) { |
4819 | // This combine may end up running before ptrtoint/inttoptr combines |
4820 | // manage to eliminate redundant conversions, so try to look through them. |
4821 | MachineInstr *ConvUseMI = &UseMI; |
4822 | unsigned ConvUseOpc = ConvUseMI->getOpcode(); |
4823 | while (ConvUseOpc == TargetOpcode::G_INTTOPTR || |
4824 | ConvUseOpc == TargetOpcode::G_PTRTOINT) { |
4825 | Register DefReg = ConvUseMI->getOperand(i: 0).getReg(); |
4826 | if (!MRI.hasOneNonDBGUse(RegNo: DefReg)) |
4827 | break; |
4828 | ConvUseMI = &*MRI.use_instr_nodbg_begin(RegNo: DefReg); |
4829 | ConvUseOpc = ConvUseMI->getOpcode(); |
4830 | } |
4831 | auto *LdStMI = dyn_cast<GLoadStore>(Val: ConvUseMI); |
4832 | if (!LdStMI) |
4833 | continue; |
4834 | // Is x[offset2] already not a legal addressing mode? If so then |
4835 | // reassociating the constants breaks nothing (we test offset2 because |
4836 | // that's the one we hope to fold into the load or store). |
4837 | TargetLoweringBase::AddrMode AM; |
4838 | AM.HasBaseReg = true; |
4839 | AM.BaseOffs = C2APIntVal.getSExtValue(); |
4840 | unsigned AS = MRI.getType(Reg: LdStMI->getPointerReg()).getAddressSpace(); |
4841 | Type *AccessTy = getTypeForLLT(Ty: LdStMI->getMMO().getMemoryType(), |
4842 | C&: PtrAdd.getMF()->getFunction().getContext()); |
4843 | const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering(); |
4844 | if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM, |
4845 | Ty: AccessTy, AddrSpace: AS)) |
4846 | continue; |
4847 | |
4848 | // Would x[offset1+offset2] still be a legal addressing mode? |
4849 | AM.BaseOffs = CombinedValue; |
4850 | if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM, |
4851 | Ty: AccessTy, AddrSpace: AS)) |
4852 | return true; |
4853 | } |
4854 | |
4855 | return false; |
4856 | } |
4857 | |
4858 | bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI, |
4859 | MachineInstr *RHS, |
4860 | BuildFnTy &MatchInfo) const { |
4861 | // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) |
4862 | Register Src1Reg = MI.getOperand(i: 1).getReg(); |
4863 | if (RHS->getOpcode() != TargetOpcode::G_ADD) |
4864 | return false; |
4865 | auto C2 = getIConstantVRegVal(VReg: RHS->getOperand(i: 2).getReg(), MRI); |
4866 | if (!C2) |
4867 | return false; |
4868 | |
4869 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4870 | LLT PtrTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
4871 | |
4872 | auto NewBase = |
4873 | Builder.buildPtrAdd(Res: PtrTy, Op0: Src1Reg, Op1: RHS->getOperand(i: 1).getReg()); |
4874 | Observer.changingInstr(MI); |
4875 | MI.getOperand(i: 1).setReg(NewBase.getReg(Idx: 0)); |
4876 | MI.getOperand(i: 2).setReg(RHS->getOperand(i: 2).getReg()); |
4877 | Observer.changedInstr(MI); |
4878 | }; |
4879 | return !reassociationCanBreakAddressingModePattern(MI); |
4880 | } |
4881 | |
4882 | bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI, |
4883 | MachineInstr *LHS, |
4884 | MachineInstr *RHS, |
4885 | BuildFnTy &MatchInfo) const { |
4886 | // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C) |
4887 | // if and only if (G_PTR_ADD X, C) has one use. |
4888 | Register LHSBase; |
4889 | std::optional<ValueAndVReg> LHSCstOff; |
4890 | if (!mi_match(R: MI.getBaseReg(), MRI, |
4891 | P: m_OneNonDBGUse(SP: m_GPtrAdd(L: m_Reg(R&: LHSBase), R: m_GCst(ValReg&: LHSCstOff))))) |
4892 | return false; |
4893 | |
4894 | auto *LHSPtrAdd = cast<GPtrAdd>(Val: LHS); |
4895 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4896 | // When we change LHSPtrAdd's offset register we might cause it to use a reg |
4897 | // before its def. Sink the instruction so the outer PTR_ADD to ensure this |
4898 | // doesn't happen. |
4899 | LHSPtrAdd->moveBefore(MovePos: &MI); |
4900 | Register RHSReg = MI.getOffsetReg(); |
4901 | // set VReg will cause type mismatch if it comes from extend/trunc |
4902 | auto NewCst = B.buildConstant(Res: MRI.getType(Reg: RHSReg), Val: LHSCstOff->Value); |
4903 | Observer.changingInstr(MI); |
4904 | MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0)); |
4905 | Observer.changedInstr(MI); |
4906 | Observer.changingInstr(MI&: *LHSPtrAdd); |
4907 | LHSPtrAdd->getOperand(i: 2).setReg(RHSReg); |
4908 | Observer.changedInstr(MI&: *LHSPtrAdd); |
4909 | }; |
4910 | return !reassociationCanBreakAddressingModePattern(MI); |
4911 | } |
4912 | |
4913 | bool CombinerHelper::matchReassocFoldConstantsInSubTree( |
4914 | GPtrAdd &MI, MachineInstr *LHS, MachineInstr *RHS, |
4915 | BuildFnTy &MatchInfo) const { |
4916 | // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) |
4917 | auto *LHSPtrAdd = dyn_cast<GPtrAdd>(Val: LHS); |
4918 | if (!LHSPtrAdd) |
4919 | return false; |
4920 | |
4921 | Register Src2Reg = MI.getOperand(i: 2).getReg(); |
4922 | Register LHSSrc1 = LHSPtrAdd->getBaseReg(); |
4923 | Register LHSSrc2 = LHSPtrAdd->getOffsetReg(); |
4924 | auto C1 = getIConstantVRegVal(VReg: LHSSrc2, MRI); |
4925 | if (!C1) |
4926 | return false; |
4927 | auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI); |
4928 | if (!C2) |
4929 | return false; |
4930 | |
4931 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4932 | auto NewCst = B.buildConstant(Res: MRI.getType(Reg: Src2Reg), Val: *C1 + *C2); |
4933 | Observer.changingInstr(MI); |
4934 | MI.getOperand(i: 1).setReg(LHSSrc1); |
4935 | MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0)); |
4936 | Observer.changedInstr(MI); |
4937 | }; |
4938 | return !reassociationCanBreakAddressingModePattern(MI); |
4939 | } |
4940 | |
4941 | bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI, |
4942 | BuildFnTy &MatchInfo) const { |
4943 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
4944 | // We're trying to match a few pointer computation patterns here for |
4945 | // re-association opportunities. |
4946 | // 1) Isolating a constant operand to be on the RHS, e.g.: |
4947 | // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) |
4948 | // |
4949 | // 2) Folding two constants in each sub-tree as long as such folding |
4950 | // doesn't break a legal addressing mode. |
4951 | // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) |
4952 | // |
4953 | // 3) Move a constant from the LHS of an inner op to the RHS of the outer. |
4954 | // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C) |
4955 | // iif (G_PTR_ADD X, C) has one use. |
4956 | MachineInstr *LHS = MRI.getVRegDef(Reg: PtrAdd.getBaseReg()); |
4957 | MachineInstr *RHS = MRI.getVRegDef(Reg: PtrAdd.getOffsetReg()); |
4958 | |
4959 | // Try to match example 2. |
4960 | if (matchReassocFoldConstantsInSubTree(MI&: PtrAdd, LHS, RHS, MatchInfo)) |
4961 | return true; |
4962 | |
4963 | // Try to match example 3. |
4964 | if (matchReassocConstantInnerLHS(MI&: PtrAdd, LHS, RHS, MatchInfo)) |
4965 | return true; |
4966 | |
4967 | // Try to match example 1. |
4968 | if (matchReassocConstantInnerRHS(MI&: PtrAdd, RHS, MatchInfo)) |
4969 | return true; |
4970 | |
4971 | return false; |
4972 | } |
4973 | bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg, |
4974 | Register OpLHS, Register OpRHS, |
4975 | BuildFnTy &MatchInfo) const { |
4976 | LLT OpRHSTy = MRI.getType(Reg: OpRHS); |
4977 | MachineInstr *OpLHSDef = MRI.getVRegDef(Reg: OpLHS); |
4978 | |
4979 | if (OpLHSDef->getOpcode() != Opc) |
4980 | return false; |
4981 | |
4982 | MachineInstr *OpRHSDef = MRI.getVRegDef(Reg: OpRHS); |
4983 | Register OpLHSLHS = OpLHSDef->getOperand(i: 1).getReg(); |
4984 | Register OpLHSRHS = OpLHSDef->getOperand(i: 2).getReg(); |
4985 | |
4986 | // If the inner op is (X op C), pull the constant out so it can be folded with |
4987 | // other constants in the expression tree. Folding is not guaranteed so we |
4988 | // might have (C1 op C2). In that case do not pull a constant out because it |
4989 | // won't help and can lead to infinite loops. |
4990 | if (isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSRHS), MRI) && |
4991 | !isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSLHS), MRI)) { |
4992 | if (isConstantOrConstantSplatVector(MI&: *OpRHSDef, MRI)) { |
4993 | // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2)) |
4994 | MatchInfo = [=](MachineIRBuilder &B) { |
4995 | auto NewCst = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSRHS, OpRHS}); |
4996 | B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {OpLHSLHS, NewCst}); |
4997 | }; |
4998 | return true; |
4999 | } |
5000 | if (getTargetLowering().isReassocProfitable(MRI, N0: OpLHS, N1: OpRHS)) { |
5001 | // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1) |
5002 | // iff (op x, c1) has one use |
5003 | MatchInfo = [=](MachineIRBuilder &B) { |
5004 | auto NewLHSLHS = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSLHS, OpRHS}); |
5005 | B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {NewLHSLHS, OpLHSRHS}); |
5006 | }; |
5007 | return true; |
5008 | } |
5009 | } |
5010 | |
5011 | return false; |
5012 | } |
5013 | |
5014 | bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI, |
5015 | BuildFnTy &MatchInfo) const { |
5016 | // We don't check if the reassociation will break a legal addressing mode |
5017 | // here since pointer arithmetic is handled by G_PTR_ADD. |
5018 | unsigned Opc = MI.getOpcode(); |
5019 | Register DstReg = MI.getOperand(i: 0).getReg(); |
5020 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
5021 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
5022 | |
5023 | if (tryReassocBinOp(Opc, DstReg, OpLHS: LHSReg, OpRHS: RHSReg, MatchInfo)) |
5024 | return true; |
5025 | if (tryReassocBinOp(Opc, DstReg, OpLHS: RHSReg, OpRHS: LHSReg, MatchInfo)) |
5026 | return true; |
5027 | return false; |
5028 | } |
5029 | |
5030 | bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI, |
5031 | APInt &MatchInfo) const { |
5032 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5033 | Register SrcOp = MI.getOperand(i: 1).getReg(); |
5034 | |
5035 | if (auto MaybeCst = ConstantFoldCastOp(Opcode: MI.getOpcode(), DstTy, Op0: SrcOp, MRI)) { |
5036 | MatchInfo = *MaybeCst; |
5037 | return true; |
5038 | } |
5039 | |
5040 | return false; |
5041 | } |
5042 | |
5043 | bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI, |
5044 | APInt &MatchInfo) const { |
5045 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5046 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5047 | auto MaybeCst = ConstantFoldBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI); |
5048 | if (!MaybeCst) |
5049 | return false; |
5050 | MatchInfo = *MaybeCst; |
5051 | return true; |
5052 | } |
5053 | |
5054 | bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI, |
5055 | ConstantFP *&MatchInfo) const { |
5056 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5057 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5058 | auto MaybeCst = ConstantFoldFPBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI); |
5059 | if (!MaybeCst) |
5060 | return false; |
5061 | MatchInfo = |
5062 | ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: *MaybeCst); |
5063 | return true; |
5064 | } |
5065 | |
5066 | bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI, |
5067 | ConstantFP *&MatchInfo) const { |
5068 | assert(MI.getOpcode() == TargetOpcode::G_FMA || |
5069 | MI.getOpcode() == TargetOpcode::G_FMAD); |
5070 | auto [_, Op1, Op2, Op3] = MI.getFirst4Regs(); |
5071 | |
5072 | const ConstantFP *Op3Cst = getConstantFPVRegVal(VReg: Op3, MRI); |
5073 | if (!Op3Cst) |
5074 | return false; |
5075 | |
5076 | const ConstantFP *Op2Cst = getConstantFPVRegVal(VReg: Op2, MRI); |
5077 | if (!Op2Cst) |
5078 | return false; |
5079 | |
5080 | const ConstantFP *Op1Cst = getConstantFPVRegVal(VReg: Op1, MRI); |
5081 | if (!Op1Cst) |
5082 | return false; |
5083 | |
5084 | APFloat Op1F = Op1Cst->getValueAPF(); |
5085 | Op1F.fusedMultiplyAdd(Multiplicand: Op2Cst->getValueAPF(), Addend: Op3Cst->getValueAPF(), |
5086 | RM: APFloat::rmNearestTiesToEven); |
5087 | MatchInfo = ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: Op1F); |
5088 | return true; |
5089 | } |
5090 | |
5091 | bool CombinerHelper::matchNarrowBinopFeedingAnd( |
5092 | MachineInstr &MI, |
5093 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
5094 | // Look for a binop feeding into an AND with a mask: |
5095 | // |
5096 | // %add = G_ADD %lhs, %rhs |
5097 | // %and = G_AND %add, 000...11111111 |
5098 | // |
5099 | // Check if it's possible to perform the binop at a narrower width and zext |
5100 | // back to the original width like so: |
5101 | // |
5102 | // %narrow_lhs = G_TRUNC %lhs |
5103 | // %narrow_rhs = G_TRUNC %rhs |
5104 | // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs |
5105 | // %new_add = G_ZEXT %narrow_add |
5106 | // %and = G_AND %new_add, 000...11111111 |
5107 | // |
5108 | // This can allow later combines to eliminate the G_AND if it turns out |
5109 | // that the mask is irrelevant. |
5110 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
5111 | Register Dst = MI.getOperand(i: 0).getReg(); |
5112 | Register AndLHS = MI.getOperand(i: 1).getReg(); |
5113 | Register AndRHS = MI.getOperand(i: 2).getReg(); |
5114 | LLT WideTy = MRI.getType(Reg: Dst); |
5115 | |
5116 | // If the potential binop has more than one use, then it's possible that one |
5117 | // of those uses will need its full width. |
5118 | if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(RegNo: AndLHS)) |
5119 | return false; |
5120 | |
5121 | // Check if the LHS feeding the AND is impacted by the high bits that we're |
5122 | // masking out. |
5123 | // |
5124 | // e.g. for 64-bit x, y: |
5125 | // |
5126 | // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535 |
5127 | MachineInstr *LHSInst = getDefIgnoringCopies(Reg: AndLHS, MRI); |
5128 | if (!LHSInst) |
5129 | return false; |
5130 | unsigned LHSOpc = LHSInst->getOpcode(); |
5131 | switch (LHSOpc) { |
5132 | default: |
5133 | return false; |
5134 | case TargetOpcode::G_ADD: |
5135 | case TargetOpcode::G_SUB: |
5136 | case TargetOpcode::G_MUL: |
5137 | case TargetOpcode::G_AND: |
5138 | case TargetOpcode::G_OR: |
5139 | case TargetOpcode::G_XOR: |
5140 | break; |
5141 | } |
5142 | |
5143 | // Find the mask on the RHS. |
5144 | auto Cst = getIConstantVRegValWithLookThrough(VReg: AndRHS, MRI); |
5145 | if (!Cst) |
5146 | return false; |
5147 | auto Mask = Cst->Value; |
5148 | if (!Mask.isMask()) |
5149 | return false; |
5150 | |
5151 | // No point in combining if there's nothing to truncate. |
5152 | unsigned NarrowWidth = Mask.countr_one(); |
5153 | if (NarrowWidth == WideTy.getSizeInBits()) |
5154 | return false; |
5155 | LLT NarrowTy = LLT::scalar(SizeInBits: NarrowWidth); |
5156 | |
5157 | // Check if adding the zext + truncates could be harmful. |
5158 | auto &MF = *MI.getMF(); |
5159 | const auto &TLI = getTargetLowering(); |
5160 | LLVMContext &Ctx = MF.getFunction().getContext(); |
5161 | if (!TLI.isTruncateFree(FromTy: WideTy, ToTy: NarrowTy, Ctx) || |
5162 | !TLI.isZExtFree(FromTy: NarrowTy, ToTy: WideTy, Ctx)) |
5163 | return false; |
5164 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) || |
5165 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {WideTy, NarrowTy}})) |
5166 | return false; |
5167 | Register BinOpLHS = LHSInst->getOperand(i: 1).getReg(); |
5168 | Register BinOpRHS = LHSInst->getOperand(i: 2).getReg(); |
5169 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5170 | auto NarrowLHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpLHS); |
5171 | auto NarrowRHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpRHS); |
5172 | auto NarrowBinOp = |
5173 | Builder.buildInstr(Opc: LHSOpc, DstOps: {NarrowTy}, SrcOps: {NarrowLHS, NarrowRHS}); |
5174 | auto Ext = Builder.buildZExt(Res: WideTy, Op: NarrowBinOp); |
5175 | Observer.changingInstr(MI); |
5176 | MI.getOperand(i: 1).setReg(Ext.getReg(Idx: 0)); |
5177 | Observer.changedInstr(MI); |
5178 | }; |
5179 | return true; |
5180 | } |
5181 | |
5182 | bool CombinerHelper::matchMulOBy2(MachineInstr &MI, |
5183 | BuildFnTy &MatchInfo) const { |
5184 | unsigned Opc = MI.getOpcode(); |
5185 | assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO); |
5186 | |
5187 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 2))) |
5188 | return false; |
5189 | |
5190 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5191 | Observer.changingInstr(MI); |
5192 | unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO |
5193 | : TargetOpcode::G_SADDO; |
5194 | MI.setDesc(Builder.getTII().get(Opcode: NewOpc)); |
5195 | MI.getOperand(i: 3).setReg(MI.getOperand(i: 2).getReg()); |
5196 | Observer.changedInstr(MI); |
5197 | }; |
5198 | return true; |
5199 | } |
5200 | |
5201 | bool CombinerHelper::matchMulOBy0(MachineInstr &MI, |
5202 | BuildFnTy &MatchInfo) const { |
5203 | // (G_*MULO x, 0) -> 0 + no carry out |
5204 | assert(MI.getOpcode() == TargetOpcode::G_UMULO || |
5205 | MI.getOpcode() == TargetOpcode::G_SMULO); |
5206 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0))) |
5207 | return false; |
5208 | Register Dst = MI.getOperand(i: 0).getReg(); |
5209 | Register Carry = MI.getOperand(i: 1).getReg(); |
5210 | if (!isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Dst)) || |
5211 | !isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Carry))) |
5212 | return false; |
5213 | MatchInfo = [=](MachineIRBuilder &B) { |
5214 | B.buildConstant(Res: Dst, Val: 0); |
5215 | B.buildConstant(Res: Carry, Val: 0); |
5216 | }; |
5217 | return true; |
5218 | } |
5219 | |
5220 | bool CombinerHelper::matchAddEToAddO(MachineInstr &MI, |
5221 | BuildFnTy &MatchInfo) const { |
5222 | // (G_*ADDE x, y, 0) -> (G_*ADDO x, y) |
5223 | // (G_*SUBE x, y, 0) -> (G_*SUBO x, y) |
5224 | assert(MI.getOpcode() == TargetOpcode::G_UADDE || |
5225 | MI.getOpcode() == TargetOpcode::G_SADDE || |
5226 | MI.getOpcode() == TargetOpcode::G_USUBE || |
5227 | MI.getOpcode() == TargetOpcode::G_SSUBE); |
5228 | if (!mi_match(R: MI.getOperand(i: 4).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0))) |
5229 | return false; |
5230 | MatchInfo = [&](MachineIRBuilder &B) { |
5231 | unsigned NewOpcode; |
5232 | switch (MI.getOpcode()) { |
5233 | case TargetOpcode::G_UADDE: |
5234 | NewOpcode = TargetOpcode::G_UADDO; |
5235 | break; |
5236 | case TargetOpcode::G_SADDE: |
5237 | NewOpcode = TargetOpcode::G_SADDO; |
5238 | break; |
5239 | case TargetOpcode::G_USUBE: |
5240 | NewOpcode = TargetOpcode::G_USUBO; |
5241 | break; |
5242 | case TargetOpcode::G_SSUBE: |
5243 | NewOpcode = TargetOpcode::G_SSUBO; |
5244 | break; |
5245 | } |
5246 | Observer.changingInstr(MI); |
5247 | MI.setDesc(B.getTII().get(Opcode: NewOpcode)); |
5248 | MI.removeOperand(OpNo: 4); |
5249 | Observer.changedInstr(MI); |
5250 | }; |
5251 | return true; |
5252 | } |
5253 | |
5254 | bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI, |
5255 | BuildFnTy &MatchInfo) const { |
5256 | assert(MI.getOpcode() == TargetOpcode::G_SUB); |
5257 | Register Dst = MI.getOperand(i: 0).getReg(); |
5258 | // (x + y) - z -> x (if y == z) |
5259 | // (x + y) - z -> y (if x == z) |
5260 | Register X, Y, Z; |
5261 | if (mi_match(R: Dst, MRI, P: m_GSub(L: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: Y)), R: m_Reg(R&: Z)))) { |
5262 | Register ReplaceReg; |
5263 | int64_t CstX, CstY; |
5264 | if (Y == Z || (mi_match(R: Y, MRI, P: m_ICstOrSplat(Cst&: CstY)) && |
5265 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstY)))) |
5266 | ReplaceReg = X; |
5267 | else if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
5268 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
5269 | ReplaceReg = Y; |
5270 | if (ReplaceReg) { |
5271 | MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: ReplaceReg); }; |
5272 | return true; |
5273 | } |
5274 | } |
5275 | |
5276 | // x - (y + z) -> 0 - y (if x == z) |
5277 | // x - (y + z) -> 0 - z (if x == y) |
5278 | if (mi_match(R: Dst, MRI, P: m_GSub(L: m_Reg(R&: X), R: m_GAdd(L: m_Reg(R&: Y), R: m_Reg(R&: Z))))) { |
5279 | Register ReplaceReg; |
5280 | int64_t CstX; |
5281 | if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
5282 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
5283 | ReplaceReg = Y; |
5284 | else if (X == Y || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
5285 | mi_match(R: Y, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
5286 | ReplaceReg = Z; |
5287 | if (ReplaceReg) { |
5288 | MatchInfo = [=](MachineIRBuilder &B) { |
5289 | auto Zero = B.buildConstant(Res: MRI.getType(Reg: Dst), Val: 0); |
5290 | B.buildSub(Dst, Src0: Zero, Src1: ReplaceReg); |
5291 | }; |
5292 | return true; |
5293 | } |
5294 | } |
5295 | return false; |
5296 | } |
5297 | |
5298 | MachineInstr *CombinerHelper::buildUDivorURemUsingMul(MachineInstr &MI) const { |
5299 | unsigned Opcode = MI.getOpcode(); |
5300 | assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM); |
5301 | auto &UDivorRem = cast<GenericMachineInstr>(Val&: MI); |
5302 | Register Dst = UDivorRem.getReg(Idx: 0); |
5303 | Register LHS = UDivorRem.getReg(Idx: 1); |
5304 | Register RHS = UDivorRem.getReg(Idx: 2); |
5305 | LLT Ty = MRI.getType(Reg: Dst); |
5306 | LLT ScalarTy = Ty.getScalarType(); |
5307 | const unsigned EltBits = ScalarTy.getScalarSizeInBits(); |
5308 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5309 | LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); |
5310 | |
5311 | auto &MIB = Builder; |
5312 | |
5313 | bool UseSRL = false; |
5314 | SmallVector<Register, 16> Shifts, Factors; |
5315 | auto *RHSDefInstr = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI)); |
5316 | bool IsSplat = getIConstantSplatVal(MI: *RHSDefInstr, MRI).has_value(); |
5317 | |
5318 | auto BuildExactUDIVPattern = [&](const Constant *C) { |
5319 | // Don't recompute inverses for each splat element. |
5320 | if (IsSplat && !Factors.empty()) { |
5321 | Shifts.push_back(Elt: Shifts[0]); |
5322 | Factors.push_back(Elt: Factors[0]); |
5323 | return true; |
5324 | } |
5325 | |
5326 | auto *CI = cast<ConstantInt>(Val: C); |
5327 | APInt Divisor = CI->getValue(); |
5328 | unsigned Shift = Divisor.countr_zero(); |
5329 | if (Shift) { |
5330 | Divisor.lshrInPlace(ShiftAmt: Shift); |
5331 | UseSRL = true; |
5332 | } |
5333 | |
5334 | // Calculate the multiplicative inverse modulo BW. |
5335 | APInt Factor = Divisor.multiplicativeInverse(); |
5336 | Shifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0)); |
5337 | Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0)); |
5338 | return true; |
5339 | }; |
5340 | |
5341 | if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) { |
5342 | // Collect all magic values from the build vector. |
5343 | if (!matchUnaryPredicate(MRI, Reg: RHS, Match: BuildExactUDIVPattern)) |
5344 | llvm_unreachable("Expected unary predicate match to succeed" ); |
5345 | |
5346 | Register Shift, Factor; |
5347 | if (Ty.isVector()) { |
5348 | Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0); |
5349 | Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0); |
5350 | } else { |
5351 | Shift = Shifts[0]; |
5352 | Factor = Factors[0]; |
5353 | } |
5354 | |
5355 | Register Res = LHS; |
5356 | |
5357 | if (UseSRL) |
5358 | Res = MIB.buildLShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0); |
5359 | |
5360 | return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor); |
5361 | } |
5362 | |
5363 | unsigned KnownLeadingZeros = |
5364 | VT ? VT->getKnownBits(R: LHS).countMinLeadingZeros() : 0; |
5365 | |
5366 | bool UseNPQ = false; |
5367 | SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors; |
5368 | auto BuildUDIVPattern = [&](const Constant *C) { |
5369 | auto *CI = cast<ConstantInt>(Val: C); |
5370 | const APInt &Divisor = CI->getValue(); |
5371 | |
5372 | bool SelNPQ = false; |
5373 | APInt Magic(Divisor.getBitWidth(), 0); |
5374 | unsigned PreShift = 0, PostShift = 0; |
5375 | |
5376 | // Magic algorithm doesn't work for division by 1. We need to emit a select |
5377 | // at the end. |
5378 | // TODO: Use undef values for divisor of 1. |
5379 | if (!Divisor.isOne()) { |
5380 | |
5381 | // UnsignedDivisionByConstantInfo doesn't work correctly if leading zeros |
5382 | // in the dividend exceeds the leading zeros for the divisor. |
5383 | UnsignedDivisionByConstantInfo magics = |
5384 | UnsignedDivisionByConstantInfo::get( |
5385 | D: Divisor, LeadingZeros: std::min(a: KnownLeadingZeros, b: Divisor.countl_zero())); |
5386 | |
5387 | Magic = std::move(magics.Magic); |
5388 | |
5389 | assert(magics.PreShift < Divisor.getBitWidth() && |
5390 | "We shouldn't generate an undefined shift!" ); |
5391 | assert(magics.PostShift < Divisor.getBitWidth() && |
5392 | "We shouldn't generate an undefined shift!" ); |
5393 | assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift" ); |
5394 | PreShift = magics.PreShift; |
5395 | PostShift = magics.PostShift; |
5396 | SelNPQ = magics.IsAdd; |
5397 | } |
5398 | |
5399 | PreShifts.push_back( |
5400 | Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PreShift).getReg(Idx: 0)); |
5401 | MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magic).getReg(Idx: 0)); |
5402 | NPQFactors.push_back( |
5403 | Elt: MIB.buildConstant(Res: ScalarTy, |
5404 | Val: SelNPQ ? APInt::getOneBitSet(numBits: EltBits, BitNo: EltBits - 1) |
5405 | : APInt::getZero(numBits: EltBits)) |
5406 | .getReg(Idx: 0)); |
5407 | PostShifts.push_back( |
5408 | Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PostShift).getReg(Idx: 0)); |
5409 | UseNPQ |= SelNPQ; |
5410 | return true; |
5411 | }; |
5412 | |
5413 | // Collect the shifts/magic values from each element. |
5414 | bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildUDIVPattern); |
5415 | (void)Matched; |
5416 | assert(Matched && "Expected unary predicate match to succeed" ); |
5417 | |
5418 | Register PreShift, PostShift, MagicFactor, NPQFactor; |
5419 | auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI); |
5420 | if (RHSDef) { |
5421 | PreShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PreShifts).getReg(Idx: 0); |
5422 | MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0); |
5423 | NPQFactor = MIB.buildBuildVector(Res: Ty, Ops: NPQFactors).getReg(Idx: 0); |
5424 | PostShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PostShifts).getReg(Idx: 0); |
5425 | } else { |
5426 | assert(MRI.getType(RHS).isScalar() && |
5427 | "Non-build_vector operation should have been a scalar" ); |
5428 | PreShift = PreShifts[0]; |
5429 | MagicFactor = MagicFactors[0]; |
5430 | PostShift = PostShifts[0]; |
5431 | } |
5432 | |
5433 | Register Q = LHS; |
5434 | Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PreShift).getReg(Idx: 0); |
5435 | |
5436 | // Multiply the numerator (operand 0) by the magic value. |
5437 | Q = MIB.buildUMulH(Dst: Ty, Src0: Q, Src1: MagicFactor).getReg(Idx: 0); |
5438 | |
5439 | if (UseNPQ) { |
5440 | Register NPQ = MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Q).getReg(Idx: 0); |
5441 | |
5442 | // For vectors we might have a mix of non-NPQ/NPQ paths, so use |
5443 | // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero. |
5444 | if (Ty.isVector()) |
5445 | NPQ = MIB.buildUMulH(Dst: Ty, Src0: NPQ, Src1: NPQFactor).getReg(Idx: 0); |
5446 | else |
5447 | NPQ = MIB.buildLShr(Dst: Ty, Src0: NPQ, Src1: MIB.buildConstant(Res: ShiftAmtTy, Val: 1)).getReg(Idx: 0); |
5448 | |
5449 | Q = MIB.buildAdd(Dst: Ty, Src0: NPQ, Src1: Q).getReg(Idx: 0); |
5450 | } |
5451 | |
5452 | Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PostShift).getReg(Idx: 0); |
5453 | auto One = MIB.buildConstant(Res: Ty, Val: 1); |
5454 | auto IsOne = MIB.buildICmp( |
5455 | Pred: CmpInst::Predicate::ICMP_EQ, |
5456 | Res: Ty.isScalar() ? LLT::scalar(SizeInBits: 1) : Ty.changeElementSize(NewEltSize: 1), Op0: RHS, Op1: One); |
5457 | auto ret = MIB.buildSelect(Res: Ty, Tst: IsOne, Op0: LHS, Op1: Q); |
5458 | |
5459 | if (Opcode == TargetOpcode::G_UREM) { |
5460 | auto Prod = MIB.buildMul(Dst: Ty, Src0: ret, Src1: RHS); |
5461 | return MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Prod); |
5462 | } |
5463 | return ret; |
5464 | } |
5465 | |
5466 | bool CombinerHelper::matchUDivorURemByConst(MachineInstr &MI) const { |
5467 | unsigned Opcode = MI.getOpcode(); |
5468 | assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM); |
5469 | Register Dst = MI.getOperand(i: 0).getReg(); |
5470 | Register RHS = MI.getOperand(i: 2).getReg(); |
5471 | LLT DstTy = MRI.getType(Reg: Dst); |
5472 | |
5473 | auto &MF = *MI.getMF(); |
5474 | AttributeList Attr = MF.getFunction().getAttributes(); |
5475 | const auto &TLI = getTargetLowering(); |
5476 | LLVMContext &Ctx = MF.getFunction().getContext(); |
5477 | if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, Ctx), Attr)) |
5478 | return false; |
5479 | |
5480 | // Don't do this for minsize because the instruction sequence is usually |
5481 | // larger. |
5482 | if (MF.getFunction().hasMinSize()) |
5483 | return false; |
5484 | |
5485 | if (Opcode == TargetOpcode::G_UDIV && |
5486 | MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) { |
5487 | return matchUnaryPredicate( |
5488 | MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); }); |
5489 | } |
5490 | |
5491 | auto *RHSDef = MRI.getVRegDef(Reg: RHS); |
5492 | if (!isConstantOrConstantVector(MI&: *RHSDef, MRI)) |
5493 | return false; |
5494 | |
5495 | // Don't do this if the types are not going to be legal. |
5496 | if (LI) { |
5497 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}})) |
5498 | return false; |
5499 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMULH, {DstTy}})) |
5500 | return false; |
5501 | if (!isLegalOrBeforeLegalizer( |
5502 | Query: {TargetOpcode::G_ICMP, |
5503 | {DstTy.isVector() ? DstTy.changeElementSize(NewEltSize: 1) : LLT::scalar(SizeInBits: 1), |
5504 | DstTy}})) |
5505 | return false; |
5506 | if (Opcode == TargetOpcode::G_UREM && |
5507 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy, DstTy}})) |
5508 | return false; |
5509 | } |
5510 | |
5511 | return matchUnaryPredicate( |
5512 | MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); }); |
5513 | } |
5514 | |
5515 | void CombinerHelper::applyUDivorURemByConst(MachineInstr &MI) const { |
5516 | auto *NewMI = buildUDivorURemUsingMul(MI); |
5517 | replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg()); |
5518 | } |
5519 | |
5520 | bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const { |
5521 | assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV" ); |
5522 | Register Dst = MI.getOperand(i: 0).getReg(); |
5523 | Register RHS = MI.getOperand(i: 2).getReg(); |
5524 | LLT DstTy = MRI.getType(Reg: Dst); |
5525 | |
5526 | auto &MF = *MI.getMF(); |
5527 | AttributeList Attr = MF.getFunction().getAttributes(); |
5528 | const auto &TLI = getTargetLowering(); |
5529 | LLVMContext &Ctx = MF.getFunction().getContext(); |
5530 | if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, Ctx), Attr)) |
5531 | return false; |
5532 | |
5533 | // Don't do this for minsize because the instruction sequence is usually |
5534 | // larger. |
5535 | if (MF.getFunction().hasMinSize()) |
5536 | return false; |
5537 | |
5538 | // If the sdiv has an 'exact' flag we can use a simpler lowering. |
5539 | if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) { |
5540 | return matchUnaryPredicate( |
5541 | MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); }); |
5542 | } |
5543 | |
5544 | // Don't support the general case for now. |
5545 | return false; |
5546 | } |
5547 | |
5548 | void CombinerHelper::applySDivByConst(MachineInstr &MI) const { |
5549 | auto *NewMI = buildSDivUsingMul(MI); |
5550 | replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg()); |
5551 | } |
5552 | |
5553 | MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const { |
5554 | assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV" ); |
5555 | auto &SDiv = cast<GenericMachineInstr>(Val&: MI); |
5556 | Register Dst = SDiv.getReg(Idx: 0); |
5557 | Register LHS = SDiv.getReg(Idx: 1); |
5558 | Register RHS = SDiv.getReg(Idx: 2); |
5559 | LLT Ty = MRI.getType(Reg: Dst); |
5560 | LLT ScalarTy = Ty.getScalarType(); |
5561 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5562 | LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); |
5563 | auto &MIB = Builder; |
5564 | |
5565 | bool UseSRA = false; |
5566 | SmallVector<Register, 16> Shifts, Factors; |
5567 | |
5568 | auto *RHSDef = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI)); |
5569 | bool IsSplat = getIConstantSplatVal(MI: *RHSDef, MRI).has_value(); |
5570 | |
5571 | auto BuildSDIVPattern = [&](const Constant *C) { |
5572 | // Don't recompute inverses for each splat element. |
5573 | if (IsSplat && !Factors.empty()) { |
5574 | Shifts.push_back(Elt: Shifts[0]); |
5575 | Factors.push_back(Elt: Factors[0]); |
5576 | return true; |
5577 | } |
5578 | |
5579 | auto *CI = cast<ConstantInt>(Val: C); |
5580 | APInt Divisor = CI->getValue(); |
5581 | unsigned Shift = Divisor.countr_zero(); |
5582 | if (Shift) { |
5583 | Divisor.ashrInPlace(ShiftAmt: Shift); |
5584 | UseSRA = true; |
5585 | } |
5586 | |
5587 | // Calculate the multiplicative inverse modulo BW. |
5588 | // 2^W requires W + 1 bits, so we have to extend and then truncate. |
5589 | APInt Factor = Divisor.multiplicativeInverse(); |
5590 | Shifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0)); |
5591 | Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0)); |
5592 | return true; |
5593 | }; |
5594 | |
5595 | // Collect all magic values from the build vector. |
5596 | bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildSDIVPattern); |
5597 | (void)Matched; |
5598 | assert(Matched && "Expected unary predicate match to succeed" ); |
5599 | |
5600 | Register Shift, Factor; |
5601 | if (Ty.isVector()) { |
5602 | Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0); |
5603 | Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0); |
5604 | } else { |
5605 | Shift = Shifts[0]; |
5606 | Factor = Factors[0]; |
5607 | } |
5608 | |
5609 | Register Res = LHS; |
5610 | |
5611 | if (UseSRA) |
5612 | Res = MIB.buildAShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0); |
5613 | |
5614 | return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor); |
5615 | } |
5616 | |
5617 | bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const { |
5618 | assert((MI.getOpcode() == TargetOpcode::G_SDIV || |
5619 | MI.getOpcode() == TargetOpcode::G_UDIV) && |
5620 | "Expected SDIV or UDIV" ); |
5621 | auto &Div = cast<GenericMachineInstr>(Val&: MI); |
5622 | Register RHS = Div.getReg(Idx: 2); |
5623 | auto MatchPow2 = [&](const Constant *C) { |
5624 | auto *CI = dyn_cast<ConstantInt>(Val: C); |
5625 | return CI && (CI->getValue().isPowerOf2() || |
5626 | (IsSigned && CI->getValue().isNegatedPowerOf2())); |
5627 | }; |
5628 | return matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2, /*AllowUndefs=*/false); |
5629 | } |
5630 | |
5631 | void CombinerHelper::applySDivByPow2(MachineInstr &MI) const { |
5632 | assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV" ); |
5633 | auto &SDiv = cast<GenericMachineInstr>(Val&: MI); |
5634 | Register Dst = SDiv.getReg(Idx: 0); |
5635 | Register LHS = SDiv.getReg(Idx: 1); |
5636 | Register RHS = SDiv.getReg(Idx: 2); |
5637 | LLT Ty = MRI.getType(Reg: Dst); |
5638 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5639 | LLT CCVT = |
5640 | Ty.isVector() ? LLT::vector(EC: Ty.getElementCount(), ScalarSizeInBits: 1) : LLT::scalar(SizeInBits: 1); |
5641 | |
5642 | // Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2, |
5643 | // to the following version: |
5644 | // |
5645 | // %c1 = G_CTTZ %rhs |
5646 | // %inexact = G_SUB $bitwidth, %c1 |
5647 | // %sign = %G_ASHR %lhs, $(bitwidth - 1) |
5648 | // %lshr = G_LSHR %sign, %inexact |
5649 | // %add = G_ADD %lhs, %lshr |
5650 | // %ashr = G_ASHR %add, %c1 |
5651 | // %ashr = G_SELECT, %isoneorallones, %lhs, %ashr |
5652 | // %zero = G_CONSTANT $0 |
5653 | // %neg = G_NEG %ashr |
5654 | // %isneg = G_ICMP SLT %rhs, %zero |
5655 | // %res = G_SELECT %isneg, %neg, %ashr |
5656 | |
5657 | unsigned BitWidth = Ty.getScalarSizeInBits(); |
5658 | auto Zero = Builder.buildConstant(Res: Ty, Val: 0); |
5659 | |
5660 | auto Bits = Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth); |
5661 | auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS); |
5662 | auto Inexact = Builder.buildSub(Dst: ShiftAmtTy, Src0: Bits, Src1: C1); |
5663 | // Splat the sign bit into the register |
5664 | auto Sign = Builder.buildAShr( |
5665 | Dst: Ty, Src0: LHS, Src1: Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth - 1)); |
5666 | |
5667 | // Add (LHS < 0) ? abs2 - 1 : 0; |
5668 | auto LSrl = Builder.buildLShr(Dst: Ty, Src0: Sign, Src1: Inexact); |
5669 | auto Add = Builder.buildAdd(Dst: Ty, Src0: LHS, Src1: LSrl); |
5670 | auto AShr = Builder.buildAShr(Dst: Ty, Src0: Add, Src1: C1); |
5671 | |
5672 | // Special case: (sdiv X, 1) -> X |
5673 | // Special Case: (sdiv X, -1) -> 0-X |
5674 | auto One = Builder.buildConstant(Res: Ty, Val: 1); |
5675 | auto MinusOne = Builder.buildConstant(Res: Ty, Val: -1); |
5676 | auto IsOne = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: One); |
5677 | auto IsMinusOne = |
5678 | Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: MinusOne); |
5679 | auto IsOneOrMinusOne = Builder.buildOr(Dst: CCVT, Src0: IsOne, Src1: IsMinusOne); |
5680 | AShr = Builder.buildSelect(Res: Ty, Tst: IsOneOrMinusOne, Op0: LHS, Op1: AShr); |
5681 | |
5682 | // If divided by a positive value, we're done. Otherwise, the result must be |
5683 | // negated. |
5684 | auto Neg = Builder.buildNeg(Dst: Ty, Src0: AShr); |
5685 | auto IsNeg = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_SLT, Res: CCVT, Op0: RHS, Op1: Zero); |
5686 | Builder.buildSelect(Res: MI.getOperand(i: 0).getReg(), Tst: IsNeg, Op0: Neg, Op1: AShr); |
5687 | MI.eraseFromParent(); |
5688 | } |
5689 | |
5690 | void CombinerHelper::applyUDivByPow2(MachineInstr &MI) const { |
5691 | assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV" ); |
5692 | auto &UDiv = cast<GenericMachineInstr>(Val&: MI); |
5693 | Register Dst = UDiv.getReg(Idx: 0); |
5694 | Register LHS = UDiv.getReg(Idx: 1); |
5695 | Register RHS = UDiv.getReg(Idx: 2); |
5696 | LLT Ty = MRI.getType(Reg: Dst); |
5697 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5698 | |
5699 | auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS); |
5700 | Builder.buildLShr(Dst: MI.getOperand(i: 0).getReg(), Src0: LHS, Src1: C1); |
5701 | MI.eraseFromParent(); |
5702 | } |
5703 | |
5704 | bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) const { |
5705 | assert(MI.getOpcode() == TargetOpcode::G_UMULH); |
5706 | Register RHS = MI.getOperand(i: 2).getReg(); |
5707 | Register Dst = MI.getOperand(i: 0).getReg(); |
5708 | LLT Ty = MRI.getType(Reg: Dst); |
5709 | LLT RHSTy = MRI.getType(Reg: RHS); |
5710 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5711 | auto MatchPow2ExceptOne = [&](const Constant *C) { |
5712 | if (auto *CI = dyn_cast<ConstantInt>(Val: C)) |
5713 | return CI->getValue().isPowerOf2() && !CI->getValue().isOne(); |
5714 | return false; |
5715 | }; |
5716 | if (!matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2ExceptOne, AllowUndefs: false)) |
5717 | return false; |
5718 | // We need to check both G_LSHR and G_CTLZ because the combine uses G_CTLZ to |
5719 | // get log base 2, and it is not always legal for on a target. |
5720 | return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}}) && |
5721 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CTLZ, {RHSTy, RHSTy}}); |
5722 | } |
5723 | |
5724 | void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) const { |
5725 | Register LHS = MI.getOperand(i: 1).getReg(); |
5726 | Register RHS = MI.getOperand(i: 2).getReg(); |
5727 | Register Dst = MI.getOperand(i: 0).getReg(); |
5728 | LLT Ty = MRI.getType(Reg: Dst); |
5729 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5730 | unsigned NumEltBits = Ty.getScalarSizeInBits(); |
5731 | |
5732 | auto LogBase2 = buildLogBase2(V: RHS, MIB&: Builder); |
5733 | auto ShiftAmt = |
5734 | Builder.buildSub(Dst: Ty, Src0: Builder.buildConstant(Res: Ty, Val: NumEltBits), Src1: LogBase2); |
5735 | auto Trunc = Builder.buildZExtOrTrunc(Res: ShiftAmtTy, Op: ShiftAmt); |
5736 | Builder.buildLShr(Dst, Src0: LHS, Src1: Trunc); |
5737 | MI.eraseFromParent(); |
5738 | } |
5739 | |
5740 | bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI, |
5741 | BuildFnTy &MatchInfo) const { |
5742 | unsigned Opc = MI.getOpcode(); |
5743 | assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB || |
5744 | Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || |
5745 | Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA); |
5746 | |
5747 | Register Dst = MI.getOperand(i: 0).getReg(); |
5748 | Register X = MI.getOperand(i: 1).getReg(); |
5749 | Register Y = MI.getOperand(i: 2).getReg(); |
5750 | LLT Type = MRI.getType(Reg: Dst); |
5751 | |
5752 | // fold (fadd x, fneg(y)) -> (fsub x, y) |
5753 | // fold (fadd fneg(y), x) -> (fsub x, y) |
5754 | // G_ADD is commutative so both cases are checked by m_GFAdd |
5755 | if (mi_match(R: Dst, MRI, P: m_GFAdd(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) && |
5756 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FSUB, {Type}})) { |
5757 | Opc = TargetOpcode::G_FSUB; |
5758 | } |
5759 | /// fold (fsub x, fneg(y)) -> (fadd x, y) |
5760 | else if (mi_match(R: Dst, MRI, P: m_GFSub(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) && |
5761 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FADD, {Type}})) { |
5762 | Opc = TargetOpcode::G_FADD; |
5763 | } |
5764 | // fold (fmul fneg(x), fneg(y)) -> (fmul x, y) |
5765 | // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y) |
5766 | // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z) |
5767 | // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z) |
5768 | else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || |
5769 | Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) && |
5770 | mi_match(R: X, MRI, P: m_GFNeg(Src: m_Reg(R&: X))) && |
5771 | mi_match(R: Y, MRI, P: m_GFNeg(Src: m_Reg(R&: Y)))) { |
5772 | // no opcode change |
5773 | } else |
5774 | return false; |
5775 | |
5776 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5777 | Observer.changingInstr(MI); |
5778 | MI.setDesc(B.getTII().get(Opcode: Opc)); |
5779 | MI.getOperand(i: 1).setReg(X); |
5780 | MI.getOperand(i: 2).setReg(Y); |
5781 | Observer.changedInstr(MI); |
5782 | }; |
5783 | return true; |
5784 | } |
5785 | |
5786 | bool CombinerHelper::matchFsubToFneg(MachineInstr &MI, |
5787 | Register &MatchInfo) const { |
5788 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
5789 | |
5790 | Register LHS = MI.getOperand(i: 1).getReg(); |
5791 | MatchInfo = MI.getOperand(i: 2).getReg(); |
5792 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5793 | |
5794 | const auto LHSCst = Ty.isVector() |
5795 | ? getFConstantSplat(VReg: LHS, MRI, /* allowUndef */ AllowUndef: true) |
5796 | : getFConstantVRegValWithLookThrough(VReg: LHS, MRI); |
5797 | if (!LHSCst) |
5798 | return false; |
5799 | |
5800 | // -0.0 is always allowed |
5801 | if (LHSCst->Value.isNegZero()) |
5802 | return true; |
5803 | |
5804 | // +0.0 is only allowed if nsz is set. |
5805 | if (LHSCst->Value.isPosZero()) |
5806 | return MI.getFlag(Flag: MachineInstr::FmNsz); |
5807 | |
5808 | return false; |
5809 | } |
5810 | |
5811 | void CombinerHelper::applyFsubToFneg(MachineInstr &MI, |
5812 | Register &MatchInfo) const { |
5813 | Register Dst = MI.getOperand(i: 0).getReg(); |
5814 | Builder.buildFNeg( |
5815 | Dst, Src0: Builder.buildFCanonicalize(Dst: MRI.getType(Reg: Dst), Src0: MatchInfo).getReg(Idx: 0)); |
5816 | eraseInst(MI); |
5817 | } |
5818 | |
5819 | /// Checks if \p MI is TargetOpcode::G_FMUL and contractable either |
5820 | /// due to global flags or MachineInstr flags. |
5821 | static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) { |
5822 | if (MI.getOpcode() != TargetOpcode::G_FMUL) |
5823 | return false; |
5824 | return AllowFusionGlobally || MI.getFlag(Flag: MachineInstr::MIFlag::FmContract); |
5825 | } |
5826 | |
5827 | static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1, |
5828 | const MachineRegisterInfo &MRI) { |
5829 | return std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI0.getOperand(i: 0).getReg()), |
5830 | last: MRI.use_instr_nodbg_end()) > |
5831 | std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI1.getOperand(i: 0).getReg()), |
5832 | last: MRI.use_instr_nodbg_end()); |
5833 | } |
5834 | |
5835 | bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI, |
5836 | bool &AllowFusionGlobally, |
5837 | bool &HasFMAD, bool &Aggressive, |
5838 | bool CanReassociate) const { |
5839 | |
5840 | auto *MF = MI.getMF(); |
5841 | const auto &TLI = *MF->getSubtarget().getTargetLowering(); |
5842 | const TargetOptions &Options = MF->getTarget().Options; |
5843 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5844 | |
5845 | if (CanReassociate && |
5846 | !(Options.UnsafeFPMath || MI.getFlag(Flag: MachineInstr::MIFlag::FmReassoc))) |
5847 | return false; |
5848 | |
5849 | // Floating-point multiply-add with intermediate rounding. |
5850 | HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, Ty: DstType)); |
5851 | // Floating-point multiply-add without intermediate rounding. |
5852 | bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(MF: *MF, DstType) && |
5853 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FMA, {DstType}}); |
5854 | // No valid opcode, do not combine. |
5855 | if (!HasFMAD && !HasFMA) |
5856 | return false; |
5857 | |
5858 | AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast || |
5859 | Options.UnsafeFPMath || HasFMAD; |
5860 | // If the addition is not contractable, do not combine. |
5861 | if (!AllowFusionGlobally && !MI.getFlag(Flag: MachineInstr::MIFlag::FmContract)) |
5862 | return false; |
5863 | |
5864 | Aggressive = TLI.enableAggressiveFMAFusion(Ty: DstType); |
5865 | return true; |
5866 | } |
5867 | |
5868 | bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA( |
5869 | MachineInstr &MI, |
5870 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
5871 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5872 | |
5873 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5874 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5875 | return false; |
5876 | |
5877 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5878 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5879 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5880 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5881 | unsigned PreferredFusedOpcode = |
5882 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5883 | |
5884 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5885 | // prefer to fold the multiply with fewer uses. |
5886 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5887 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5888 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5889 | std::swap(a&: LHS, b&: RHS); |
5890 | } |
5891 | |
5892 | // fold (fadd (fmul x, y), z) -> (fma x, y, z) |
5893 | if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5894 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg))) { |
5895 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5896 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5897 | SrcOps: {LHS.MI->getOperand(i: 1).getReg(), |
5898 | LHS.MI->getOperand(i: 2).getReg(), RHS.Reg}); |
5899 | }; |
5900 | return true; |
5901 | } |
5902 | |
5903 | // fold (fadd x, (fmul y, z)) -> (fma y, z, x) |
5904 | if (isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
5905 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg))) { |
5906 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5907 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5908 | SrcOps: {RHS.MI->getOperand(i: 1).getReg(), |
5909 | RHS.MI->getOperand(i: 2).getReg(), LHS.Reg}); |
5910 | }; |
5911 | return true; |
5912 | } |
5913 | |
5914 | return false; |
5915 | } |
5916 | |
5917 | bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA( |
5918 | MachineInstr &MI, |
5919 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
5920 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5921 | |
5922 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5923 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5924 | return false; |
5925 | |
5926 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
5927 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5928 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5929 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5930 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5931 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5932 | |
5933 | unsigned PreferredFusedOpcode = |
5934 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5935 | |
5936 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5937 | // prefer to fold the multiply with fewer uses. |
5938 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5939 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5940 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5941 | std::swap(a&: LHS, b&: RHS); |
5942 | } |
5943 | |
5944 | // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) |
5945 | MachineInstr *FpExtSrc; |
5946 | if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) && |
5947 | isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) && |
5948 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5949 | SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) { |
5950 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5951 | auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg()); |
5952 | auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg()); |
5953 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5954 | SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), RHS.Reg}); |
5955 | }; |
5956 | return true; |
5957 | } |
5958 | |
5959 | // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z) |
5960 | // Note: Commutes FADD operands. |
5961 | if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) && |
5962 | isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) && |
5963 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5964 | SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) { |
5965 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5966 | auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg()); |
5967 | auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg()); |
5968 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5969 | SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), LHS.Reg}); |
5970 | }; |
5971 | return true; |
5972 | } |
5973 | |
5974 | return false; |
5975 | } |
5976 | |
5977 | bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA( |
5978 | MachineInstr &MI, |
5979 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
5980 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5981 | |
5982 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5983 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, CanReassociate: true)) |
5984 | return false; |
5985 | |
5986 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5987 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5988 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5989 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5990 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5991 | |
5992 | unsigned PreferredFusedOpcode = |
5993 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5994 | |
5995 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5996 | // prefer to fold the multiply with fewer uses. |
5997 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5998 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5999 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
6000 | std::swap(a&: LHS, b&: RHS); |
6001 | } |
6002 | |
6003 | MachineInstr *FMA = nullptr; |
6004 | Register Z; |
6005 | // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z)) |
6006 | if (LHS.MI->getOpcode() == PreferredFusedOpcode && |
6007 | (MRI.getVRegDef(Reg: LHS.MI->getOperand(i: 3).getReg())->getOpcode() == |
6008 | TargetOpcode::G_FMUL) && |
6009 | MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 0).getReg()) && |
6010 | MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 3).getReg())) { |
6011 | FMA = LHS.MI; |
6012 | Z = RHS.Reg; |
6013 | } |
6014 | // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z)) |
6015 | else if (RHS.MI->getOpcode() == PreferredFusedOpcode && |
6016 | (MRI.getVRegDef(Reg: RHS.MI->getOperand(i: 3).getReg())->getOpcode() == |
6017 | TargetOpcode::G_FMUL) && |
6018 | MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 0).getReg()) && |
6019 | MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 3).getReg())) { |
6020 | Z = LHS.Reg; |
6021 | FMA = RHS.MI; |
6022 | } |
6023 | |
6024 | if (FMA) { |
6025 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMA->getOperand(i: 3).getReg()); |
6026 | Register X = FMA->getOperand(i: 1).getReg(); |
6027 | Register Y = FMA->getOperand(i: 2).getReg(); |
6028 | Register U = FMulMI->getOperand(i: 1).getReg(); |
6029 | Register V = FMulMI->getOperand(i: 2).getReg(); |
6030 | |
6031 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6032 | Register InnerFMA = MRI.createGenericVirtualRegister(Ty: DstTy); |
6033 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {InnerFMA}, SrcOps: {U, V, Z}); |
6034 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6035 | SrcOps: {X, Y, InnerFMA}); |
6036 | }; |
6037 | return true; |
6038 | } |
6039 | |
6040 | return false; |
6041 | } |
6042 | |
6043 | bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive( |
6044 | MachineInstr &MI, |
6045 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
6046 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
6047 | |
6048 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
6049 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
6050 | return false; |
6051 | |
6052 | if (!Aggressive) |
6053 | return false; |
6054 | |
6055 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
6056 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6057 | Register Op1 = MI.getOperand(i: 1).getReg(); |
6058 | Register Op2 = MI.getOperand(i: 2).getReg(); |
6059 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
6060 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
6061 | |
6062 | unsigned PreferredFusedOpcode = |
6063 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
6064 | |
6065 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
6066 | // prefer to fold the multiply with fewer uses. |
6067 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
6068 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
6069 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
6070 | std::swap(a&: LHS, b&: RHS); |
6071 | } |
6072 | |
6073 | // Builds: (fma x, y, (fma (fpext u), (fpext v), z)) |
6074 | auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X, |
6075 | Register Y, MachineIRBuilder &B) { |
6076 | Register FpExtU = B.buildFPExt(Res: DstType, Op: U).getReg(Idx: 0); |
6077 | Register FpExtV = B.buildFPExt(Res: DstType, Op: V).getReg(Idx: 0); |
6078 | Register InnerFMA = |
6079 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {DstType}, SrcOps: {FpExtU, FpExtV, Z}) |
6080 | .getReg(Idx: 0); |
6081 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6082 | SrcOps: {X, Y, InnerFMA}); |
6083 | }; |
6084 | |
6085 | MachineInstr *FMulMI, *FMAMI; |
6086 | // fold (fadd (fma x, y, (fpext (fmul u, v))), z) |
6087 | // -> (fma x, y, (fma (fpext u), (fpext v), z)) |
6088 | if (LHS.MI->getOpcode() == PreferredFusedOpcode && |
6089 | mi_match(R: LHS.MI->getOperand(i: 3).getReg(), MRI, |
6090 | P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
6091 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6092 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
6093 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
6094 | MatchInfo = [=](MachineIRBuilder &B) { |
6095 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
6096 | FMulMI->getOperand(i: 2).getReg(), RHS.Reg, |
6097 | LHS.MI->getOperand(i: 1).getReg(), |
6098 | LHS.MI->getOperand(i: 2).getReg(), B); |
6099 | }; |
6100 | return true; |
6101 | } |
6102 | |
6103 | // fold (fadd (fpext (fma x, y, (fmul u, v))), z) |
6104 | // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) |
6105 | // FIXME: This turns two single-precision and one double-precision |
6106 | // operation into two double-precision operations, which might not be |
6107 | // interesting for all targets, especially GPUs. |
6108 | if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) && |
6109 | FMAMI->getOpcode() == PreferredFusedOpcode) { |
6110 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg()); |
6111 | if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6112 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
6113 | SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) { |
6114 | MatchInfo = [=](MachineIRBuilder &B) { |
6115 | Register X = FMAMI->getOperand(i: 1).getReg(); |
6116 | Register Y = FMAMI->getOperand(i: 2).getReg(); |
6117 | X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0); |
6118 | Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0); |
6119 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
6120 | FMulMI->getOperand(i: 2).getReg(), RHS.Reg, X, Y, B); |
6121 | }; |
6122 | |
6123 | return true; |
6124 | } |
6125 | } |
6126 | |
6127 | // fold (fadd z, (fma x, y, (fpext (fmul u, v))) |
6128 | // -> (fma x, y, (fma (fpext u), (fpext v), z)) |
6129 | if (RHS.MI->getOpcode() == PreferredFusedOpcode && |
6130 | mi_match(R: RHS.MI->getOperand(i: 3).getReg(), MRI, |
6131 | P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
6132 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6133 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
6134 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
6135 | MatchInfo = [=](MachineIRBuilder &B) { |
6136 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
6137 | FMulMI->getOperand(i: 2).getReg(), LHS.Reg, |
6138 | RHS.MI->getOperand(i: 1).getReg(), |
6139 | RHS.MI->getOperand(i: 2).getReg(), B); |
6140 | }; |
6141 | return true; |
6142 | } |
6143 | |
6144 | // fold (fadd z, (fpext (fma x, y, (fmul u, v))) |
6145 | // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) |
6146 | // FIXME: This turns two single-precision and one double-precision |
6147 | // operation into two double-precision operations, which might not be |
6148 | // interesting for all targets, especially GPUs. |
6149 | if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) && |
6150 | FMAMI->getOpcode() == PreferredFusedOpcode) { |
6151 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg()); |
6152 | if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6153 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
6154 | SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) { |
6155 | MatchInfo = [=](MachineIRBuilder &B) { |
6156 | Register X = FMAMI->getOperand(i: 1).getReg(); |
6157 | Register Y = FMAMI->getOperand(i: 2).getReg(); |
6158 | X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0); |
6159 | Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0); |
6160 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
6161 | FMulMI->getOperand(i: 2).getReg(), LHS.Reg, X, Y, B); |
6162 | }; |
6163 | return true; |
6164 | } |
6165 | } |
6166 | |
6167 | return false; |
6168 | } |
6169 | |
6170 | bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA( |
6171 | MachineInstr &MI, |
6172 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
6173 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
6174 | |
6175 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
6176 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
6177 | return false; |
6178 | |
6179 | Register Op1 = MI.getOperand(i: 1).getReg(); |
6180 | Register Op2 = MI.getOperand(i: 2).getReg(); |
6181 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
6182 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
6183 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6184 | |
6185 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
6186 | // prefer to fold the multiply with fewer uses. |
6187 | int FirstMulHasFewerUses = true; |
6188 | if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
6189 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
6190 | hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
6191 | FirstMulHasFewerUses = false; |
6192 | |
6193 | unsigned PreferredFusedOpcode = |
6194 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
6195 | |
6196 | // fold (fsub (fmul x, y), z) -> (fma x, y, -z) |
6197 | if (FirstMulHasFewerUses && |
6198 | (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
6199 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg)))) { |
6200 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6201 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHS.Reg).getReg(Idx: 0); |
6202 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6203 | SrcOps: {LHS.MI->getOperand(i: 1).getReg(), |
6204 | LHS.MI->getOperand(i: 2).getReg(), NegZ}); |
6205 | }; |
6206 | return true; |
6207 | } |
6208 | // fold (fsub x, (fmul y, z)) -> (fma -y, z, x) |
6209 | else if ((isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
6210 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg)))) { |
6211 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6212 | Register NegY = |
6213 | B.buildFNeg(Dst: DstTy, Src0: RHS.MI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
6214 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6215 | SrcOps: {NegY, RHS.MI->getOperand(i: 2).getReg(), LHS.Reg}); |
6216 | }; |
6217 | return true; |
6218 | } |
6219 | |
6220 | return false; |
6221 | } |
6222 | |
6223 | bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA( |
6224 | MachineInstr &MI, |
6225 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
6226 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
6227 | |
6228 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
6229 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
6230 | return false; |
6231 | |
6232 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
6233 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
6234 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6235 | |
6236 | unsigned PreferredFusedOpcode = |
6237 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
6238 | |
6239 | MachineInstr *FMulMI; |
6240 | // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z)) |
6241 | if (mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) && |
6242 | (Aggressive || (MRI.hasOneNonDBGUse(RegNo: LHSReg) && |
6243 | MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) && |
6244 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) { |
6245 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6246 | Register NegX = |
6247 | B.buildFNeg(Dst: DstTy, Src0: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
6248 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0); |
6249 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6250 | SrcOps: {NegX, FMulMI->getOperand(i: 2).getReg(), NegZ}); |
6251 | }; |
6252 | return true; |
6253 | } |
6254 | |
6255 | // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x) |
6256 | if (mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) && |
6257 | (Aggressive || (MRI.hasOneNonDBGUse(RegNo: RHSReg) && |
6258 | MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) && |
6259 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) { |
6260 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6261 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6262 | SrcOps: {FMulMI->getOperand(i: 1).getReg(), |
6263 | FMulMI->getOperand(i: 2).getReg(), LHSReg}); |
6264 | }; |
6265 | return true; |
6266 | } |
6267 | |
6268 | return false; |
6269 | } |
6270 | |
6271 | bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA( |
6272 | MachineInstr &MI, |
6273 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
6274 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
6275 | |
6276 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
6277 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
6278 | return false; |
6279 | |
6280 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
6281 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
6282 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6283 | |
6284 | unsigned PreferredFusedOpcode = |
6285 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
6286 | |
6287 | MachineInstr *FMulMI; |
6288 | // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z)) |
6289 | if (mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
6290 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6291 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHSReg))) { |
6292 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6293 | Register FpExtX = |
6294 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
6295 | Register FpExtY = |
6296 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0); |
6297 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0); |
6298 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6299 | SrcOps: {FpExtX, FpExtY, NegZ}); |
6300 | }; |
6301 | return true; |
6302 | } |
6303 | |
6304 | // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x) |
6305 | if (mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
6306 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6307 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHSReg))) { |
6308 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6309 | Register FpExtY = |
6310 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
6311 | Register NegY = B.buildFNeg(Dst: DstTy, Src0: FpExtY).getReg(Idx: 0); |
6312 | Register FpExtZ = |
6313 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0); |
6314 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
6315 | SrcOps: {NegY, FpExtZ, LHSReg}); |
6316 | }; |
6317 | return true; |
6318 | } |
6319 | |
6320 | return false; |
6321 | } |
6322 | |
6323 | bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA( |
6324 | MachineInstr &MI, |
6325 | std::function<void(MachineIRBuilder &)> &MatchInfo) const { |
6326 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
6327 | |
6328 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
6329 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
6330 | return false; |
6331 | |
6332 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
6333 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6334 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
6335 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
6336 | |
6337 | unsigned PreferredFusedOpcode = |
6338 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
6339 | |
6340 | auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z, |
6341 | MachineIRBuilder &B) { |
6342 | Register FpExtX = B.buildFPExt(Res: DstTy, Op: X).getReg(Idx: 0); |
6343 | Register FpExtY = B.buildFPExt(Res: DstTy, Op: Y).getReg(Idx: 0); |
6344 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {Dst}, SrcOps: {FpExtX, FpExtY, Z}); |
6345 | }; |
6346 | |
6347 | MachineInstr *FMulMI; |
6348 | // fold (fsub (fpext (fneg (fmul x, y))), z) -> |
6349 | // (fneg (fma (fpext x), (fpext y), z)) |
6350 | // fold (fsub (fneg (fpext (fmul x, y))), z) -> |
6351 | // (fneg (fma (fpext x), (fpext y), z)) |
6352 | if ((mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) || |
6353 | mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) && |
6354 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6355 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy, |
6356 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
6357 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6358 | Register FMAReg = MRI.createGenericVirtualRegister(Ty: DstTy); |
6359 | buildMatchInfo(FMAReg, FMulMI->getOperand(i: 1).getReg(), |
6360 | FMulMI->getOperand(i: 2).getReg(), RHSReg, B); |
6361 | B.buildFNeg(Dst: MI.getOperand(i: 0).getReg(), Src0: FMAReg); |
6362 | }; |
6363 | return true; |
6364 | } |
6365 | |
6366 | // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x) |
6367 | // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x) |
6368 | if ((mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) || |
6369 | mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) && |
6370 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6371 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy, |
6372 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
6373 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6374 | buildMatchInfo(MI.getOperand(i: 0).getReg(), FMulMI->getOperand(i: 1).getReg(), |
6375 | FMulMI->getOperand(i: 2).getReg(), LHSReg, B); |
6376 | }; |
6377 | return true; |
6378 | } |
6379 | |
6380 | return false; |
6381 | } |
6382 | |
6383 | bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI, |
6384 | unsigned &IdxToPropagate) const { |
6385 | bool PropagateNaN; |
6386 | switch (MI.getOpcode()) { |
6387 | default: |
6388 | return false; |
6389 | case TargetOpcode::G_FMINNUM: |
6390 | case TargetOpcode::G_FMAXNUM: |
6391 | PropagateNaN = false; |
6392 | break; |
6393 | case TargetOpcode::G_FMINIMUM: |
6394 | case TargetOpcode::G_FMAXIMUM: |
6395 | PropagateNaN = true; |
6396 | break; |
6397 | } |
6398 | |
6399 | auto MatchNaN = [&](unsigned Idx) { |
6400 | Register MaybeNaNReg = MI.getOperand(i: Idx).getReg(); |
6401 | const ConstantFP *MaybeCst = getConstantFPVRegVal(VReg: MaybeNaNReg, MRI); |
6402 | if (!MaybeCst || !MaybeCst->getValueAPF().isNaN()) |
6403 | return false; |
6404 | IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1); |
6405 | return true; |
6406 | }; |
6407 | |
6408 | return MatchNaN(1) || MatchNaN(2); |
6409 | } |
6410 | |
6411 | bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) const { |
6412 | assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD" ); |
6413 | Register LHS = MI.getOperand(i: 1).getReg(); |
6414 | Register RHS = MI.getOperand(i: 2).getReg(); |
6415 | |
6416 | // Helper lambda to check for opportunities for |
6417 | // A + (B - A) -> B |
6418 | // (B - A) + A -> B |
6419 | auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) { |
6420 | Register Reg; |
6421 | return mi_match(R: MaybeSub, MRI, P: m_GSub(L: m_Reg(R&: Src), R: m_Reg(R&: Reg))) && |
6422 | Reg == MaybeSameReg; |
6423 | }; |
6424 | return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); |
6425 | } |
6426 | |
6427 | bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI, |
6428 | Register &MatchInfo) const { |
6429 | // This combine folds the following patterns: |
6430 | // |
6431 | // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k)) |
6432 | // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k))) |
6433 | // into |
6434 | // x |
6435 | // if |
6436 | // k == sizeof(VecEltTy)/2 |
6437 | // type(x) == type(dst) |
6438 | // |
6439 | // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef) |
6440 | // into |
6441 | // x |
6442 | // if |
6443 | // type(x) == type(dst) |
6444 | |
6445 | LLT DstVecTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6446 | LLT DstEltTy = DstVecTy.getElementType(); |
6447 | |
6448 | Register Lo, Hi; |
6449 | |
6450 | if (mi_match( |
6451 | MI, MRI, |
6452 | P: m_GBuildVector(L: m_GTrunc(Src: m_GBitcast(Src: m_Reg(R&: Lo))), R: m_GImplicitDef()))) { |
6453 | MatchInfo = Lo; |
6454 | return MRI.getType(Reg: MatchInfo) == DstVecTy; |
6455 | } |
6456 | |
6457 | std::optional<ValueAndVReg> ShiftAmount; |
6458 | const auto LoPattern = m_GBitcast(Src: m_Reg(R&: Lo)); |
6459 | const auto HiPattern = m_GLShr(L: m_GBitcast(Src: m_Reg(R&: Hi)), R: m_GCst(ValReg&: ShiftAmount)); |
6460 | if (mi_match( |
6461 | MI, MRI, |
6462 | P: m_any_of(preds: m_GBuildVectorTrunc(L: LoPattern, R: HiPattern), |
6463 | preds: m_GBuildVector(L: m_GTrunc(Src: LoPattern), R: m_GTrunc(Src: HiPattern))))) { |
6464 | if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) { |
6465 | MatchInfo = Lo; |
6466 | return MRI.getType(Reg: MatchInfo) == DstVecTy; |
6467 | } |
6468 | } |
6469 | |
6470 | return false; |
6471 | } |
6472 | |
6473 | bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI, |
6474 | Register &MatchInfo) const { |
6475 | // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x |
6476 | // if type(x) == type(G_TRUNC) |
6477 | if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI, |
6478 | P: m_GBitcast(Src: m_GBuildVector(L: m_Reg(R&: MatchInfo), R: m_Reg())))) |
6479 | return false; |
6480 | |
6481 | return MRI.getType(Reg: MatchInfo) == MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6482 | } |
6483 | |
6484 | bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI, |
6485 | Register &MatchInfo) const { |
6486 | // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with |
6487 | // y if K == size of vector element type |
6488 | std::optional<ValueAndVReg> ShiftAmt; |
6489 | if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI, |
6490 | P: m_GLShr(L: m_GBitcast(Src: m_GBuildVector(L: m_Reg(), R: m_Reg(R&: MatchInfo))), |
6491 | R: m_GCst(ValReg&: ShiftAmt)))) |
6492 | return false; |
6493 | |
6494 | LLT MatchTy = MRI.getType(Reg: MatchInfo); |
6495 | return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() && |
6496 | MatchTy == MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6497 | } |
6498 | |
6499 | unsigned CombinerHelper::getFPMinMaxOpcForSelect( |
6500 | CmpInst::Predicate Pred, LLT DstTy, |
6501 | SelectPatternNaNBehaviour VsNaNRetVal) const { |
6502 | assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE && |
6503 | "Expected a NaN behaviour?" ); |
6504 | // Choose an opcode based off of legality or the behaviour when one of the |
6505 | // LHS/RHS may be NaN. |
6506 | switch (Pred) { |
6507 | default: |
6508 | return 0; |
6509 | case CmpInst::FCMP_UGT: |
6510 | case CmpInst::FCMP_UGE: |
6511 | case CmpInst::FCMP_OGT: |
6512 | case CmpInst::FCMP_OGE: |
6513 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6514 | return TargetOpcode::G_FMAXNUM; |
6515 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) |
6516 | return TargetOpcode::G_FMAXIMUM; |
6517 | if (isLegal(Query: {TargetOpcode::G_FMAXNUM, {DstTy}})) |
6518 | return TargetOpcode::G_FMAXNUM; |
6519 | if (isLegal(Query: {TargetOpcode::G_FMAXIMUM, {DstTy}})) |
6520 | return TargetOpcode::G_FMAXIMUM; |
6521 | return 0; |
6522 | case CmpInst::FCMP_ULT: |
6523 | case CmpInst::FCMP_ULE: |
6524 | case CmpInst::FCMP_OLT: |
6525 | case CmpInst::FCMP_OLE: |
6526 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6527 | return TargetOpcode::G_FMINNUM; |
6528 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) |
6529 | return TargetOpcode::G_FMINIMUM; |
6530 | if (isLegal(Query: {TargetOpcode::G_FMINNUM, {DstTy}})) |
6531 | return TargetOpcode::G_FMINNUM; |
6532 | if (!isLegal(Query: {TargetOpcode::G_FMINIMUM, {DstTy}})) |
6533 | return 0; |
6534 | return TargetOpcode::G_FMINIMUM; |
6535 | } |
6536 | } |
6537 | |
6538 | CombinerHelper::SelectPatternNaNBehaviour |
6539 | CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS, |
6540 | bool IsOrderedComparison) const { |
6541 | bool LHSSafe = isKnownNeverNaN(Val: LHS, MRI); |
6542 | bool RHSSafe = isKnownNeverNaN(Val: RHS, MRI); |
6543 | // Completely unsafe. |
6544 | if (!LHSSafe && !RHSSafe) |
6545 | return SelectPatternNaNBehaviour::NOT_APPLICABLE; |
6546 | if (LHSSafe && RHSSafe) |
6547 | return SelectPatternNaNBehaviour::RETURNS_ANY; |
6548 | // An ordered comparison will return false when given a NaN, so it |
6549 | // returns the RHS. |
6550 | if (IsOrderedComparison) |
6551 | return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN |
6552 | : SelectPatternNaNBehaviour::RETURNS_OTHER; |
6553 | // An unordered comparison will return true when given a NaN, so it |
6554 | // returns the LHS. |
6555 | return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER |
6556 | : SelectPatternNaNBehaviour::RETURNS_NAN; |
6557 | } |
6558 | |
6559 | bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond, |
6560 | Register TrueVal, Register FalseVal, |
6561 | BuildFnTy &MatchInfo) const { |
6562 | // Match: select (fcmp cond x, y) x, y |
6563 | // select (fcmp cond x, y) y, x |
6564 | // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition. |
6565 | LLT DstTy = MRI.getType(Reg: Dst); |
6566 | // Bail out early on pointers, since we'll never want to fold to a min/max. |
6567 | if (DstTy.isPointer()) |
6568 | return false; |
6569 | // Match a floating point compare with a less-than/greater-than predicate. |
6570 | // TODO: Allow multiple users of the compare if they are all selects. |
6571 | CmpInst::Predicate Pred; |
6572 | Register CmpLHS, CmpRHS; |
6573 | if (!mi_match(R: Cond, MRI, |
6574 | P: m_OneNonDBGUse( |
6575 | SP: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: CmpLHS), R: m_Reg(R&: CmpRHS)))) || |
6576 | CmpInst::isEquality(pred: Pred)) |
6577 | return false; |
6578 | SelectPatternNaNBehaviour ResWithKnownNaNInfo = |
6579 | computeRetValAgainstNaN(LHS: CmpLHS, RHS: CmpRHS, IsOrderedComparison: CmpInst::isOrdered(predicate: Pred)); |
6580 | if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE) |
6581 | return false; |
6582 | if (TrueVal == CmpRHS && FalseVal == CmpLHS) { |
6583 | std::swap(a&: CmpLHS, b&: CmpRHS); |
6584 | Pred = CmpInst::getSwappedPredicate(pred: Pred); |
6585 | if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN) |
6586 | ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER; |
6587 | else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6588 | ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN; |
6589 | } |
6590 | if (TrueVal != CmpLHS || FalseVal != CmpRHS) |
6591 | return false; |
6592 | // Decide what type of max/min this should be based off of the predicate. |
6593 | unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, VsNaNRetVal: ResWithKnownNaNInfo); |
6594 | if (!Opc || !isLegal(Query: {Opc, {DstTy}})) |
6595 | return false; |
6596 | // Comparisons between signed zero and zero may have different results... |
6597 | // unless we have fmaximum/fminimum. In that case, we know -0 < 0. |
6598 | if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) { |
6599 | // We don't know if a comparison between two 0s will give us a consistent |
6600 | // result. Be conservative and only proceed if at least one side is |
6601 | // non-zero. |
6602 | auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpLHS, MRI); |
6603 | if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) { |
6604 | KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpRHS, MRI); |
6605 | if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) |
6606 | return false; |
6607 | } |
6608 | } |
6609 | MatchInfo = [=](MachineIRBuilder &B) { |
6610 | B.buildInstr(Opc, DstOps: {Dst}, SrcOps: {CmpLHS, CmpRHS}); |
6611 | }; |
6612 | return true; |
6613 | } |
6614 | |
6615 | bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI, |
6616 | BuildFnTy &MatchInfo) const { |
6617 | // TODO: Handle integer cases. |
6618 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
6619 | // Condition may be fed by a truncated compare. |
6620 | Register Cond = MI.getOperand(i: 1).getReg(); |
6621 | Register MaybeTrunc; |
6622 | if (mi_match(R: Cond, MRI, P: m_OneNonDBGUse(SP: m_GTrunc(Src: m_Reg(R&: MaybeTrunc))))) |
6623 | Cond = MaybeTrunc; |
6624 | Register Dst = MI.getOperand(i: 0).getReg(); |
6625 | Register TrueVal = MI.getOperand(i: 2).getReg(); |
6626 | Register FalseVal = MI.getOperand(i: 3).getReg(); |
6627 | return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo); |
6628 | } |
6629 | |
6630 | bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI, |
6631 | BuildFnTy &MatchInfo) const { |
6632 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
6633 | // (X + Y) == X --> Y == 0 |
6634 | // (X + Y) != X --> Y != 0 |
6635 | // (X - Y) == X --> Y == 0 |
6636 | // (X - Y) != X --> Y != 0 |
6637 | // (X ^ Y) == X --> Y == 0 |
6638 | // (X ^ Y) != X --> Y != 0 |
6639 | Register Dst = MI.getOperand(i: 0).getReg(); |
6640 | CmpInst::Predicate Pred; |
6641 | Register X, Y, OpLHS, OpRHS; |
6642 | bool MatchedSub = mi_match( |
6643 | R: Dst, MRI, |
6644 | P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), R: m_GSub(L: m_Reg(R&: OpLHS), R: m_Reg(R&: Y)))); |
6645 | if (MatchedSub && X != OpLHS) |
6646 | return false; |
6647 | if (!MatchedSub) { |
6648 | if (!mi_match(R: Dst, MRI, |
6649 | P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), |
6650 | R: m_any_of(preds: m_GAdd(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)), |
6651 | preds: m_GXor(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)))))) |
6652 | return false; |
6653 | Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register(); |
6654 | } |
6655 | MatchInfo = [=](MachineIRBuilder &B) { |
6656 | auto Zero = B.buildConstant(Res: MRI.getType(Reg: Y), Val: 0); |
6657 | B.buildICmp(Pred, Res: Dst, Op0: Y, Op1: Zero); |
6658 | }; |
6659 | return CmpInst::isEquality(pred: Pred) && Y.isValid(); |
6660 | } |
6661 | |
6662 | /// Return the minimum useless shift amount that results in complete loss of the |
6663 | /// source value. Return std::nullopt when it cannot determine a value. |
6664 | static std::optional<unsigned> |
6665 | getMinUselessShift(KnownBits ValueKB, unsigned Opcode, |
6666 | std::optional<int64_t> &Result) { |
6667 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR || |
6668 | Opcode == TargetOpcode::G_ASHR) && |
6669 | "Expect G_SHL, G_LSHR or G_ASHR." ); |
6670 | auto SignificantBits = 0; |
6671 | switch (Opcode) { |
6672 | case TargetOpcode::G_SHL: |
6673 | SignificantBits = ValueKB.countMinTrailingZeros(); |
6674 | Result = 0; |
6675 | break; |
6676 | case TargetOpcode::G_LSHR: |
6677 | Result = 0; |
6678 | SignificantBits = ValueKB.countMinLeadingZeros(); |
6679 | break; |
6680 | case TargetOpcode::G_ASHR: |
6681 | if (ValueKB.isNonNegative()) { |
6682 | SignificantBits = ValueKB.countMinLeadingZeros(); |
6683 | Result = 0; |
6684 | } else if (ValueKB.isNegative()) { |
6685 | SignificantBits = ValueKB.countMinLeadingOnes(); |
6686 | Result = -1; |
6687 | } else { |
6688 | // Cannot determine shift result. |
6689 | Result = std::nullopt; |
6690 | } |
6691 | break; |
6692 | default: |
6693 | break; |
6694 | } |
6695 | return ValueKB.getBitWidth() - SignificantBits; |
6696 | } |
6697 | |
6698 | bool CombinerHelper::matchShiftsTooBig( |
6699 | MachineInstr &MI, std::optional<int64_t> &MatchInfo) const { |
6700 | Register ShiftVal = MI.getOperand(i: 1).getReg(); |
6701 | Register ShiftReg = MI.getOperand(i: 2).getReg(); |
6702 | LLT ResTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6703 | auto IsShiftTooBig = [&](const Constant *C) { |
6704 | auto *CI = dyn_cast<ConstantInt>(Val: C); |
6705 | if (!CI) |
6706 | return false; |
6707 | if (CI->uge(Num: ResTy.getScalarSizeInBits())) { |
6708 | MatchInfo = std::nullopt; |
6709 | return true; |
6710 | } |
6711 | auto OptMaxUsefulShift = getMinUselessShift(ValueKB: VT->getKnownBits(R: ShiftVal), |
6712 | Opcode: MI.getOpcode(), Result&: MatchInfo); |
6713 | return OptMaxUsefulShift && CI->uge(Num: *OptMaxUsefulShift); |
6714 | }; |
6715 | return matchUnaryPredicate(MRI, Reg: ShiftReg, Match: IsShiftTooBig); |
6716 | } |
6717 | |
6718 | bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) const { |
6719 | unsigned LHSOpndIdx = 1; |
6720 | unsigned RHSOpndIdx = 2; |
6721 | switch (MI.getOpcode()) { |
6722 | case TargetOpcode::G_UADDO: |
6723 | case TargetOpcode::G_SADDO: |
6724 | case TargetOpcode::G_UMULO: |
6725 | case TargetOpcode::G_SMULO: |
6726 | LHSOpndIdx = 2; |
6727 | RHSOpndIdx = 3; |
6728 | break; |
6729 | default: |
6730 | break; |
6731 | } |
6732 | Register LHS = MI.getOperand(i: LHSOpndIdx).getReg(); |
6733 | Register RHS = MI.getOperand(i: RHSOpndIdx).getReg(); |
6734 | if (!getIConstantVRegVal(VReg: LHS, MRI)) { |
6735 | // Skip commuting if LHS is not a constant. But, LHS may be a |
6736 | // G_CONSTANT_FOLD_BARRIER. If so we commute as long as we don't already |
6737 | // have a constant on the RHS. |
6738 | if (MRI.getVRegDef(Reg: LHS)->getOpcode() != |
6739 | TargetOpcode::G_CONSTANT_FOLD_BARRIER) |
6740 | return false; |
6741 | } |
6742 | // Commute as long as RHS is not a constant or G_CONSTANT_FOLD_BARRIER. |
6743 | return MRI.getVRegDef(Reg: RHS)->getOpcode() != |
6744 | TargetOpcode::G_CONSTANT_FOLD_BARRIER && |
6745 | !getIConstantVRegVal(VReg: RHS, MRI); |
6746 | } |
6747 | |
6748 | bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) const { |
6749 | Register LHS = MI.getOperand(i: 1).getReg(); |
6750 | Register RHS = MI.getOperand(i: 2).getReg(); |
6751 | std::optional<FPValueAndVReg> ValAndVReg; |
6752 | if (!mi_match(R: LHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg))) |
6753 | return false; |
6754 | return !mi_match(R: RHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg)); |
6755 | } |
6756 | |
6757 | void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) const { |
6758 | Observer.changingInstr(MI); |
6759 | unsigned LHSOpndIdx = 1; |
6760 | unsigned RHSOpndIdx = 2; |
6761 | switch (MI.getOpcode()) { |
6762 | case TargetOpcode::G_UADDO: |
6763 | case TargetOpcode::G_SADDO: |
6764 | case TargetOpcode::G_UMULO: |
6765 | case TargetOpcode::G_SMULO: |
6766 | LHSOpndIdx = 2; |
6767 | RHSOpndIdx = 3; |
6768 | break; |
6769 | default: |
6770 | break; |
6771 | } |
6772 | Register LHSReg = MI.getOperand(i: LHSOpndIdx).getReg(); |
6773 | Register RHSReg = MI.getOperand(i: RHSOpndIdx).getReg(); |
6774 | MI.getOperand(i: LHSOpndIdx).setReg(RHSReg); |
6775 | MI.getOperand(i: RHSOpndIdx).setReg(LHSReg); |
6776 | Observer.changedInstr(MI); |
6777 | } |
6778 | |
6779 | bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) const { |
6780 | LLT SrcTy = MRI.getType(Reg: Src); |
6781 | if (SrcTy.isFixedVector()) |
6782 | return isConstantSplatVector(Src, SplatValue: 1, AllowUndefs); |
6783 | if (SrcTy.isScalar()) { |
6784 | if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr) |
6785 | return true; |
6786 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6787 | return IConstant && IConstant->Value == 1; |
6788 | } |
6789 | return false; // scalable vector |
6790 | } |
6791 | |
6792 | bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) const { |
6793 | LLT SrcTy = MRI.getType(Reg: Src); |
6794 | if (SrcTy.isFixedVector()) |
6795 | return isConstantSplatVector(Src, SplatValue: 0, AllowUndefs); |
6796 | if (SrcTy.isScalar()) { |
6797 | if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr) |
6798 | return true; |
6799 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6800 | return IConstant && IConstant->Value == 0; |
6801 | } |
6802 | return false; // scalable vector |
6803 | } |
6804 | |
6805 | // Ignores COPYs during conformance checks. |
6806 | // FIXME scalable vectors. |
6807 | bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue, |
6808 | bool AllowUndefs) const { |
6809 | GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI); |
6810 | if (!BuildVector) |
6811 | return false; |
6812 | unsigned NumSources = BuildVector->getNumSources(); |
6813 | |
6814 | for (unsigned I = 0; I < NumSources; ++I) { |
6815 | GImplicitDef *ImplicitDef = |
6816 | getOpcodeDef<GImplicitDef>(Reg: BuildVector->getSourceReg(I), MRI); |
6817 | if (ImplicitDef && AllowUndefs) |
6818 | continue; |
6819 | if (ImplicitDef && !AllowUndefs) |
6820 | return false; |
6821 | std::optional<ValueAndVReg> IConstant = |
6822 | getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI); |
6823 | if (IConstant && IConstant->Value == SplatValue) |
6824 | continue; |
6825 | return false; |
6826 | } |
6827 | return true; |
6828 | } |
6829 | |
6830 | // Ignores COPYs during lookups. |
6831 | // FIXME scalable vectors |
6832 | std::optional<APInt> |
6833 | CombinerHelper::getConstantOrConstantSplatVector(Register Src) const { |
6834 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6835 | if (IConstant) |
6836 | return IConstant->Value; |
6837 | |
6838 | GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI); |
6839 | if (!BuildVector) |
6840 | return std::nullopt; |
6841 | unsigned NumSources = BuildVector->getNumSources(); |
6842 | |
6843 | std::optional<APInt> Value = std::nullopt; |
6844 | for (unsigned I = 0; I < NumSources; ++I) { |
6845 | std::optional<ValueAndVReg> IConstant = |
6846 | getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI); |
6847 | if (!IConstant) |
6848 | return std::nullopt; |
6849 | if (!Value) |
6850 | Value = IConstant->Value; |
6851 | else if (*Value != IConstant->Value) |
6852 | return std::nullopt; |
6853 | } |
6854 | return Value; |
6855 | } |
6856 | |
6857 | // FIXME G_SPLAT_VECTOR |
6858 | bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const { |
6859 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6860 | if (IConstant) |
6861 | return true; |
6862 | |
6863 | GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI); |
6864 | if (!BuildVector) |
6865 | return false; |
6866 | |
6867 | unsigned NumSources = BuildVector->getNumSources(); |
6868 | for (unsigned I = 0; I < NumSources; ++I) { |
6869 | std::optional<ValueAndVReg> IConstant = |
6870 | getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI); |
6871 | if (!IConstant) |
6872 | return false; |
6873 | } |
6874 | return true; |
6875 | } |
6876 | |
6877 | // TODO: use knownbits to determine zeros |
6878 | bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select, |
6879 | BuildFnTy &MatchInfo) const { |
6880 | uint32_t Flags = Select->getFlags(); |
6881 | Register Dest = Select->getReg(Idx: 0); |
6882 | Register Cond = Select->getCondReg(); |
6883 | Register True = Select->getTrueReg(); |
6884 | Register False = Select->getFalseReg(); |
6885 | LLT CondTy = MRI.getType(Reg: Select->getCondReg()); |
6886 | LLT TrueTy = MRI.getType(Reg: Select->getTrueReg()); |
6887 | |
6888 | // We only do this combine for scalar boolean conditions. |
6889 | if (CondTy != LLT::scalar(SizeInBits: 1)) |
6890 | return false; |
6891 | |
6892 | if (TrueTy.isPointer()) |
6893 | return false; |
6894 | |
6895 | // Both are scalars. |
6896 | std::optional<ValueAndVReg> TrueOpt = |
6897 | getIConstantVRegValWithLookThrough(VReg: True, MRI); |
6898 | std::optional<ValueAndVReg> FalseOpt = |
6899 | getIConstantVRegValWithLookThrough(VReg: False, MRI); |
6900 | |
6901 | if (!TrueOpt || !FalseOpt) |
6902 | return false; |
6903 | |
6904 | APInt TrueValue = TrueOpt->Value; |
6905 | APInt FalseValue = FalseOpt->Value; |
6906 | |
6907 | // select Cond, 1, 0 --> zext (Cond) |
6908 | if (TrueValue.isOne() && FalseValue.isZero()) { |
6909 | MatchInfo = [=](MachineIRBuilder &B) { |
6910 | B.setInstrAndDebugLoc(*Select); |
6911 | B.buildZExtOrTrunc(Res: Dest, Op: Cond); |
6912 | }; |
6913 | return true; |
6914 | } |
6915 | |
6916 | // select Cond, -1, 0 --> sext (Cond) |
6917 | if (TrueValue.isAllOnes() && FalseValue.isZero()) { |
6918 | MatchInfo = [=](MachineIRBuilder &B) { |
6919 | B.setInstrAndDebugLoc(*Select); |
6920 | B.buildSExtOrTrunc(Res: Dest, Op: Cond); |
6921 | }; |
6922 | return true; |
6923 | } |
6924 | |
6925 | // select Cond, 0, 1 --> zext (!Cond) |
6926 | if (TrueValue.isZero() && FalseValue.isOne()) { |
6927 | MatchInfo = [=](MachineIRBuilder &B) { |
6928 | B.setInstrAndDebugLoc(*Select); |
6929 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6930 | B.buildNot(Dst: Inner, Src0: Cond); |
6931 | B.buildZExtOrTrunc(Res: Dest, Op: Inner); |
6932 | }; |
6933 | return true; |
6934 | } |
6935 | |
6936 | // select Cond, 0, -1 --> sext (!Cond) |
6937 | if (TrueValue.isZero() && FalseValue.isAllOnes()) { |
6938 | MatchInfo = [=](MachineIRBuilder &B) { |
6939 | B.setInstrAndDebugLoc(*Select); |
6940 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6941 | B.buildNot(Dst: Inner, Src0: Cond); |
6942 | B.buildSExtOrTrunc(Res: Dest, Op: Inner); |
6943 | }; |
6944 | return true; |
6945 | } |
6946 | |
6947 | // select Cond, C1, C1-1 --> add (zext Cond), C1-1 |
6948 | if (TrueValue - 1 == FalseValue) { |
6949 | MatchInfo = [=](MachineIRBuilder &B) { |
6950 | B.setInstrAndDebugLoc(*Select); |
6951 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6952 | B.buildZExtOrTrunc(Res: Inner, Op: Cond); |
6953 | B.buildAdd(Dst: Dest, Src0: Inner, Src1: False); |
6954 | }; |
6955 | return true; |
6956 | } |
6957 | |
6958 | // select Cond, C1, C1+1 --> add (sext Cond), C1+1 |
6959 | if (TrueValue + 1 == FalseValue) { |
6960 | MatchInfo = [=](MachineIRBuilder &B) { |
6961 | B.setInstrAndDebugLoc(*Select); |
6962 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6963 | B.buildSExtOrTrunc(Res: Inner, Op: Cond); |
6964 | B.buildAdd(Dst: Dest, Src0: Inner, Src1: False); |
6965 | }; |
6966 | return true; |
6967 | } |
6968 | |
6969 | // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2) |
6970 | if (TrueValue.isPowerOf2() && FalseValue.isZero()) { |
6971 | MatchInfo = [=](MachineIRBuilder &B) { |
6972 | B.setInstrAndDebugLoc(*Select); |
6973 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6974 | B.buildZExtOrTrunc(Res: Inner, Op: Cond); |
6975 | // The shift amount must be scalar. |
6976 | LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy; |
6977 | auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: TrueValue.exactLogBase2()); |
6978 | B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags); |
6979 | }; |
6980 | return true; |
6981 | } |
6982 | |
6983 | // select Cond, 0, Pow2 --> (zext (!Cond)) << log2(Pow2) |
6984 | if (FalseValue.isPowerOf2() && TrueValue.isZero()) { |
6985 | MatchInfo = [=](MachineIRBuilder &B) { |
6986 | B.setInstrAndDebugLoc(*Select); |
6987 | Register Not = MRI.createGenericVirtualRegister(Ty: CondTy); |
6988 | B.buildNot(Dst: Not, Src0: Cond); |
6989 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6990 | B.buildZExtOrTrunc(Res: Inner, Op: Not); |
6991 | // The shift amount must be scalar. |
6992 | LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy; |
6993 | auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: FalseValue.exactLogBase2()); |
6994 | B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags); |
6995 | }; |
6996 | return true; |
6997 | } |
6998 | |
6999 | // select Cond, -1, C --> or (sext Cond), C |
7000 | if (TrueValue.isAllOnes()) { |
7001 | MatchInfo = [=](MachineIRBuilder &B) { |
7002 | B.setInstrAndDebugLoc(*Select); |
7003 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
7004 | B.buildSExtOrTrunc(Res: Inner, Op: Cond); |
7005 | B.buildOr(Dst: Dest, Src0: Inner, Src1: False, Flags); |
7006 | }; |
7007 | return true; |
7008 | } |
7009 | |
7010 | // select Cond, C, -1 --> or (sext (not Cond)), C |
7011 | if (FalseValue.isAllOnes()) { |
7012 | MatchInfo = [=](MachineIRBuilder &B) { |
7013 | B.setInstrAndDebugLoc(*Select); |
7014 | Register Not = MRI.createGenericVirtualRegister(Ty: CondTy); |
7015 | B.buildNot(Dst: Not, Src0: Cond); |
7016 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
7017 | B.buildSExtOrTrunc(Res: Inner, Op: Not); |
7018 | B.buildOr(Dst: Dest, Src0: Inner, Src1: True, Flags); |
7019 | }; |
7020 | return true; |
7021 | } |
7022 | |
7023 | return false; |
7024 | } |
7025 | |
7026 | // TODO: use knownbits to determine zeros |
7027 | bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select, |
7028 | BuildFnTy &MatchInfo) const { |
7029 | uint32_t Flags = Select->getFlags(); |
7030 | Register DstReg = Select->getReg(Idx: 0); |
7031 | Register Cond = Select->getCondReg(); |
7032 | Register True = Select->getTrueReg(); |
7033 | Register False = Select->getFalseReg(); |
7034 | LLT CondTy = MRI.getType(Reg: Select->getCondReg()); |
7035 | LLT TrueTy = MRI.getType(Reg: Select->getTrueReg()); |
7036 | |
7037 | // Boolean or fixed vector of booleans. |
7038 | if (CondTy.isScalableVector() || |
7039 | (CondTy.isFixedVector() && |
7040 | CondTy.getElementType().getScalarSizeInBits() != 1) || |
7041 | CondTy.getScalarSizeInBits() != 1) |
7042 | return false; |
7043 | |
7044 | if (CondTy != TrueTy) |
7045 | return false; |
7046 | |
7047 | // select Cond, Cond, F --> or Cond, F |
7048 | // select Cond, 1, F --> or Cond, F |
7049 | if ((Cond == True) || isOneOrOneSplat(Src: True, /* AllowUndefs */ true)) { |
7050 | MatchInfo = [=](MachineIRBuilder &B) { |
7051 | B.setInstrAndDebugLoc(*Select); |
7052 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
7053 | B.buildZExtOrTrunc(Res: Ext, Op: Cond); |
7054 | auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False); |
7055 | B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeFalse, Flags); |
7056 | }; |
7057 | return true; |
7058 | } |
7059 | |
7060 | // select Cond, T, Cond --> and Cond, T |
7061 | // select Cond, T, 0 --> and Cond, T |
7062 | if ((Cond == False) || isZeroOrZeroSplat(Src: False, /* AllowUndefs */ true)) { |
7063 | MatchInfo = [=](MachineIRBuilder &B) { |
7064 | B.setInstrAndDebugLoc(*Select); |
7065 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
7066 | B.buildZExtOrTrunc(Res: Ext, Op: Cond); |
7067 | auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True); |
7068 | B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeTrue); |
7069 | }; |
7070 | return true; |
7071 | } |
7072 | |
7073 | // select Cond, T, 1 --> or (not Cond), T |
7074 | if (isOneOrOneSplat(Src: False, /* AllowUndefs */ true)) { |
7075 | MatchInfo = [=](MachineIRBuilder &B) { |
7076 | B.setInstrAndDebugLoc(*Select); |
7077 | // First the not. |
7078 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
7079 | B.buildNot(Dst: Inner, Src0: Cond); |
7080 | // Then an ext to match the destination register. |
7081 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
7082 | B.buildZExtOrTrunc(Res: Ext, Op: Inner); |
7083 | auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True); |
7084 | B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeTrue, Flags); |
7085 | }; |
7086 | return true; |
7087 | } |
7088 | |
7089 | // select Cond, 0, F --> and (not Cond), F |
7090 | if (isZeroOrZeroSplat(Src: True, /* AllowUndefs */ true)) { |
7091 | MatchInfo = [=](MachineIRBuilder &B) { |
7092 | B.setInstrAndDebugLoc(*Select); |
7093 | // First the not. |
7094 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
7095 | B.buildNot(Dst: Inner, Src0: Cond); |
7096 | // Then an ext to match the destination register. |
7097 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
7098 | B.buildZExtOrTrunc(Res: Ext, Op: Inner); |
7099 | auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False); |
7100 | B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeFalse); |
7101 | }; |
7102 | return true; |
7103 | } |
7104 | |
7105 | return false; |
7106 | } |
7107 | |
7108 | bool CombinerHelper::matchSelectIMinMax(const MachineOperand &MO, |
7109 | BuildFnTy &MatchInfo) const { |
7110 | GSelect *Select = cast<GSelect>(Val: MRI.getVRegDef(Reg: MO.getReg())); |
7111 | GICmp *Cmp = cast<GICmp>(Val: MRI.getVRegDef(Reg: Select->getCondReg())); |
7112 | |
7113 | Register DstReg = Select->getReg(Idx: 0); |
7114 | Register True = Select->getTrueReg(); |
7115 | Register False = Select->getFalseReg(); |
7116 | LLT DstTy = MRI.getType(Reg: DstReg); |
7117 | |
7118 | if (DstTy.isPointer()) |
7119 | return false; |
7120 | |
7121 | // We want to fold the icmp and replace the select. |
7122 | if (!MRI.hasOneNonDBGUse(RegNo: Cmp->getReg(Idx: 0))) |
7123 | return false; |
7124 | |
7125 | CmpInst::Predicate Pred = Cmp->getCond(); |
7126 | // We need a larger or smaller predicate for |
7127 | // canonicalization. |
7128 | if (CmpInst::isEquality(pred: Pred)) |
7129 | return false; |
7130 | |
7131 | Register CmpLHS = Cmp->getLHSReg(); |
7132 | Register CmpRHS = Cmp->getRHSReg(); |
7133 | |
7134 | // We can swap CmpLHS and CmpRHS for higher hitrate. |
7135 | if (True == CmpRHS && False == CmpLHS) { |
7136 | std::swap(a&: CmpLHS, b&: CmpRHS); |
7137 | Pred = CmpInst::getSwappedPredicate(pred: Pred); |
7138 | } |
7139 | |
7140 | // (icmp X, Y) ? X : Y -> integer minmax. |
7141 | // see matchSelectPattern in ValueTracking. |
7142 | // Legality between G_SELECT and integer minmax can differ. |
7143 | if (True != CmpLHS || False != CmpRHS) |
7144 | return false; |
7145 | |
7146 | switch (Pred) { |
7147 | case ICmpInst::ICMP_UGT: |
7148 | case ICmpInst::ICMP_UGE: { |
7149 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMAX, DstTy})) |
7150 | return false; |
7151 | MatchInfo = [=](MachineIRBuilder &B) { B.buildUMax(Dst: DstReg, Src0: True, Src1: False); }; |
7152 | return true; |
7153 | } |
7154 | case ICmpInst::ICMP_SGT: |
7155 | case ICmpInst::ICMP_SGE: { |
7156 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMAX, DstTy})) |
7157 | return false; |
7158 | MatchInfo = [=](MachineIRBuilder &B) { B.buildSMax(Dst: DstReg, Src0: True, Src1: False); }; |
7159 | return true; |
7160 | } |
7161 | case ICmpInst::ICMP_ULT: |
7162 | case ICmpInst::ICMP_ULE: { |
7163 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMIN, DstTy})) |
7164 | return false; |
7165 | MatchInfo = [=](MachineIRBuilder &B) { B.buildUMin(Dst: DstReg, Src0: True, Src1: False); }; |
7166 | return true; |
7167 | } |
7168 | case ICmpInst::ICMP_SLT: |
7169 | case ICmpInst::ICMP_SLE: { |
7170 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMIN, DstTy})) |
7171 | return false; |
7172 | MatchInfo = [=](MachineIRBuilder &B) { B.buildSMin(Dst: DstReg, Src0: True, Src1: False); }; |
7173 | return true; |
7174 | } |
7175 | default: |
7176 | return false; |
7177 | } |
7178 | } |
7179 | |
7180 | // (neg (min/max x, (neg x))) --> (max/min x, (neg x)) |
7181 | bool CombinerHelper::matchSimplifyNegMinMax(MachineInstr &MI, |
7182 | BuildFnTy &MatchInfo) const { |
7183 | assert(MI.getOpcode() == TargetOpcode::G_SUB); |
7184 | Register DestReg = MI.getOperand(i: 0).getReg(); |
7185 | LLT DestTy = MRI.getType(Reg: DestReg); |
7186 | |
7187 | Register X; |
7188 | Register Sub0; |
7189 | auto NegPattern = m_all_of(preds: m_Neg(Src: m_DeferredReg(R&: X)), preds: m_Reg(R&: Sub0)); |
7190 | if (mi_match(R: DestReg, MRI, |
7191 | P: m_Neg(Src: m_OneUse(SP: m_any_of(preds: m_GSMin(L: m_Reg(R&: X), R: NegPattern), |
7192 | preds: m_GSMax(L: m_Reg(R&: X), R: NegPattern), |
7193 | preds: m_GUMin(L: m_Reg(R&: X), R: NegPattern), |
7194 | preds: m_GUMax(L: m_Reg(R&: X), R: NegPattern)))))) { |
7195 | MachineInstr *MinMaxMI = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg()); |
7196 | unsigned NewOpc = getInverseGMinMaxOpcode(MinMaxOpc: MinMaxMI->getOpcode()); |
7197 | if (isLegal(Query: {NewOpc, {DestTy}})) { |
7198 | MatchInfo = [=](MachineIRBuilder &B) { |
7199 | B.buildInstr(Opc: NewOpc, DstOps: {DestReg}, SrcOps: {X, Sub0}); |
7200 | }; |
7201 | return true; |
7202 | } |
7203 | } |
7204 | |
7205 | return false; |
7206 | } |
7207 | |
7208 | bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) const { |
7209 | GSelect *Select = cast<GSelect>(Val: &MI); |
7210 | |
7211 | if (tryFoldSelectOfConstants(Select, MatchInfo)) |
7212 | return true; |
7213 | |
7214 | if (tryFoldBoolSelectToLogic(Select, MatchInfo)) |
7215 | return true; |
7216 | |
7217 | return false; |
7218 | } |
7219 | |
7220 | /// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2) |
7221 | /// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2) |
7222 | /// into a single comparison using range-based reasoning. |
7223 | /// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges. |
7224 | bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges( |
7225 | GLogicalBinOp *Logic, BuildFnTy &MatchInfo) const { |
7226 | assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor" ); |
7227 | bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND; |
7228 | Register DstReg = Logic->getReg(Idx: 0); |
7229 | Register LHS = Logic->getLHSReg(); |
7230 | Register RHS = Logic->getRHSReg(); |
7231 | unsigned Flags = Logic->getFlags(); |
7232 | |
7233 | // We need an G_ICMP on the LHS register. |
7234 | GICmp *Cmp1 = getOpcodeDef<GICmp>(Reg: LHS, MRI); |
7235 | if (!Cmp1) |
7236 | return false; |
7237 | |
7238 | // We need an G_ICMP on the RHS register. |
7239 | GICmp *Cmp2 = getOpcodeDef<GICmp>(Reg: RHS, MRI); |
7240 | if (!Cmp2) |
7241 | return false; |
7242 | |
7243 | // We want to fold the icmps. |
7244 | if (!MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) || |
7245 | !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0))) |
7246 | return false; |
7247 | |
7248 | APInt C1; |
7249 | APInt C2; |
7250 | std::optional<ValueAndVReg> MaybeC1 = |
7251 | getIConstantVRegValWithLookThrough(VReg: Cmp1->getRHSReg(), MRI); |
7252 | if (!MaybeC1) |
7253 | return false; |
7254 | C1 = MaybeC1->Value; |
7255 | |
7256 | std::optional<ValueAndVReg> MaybeC2 = |
7257 | getIConstantVRegValWithLookThrough(VReg: Cmp2->getRHSReg(), MRI); |
7258 | if (!MaybeC2) |
7259 | return false; |
7260 | C2 = MaybeC2->Value; |
7261 | |
7262 | Register R1 = Cmp1->getLHSReg(); |
7263 | Register R2 = Cmp2->getLHSReg(); |
7264 | CmpInst::Predicate Pred1 = Cmp1->getCond(); |
7265 | CmpInst::Predicate Pred2 = Cmp2->getCond(); |
7266 | LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0)); |
7267 | LLT CmpOperandTy = MRI.getType(Reg: R1); |
7268 | |
7269 | if (CmpOperandTy.isPointer()) |
7270 | return false; |
7271 | |
7272 | // We build ands, adds, and constants of type CmpOperandTy. |
7273 | // They must be legal to build. |
7274 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_AND, CmpOperandTy}) || |
7275 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, CmpOperandTy}) || |
7276 | !isConstantLegalOrBeforeLegalizer(Ty: CmpOperandTy)) |
7277 | return false; |
7278 | |
7279 | // Look through add of a constant offset on R1, R2, or both operands. This |
7280 | // allows us to interpret the R + C' < C'' range idiom into a proper range. |
7281 | std::optional<APInt> Offset1; |
7282 | std::optional<APInt> Offset2; |
7283 | if (R1 != R2) { |
7284 | if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R1, MRI)) { |
7285 | std::optional<ValueAndVReg> MaybeOffset1 = |
7286 | getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI); |
7287 | if (MaybeOffset1) { |
7288 | R1 = Add->getLHSReg(); |
7289 | Offset1 = MaybeOffset1->Value; |
7290 | } |
7291 | } |
7292 | if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R2, MRI)) { |
7293 | std::optional<ValueAndVReg> MaybeOffset2 = |
7294 | getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI); |
7295 | if (MaybeOffset2) { |
7296 | R2 = Add->getLHSReg(); |
7297 | Offset2 = MaybeOffset2->Value; |
7298 | } |
7299 | } |
7300 | } |
7301 | |
7302 | if (R1 != R2) |
7303 | return false; |
7304 | |
7305 | // We calculate the icmp ranges including maybe offsets. |
7306 | ConstantRange CR1 = ConstantRange::makeExactICmpRegion( |
7307 | Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred1) : Pred1, Other: C1); |
7308 | if (Offset1) |
7309 | CR1 = CR1.subtract(CI: *Offset1); |
7310 | |
7311 | ConstantRange CR2 = ConstantRange::makeExactICmpRegion( |
7312 | Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred2) : Pred2, Other: C2); |
7313 | if (Offset2) |
7314 | CR2 = CR2.subtract(CI: *Offset2); |
7315 | |
7316 | bool CreateMask = false; |
7317 | APInt LowerDiff; |
7318 | std::optional<ConstantRange> CR = CR1.exactUnionWith(CR: CR2); |
7319 | if (!CR) { |
7320 | // We need non-wrapping ranges. |
7321 | if (CR1.isWrappedSet() || CR2.isWrappedSet()) |
7322 | return false; |
7323 | |
7324 | // Check whether we have equal-size ranges that only differ by one bit. |
7325 | // In that case we can apply a mask to map one range onto the other. |
7326 | LowerDiff = CR1.getLower() ^ CR2.getLower(); |
7327 | APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1); |
7328 | APInt CR1Size = CR1.getUpper() - CR1.getLower(); |
7329 | if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff || |
7330 | CR1Size != CR2.getUpper() - CR2.getLower()) |
7331 | return false; |
7332 | |
7333 | CR = CR1.getLower().ult(RHS: CR2.getLower()) ? CR1 : CR2; |
7334 | CreateMask = true; |
7335 | } |
7336 | |
7337 | if (IsAnd) |
7338 | CR = CR->inverse(); |
7339 | |
7340 | CmpInst::Predicate NewPred; |
7341 | APInt NewC, Offset; |
7342 | CR->getEquivalentICmp(Pred&: NewPred, RHS&: NewC, Offset); |
7343 | |
7344 | // We take the result type of one of the original icmps, CmpTy, for |
7345 | // the to be build icmp. The operand type, CmpOperandTy, is used for |
7346 | // the other instructions and constants to be build. The types of |
7347 | // the parameters and output are the same for add and and. CmpTy |
7348 | // and the type of DstReg might differ. That is why we zext or trunc |
7349 | // the icmp into the destination register. |
7350 | |
7351 | MatchInfo = [=](MachineIRBuilder &B) { |
7352 | if (CreateMask && Offset != 0) { |
7353 | auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff); |
7354 | auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask. |
7355 | auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset); |
7356 | auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: And, Src1: OffsetC, Flags); |
7357 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
7358 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon); |
7359 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
7360 | } else if (CreateMask && Offset == 0) { |
7361 | auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff); |
7362 | auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask. |
7363 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
7364 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: And, Op1: NewCon); |
7365 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
7366 | } else if (!CreateMask && Offset != 0) { |
7367 | auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset); |
7368 | auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: R1, Src1: OffsetC, Flags); |
7369 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
7370 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon); |
7371 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
7372 | } else if (!CreateMask && Offset == 0) { |
7373 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
7374 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: R1, Op1: NewCon); |
7375 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
7376 | } else { |
7377 | llvm_unreachable("unexpected configuration of CreateMask and Offset" ); |
7378 | } |
7379 | }; |
7380 | return true; |
7381 | } |
7382 | |
7383 | bool CombinerHelper::tryFoldLogicOfFCmps(GLogicalBinOp *Logic, |
7384 | BuildFnTy &MatchInfo) const { |
7385 | assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpecte xor" ); |
7386 | Register DestReg = Logic->getReg(Idx: 0); |
7387 | Register LHS = Logic->getLHSReg(); |
7388 | Register RHS = Logic->getRHSReg(); |
7389 | bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND; |
7390 | |
7391 | // We need a compare on the LHS register. |
7392 | GFCmp *Cmp1 = getOpcodeDef<GFCmp>(Reg: LHS, MRI); |
7393 | if (!Cmp1) |
7394 | return false; |
7395 | |
7396 | // We need a compare on the RHS register. |
7397 | GFCmp *Cmp2 = getOpcodeDef<GFCmp>(Reg: RHS, MRI); |
7398 | if (!Cmp2) |
7399 | return false; |
7400 | |
7401 | LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0)); |
7402 | LLT CmpOperandTy = MRI.getType(Reg: Cmp1->getLHSReg()); |
7403 | |
7404 | // We build one fcmp, want to fold the fcmps, replace the logic op, |
7405 | // and the fcmps must have the same shape. |
7406 | if (!isLegalOrBeforeLegalizer( |
7407 | Query: {TargetOpcode::G_FCMP, {CmpTy, CmpOperandTy}}) || |
7408 | !MRI.hasOneNonDBGUse(RegNo: Logic->getReg(Idx: 0)) || |
7409 | !MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) || |
7410 | !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)) || |
7411 | MRI.getType(Reg: Cmp1->getLHSReg()) != MRI.getType(Reg: Cmp2->getLHSReg())) |
7412 | return false; |
7413 | |
7414 | CmpInst::Predicate PredL = Cmp1->getCond(); |
7415 | CmpInst::Predicate PredR = Cmp2->getCond(); |
7416 | Register LHS0 = Cmp1->getLHSReg(); |
7417 | Register LHS1 = Cmp1->getRHSReg(); |
7418 | Register RHS0 = Cmp2->getLHSReg(); |
7419 | Register RHS1 = Cmp2->getRHSReg(); |
7420 | |
7421 | if (LHS0 == RHS1 && LHS1 == RHS0) { |
7422 | // Swap RHS operands to match LHS. |
7423 | PredR = CmpInst::getSwappedPredicate(pred: PredR); |
7424 | std::swap(a&: RHS0, b&: RHS1); |
7425 | } |
7426 | |
7427 | if (LHS0 == RHS0 && LHS1 == RHS1) { |
7428 | // We determine the new predicate. |
7429 | unsigned CmpCodeL = getFCmpCode(CC: PredL); |
7430 | unsigned CmpCodeR = getFCmpCode(CC: PredR); |
7431 | unsigned NewPred = IsAnd ? CmpCodeL & CmpCodeR : CmpCodeL | CmpCodeR; |
7432 | unsigned Flags = Cmp1->getFlags() | Cmp2->getFlags(); |
7433 | MatchInfo = [=](MachineIRBuilder &B) { |
7434 | // The fcmp predicates fill the lower part of the enum. |
7435 | FCmpInst::Predicate Pred = static_cast<FCmpInst::Predicate>(NewPred); |
7436 | if (Pred == FCmpInst::FCMP_FALSE && |
7437 | isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) { |
7438 | auto False = B.buildConstant(Res: CmpTy, Val: 0); |
7439 | B.buildZExtOrTrunc(Res: DestReg, Op: False); |
7440 | } else if (Pred == FCmpInst::FCMP_TRUE && |
7441 | isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) { |
7442 | auto True = |
7443 | B.buildConstant(Res: CmpTy, Val: getICmpTrueVal(TLI: getTargetLowering(), |
7444 | IsVector: CmpTy.isVector() /*isVector*/, |
7445 | IsFP: true /*isFP*/)); |
7446 | B.buildZExtOrTrunc(Res: DestReg, Op: True); |
7447 | } else { // We take the predicate without predicate optimizations. |
7448 | auto Cmp = B.buildFCmp(Pred, Res: CmpTy, Op0: LHS0, Op1: LHS1, Flags); |
7449 | B.buildZExtOrTrunc(Res: DestReg, Op: Cmp); |
7450 | } |
7451 | }; |
7452 | return true; |
7453 | } |
7454 | |
7455 | return false; |
7456 | } |
7457 | |
7458 | bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) const { |
7459 | GAnd *And = cast<GAnd>(Val: &MI); |
7460 | |
7461 | if (tryFoldAndOrOrICmpsUsingRanges(Logic: And, MatchInfo)) |
7462 | return true; |
7463 | |
7464 | if (tryFoldLogicOfFCmps(Logic: And, MatchInfo)) |
7465 | return true; |
7466 | |
7467 | return false; |
7468 | } |
7469 | |
7470 | bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) const { |
7471 | GOr *Or = cast<GOr>(Val: &MI); |
7472 | |
7473 | if (tryFoldAndOrOrICmpsUsingRanges(Logic: Or, MatchInfo)) |
7474 | return true; |
7475 | |
7476 | if (tryFoldLogicOfFCmps(Logic: Or, MatchInfo)) |
7477 | return true; |
7478 | |
7479 | return false; |
7480 | } |
7481 | |
7482 | bool CombinerHelper::matchAddOverflow(MachineInstr &MI, |
7483 | BuildFnTy &MatchInfo) const { |
7484 | GAddCarryOut *Add = cast<GAddCarryOut>(Val: &MI); |
7485 | |
7486 | // Addo has no flags |
7487 | Register Dst = Add->getReg(Idx: 0); |
7488 | Register Carry = Add->getReg(Idx: 1); |
7489 | Register LHS = Add->getLHSReg(); |
7490 | Register RHS = Add->getRHSReg(); |
7491 | bool IsSigned = Add->isSigned(); |
7492 | LLT DstTy = MRI.getType(Reg: Dst); |
7493 | LLT CarryTy = MRI.getType(Reg: Carry); |
7494 | |
7495 | // Fold addo, if the carry is dead -> add, undef. |
7496 | if (MRI.use_nodbg_empty(RegNo: Carry) && |
7497 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}})) { |
7498 | MatchInfo = [=](MachineIRBuilder &B) { |
7499 | B.buildAdd(Dst, Src0: LHS, Src1: RHS); |
7500 | B.buildUndef(Res: Carry); |
7501 | }; |
7502 | return true; |
7503 | } |
7504 | |
7505 | // Canonicalize constant to RHS. |
7506 | if (isConstantOrConstantVectorI(Src: LHS) && !isConstantOrConstantVectorI(Src: RHS)) { |
7507 | if (IsSigned) { |
7508 | MatchInfo = [=](MachineIRBuilder &B) { |
7509 | B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS); |
7510 | }; |
7511 | return true; |
7512 | } |
7513 | // !IsSigned |
7514 | MatchInfo = [=](MachineIRBuilder &B) { |
7515 | B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS); |
7516 | }; |
7517 | return true; |
7518 | } |
7519 | |
7520 | std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(Src: LHS); |
7521 | std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(Src: RHS); |
7522 | |
7523 | // Fold addo(c1, c2) -> c3, carry. |
7524 | if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(Ty: DstTy) && |
7525 | isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) { |
7526 | bool Overflow; |
7527 | APInt Result = IsSigned ? MaybeLHS->sadd_ov(RHS: *MaybeRHS, Overflow) |
7528 | : MaybeLHS->uadd_ov(RHS: *MaybeRHS, Overflow); |
7529 | MatchInfo = [=](MachineIRBuilder &B) { |
7530 | B.buildConstant(Res: Dst, Val: Result); |
7531 | B.buildConstant(Res: Carry, Val: Overflow); |
7532 | }; |
7533 | return true; |
7534 | } |
7535 | |
7536 | // Fold (addo x, 0) -> x, no carry |
7537 | if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) { |
7538 | MatchInfo = [=](MachineIRBuilder &B) { |
7539 | B.buildCopy(Res: Dst, Op: LHS); |
7540 | B.buildConstant(Res: Carry, Val: 0); |
7541 | }; |
7542 | return true; |
7543 | } |
7544 | |
7545 | // Given 2 constant operands whose sum does not overflow: |
7546 | // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1 |
7547 | // saddo (X +nsw C0), C1 -> saddo X, C0 + C1 |
7548 | GAdd *AddLHS = getOpcodeDef<GAdd>(Reg: LHS, MRI); |
7549 | if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)) && |
7550 | ((IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoSWrap)) || |
7551 | (!IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoUWrap)))) { |
7552 | std::optional<APInt> MaybeAddRHS = |
7553 | getConstantOrConstantSplatVector(Src: AddLHS->getRHSReg()); |
7554 | if (MaybeAddRHS) { |
7555 | bool Overflow; |
7556 | APInt NewC = IsSigned ? MaybeAddRHS->sadd_ov(RHS: *MaybeRHS, Overflow) |
7557 | : MaybeAddRHS->uadd_ov(RHS: *MaybeRHS, Overflow); |
7558 | if (!Overflow && isConstantLegalOrBeforeLegalizer(Ty: DstTy)) { |
7559 | if (IsSigned) { |
7560 | MatchInfo = [=](MachineIRBuilder &B) { |
7561 | auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC); |
7562 | B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS); |
7563 | }; |
7564 | return true; |
7565 | } |
7566 | // !IsSigned |
7567 | MatchInfo = [=](MachineIRBuilder &B) { |
7568 | auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC); |
7569 | B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS); |
7570 | }; |
7571 | return true; |
7572 | } |
7573 | } |
7574 | }; |
7575 | |
7576 | // We try to combine addo to non-overflowing add. |
7577 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}}) || |
7578 | !isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) |
7579 | return false; |
7580 | |
7581 | // We try to combine uaddo to non-overflowing add. |
7582 | if (!IsSigned) { |
7583 | ConstantRange CRLHS = |
7584 | ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS), /*IsSigned=*/false); |
7585 | ConstantRange CRRHS = |
7586 | ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS), /*IsSigned=*/false); |
7587 | |
7588 | switch (CRLHS.unsignedAddMayOverflow(Other: CRRHS)) { |
7589 | case ConstantRange::OverflowResult::MayOverflow: |
7590 | return false; |
7591 | case ConstantRange::OverflowResult::NeverOverflows: { |
7592 | MatchInfo = [=](MachineIRBuilder &B) { |
7593 | B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap); |
7594 | B.buildConstant(Res: Carry, Val: 0); |
7595 | }; |
7596 | return true; |
7597 | } |
7598 | case ConstantRange::OverflowResult::AlwaysOverflowsLow: |
7599 | case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { |
7600 | MatchInfo = [=](MachineIRBuilder &B) { |
7601 | B.buildAdd(Dst, Src0: LHS, Src1: RHS); |
7602 | B.buildConstant(Res: Carry, Val: 1); |
7603 | }; |
7604 | return true; |
7605 | } |
7606 | } |
7607 | return false; |
7608 | } |
7609 | |
7610 | // We try to combine saddo to non-overflowing add. |
7611 | |
7612 | // If LHS and RHS each have at least two sign bits, then there is no signed |
7613 | // overflow. |
7614 | if (VT->computeNumSignBits(R: RHS) > 1 && VT->computeNumSignBits(R: LHS) > 1) { |
7615 | MatchInfo = [=](MachineIRBuilder &B) { |
7616 | B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap); |
7617 | B.buildConstant(Res: Carry, Val: 0); |
7618 | }; |
7619 | return true; |
7620 | } |
7621 | |
7622 | ConstantRange CRLHS = |
7623 | ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS), /*IsSigned=*/true); |
7624 | ConstantRange CRRHS = |
7625 | ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS), /*IsSigned=*/true); |
7626 | |
7627 | switch (CRLHS.signedAddMayOverflow(Other: CRRHS)) { |
7628 | case ConstantRange::OverflowResult::MayOverflow: |
7629 | return false; |
7630 | case ConstantRange::OverflowResult::NeverOverflows: { |
7631 | MatchInfo = [=](MachineIRBuilder &B) { |
7632 | B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap); |
7633 | B.buildConstant(Res: Carry, Val: 0); |
7634 | }; |
7635 | return true; |
7636 | } |
7637 | case ConstantRange::OverflowResult::AlwaysOverflowsLow: |
7638 | case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { |
7639 | MatchInfo = [=](MachineIRBuilder &B) { |
7640 | B.buildAdd(Dst, Src0: LHS, Src1: RHS); |
7641 | B.buildConstant(Res: Carry, Val: 1); |
7642 | }; |
7643 | return true; |
7644 | } |
7645 | } |
7646 | |
7647 | return false; |
7648 | } |
7649 | |
7650 | void CombinerHelper::applyBuildFnMO(const MachineOperand &MO, |
7651 | BuildFnTy &MatchInfo) const { |
7652 | MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI); |
7653 | MatchInfo(Builder); |
7654 | Root->eraseFromParent(); |
7655 | } |
7656 | |
7657 | bool CombinerHelper::matchFPowIExpansion(MachineInstr &MI, |
7658 | int64_t Exponent) const { |
7659 | bool OptForSize = MI.getMF()->getFunction().hasOptSize(); |
7660 | return getTargetLowering().isBeneficialToExpandPowI(Exponent, OptForSize); |
7661 | } |
7662 | |
7663 | void CombinerHelper::applyExpandFPowI(MachineInstr &MI, |
7664 | int64_t Exponent) const { |
7665 | auto [Dst, Base] = MI.getFirst2Regs(); |
7666 | LLT Ty = MRI.getType(Reg: Dst); |
7667 | int64_t ExpVal = Exponent; |
7668 | |
7669 | if (ExpVal == 0) { |
7670 | Builder.buildFConstant(Res: Dst, Val: 1.0); |
7671 | MI.removeFromParent(); |
7672 | return; |
7673 | } |
7674 | |
7675 | if (ExpVal < 0) |
7676 | ExpVal = -ExpVal; |
7677 | |
7678 | // We use the simple binary decomposition method from SelectionDAG ExpandPowI |
7679 | // to generate the multiply sequence. There are more optimal ways to do this |
7680 | // (for example, powi(x,15) generates one more multiply than it should), but |
7681 | // this has the benefit of being both really simple and much better than a |
7682 | // libcall. |
7683 | std::optional<SrcOp> Res; |
7684 | SrcOp CurSquare = Base; |
7685 | while (ExpVal > 0) { |
7686 | if (ExpVal & 1) { |
7687 | if (!Res) |
7688 | Res = CurSquare; |
7689 | else |
7690 | Res = Builder.buildFMul(Dst: Ty, Src0: *Res, Src1: CurSquare); |
7691 | } |
7692 | |
7693 | CurSquare = Builder.buildFMul(Dst: Ty, Src0: CurSquare, Src1: CurSquare); |
7694 | ExpVal >>= 1; |
7695 | } |
7696 | |
7697 | // If the original exponent was negative, invert the result, producing |
7698 | // 1/(x*x*x). |
7699 | if (Exponent < 0) |
7700 | Res = Builder.buildFDiv(Dst: Ty, Src0: Builder.buildFConstant(Res: Ty, Val: 1.0), Src1: *Res, |
7701 | Flags: MI.getFlags()); |
7702 | |
7703 | Builder.buildCopy(Res: Dst, Op: *Res); |
7704 | MI.eraseFromParent(); |
7705 | } |
7706 | |
7707 | bool CombinerHelper::matchFoldAPlusC1MinusC2(const MachineInstr &MI, |
7708 | BuildFnTy &MatchInfo) const { |
7709 | // fold (A+C1)-C2 -> A+(C1-C2) |
7710 | const GSub *Sub = cast<GSub>(Val: &MI); |
7711 | GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: Sub->getLHSReg())); |
7712 | |
7713 | if (!MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0))) |
7714 | return false; |
7715 | |
7716 | APInt C2 = getIConstantFromReg(VReg: Sub->getRHSReg(), MRI); |
7717 | APInt C1 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI); |
7718 | |
7719 | Register Dst = Sub->getReg(Idx: 0); |
7720 | LLT DstTy = MRI.getType(Reg: Dst); |
7721 | |
7722 | MatchInfo = [=](MachineIRBuilder &B) { |
7723 | auto Const = B.buildConstant(Res: DstTy, Val: C1 - C2); |
7724 | B.buildAdd(Dst, Src0: Add->getLHSReg(), Src1: Const); |
7725 | }; |
7726 | |
7727 | return true; |
7728 | } |
7729 | |
7730 | bool CombinerHelper::matchFoldC2MinusAPlusC1(const MachineInstr &MI, |
7731 | BuildFnTy &MatchInfo) const { |
7732 | // fold C2-(A+C1) -> (C2-C1)-A |
7733 | const GSub *Sub = cast<GSub>(Val: &MI); |
7734 | GAdd *Add = cast<GAdd>(Val: MRI.getVRegDef(Reg: Sub->getRHSReg())); |
7735 | |
7736 | if (!MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0))) |
7737 | return false; |
7738 | |
7739 | APInt C2 = getIConstantFromReg(VReg: Sub->getLHSReg(), MRI); |
7740 | APInt C1 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI); |
7741 | |
7742 | Register Dst = Sub->getReg(Idx: 0); |
7743 | LLT DstTy = MRI.getType(Reg: Dst); |
7744 | |
7745 | MatchInfo = [=](MachineIRBuilder &B) { |
7746 | auto Const = B.buildConstant(Res: DstTy, Val: C2 - C1); |
7747 | B.buildSub(Dst, Src0: Const, Src1: Add->getLHSReg()); |
7748 | }; |
7749 | |
7750 | return true; |
7751 | } |
7752 | |
7753 | bool CombinerHelper::matchFoldAMinusC1MinusC2(const MachineInstr &MI, |
7754 | BuildFnTy &MatchInfo) const { |
7755 | // fold (A-C1)-C2 -> A-(C1+C2) |
7756 | const GSub *Sub1 = cast<GSub>(Val: &MI); |
7757 | GSub *Sub2 = cast<GSub>(Val: MRI.getVRegDef(Reg: Sub1->getLHSReg())); |
7758 | |
7759 | if (!MRI.hasOneNonDBGUse(RegNo: Sub2->getReg(Idx: 0))) |
7760 | return false; |
7761 | |
7762 | APInt C2 = getIConstantFromReg(VReg: Sub1->getRHSReg(), MRI); |
7763 | APInt C1 = getIConstantFromReg(VReg: Sub2->getRHSReg(), MRI); |
7764 | |
7765 | Register Dst = Sub1->getReg(Idx: 0); |
7766 | LLT DstTy = MRI.getType(Reg: Dst); |
7767 | |
7768 | MatchInfo = [=](MachineIRBuilder &B) { |
7769 | auto Const = B.buildConstant(Res: DstTy, Val: C1 + C2); |
7770 | B.buildSub(Dst, Src0: Sub2->getLHSReg(), Src1: Const); |
7771 | }; |
7772 | |
7773 | return true; |
7774 | } |
7775 | |
7776 | bool CombinerHelper::matchFoldC1Minus2MinusC2(const MachineInstr &MI, |
7777 | BuildFnTy &MatchInfo) const { |
7778 | // fold (C1-A)-C2 -> (C1-C2)-A |
7779 | const GSub *Sub1 = cast<GSub>(Val: &MI); |
7780 | GSub *Sub2 = cast<GSub>(Val: MRI.getVRegDef(Reg: Sub1->getLHSReg())); |
7781 | |
7782 | if (!MRI.hasOneNonDBGUse(RegNo: Sub2->getReg(Idx: 0))) |
7783 | return false; |
7784 | |
7785 | APInt C2 = getIConstantFromReg(VReg: Sub1->getRHSReg(), MRI); |
7786 | APInt C1 = getIConstantFromReg(VReg: Sub2->getLHSReg(), MRI); |
7787 | |
7788 | Register Dst = Sub1->getReg(Idx: 0); |
7789 | LLT DstTy = MRI.getType(Reg: Dst); |
7790 | |
7791 | MatchInfo = [=](MachineIRBuilder &B) { |
7792 | auto Const = B.buildConstant(Res: DstTy, Val: C1 - C2); |
7793 | B.buildSub(Dst, Src0: Const, Src1: Sub2->getRHSReg()); |
7794 | }; |
7795 | |
7796 | return true; |
7797 | } |
7798 | |
7799 | bool CombinerHelper::matchFoldAMinusC1PlusC2(const MachineInstr &MI, |
7800 | BuildFnTy &MatchInfo) const { |
7801 | // fold ((A-C1)+C2) -> (A+(C2-C1)) |
7802 | const GAdd *Add = cast<GAdd>(Val: &MI); |
7803 | GSub *Sub = cast<GSub>(Val: MRI.getVRegDef(Reg: Add->getLHSReg())); |
7804 | |
7805 | if (!MRI.hasOneNonDBGUse(RegNo: Sub->getReg(Idx: 0))) |
7806 | return false; |
7807 | |
7808 | APInt C2 = getIConstantFromReg(VReg: Add->getRHSReg(), MRI); |
7809 | APInt C1 = getIConstantFromReg(VReg: Sub->getRHSReg(), MRI); |
7810 | |
7811 | Register Dst = Add->getReg(Idx: 0); |
7812 | LLT DstTy = MRI.getType(Reg: Dst); |
7813 | |
7814 | MatchInfo = [=](MachineIRBuilder &B) { |
7815 | auto Const = B.buildConstant(Res: DstTy, Val: C2 - C1); |
7816 | B.buildAdd(Dst, Src0: Sub->getLHSReg(), Src1: Const); |
7817 | }; |
7818 | |
7819 | return true; |
7820 | } |
7821 | |
7822 | bool CombinerHelper::matchUnmergeValuesAnyExtBuildVector( |
7823 | const MachineInstr &MI, BuildFnTy &MatchInfo) const { |
7824 | const GUnmerge *Unmerge = cast<GUnmerge>(Val: &MI); |
7825 | |
7826 | if (!MRI.hasOneNonDBGUse(RegNo: Unmerge->getSourceReg())) |
7827 | return false; |
7828 | |
7829 | const MachineInstr *Source = MRI.getVRegDef(Reg: Unmerge->getSourceReg()); |
7830 | |
7831 | LLT DstTy = MRI.getType(Reg: Unmerge->getReg(Idx: 0)); |
7832 | |
7833 | // $bv:_(<8 x s8>) = G_BUILD_VECTOR .... |
7834 | // $any:_(<8 x s16>) = G_ANYEXT $bv |
7835 | // $uv:_(<4 x s16>), $uv1:_(<4 x s16>) = G_UNMERGE_VALUES $any |
7836 | // |
7837 | // -> |
7838 | // |
7839 | // $any:_(s16) = G_ANYEXT $bv[0] |
7840 | // $any1:_(s16) = G_ANYEXT $bv[1] |
7841 | // $any2:_(s16) = G_ANYEXT $bv[2] |
7842 | // $any3:_(s16) = G_ANYEXT $bv[3] |
7843 | // $any4:_(s16) = G_ANYEXT $bv[4] |
7844 | // $any5:_(s16) = G_ANYEXT $bv[5] |
7845 | // $any6:_(s16) = G_ANYEXT $bv[6] |
7846 | // $any7:_(s16) = G_ANYEXT $bv[7] |
7847 | // $uv:_(<4 x s16>) = G_BUILD_VECTOR $any, $any1, $any2, $any3 |
7848 | // $uv1:_(<4 x s16>) = G_BUILD_VECTOR $any4, $any5, $any6, $any7 |
7849 | |
7850 | // We want to unmerge into vectors. |
7851 | if (!DstTy.isFixedVector()) |
7852 | return false; |
7853 | |
7854 | const GAnyExt *Any = dyn_cast<GAnyExt>(Val: Source); |
7855 | if (!Any) |
7856 | return false; |
7857 | |
7858 | const MachineInstr *NextSource = MRI.getVRegDef(Reg: Any->getSrcReg()); |
7859 | |
7860 | if (const GBuildVector *BV = dyn_cast<GBuildVector>(Val: NextSource)) { |
7861 | // G_UNMERGE_VALUES G_ANYEXT G_BUILD_VECTOR |
7862 | |
7863 | if (!MRI.hasOneNonDBGUse(RegNo: BV->getReg(Idx: 0))) |
7864 | return false; |
7865 | |
7866 | // FIXME: check element types? |
7867 | if (BV->getNumSources() % Unmerge->getNumDefs() != 0) |
7868 | return false; |
7869 | |
7870 | LLT BigBvTy = MRI.getType(Reg: BV->getReg(Idx: 0)); |
7871 | LLT SmallBvTy = DstTy; |
7872 | LLT SmallBvElemenTy = SmallBvTy.getElementType(); |
7873 | |
7874 | if (!isLegalOrBeforeLegalizer( |
7875 | Query: {TargetOpcode::G_BUILD_VECTOR, {SmallBvTy, SmallBvElemenTy}})) |
7876 | return false; |
7877 | |
7878 | // We check the legality of scalar anyext. |
7879 | if (!isLegalOrBeforeLegalizer( |
7880 | Query: {TargetOpcode::G_ANYEXT, |
7881 | {SmallBvElemenTy, BigBvTy.getElementType()}})) |
7882 | return false; |
7883 | |
7884 | MatchInfo = [=](MachineIRBuilder &B) { |
7885 | // Build into each G_UNMERGE_VALUES def |
7886 | // a small build vector with anyext from the source build vector. |
7887 | for (unsigned I = 0; I < Unmerge->getNumDefs(); ++I) { |
7888 | SmallVector<Register> Ops; |
7889 | for (unsigned J = 0; J < SmallBvTy.getNumElements(); ++J) { |
7890 | Register SourceArray = |
7891 | BV->getSourceReg(I: I * SmallBvTy.getNumElements() + J); |
7892 | auto AnyExt = B.buildAnyExt(Res: SmallBvElemenTy, Op: SourceArray); |
7893 | Ops.push_back(Elt: AnyExt.getReg(Idx: 0)); |
7894 | } |
7895 | B.buildBuildVector(Res: Unmerge->getOperand(i: I).getReg(), Ops); |
7896 | }; |
7897 | }; |
7898 | return true; |
7899 | }; |
7900 | |
7901 | return false; |
7902 | } |
7903 | |
7904 | bool CombinerHelper::matchShuffleUndefRHS(MachineInstr &MI, |
7905 | BuildFnTy &MatchInfo) const { |
7906 | |
7907 | bool Changed = false; |
7908 | auto &Shuffle = cast<GShuffleVector>(Val&: MI); |
7909 | ArrayRef<int> OrigMask = Shuffle.getMask(); |
7910 | SmallVector<int, 16> NewMask; |
7911 | const LLT SrcTy = MRI.getType(Reg: Shuffle.getSrc1Reg()); |
7912 | const unsigned NumSrcElems = SrcTy.isVector() ? SrcTy.getNumElements() : 1; |
7913 | const unsigned NumDstElts = OrigMask.size(); |
7914 | for (unsigned i = 0; i != NumDstElts; ++i) { |
7915 | int Idx = OrigMask[i]; |
7916 | if (Idx >= (int)NumSrcElems) { |
7917 | Idx = -1; |
7918 | Changed = true; |
7919 | } |
7920 | NewMask.push_back(Elt: Idx); |
7921 | } |
7922 | |
7923 | if (!Changed) |
7924 | return false; |
7925 | |
7926 | MatchInfo = [&, NewMask = std::move(NewMask)](MachineIRBuilder &B) { |
7927 | B.buildShuffleVector(Res: MI.getOperand(i: 0), Src1: MI.getOperand(i: 1), Src2: MI.getOperand(i: 2), |
7928 | Mask: std::move(NewMask)); |
7929 | }; |
7930 | |
7931 | return true; |
7932 | } |
7933 | |
7934 | static void commuteMask(MutableArrayRef<int> Mask, const unsigned NumElems) { |
7935 | const unsigned MaskSize = Mask.size(); |
7936 | for (unsigned I = 0; I < MaskSize; ++I) { |
7937 | int Idx = Mask[I]; |
7938 | if (Idx < 0) |
7939 | continue; |
7940 | |
7941 | if (Idx < (int)NumElems) |
7942 | Mask[I] = Idx + NumElems; |
7943 | else |
7944 | Mask[I] = Idx - NumElems; |
7945 | } |
7946 | } |
7947 | |
7948 | bool CombinerHelper::matchShuffleDisjointMask(MachineInstr &MI, |
7949 | BuildFnTy &MatchInfo) const { |
7950 | |
7951 | auto &Shuffle = cast<GShuffleVector>(Val&: MI); |
7952 | // If any of the two inputs is already undef, don't check the mask again to |
7953 | // prevent infinite loop |
7954 | if (getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: Shuffle.getSrc1Reg(), MRI)) |
7955 | return false; |
7956 | |
7957 | if (getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: Shuffle.getSrc2Reg(), MRI)) |
7958 | return false; |
7959 | |
7960 | const LLT DstTy = MRI.getType(Reg: Shuffle.getReg(Idx: 0)); |
7961 | const LLT Src1Ty = MRI.getType(Reg: Shuffle.getSrc1Reg()); |
7962 | if (!isLegalOrBeforeLegalizer( |
7963 | Query: {TargetOpcode::G_SHUFFLE_VECTOR, {DstTy, Src1Ty}})) |
7964 | return false; |
7965 | |
7966 | ArrayRef<int> Mask = Shuffle.getMask(); |
7967 | const unsigned NumSrcElems = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1; |
7968 | |
7969 | bool TouchesSrc1 = false; |
7970 | bool TouchesSrc2 = false; |
7971 | const unsigned NumElems = Mask.size(); |
7972 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
7973 | if (Mask[Idx] < 0) |
7974 | continue; |
7975 | |
7976 | if (Mask[Idx] < (int)NumSrcElems) |
7977 | TouchesSrc1 = true; |
7978 | else |
7979 | TouchesSrc2 = true; |
7980 | } |
7981 | |
7982 | if (TouchesSrc1 == TouchesSrc2) |
7983 | return false; |
7984 | |
7985 | Register NewSrc1 = Shuffle.getSrc1Reg(); |
7986 | SmallVector<int, 16> NewMask(Mask); |
7987 | if (TouchesSrc2) { |
7988 | NewSrc1 = Shuffle.getSrc2Reg(); |
7989 | commuteMask(Mask: NewMask, NumElems: NumSrcElems); |
7990 | } |
7991 | |
7992 | MatchInfo = [=, &Shuffle](MachineIRBuilder &B) { |
7993 | auto Undef = B.buildUndef(Res: Src1Ty); |
7994 | B.buildShuffleVector(Res: Shuffle.getReg(Idx: 0), Src1: NewSrc1, Src2: Undef, Mask: NewMask); |
7995 | }; |
7996 | |
7997 | return true; |
7998 | } |
7999 | |
8000 | bool CombinerHelper::matchSuboCarryOut(const MachineInstr &MI, |
8001 | BuildFnTy &MatchInfo) const { |
8002 | const GSubCarryOut *Subo = cast<GSubCarryOut>(Val: &MI); |
8003 | |
8004 | Register Dst = Subo->getReg(Idx: 0); |
8005 | Register LHS = Subo->getLHSReg(); |
8006 | Register RHS = Subo->getRHSReg(); |
8007 | Register Carry = Subo->getCarryOutReg(); |
8008 | LLT DstTy = MRI.getType(Reg: Dst); |
8009 | LLT CarryTy = MRI.getType(Reg: Carry); |
8010 | |
8011 | // Check legality before known bits. |
8012 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SUB, {DstTy}}) || |
8013 | !isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) |
8014 | return false; |
8015 | |
8016 | ConstantRange KBLHS = |
8017 | ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: LHS), |
8018 | /* IsSigned=*/Subo->isSigned()); |
8019 | ConstantRange KBRHS = |
8020 | ConstantRange::fromKnownBits(Known: VT->getKnownBits(R: RHS), |
8021 | /* IsSigned=*/Subo->isSigned()); |
8022 | |
8023 | if (Subo->isSigned()) { |
8024 | // G_SSUBO |
8025 | switch (KBLHS.signedSubMayOverflow(Other: KBRHS)) { |
8026 | case ConstantRange::OverflowResult::MayOverflow: |
8027 | return false; |
8028 | case ConstantRange::OverflowResult::NeverOverflows: { |
8029 | MatchInfo = [=](MachineIRBuilder &B) { |
8030 | B.buildSub(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap); |
8031 | B.buildConstant(Res: Carry, Val: 0); |
8032 | }; |
8033 | return true; |
8034 | } |
8035 | case ConstantRange::OverflowResult::AlwaysOverflowsLow: |
8036 | case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { |
8037 | MatchInfo = [=](MachineIRBuilder &B) { |
8038 | B.buildSub(Dst, Src0: LHS, Src1: RHS); |
8039 | B.buildConstant(Res: Carry, Val: getICmpTrueVal(TLI: getTargetLowering(), |
8040 | /*isVector=*/IsVector: CarryTy.isVector(), |
8041 | /*isFP=*/IsFP: false)); |
8042 | }; |
8043 | return true; |
8044 | } |
8045 | } |
8046 | return false; |
8047 | } |
8048 | |
8049 | // G_USUBO |
8050 | switch (KBLHS.unsignedSubMayOverflow(Other: KBRHS)) { |
8051 | case ConstantRange::OverflowResult::MayOverflow: |
8052 | return false; |
8053 | case ConstantRange::OverflowResult::NeverOverflows: { |
8054 | MatchInfo = [=](MachineIRBuilder &B) { |
8055 | B.buildSub(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap); |
8056 | B.buildConstant(Res: Carry, Val: 0); |
8057 | }; |
8058 | return true; |
8059 | } |
8060 | case ConstantRange::OverflowResult::AlwaysOverflowsLow: |
8061 | case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { |
8062 | MatchInfo = [=](MachineIRBuilder &B) { |
8063 | B.buildSub(Dst, Src0: LHS, Src1: RHS); |
8064 | B.buildConstant(Res: Carry, Val: getICmpTrueVal(TLI: getTargetLowering(), |
8065 | /*isVector=*/IsVector: CarryTy.isVector(), |
8066 | /*isFP=*/IsFP: false)); |
8067 | }; |
8068 | return true; |
8069 | } |
8070 | } |
8071 | |
8072 | return false; |
8073 | } |
8074 | |