1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
17#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19#include "llvm/CodeGen/MachineInstr.h"
20#include "llvm/CodeGen/MachineRegisterInfo.h"
21#include "llvm/CodeGen/TargetOpcodes.h"
22
23using namespace llvm;
24using namespace llvm::LegalizeActions;
25using namespace llvm::LegalityPredicates;
26
27LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
28 return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
29 const LLT Ty = Query.Types[TypeIdx];
30 return IsExtendedInts && Ty.isValid() && Ty.isScalar();
31 };
32}
33
34SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
35 using namespace TargetOpcode;
36
37 this->ST = &ST;
38 GR = ST.getSPIRVGlobalRegistry();
39
40 const LLT s1 = LLT::scalar(SizeInBits: 1);
41 const LLT s8 = LLT::scalar(SizeInBits: 8);
42 const LLT s16 = LLT::scalar(SizeInBits: 16);
43 const LLT s32 = LLT::scalar(SizeInBits: 32);
44 const LLT s64 = LLT::scalar(SizeInBits: 64);
45
46 const LLT v16s64 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 64);
47 const LLT v16s32 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 32);
48 const LLT v16s16 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 16);
49 const LLT v16s8 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8);
50 const LLT v16s1 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 1);
51
52 const LLT v8s64 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 64);
53 const LLT v8s32 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 32);
54 const LLT v8s16 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 16);
55 const LLT v8s8 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 8);
56 const LLT v8s1 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 1);
57
58 const LLT v4s64 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 64);
59 const LLT v4s32 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32);
60 const LLT v4s16 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 16);
61 const LLT v4s8 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 8);
62 const LLT v4s1 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 1);
63
64 const LLT v3s64 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 64);
65 const LLT v3s32 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 32);
66 const LLT v3s16 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 16);
67 const LLT v3s8 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 8);
68 const LLT v3s1 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 1);
69
70 const LLT v2s64 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 64);
71 const LLT v2s32 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 32);
72 const LLT v2s16 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 16);
73 const LLT v2s8 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 8);
74 const LLT v2s1 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 1);
75
76 const unsigned PSize = ST.getPointerSize();
77 const LLT p0 = LLT::pointer(AddressSpace: 0, SizeInBits: PSize); // Function
78 const LLT p1 = LLT::pointer(AddressSpace: 1, SizeInBits: PSize); // CrossWorkgroup
79 const LLT p2 = LLT::pointer(AddressSpace: 2, SizeInBits: PSize); // UniformConstant
80 const LLT p3 = LLT::pointer(AddressSpace: 3, SizeInBits: PSize); // Workgroup
81 const LLT p4 = LLT::pointer(AddressSpace: 4, SizeInBits: PSize); // Generic
82 const LLT p5 =
83 LLT::pointer(AddressSpace: 5, SizeInBits: PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
84 const LLT p6 = LLT::pointer(AddressSpace: 6, SizeInBits: PSize); // SPV_INTEL_usm_storage_classes (Host)
85 const LLT p7 = LLT::pointer(AddressSpace: 7, SizeInBits: PSize); // Input
86 const LLT p8 = LLT::pointer(AddressSpace: 8, SizeInBits: PSize); // Output
87 const LLT p10 = LLT::pointer(AddressSpace: 10, SizeInBits: PSize); // Private
88 const LLT p11 = LLT::pointer(AddressSpace: 11, SizeInBits: PSize); // StorageBuffer
89 const LLT p12 = LLT::pointer(AddressSpace: 12, SizeInBits: PSize); // Uniform
90
91 // TODO: remove copy-pasting here by using concatenation in some way.
92 auto allPtrsScalarsAndVectors = {
93 p0, p1, p2, p3, p4, p5, p6, p7, p8,
94 p10, p11, p12, s1, s8, s16, s32, s64, v2s1,
95 v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64,
96 v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, v8s32,
97 v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
98
99 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
100 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
101 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
102 v16s8, v16s16, v16s32, v16s64};
103
104 auto allScalarsAndVectors = {
105 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
106 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
107 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
108
109 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
110 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
111 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
112 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
113
114 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
115
116 auto allIntScalars = {s8, s16, s32, s64};
117
118 auto allFloatScalars = {s16, s32, s64};
119
120 auto allFloatScalarsAndVectors = {
121 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
122 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
123
124 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3,
125 p4, p5, p6, p7, p8, p10, p11, p12};
126
127 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
128
129 bool IsExtendedInts =
130 ST.canUseExtension(
131 E: SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
132 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bit_instructions) ||
133 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_int4);
134 auto extendedScalarsAndVectors =
135 [IsExtendedInts](const LegalityQuery &Query) {
136 const LLT Ty = Query.Types[0];
137 return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
138 };
139 auto extendedScalarsAndVectorsProduct = [IsExtendedInts](
140 const LegalityQuery &Query) {
141 const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1];
142 return IsExtendedInts && Ty1.isValid() && Ty2.isValid() &&
143 !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector();
144 };
145 auto extendedPtrsScalarsAndVectors =
146 [IsExtendedInts](const LegalityQuery &Query) {
147 const LLT Ty = Query.Types[0];
148 return IsExtendedInts && Ty.isValid();
149 };
150
151 for (auto Opc : getTypeFoldingSupportedOpcodes())
152 getActionDefinitionsBuilder(Opcode: Opc).custom();
153
154 getActionDefinitionsBuilder(Opcode: G_GLOBAL_VALUE).alwaysLegal();
155
156 // TODO: add proper rules for vectors legalization.
157 getActionDefinitionsBuilder(
158 Opcodes: {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
159 .alwaysLegal();
160
161 // Vector Reduction Operations
162 getActionDefinitionsBuilder(
163 Opcodes: {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
164 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
165 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
166 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
167 .legalFor(Types: allVectors)
168 .scalarize(TypeIdx: 1)
169 .lower();
170
171 getActionDefinitionsBuilder(Opcodes: {G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
172 .scalarize(TypeIdx: 2)
173 .lower();
174
175 // Merge/Unmerge
176 // TODO: add proper legalization rules.
177 getActionDefinitionsBuilder(Opcode: G_UNMERGE_VALUES).alwaysLegal();
178
179 getActionDefinitionsBuilder(Opcodes: {G_MEMCPY, G_MEMMOVE})
180 .legalIf(Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs)));
181
182 getActionDefinitionsBuilder(Opcode: G_MEMSET).legalIf(
183 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeInSet(TypeIdx: 1, TypesInit: allIntScalars)));
184
185 getActionDefinitionsBuilder(Opcode: G_ADDRSPACE_CAST)
186 .legalForCartesianProduct(Types0: allPtrs, Types1: allPtrs);
187
188 getActionDefinitionsBuilder(Opcodes: {G_LOAD, G_STORE}).legalIf(Predicate: typeInSet(TypeIdx: 1, TypesInit: allPtrs));
189
190 getActionDefinitionsBuilder(Opcodes: {G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
191 G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
192 G_USUBSAT, G_SCMP, G_UCMP})
193 .legalFor(Types: allIntScalarsAndVectors)
194 .legalIf(Predicate: extendedScalarsAndVectors);
195
196 getActionDefinitionsBuilder(Opcodes: {G_FMA, G_STRICT_FMA})
197 .legalFor(Types: allFloatScalarsAndVectors);
198
199 getActionDefinitionsBuilder(Opcode: G_STRICT_FLDEXP)
200 .legalForCartesianProduct(Types0: allFloatScalarsAndVectors, Types1: allIntScalars);
201
202 getActionDefinitionsBuilder(Opcodes: {G_FPTOSI, G_FPTOUI})
203 .legalForCartesianProduct(Types0: allIntScalarsAndVectors,
204 Types1: allFloatScalarsAndVectors);
205
206 getActionDefinitionsBuilder(Opcodes: {G_SITOFP, G_UITOFP})
207 .legalForCartesianProduct(Types0: allFloatScalarsAndVectors,
208 Types1: allScalarsAndVectors);
209
210 getActionDefinitionsBuilder(Opcode: G_CTPOP)
211 .legalForCartesianProduct(Types: allIntScalarsAndVectors)
212 .legalIf(Predicate: extendedScalarsAndVectorsProduct);
213
214 // Extensions.
215 getActionDefinitionsBuilder(Opcodes: {G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
216 .legalForCartesianProduct(Types: allScalarsAndVectors)
217 .legalIf(Predicate: extendedScalarsAndVectorsProduct);
218
219 getActionDefinitionsBuilder(Opcode: G_PHI)
220 .legalFor(Types: allPtrsScalarsAndVectors)
221 .legalIf(Predicate: extendedPtrsScalarsAndVectors);
222
223 getActionDefinitionsBuilder(Opcode: G_BITCAST).legalIf(
224 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrsScalarsAndVectors),
225 P1: typeInSet(TypeIdx: 1, TypesInit: allPtrsScalarsAndVectors)));
226
227 getActionDefinitionsBuilder(Opcodes: {G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
228
229 getActionDefinitionsBuilder(Opcodes: {G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
230
231 getActionDefinitionsBuilder(Opcode: G_INTTOPTR)
232 .legalForCartesianProduct(Types0: allPtrs, Types1: allIntScalars)
233 .legalIf(
234 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeOfExtendedScalars(TypeIdx: 1, IsExtendedInts)));
235 getActionDefinitionsBuilder(Opcode: G_PTRTOINT)
236 .legalForCartesianProduct(Types0: allIntScalars, Types1: allPtrs)
237 .legalIf(
238 Predicate: all(P0: typeOfExtendedScalars(TypeIdx: 0, IsExtendedInts), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs)));
239 getActionDefinitionsBuilder(Opcode: G_PTR_ADD)
240 .legalForCartesianProduct(Types0: allPtrs, Types1: allIntScalars)
241 .legalIf(
242 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeOfExtendedScalars(TypeIdx: 1, IsExtendedInts)));
243
244 // ST.canDirectlyComparePointers() for pointer args is supported in
245 // legalizeCustom().
246 getActionDefinitionsBuilder(Opcode: G_ICMP).customIf(
247 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allBoolScalarsAndVectors),
248 P1: typeInSet(TypeIdx: 1, TypesInit: allPtrsScalarsAndVectors)));
249
250 getActionDefinitionsBuilder(Opcode: G_FCMP).legalIf(
251 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allBoolScalarsAndVectors),
252 P1: typeInSet(TypeIdx: 1, TypesInit: allFloatScalarsAndVectors)));
253
254 getActionDefinitionsBuilder(Opcodes: {G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
255 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
256 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
257 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
258 .legalForCartesianProduct(Types0: allIntScalars, Types1: allPtrs);
259
260 getActionDefinitionsBuilder(
261 Opcodes: {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
262 .legalForCartesianProduct(Types0: allFloatScalars, Types1: allPtrs);
263
264 getActionDefinitionsBuilder(Opcode: G_ATOMICRMW_XCHG)
265 .legalForCartesianProduct(Types0: allFloatAndIntScalarsAndPtrs, Types1: allPtrs);
266
267 getActionDefinitionsBuilder(Opcode: G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
268 // TODO: add proper legalization rules.
269 getActionDefinitionsBuilder(Opcode: G_ATOMIC_CMPXCHG).alwaysLegal();
270
271 getActionDefinitionsBuilder(
272 Opcodes: {G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO})
273 .alwaysLegal();
274
275 // FP conversions.
276 getActionDefinitionsBuilder(Opcodes: {G_FPTRUNC, G_FPEXT})
277 .legalForCartesianProduct(Types: allFloatScalarsAndVectors);
278
279 // Pointer-handling.
280 getActionDefinitionsBuilder(Opcode: G_FRAME_INDEX).legalFor(Types: {p0});
281
282 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
283 getActionDefinitionsBuilder(Opcode: G_BRCOND).legalFor(Types: {s1, s32});
284
285 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
286 // tighten these requirements. Many of these math functions are only legal on
287 // specific bitwidths, so they are not selectable for
288 // allFloatScalarsAndVectors.
289 getActionDefinitionsBuilder(Opcodes: {G_STRICT_FSQRT,
290 G_FPOW,
291 G_FEXP,
292 G_FEXP2,
293 G_FLOG,
294 G_FLOG2,
295 G_FLOG10,
296 G_FABS,
297 G_FMINNUM,
298 G_FMAXNUM,
299 G_FCEIL,
300 G_FCOS,
301 G_FSIN,
302 G_FTAN,
303 G_FACOS,
304 G_FASIN,
305 G_FATAN,
306 G_FATAN2,
307 G_FCOSH,
308 G_FSINH,
309 G_FTANH,
310 G_FSQRT,
311 G_FFLOOR,
312 G_FRINT,
313 G_FNEARBYINT,
314 G_INTRINSIC_ROUND,
315 G_INTRINSIC_TRUNC,
316 G_FMINIMUM,
317 G_FMAXIMUM,
318 G_INTRINSIC_ROUNDEVEN})
319 .legalFor(Types: allFloatScalarsAndVectors);
320
321 getActionDefinitionsBuilder(Opcode: G_FCOPYSIGN)
322 .legalForCartesianProduct(Types0: allFloatScalarsAndVectors,
323 Types1: allFloatScalarsAndVectors);
324
325 getActionDefinitionsBuilder(Opcode: G_FPOWI).legalForCartesianProduct(
326 Types0: allFloatScalarsAndVectors, Types1: allIntScalarsAndVectors);
327
328 if (ST.canUseExtInstSet(E: SPIRV::InstructionSet::OpenCL_std)) {
329 getActionDefinitionsBuilder(
330 Opcodes: {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
331 .legalForCartesianProduct(Types0: allIntScalarsAndVectors,
332 Types1: allIntScalarsAndVectors);
333
334 // Struct return types become a single scalar, so cannot easily legalize.
335 getActionDefinitionsBuilder(Opcodes: {G_SMULH, G_UMULH}).alwaysLegal();
336 }
337
338 getLegacyLegalizerInfo().computeTables();
339 verify(MII: *ST.getInstrInfo());
340}
341
342static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
343 LegalizerHelper &Helper,
344 MachineRegisterInfo &MRI,
345 SPIRVGlobalRegistry *GR) {
346 Register ConvReg = MRI.createGenericVirtualRegister(Ty: ConvTy);
347 MRI.setRegClass(Reg: ConvReg, RC: GR->getRegClass(SpvType));
348 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: ConvReg, MF: Helper.MIRBuilder.getMF());
349 Helper.MIRBuilder.buildInstr(Opcode: TargetOpcode::G_PTRTOINT)
350 .addDef(RegNo: ConvReg)
351 .addUse(RegNo: Reg);
352 return ConvReg;
353}
354
355bool SPIRVLegalizerInfo::legalizeCustom(
356 LegalizerHelper &Helper, MachineInstr &MI,
357 LostDebugLocObserver &LocObserver) const {
358 auto Opc = MI.getOpcode();
359 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
360 if (Opc == TargetOpcode::G_ICMP) {
361 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
362 auto &Op0 = MI.getOperand(i: 2);
363 auto &Op1 = MI.getOperand(i: 3);
364 Register Reg0 = Op0.getReg();
365 Register Reg1 = Op1.getReg();
366 CmpInst::Predicate Cond =
367 static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
368 if ((!ST->canDirectlyComparePointers() ||
369 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
370 MRI.getType(Reg: Reg0).isPointer() && MRI.getType(Reg: Reg1).isPointer()) {
371 LLT ConvT = LLT::scalar(SizeInBits: ST->getPointerSize());
372 Type *LLVMTy = IntegerType::get(C&: MI.getMF()->getFunction().getContext(),
373 NumBits: ST->getPointerSize());
374 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(
375 Type: LLVMTy, MIRBuilder&: Helper.MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
376 Op0.setReg(convertPtrToInt(Reg: Reg0, ConvTy: ConvT, SpvType: SpirvTy, Helper, MRI, GR));
377 Op1.setReg(convertPtrToInt(Reg: Reg1, ConvTy: ConvT, SpvType: SpirvTy, Helper, MRI, GR));
378 }
379 return true;
380 }
381 // TODO: implement legalization for other opcodes.
382 return true;
383}
384