1 | //===- CombinerHelperCasts.cpp---------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements CombinerHelper for G_ANYEXT, G_SEXT, G_TRUNC, and |
10 | // G_ZEXT |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" |
14 | #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" |
15 | #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" |
16 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
17 | #include "llvm/CodeGen/GlobalISel/Utils.h" |
18 | #include "llvm/CodeGen/LowLevelTypeUtils.h" |
19 | #include "llvm/CodeGen/MachineOperand.h" |
20 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
21 | #include "llvm/CodeGen/TargetOpcodes.h" |
22 | #include "llvm/Support/Casting.h" |
23 | |
24 | #define DEBUG_TYPE "gi-combiner" |
25 | |
26 | using namespace llvm; |
27 | |
28 | bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO, |
29 | BuildFnTy &MatchInfo) const { |
30 | GSext *Sext = cast<GSext>(Val: getDefIgnoringCopies(Reg: MO.getReg(), MRI)); |
31 | GTrunc *Trunc = cast<GTrunc>(Val: getDefIgnoringCopies(Reg: Sext->getSrcReg(), MRI)); |
32 | |
33 | Register Dst = Sext->getReg(Idx: 0); |
34 | Register Src = Trunc->getSrcReg(); |
35 | |
36 | LLT DstTy = MRI.getType(Reg: Dst); |
37 | LLT SrcTy = MRI.getType(Reg: Src); |
38 | |
39 | // Combines without nsw trunc. |
40 | if (!Trunc->getFlag(Flag: MachineInstr::NoSWrap)) { |
41 | if (DstTy != SrcTy || |
42 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXT_INREG, {DstTy, SrcTy}})) |
43 | return false; |
44 | |
45 | // Do this for 8 bit values and up. We don't want to do it for e.g. G_TRUNC |
46 | // to i1. |
47 | unsigned TruncWidth = MRI.getType(Reg: Trunc->getReg(Idx: 0)).getScalarSizeInBits(); |
48 | if (TruncWidth < 8) |
49 | return false; |
50 | |
51 | MatchInfo = [=](MachineIRBuilder &B) { |
52 | B.buildSExtInReg(Res: Dst, Op: Src, ImmOp: TruncWidth); |
53 | }; |
54 | return true; |
55 | } |
56 | |
57 | // Combines for nsw trunc. |
58 | |
59 | if (DstTy == SrcTy) { |
60 | MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: Src); }; |
61 | return true; |
62 | } |
63 | |
64 | if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && |
65 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { |
66 | MatchInfo = [=](MachineIRBuilder &B) { |
67 | B.buildTrunc(Res: Dst, Op: Src, Flags: MachineInstr::MIFlag::NoSWrap); |
68 | }; |
69 | return true; |
70 | } |
71 | |
72 | if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && |
73 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXT, {DstTy, SrcTy}})) { |
74 | MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Res: Dst, Op: Src); }; |
75 | return true; |
76 | } |
77 | |
78 | return false; |
79 | } |
80 | |
81 | bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO, |
82 | BuildFnTy &MatchInfo) const { |
83 | GZext *Zext = cast<GZext>(Val: getDefIgnoringCopies(Reg: MO.getReg(), MRI)); |
84 | GTrunc *Trunc = cast<GTrunc>(Val: getDefIgnoringCopies(Reg: Zext->getSrcReg(), MRI)); |
85 | |
86 | Register Dst = Zext->getReg(Idx: 0); |
87 | Register Src = Trunc->getSrcReg(); |
88 | |
89 | LLT DstTy = MRI.getType(Reg: Dst); |
90 | LLT SrcTy = MRI.getType(Reg: Src); |
91 | |
92 | if (DstTy == SrcTy) { |
93 | MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: Src); }; |
94 | return true; |
95 | } |
96 | |
97 | if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && |
98 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { |
99 | MatchInfo = [=](MachineIRBuilder &B) { |
100 | B.buildTrunc(Res: Dst, Op: Src, Flags: MachineInstr::MIFlag::NoUWrap); |
101 | }; |
102 | return true; |
103 | } |
104 | |
105 | if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && |
106 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) { |
107 | MatchInfo = [=](MachineIRBuilder &B) { |
108 | B.buildZExt(Res: Dst, Op: Src, Flags: MachineInstr::MIFlag::NonNeg); |
109 | }; |
110 | return true; |
111 | } |
112 | |
113 | return false; |
114 | } |
115 | |
116 | bool CombinerHelper::matchNonNegZext(const MachineOperand &MO, |
117 | BuildFnTy &MatchInfo) const { |
118 | GZext *Zext = cast<GZext>(Val: MRI.getVRegDef(Reg: MO.getReg())); |
119 | |
120 | Register Dst = Zext->getReg(Idx: 0); |
121 | Register Src = Zext->getSrcReg(); |
122 | |
123 | LLT DstTy = MRI.getType(Reg: Dst); |
124 | LLT SrcTy = MRI.getType(Reg: Src); |
125 | const auto &TLI = getTargetLowering(); |
126 | |
127 | // Convert zext nneg to sext if sext is the preferred form for the target. |
128 | if (isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXT, {DstTy, SrcTy}}) && |
129 | TLI.isSExtCheaperThanZExt(FromTy: getMVTForLLT(Ty: SrcTy), ToTy: getMVTForLLT(Ty: DstTy))) { |
130 | MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Res: Dst, Op: Src); }; |
131 | return true; |
132 | } |
133 | |
134 | return false; |
135 | } |
136 | |
137 | bool CombinerHelper::matchTruncateOfExt(const MachineInstr &Root, |
138 | const MachineInstr &ExtMI, |
139 | BuildFnTy &MatchInfo) const { |
140 | const GTrunc *Trunc = cast<GTrunc>(Val: &Root); |
141 | const GExtOp *Ext = cast<GExtOp>(Val: &ExtMI); |
142 | |
143 | if (!MRI.hasOneNonDBGUse(RegNo: Ext->getReg(Idx: 0))) |
144 | return false; |
145 | |
146 | Register Dst = Trunc->getReg(Idx: 0); |
147 | Register Src = Ext->getSrcReg(); |
148 | LLT DstTy = MRI.getType(Reg: Dst); |
149 | LLT SrcTy = MRI.getType(Reg: Src); |
150 | |
151 | if (SrcTy == DstTy) { |
152 | // The source and the destination are equally sized. We need to copy. |
153 | MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: Src); }; |
154 | |
155 | return true; |
156 | } |
157 | |
158 | if (SrcTy.getScalarSizeInBits() < DstTy.getScalarSizeInBits()) { |
159 | // If the source is smaller than the destination, we need to extend. |
160 | |
161 | if (!isLegalOrBeforeLegalizer(Query: {Ext->getOpcode(), {DstTy, SrcTy}})) |
162 | return false; |
163 | |
164 | MatchInfo = [=](MachineIRBuilder &B) { |
165 | B.buildInstr(Opc: Ext->getOpcode(), DstOps: {Dst}, SrcOps: {Src}); |
166 | }; |
167 | |
168 | return true; |
169 | } |
170 | |
171 | if (SrcTy.getScalarSizeInBits() > DstTy.getScalarSizeInBits()) { |
172 | // If the source is larger than the destination, then we need to truncate. |
173 | |
174 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) |
175 | return false; |
176 | |
177 | MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Res: Dst, Op: Src); }; |
178 | |
179 | return true; |
180 | } |
181 | |
182 | return false; |
183 | } |
184 | |
185 | bool CombinerHelper::isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const { |
186 | const TargetLowering &TLI = getTargetLowering(); |
187 | LLVMContext &Ctx = getContext(); |
188 | |
189 | switch (Opcode) { |
190 | case TargetOpcode::G_ANYEXT: |
191 | case TargetOpcode::G_ZEXT: |
192 | return TLI.isZExtFree(FromTy, ToTy, Ctx); |
193 | case TargetOpcode::G_TRUNC: |
194 | return TLI.isTruncateFree(FromTy, ToTy, Ctx); |
195 | default: |
196 | return false; |
197 | } |
198 | } |
199 | |
200 | bool CombinerHelper::matchCastOfSelect(const MachineInstr &CastMI, |
201 | const MachineInstr &SelectMI, |
202 | BuildFnTy &MatchInfo) const { |
203 | const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(Val: &CastMI); |
204 | const GSelect *Select = cast<GSelect>(Val: &SelectMI); |
205 | |
206 | if (!MRI.hasOneNonDBGUse(RegNo: Select->getReg(Idx: 0))) |
207 | return false; |
208 | |
209 | Register Dst = Cast->getReg(Idx: 0); |
210 | LLT DstTy = MRI.getType(Reg: Dst); |
211 | LLT CondTy = MRI.getType(Reg: Select->getCondReg()); |
212 | Register TrueReg = Select->getTrueReg(); |
213 | Register FalseReg = Select->getFalseReg(); |
214 | LLT SrcTy = MRI.getType(Reg: TrueReg); |
215 | Register Cond = Select->getCondReg(); |
216 | |
217 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SELECT, {DstTy, CondTy}})) |
218 | return false; |
219 | |
220 | if (!isCastFree(Opcode: Cast->getOpcode(), ToTy: DstTy, FromTy: SrcTy)) |
221 | return false; |
222 | |
223 | MatchInfo = [=](MachineIRBuilder &B) { |
224 | auto True = B.buildInstr(Opc: Cast->getOpcode(), DstOps: {DstTy}, SrcOps: {TrueReg}); |
225 | auto False = B.buildInstr(Opc: Cast->getOpcode(), DstOps: {DstTy}, SrcOps: {FalseReg}); |
226 | B.buildSelect(Res: Dst, Tst: Cond, Op0: True, Op1: False); |
227 | }; |
228 | |
229 | return true; |
230 | } |
231 | |
232 | bool CombinerHelper::matchExtOfExt(const MachineInstr &FirstMI, |
233 | const MachineInstr &SecondMI, |
234 | BuildFnTy &MatchInfo) const { |
235 | const GExtOp *First = cast<GExtOp>(Val: &FirstMI); |
236 | const GExtOp *Second = cast<GExtOp>(Val: &SecondMI); |
237 | |
238 | Register Dst = First->getReg(Idx: 0); |
239 | Register Src = Second->getSrcReg(); |
240 | LLT DstTy = MRI.getType(Reg: Dst); |
241 | LLT SrcTy = MRI.getType(Reg: Src); |
242 | |
243 | if (!MRI.hasOneNonDBGUse(RegNo: Second->getReg(Idx: 0))) |
244 | return false; |
245 | |
246 | // ext of ext -> later ext |
247 | if (First->getOpcode() == Second->getOpcode() && |
248 | isLegalOrBeforeLegalizer(Query: {Second->getOpcode(), {DstTy, SrcTy}})) { |
249 | if (Second->getOpcode() == TargetOpcode::G_ZEXT) { |
250 | MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; |
251 | if (Second->getFlag(Flag: MachineInstr::MIFlag::NonNeg)) |
252 | Flag = MachineInstr::MIFlag::NonNeg; |
253 | MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Res: Dst, Op: Src, Flags: Flag); }; |
254 | return true; |
255 | } |
256 | // not zext -> no flags |
257 | MatchInfo = [=](MachineIRBuilder &B) { |
258 | B.buildInstr(Opc: Second->getOpcode(), DstOps: {Dst}, SrcOps: {Src}); |
259 | }; |
260 | return true; |
261 | } |
262 | |
263 | // anyext of sext/zext -> sext/zext |
264 | // -> pick anyext as second ext, then ext of ext |
265 | if (First->getOpcode() == TargetOpcode::G_ANYEXT && |
266 | isLegalOrBeforeLegalizer(Query: {Second->getOpcode(), {DstTy, SrcTy}})) { |
267 | if (Second->getOpcode() == TargetOpcode::G_ZEXT) { |
268 | MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; |
269 | if (Second->getFlag(Flag: MachineInstr::MIFlag::NonNeg)) |
270 | Flag = MachineInstr::MIFlag::NonNeg; |
271 | MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Res: Dst, Op: Src, Flags: Flag); }; |
272 | return true; |
273 | } |
274 | MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Res: Dst, Op: Src); }; |
275 | return true; |
276 | } |
277 | |
278 | // sext/zext of anyext -> sext/zext |
279 | // -> pick anyext as first ext, then ext of ext |
280 | if (Second->getOpcode() == TargetOpcode::G_ANYEXT && |
281 | isLegalOrBeforeLegalizer(Query: {First->getOpcode(), {DstTy, SrcTy}})) { |
282 | if (First->getOpcode() == TargetOpcode::G_ZEXT) { |
283 | MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; |
284 | if (First->getFlag(Flag: MachineInstr::MIFlag::NonNeg)) |
285 | Flag = MachineInstr::MIFlag::NonNeg; |
286 | MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Res: Dst, Op: Src, Flags: Flag); }; |
287 | return true; |
288 | } |
289 | MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Res: Dst, Op: Src); }; |
290 | return true; |
291 | } |
292 | |
293 | return false; |
294 | } |
295 | |
296 | bool CombinerHelper::matchCastOfBuildVector(const MachineInstr &CastMI, |
297 | const MachineInstr &BVMI, |
298 | BuildFnTy &MatchInfo) const { |
299 | const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(Val: &CastMI); |
300 | const GBuildVector *BV = cast<GBuildVector>(Val: &BVMI); |
301 | |
302 | if (!MRI.hasOneNonDBGUse(RegNo: BV->getReg(Idx: 0))) |
303 | return false; |
304 | |
305 | Register Dst = Cast->getReg(Idx: 0); |
306 | // The type of the new build vector. |
307 | LLT DstTy = MRI.getType(Reg: Dst); |
308 | // The scalar or element type of the new build vector. |
309 | LLT ElemTy = DstTy.getScalarType(); |
310 | // The scalar or element type of the old build vector. |
311 | LLT InputElemTy = MRI.getType(Reg: BV->getReg(Idx: 0)).getElementType(); |
312 | |
313 | // Check legality of new build vector, the scalar casts, and profitability of |
314 | // the many casts. |
315 | if (!isLegalOrBeforeLegalizer( |
316 | Query: {TargetOpcode::G_BUILD_VECTOR, {DstTy, ElemTy}}) || |
317 | !isLegalOrBeforeLegalizer(Query: {Cast->getOpcode(), {ElemTy, InputElemTy}}) || |
318 | !isCastFree(Opcode: Cast->getOpcode(), ToTy: ElemTy, FromTy: InputElemTy)) |
319 | return false; |
320 | |
321 | MatchInfo = [=](MachineIRBuilder &B) { |
322 | SmallVector<Register> Casts; |
323 | unsigned Elements = BV->getNumSources(); |
324 | for (unsigned I = 0; I < Elements; ++I) { |
325 | auto CastI = |
326 | B.buildInstr(Opc: Cast->getOpcode(), DstOps: {ElemTy}, SrcOps: {BV->getSourceReg(I)}); |
327 | Casts.push_back(Elt: CastI.getReg(Idx: 0)); |
328 | } |
329 | |
330 | B.buildBuildVector(Res: Dst, Ops: Casts); |
331 | }; |
332 | |
333 | return true; |
334 | } |
335 | |
336 | bool CombinerHelper::matchNarrowBinop(const MachineInstr &TruncMI, |
337 | const MachineInstr &BinopMI, |
338 | BuildFnTy &MatchInfo) const { |
339 | const GTrunc *Trunc = cast<GTrunc>(Val: &TruncMI); |
340 | const GBinOp *BinOp = cast<GBinOp>(Val: &BinopMI); |
341 | |
342 | if (!MRI.hasOneNonDBGUse(RegNo: BinOp->getReg(Idx: 0))) |
343 | return false; |
344 | |
345 | Register Dst = Trunc->getReg(Idx: 0); |
346 | LLT DstTy = MRI.getType(Reg: Dst); |
347 | |
348 | // Is narrow binop legal? |
349 | if (!isLegalOrBeforeLegalizer(Query: {BinOp->getOpcode(), {DstTy}})) |
350 | return false; |
351 | |
352 | MatchInfo = [=](MachineIRBuilder &B) { |
353 | auto LHS = B.buildTrunc(Res: DstTy, Op: BinOp->getLHSReg()); |
354 | auto RHS = B.buildTrunc(Res: DstTy, Op: BinOp->getRHSReg()); |
355 | B.buildInstr(Opc: BinOp->getOpcode(), DstOps: {Dst}, SrcOps: {LHS, RHS}); |
356 | }; |
357 | |
358 | return true; |
359 | } |
360 | |
361 | bool CombinerHelper::matchCastOfInteger(const MachineInstr &CastMI, |
362 | APInt &MatchInfo) const { |
363 | const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(Val: &CastMI); |
364 | |
365 | APInt Input = getIConstantFromReg(VReg: Cast->getSrcReg(), MRI); |
366 | |
367 | LLT DstTy = MRI.getType(Reg: Cast->getReg(Idx: 0)); |
368 | |
369 | if (!isConstantLegalOrBeforeLegalizer(Ty: DstTy)) |
370 | return false; |
371 | |
372 | switch (Cast->getOpcode()) { |
373 | case TargetOpcode::G_TRUNC: { |
374 | MatchInfo = Input.trunc(width: DstTy.getScalarSizeInBits()); |
375 | return true; |
376 | } |
377 | default: |
378 | return false; |
379 | } |
380 | } |
381 | |
382 | bool CombinerHelper::matchRedundantSextInReg(MachineInstr &Root, |
383 | MachineInstr &Other, |
384 | BuildFnTy &MatchInfo) const { |
385 | assert(Root.getOpcode() == TargetOpcode::G_SEXT_INREG && |
386 | Other.getOpcode() == TargetOpcode::G_SEXT_INREG); |
387 | |
388 | unsigned RootWidth = Root.getOperand(i: 2).getImm(); |
389 | unsigned OtherWidth = Other.getOperand(i: 2).getImm(); |
390 | |
391 | Register Dst = Root.getOperand(i: 0).getReg(); |
392 | Register OtherDst = Other.getOperand(i: 0).getReg(); |
393 | Register Src = Other.getOperand(i: 1).getReg(); |
394 | |
395 | if (RootWidth >= OtherWidth) { |
396 | // The root sext_inreg is entirely redundant because the other one |
397 | // is narrower. |
398 | if (!canReplaceReg(DstReg: Dst, SrcReg: OtherDst, MRI)) |
399 | return false; |
400 | |
401 | MatchInfo = [=](MachineIRBuilder &B) { |
402 | Observer.changingAllUsesOfReg(MRI, Reg: Dst); |
403 | MRI.replaceRegWith(FromReg: Dst, ToReg: OtherDst); |
404 | Observer.finishedChangingAllUsesOfReg(); |
405 | }; |
406 | } else { |
407 | // RootWidth < OtherWidth, rewrite this G_SEXT_INREG with the source of the |
408 | // other G_SEXT_INREG. |
409 | MatchInfo = [=](MachineIRBuilder &B) { |
410 | B.buildSExtInReg(Res: Dst, Op: Src, ImmOp: RootWidth); |
411 | }; |
412 | } |
413 | |
414 | return true; |
415 | } |
416 | |