1//===- CombinerHelperCasts.cpp---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements CombinerHelper for G_ANYEXT, G_SEXT, G_TRUNC, and
10// G_ZEXT
11//
12//===----------------------------------------------------------------------===//
13#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
14#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
15#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
16#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
17#include "llvm/CodeGen/GlobalISel/Utils.h"
18#include "llvm/CodeGen/LowLevelTypeUtils.h"
19#include "llvm/CodeGen/MachineOperand.h"
20#include "llvm/CodeGen/MachineRegisterInfo.h"
21#include "llvm/CodeGen/TargetOpcodes.h"
22#include "llvm/Support/Casting.h"
23
24#define DEBUG_TYPE "gi-combiner"
25
26using namespace llvm;
27
28bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO,
29 BuildFnTy &MatchInfo) const {
30 GSext *Sext = cast<GSext>(Val: getDefIgnoringCopies(Reg: MO.getReg(), MRI));
31 GTrunc *Trunc = cast<GTrunc>(Val: getDefIgnoringCopies(Reg: Sext->getSrcReg(), MRI));
32
33 Register Dst = Sext->getReg(Idx: 0);
34 Register Src = Trunc->getSrcReg();
35
36 LLT DstTy = MRI.getType(Reg: Dst);
37 LLT SrcTy = MRI.getType(Reg: Src);
38
39 // Combines without nsw trunc.
40 if (!Trunc->getFlag(Flag: MachineInstr::NoSWrap)) {
41 if (DstTy != SrcTy ||
42 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXT_INREG, {DstTy, SrcTy}}))
43 return false;
44
45 // Do this for 8 bit values and up. We don't want to do it for e.g. G_TRUNC
46 // to i1.
47 unsigned TruncWidth = MRI.getType(Reg: Trunc->getReg(Idx: 0)).getScalarSizeInBits();
48 if (TruncWidth < 8)
49 return false;
50
51 MatchInfo = [=](MachineIRBuilder &B) {
52 B.buildSExtInReg(Res: Dst, Op: Src, ImmOp: TruncWidth);
53 };
54 return true;
55 }
56
57 // Combines for nsw trunc.
58
59 if (DstTy == SrcTy) {
60 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: Src); };
61 return true;
62 }
63
64 if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
65 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
66 MatchInfo = [=](MachineIRBuilder &B) {
67 B.buildTrunc(Res: Dst, Op: Src, Flags: MachineInstr::MIFlag::NoSWrap);
68 };
69 return true;
70 }
71
72 if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
73 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXT, {DstTy, SrcTy}})) {
74 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Res: Dst, Op: Src); };
75 return true;
76 }
77
78 return false;
79}
80
81bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO,
82 BuildFnTy &MatchInfo) const {
83 GZext *Zext = cast<GZext>(Val: getDefIgnoringCopies(Reg: MO.getReg(), MRI));
84 GTrunc *Trunc = cast<GTrunc>(Val: getDefIgnoringCopies(Reg: Zext->getSrcReg(), MRI));
85
86 Register Dst = Zext->getReg(Idx: 0);
87 Register Src = Trunc->getSrcReg();
88
89 LLT DstTy = MRI.getType(Reg: Dst);
90 LLT SrcTy = MRI.getType(Reg: Src);
91
92 if (DstTy == SrcTy) {
93 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: Src); };
94 return true;
95 }
96
97 if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
98 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
99 MatchInfo = [=](MachineIRBuilder &B) {
100 B.buildTrunc(Res: Dst, Op: Src, Flags: MachineInstr::MIFlag::NoUWrap);
101 };
102 return true;
103 }
104
105 if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
106 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) {
107 MatchInfo = [=](MachineIRBuilder &B) {
108 B.buildZExt(Res: Dst, Op: Src, Flags: MachineInstr::MIFlag::NonNeg);
109 };
110 return true;
111 }
112
113 return false;
114}
115
116bool CombinerHelper::matchNonNegZext(const MachineOperand &MO,
117 BuildFnTy &MatchInfo) const {
118 GZext *Zext = cast<GZext>(Val: MRI.getVRegDef(Reg: MO.getReg()));
119
120 Register Dst = Zext->getReg(Idx: 0);
121 Register Src = Zext->getSrcReg();
122
123 LLT DstTy = MRI.getType(Reg: Dst);
124 LLT SrcTy = MRI.getType(Reg: Src);
125 const auto &TLI = getTargetLowering();
126
127 // Convert zext nneg to sext if sext is the preferred form for the target.
128 if (isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXT, {DstTy, SrcTy}}) &&
129 TLI.isSExtCheaperThanZExt(FromTy: getMVTForLLT(Ty: SrcTy), ToTy: getMVTForLLT(Ty: DstTy))) {
130 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Res: Dst, Op: Src); };
131 return true;
132 }
133
134 return false;
135}
136
137bool CombinerHelper::matchTruncateOfExt(const MachineInstr &Root,
138 const MachineInstr &ExtMI,
139 BuildFnTy &MatchInfo) const {
140 const GTrunc *Trunc = cast<GTrunc>(Val: &Root);
141 const GExtOp *Ext = cast<GExtOp>(Val: &ExtMI);
142
143 if (!MRI.hasOneNonDBGUse(RegNo: Ext->getReg(Idx: 0)))
144 return false;
145
146 Register Dst = Trunc->getReg(Idx: 0);
147 Register Src = Ext->getSrcReg();
148 LLT DstTy = MRI.getType(Reg: Dst);
149 LLT SrcTy = MRI.getType(Reg: Src);
150
151 if (SrcTy == DstTy) {
152 // The source and the destination are equally sized. We need to copy.
153 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: Src); };
154
155 return true;
156 }
157
158 if (SrcTy.getScalarSizeInBits() < DstTy.getScalarSizeInBits()) {
159 // If the source is smaller than the destination, we need to extend.
160
161 if (!isLegalOrBeforeLegalizer(Query: {Ext->getOpcode(), {DstTy, SrcTy}}))
162 return false;
163
164 MatchInfo = [=](MachineIRBuilder &B) {
165 B.buildInstr(Opc: Ext->getOpcode(), DstOps: {Dst}, SrcOps: {Src});
166 };
167
168 return true;
169 }
170
171 if (SrcTy.getScalarSizeInBits() > DstTy.getScalarSizeInBits()) {
172 // If the source is larger than the destination, then we need to truncate.
173
174 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {DstTy, SrcTy}}))
175 return false;
176
177 MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Res: Dst, Op: Src); };
178
179 return true;
180 }
181
182 return false;
183}
184
185bool CombinerHelper::isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const {
186 const TargetLowering &TLI = getTargetLowering();
187 LLVMContext &Ctx = getContext();
188
189 switch (Opcode) {
190 case TargetOpcode::G_ANYEXT:
191 case TargetOpcode::G_ZEXT:
192 return TLI.isZExtFree(FromTy, ToTy, Ctx);
193 case TargetOpcode::G_TRUNC:
194 return TLI.isTruncateFree(FromTy, ToTy, Ctx);
195 default:
196 return false;
197 }
198}
199
200bool CombinerHelper::matchCastOfSelect(const MachineInstr &CastMI,
201 const MachineInstr &SelectMI,
202 BuildFnTy &MatchInfo) const {
203 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(Val: &CastMI);
204 const GSelect *Select = cast<GSelect>(Val: &SelectMI);
205
206 if (!MRI.hasOneNonDBGUse(RegNo: Select->getReg(Idx: 0)))
207 return false;
208
209 Register Dst = Cast->getReg(Idx: 0);
210 LLT DstTy = MRI.getType(Reg: Dst);
211 LLT CondTy = MRI.getType(Reg: Select->getCondReg());
212 Register TrueReg = Select->getTrueReg();
213 Register FalseReg = Select->getFalseReg();
214 LLT SrcTy = MRI.getType(Reg: TrueReg);
215 Register Cond = Select->getCondReg();
216
217 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SELECT, {DstTy, CondTy}}))
218 return false;
219
220 if (!isCastFree(Opcode: Cast->getOpcode(), ToTy: DstTy, FromTy: SrcTy))
221 return false;
222
223 MatchInfo = [=](MachineIRBuilder &B) {
224 auto True = B.buildInstr(Opc: Cast->getOpcode(), DstOps: {DstTy}, SrcOps: {TrueReg});
225 auto False = B.buildInstr(Opc: Cast->getOpcode(), DstOps: {DstTy}, SrcOps: {FalseReg});
226 B.buildSelect(Res: Dst, Tst: Cond, Op0: True, Op1: False);
227 };
228
229 return true;
230}
231
232bool CombinerHelper::matchExtOfExt(const MachineInstr &FirstMI,
233 const MachineInstr &SecondMI,
234 BuildFnTy &MatchInfo) const {
235 const GExtOp *First = cast<GExtOp>(Val: &FirstMI);
236 const GExtOp *Second = cast<GExtOp>(Val: &SecondMI);
237
238 Register Dst = First->getReg(Idx: 0);
239 Register Src = Second->getSrcReg();
240 LLT DstTy = MRI.getType(Reg: Dst);
241 LLT SrcTy = MRI.getType(Reg: Src);
242
243 if (!MRI.hasOneNonDBGUse(RegNo: Second->getReg(Idx: 0)))
244 return false;
245
246 // ext of ext -> later ext
247 if (First->getOpcode() == Second->getOpcode() &&
248 isLegalOrBeforeLegalizer(Query: {Second->getOpcode(), {DstTy, SrcTy}})) {
249 if (Second->getOpcode() == TargetOpcode::G_ZEXT) {
250 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
251 if (Second->getFlag(Flag: MachineInstr::MIFlag::NonNeg))
252 Flag = MachineInstr::MIFlag::NonNeg;
253 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Res: Dst, Op: Src, Flags: Flag); };
254 return true;
255 }
256 // not zext -> no flags
257 MatchInfo = [=](MachineIRBuilder &B) {
258 B.buildInstr(Opc: Second->getOpcode(), DstOps: {Dst}, SrcOps: {Src});
259 };
260 return true;
261 }
262
263 // anyext of sext/zext -> sext/zext
264 // -> pick anyext as second ext, then ext of ext
265 if (First->getOpcode() == TargetOpcode::G_ANYEXT &&
266 isLegalOrBeforeLegalizer(Query: {Second->getOpcode(), {DstTy, SrcTy}})) {
267 if (Second->getOpcode() == TargetOpcode::G_ZEXT) {
268 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
269 if (Second->getFlag(Flag: MachineInstr::MIFlag::NonNeg))
270 Flag = MachineInstr::MIFlag::NonNeg;
271 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Res: Dst, Op: Src, Flags: Flag); };
272 return true;
273 }
274 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Res: Dst, Op: Src); };
275 return true;
276 }
277
278 // sext/zext of anyext -> sext/zext
279 // -> pick anyext as first ext, then ext of ext
280 if (Second->getOpcode() == TargetOpcode::G_ANYEXT &&
281 isLegalOrBeforeLegalizer(Query: {First->getOpcode(), {DstTy, SrcTy}})) {
282 if (First->getOpcode() == TargetOpcode::G_ZEXT) {
283 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
284 if (First->getFlag(Flag: MachineInstr::MIFlag::NonNeg))
285 Flag = MachineInstr::MIFlag::NonNeg;
286 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Res: Dst, Op: Src, Flags: Flag); };
287 return true;
288 }
289 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Res: Dst, Op: Src); };
290 return true;
291 }
292
293 return false;
294}
295
296bool CombinerHelper::matchCastOfBuildVector(const MachineInstr &CastMI,
297 const MachineInstr &BVMI,
298 BuildFnTy &MatchInfo) const {
299 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(Val: &CastMI);
300 const GBuildVector *BV = cast<GBuildVector>(Val: &BVMI);
301
302 if (!MRI.hasOneNonDBGUse(RegNo: BV->getReg(Idx: 0)))
303 return false;
304
305 Register Dst = Cast->getReg(Idx: 0);
306 // The type of the new build vector.
307 LLT DstTy = MRI.getType(Reg: Dst);
308 // The scalar or element type of the new build vector.
309 LLT ElemTy = DstTy.getScalarType();
310 // The scalar or element type of the old build vector.
311 LLT InputElemTy = MRI.getType(Reg: BV->getReg(Idx: 0)).getElementType();
312
313 // Check legality of new build vector, the scalar casts, and profitability of
314 // the many casts.
315 if (!isLegalOrBeforeLegalizer(
316 Query: {TargetOpcode::G_BUILD_VECTOR, {DstTy, ElemTy}}) ||
317 !isLegalOrBeforeLegalizer(Query: {Cast->getOpcode(), {ElemTy, InputElemTy}}) ||
318 !isCastFree(Opcode: Cast->getOpcode(), ToTy: ElemTy, FromTy: InputElemTy))
319 return false;
320
321 MatchInfo = [=](MachineIRBuilder &B) {
322 SmallVector<Register> Casts;
323 unsigned Elements = BV->getNumSources();
324 for (unsigned I = 0; I < Elements; ++I) {
325 auto CastI =
326 B.buildInstr(Opc: Cast->getOpcode(), DstOps: {ElemTy}, SrcOps: {BV->getSourceReg(I)});
327 Casts.push_back(Elt: CastI.getReg(Idx: 0));
328 }
329
330 B.buildBuildVector(Res: Dst, Ops: Casts);
331 };
332
333 return true;
334}
335
336bool CombinerHelper::matchNarrowBinop(const MachineInstr &TruncMI,
337 const MachineInstr &BinopMI,
338 BuildFnTy &MatchInfo) const {
339 const GTrunc *Trunc = cast<GTrunc>(Val: &TruncMI);
340 const GBinOp *BinOp = cast<GBinOp>(Val: &BinopMI);
341
342 if (!MRI.hasOneNonDBGUse(RegNo: BinOp->getReg(Idx: 0)))
343 return false;
344
345 Register Dst = Trunc->getReg(Idx: 0);
346 LLT DstTy = MRI.getType(Reg: Dst);
347
348 // Is narrow binop legal?
349 if (!isLegalOrBeforeLegalizer(Query: {BinOp->getOpcode(), {DstTy}}))
350 return false;
351
352 MatchInfo = [=](MachineIRBuilder &B) {
353 auto LHS = B.buildTrunc(Res: DstTy, Op: BinOp->getLHSReg());
354 auto RHS = B.buildTrunc(Res: DstTy, Op: BinOp->getRHSReg());
355 B.buildInstr(Opc: BinOp->getOpcode(), DstOps: {Dst}, SrcOps: {LHS, RHS});
356 };
357
358 return true;
359}
360
361bool CombinerHelper::matchCastOfInteger(const MachineInstr &CastMI,
362 APInt &MatchInfo) const {
363 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(Val: &CastMI);
364
365 APInt Input = getIConstantFromReg(VReg: Cast->getSrcReg(), MRI);
366
367 LLT DstTy = MRI.getType(Reg: Cast->getReg(Idx: 0));
368
369 if (!isConstantLegalOrBeforeLegalizer(Ty: DstTy))
370 return false;
371
372 switch (Cast->getOpcode()) {
373 case TargetOpcode::G_TRUNC: {
374 MatchInfo = Input.trunc(width: DstTy.getScalarSizeInBits());
375 return true;
376 }
377 default:
378 return false;
379 }
380}
381
382bool CombinerHelper::matchRedundantSextInReg(MachineInstr &Root,
383 MachineInstr &Other,
384 BuildFnTy &MatchInfo) const {
385 assert(Root.getOpcode() == TargetOpcode::G_SEXT_INREG &&
386 Other.getOpcode() == TargetOpcode::G_SEXT_INREG);
387
388 unsigned RootWidth = Root.getOperand(i: 2).getImm();
389 unsigned OtherWidth = Other.getOperand(i: 2).getImm();
390
391 Register Dst = Root.getOperand(i: 0).getReg();
392 Register OtherDst = Other.getOperand(i: 0).getReg();
393 Register Src = Other.getOperand(i: 1).getReg();
394
395 if (RootWidth >= OtherWidth) {
396 // The root sext_inreg is entirely redundant because the other one
397 // is narrower.
398 if (!canReplaceReg(DstReg: Dst, SrcReg: OtherDst, MRI))
399 return false;
400
401 MatchInfo = [=](MachineIRBuilder &B) {
402 Observer.changingAllUsesOfReg(MRI, Reg: Dst);
403 MRI.replaceRegWith(FromReg: Dst, ToReg: OtherDst);
404 Observer.finishedChangingAllUsesOfReg();
405 };
406 } else {
407 // RootWidth < OtherWidth, rewrite this G_SEXT_INREG with the source of the
408 // other G_SEXT_INREG.
409 MatchInfo = [=](MachineIRBuilder &B) {
410 B.buildSExtInReg(Res: Dst, Op: Src, ImmOp: RootWidth);
411 };
412 }
413
414 return true;
415}
416