1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
17#include "SPIRVUtils.h"
18#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
19#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
20#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
21#include "llvm/CodeGen/MachineInstr.h"
22#include "llvm/CodeGen/MachineRegisterInfo.h"
23#include "llvm/CodeGen/TargetOpcodes.h"
24#include "llvm/IR/IntrinsicsSPIRV.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/MathExtras.h"
27
28using namespace llvm;
29using namespace llvm::LegalizeActions;
30using namespace llvm::LegalityPredicates;
31
32#define DEBUG_TYPE "spirv-legalizer"
33
34LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
35 return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
36 const LLT Ty = Query.Types[TypeIdx];
37 return IsExtendedInts && Ty.isValid() && Ty.isScalar();
38 };
39}
40
41SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
42 using namespace TargetOpcode;
43
44 this->ST = &ST;
45 GR = ST.getSPIRVGlobalRegistry();
46
47 const LLT s1 = LLT::scalar(SizeInBits: 1);
48 const LLT s8 = LLT::scalar(SizeInBits: 8);
49 const LLT s16 = LLT::scalar(SizeInBits: 16);
50 const LLT s32 = LLT::scalar(SizeInBits: 32);
51 const LLT s64 = LLT::scalar(SizeInBits: 64);
52 const LLT s128 = LLT::scalar(SizeInBits: 128);
53
54 const LLT v16s64 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 64);
55 const LLT v16s32 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 32);
56 const LLT v16s16 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 16);
57 const LLT v16s8 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8);
58 const LLT v16s1 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 1);
59
60 const LLT v8s64 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 64);
61 const LLT v8s32 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 32);
62 const LLT v8s16 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 16);
63 const LLT v8s8 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 8);
64 const LLT v8s1 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 1);
65
66 const LLT v4s64 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 64);
67 const LLT v4s32 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32);
68 const LLT v4s16 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 16);
69 const LLT v4s8 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 8);
70 const LLT v4s1 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 1);
71
72 const LLT v3s64 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 64);
73 const LLT v3s32 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 32);
74 const LLT v3s16 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 16);
75 const LLT v3s8 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 8);
76 const LLT v3s1 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 1);
77
78 const LLT v2s64 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 64);
79 const LLT v2s32 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 32);
80 const LLT v2s16 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 16);
81 const LLT v2s8 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 8);
82 const LLT v2s1 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 1);
83
84 const unsigned PSize = ST.getPointerSize();
85 const LLT p0 = LLT::pointer(AddressSpace: 0, SizeInBits: PSize); // Function
86 const LLT p1 = LLT::pointer(AddressSpace: 1, SizeInBits: PSize); // CrossWorkgroup
87 const LLT p2 = LLT::pointer(AddressSpace: 2, SizeInBits: PSize); // UniformConstant
88 const LLT p3 = LLT::pointer(AddressSpace: 3, SizeInBits: PSize); // Workgroup
89 const LLT p4 = LLT::pointer(AddressSpace: 4, SizeInBits: PSize); // Generic
90 const LLT p5 =
91 LLT::pointer(AddressSpace: 5, SizeInBits: PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
92 const LLT p6 = LLT::pointer(AddressSpace: 6, SizeInBits: PSize); // SPV_INTEL_usm_storage_classes (Host)
93 const LLT p7 = LLT::pointer(AddressSpace: 7, SizeInBits: PSize); // Input
94 const LLT p8 = LLT::pointer(AddressSpace: 8, SizeInBits: PSize); // Output
95 const LLT p9 =
96 LLT::pointer(AddressSpace: 9, SizeInBits: PSize); // CodeSectionINTEL, SPV_INTEL_function_pointers
97 const LLT p10 = LLT::pointer(AddressSpace: 10, SizeInBits: PSize); // Private
98 const LLT p11 = LLT::pointer(AddressSpace: 11, SizeInBits: PSize); // StorageBuffer
99 const LLT p12 = LLT::pointer(AddressSpace: 12, SizeInBits: PSize); // Uniform
100 const LLT p13 = LLT::pointer(AddressSpace: 13, SizeInBits: PSize); // PushConstant
101
102 // TODO: remove copy-pasting here by using concatenation in some way.
103 auto allPtrsScalarsAndVectors = {
104 p0, p1, p2, p3, p4, p5, p6, p7, p8,
105 p9, p10, p11, p12, p13, s1, s8, s16, s32,
106 s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
107 v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8,
108 v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
109
110 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
112 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
113 v16s8, v16s16, v16s32, v16s64};
114
115 auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
116 v3s1, v3s8, v3s16, v3s32, v3s64,
117 v4s1, v4s8, v4s16, v4s32, v4s64};
118
119 auto allScalars = {s1, s8, s16, s32, s64};
120
121 auto allScalarsAndVectors = {
122 s1, s8, s16, s32, s64, s128, v2s1, v2s8,
123 v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64,
124 v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
125 v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
126
127 auto allIntScalarsAndVectors = {
128 s8, s16, s32, s64, s128, v2s8, v2s16, v2s32, v2s64,
129 v3s8, v3s16, v3s32, v3s64, v4s8, v4s16, v4s32, v4s64, v8s8,
130 v8s16, v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
131
132 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
133
134 auto allIntScalars = {s8, s16, s32, s64, s128};
135
136 auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16};
137
138 auto allFloatScalarsAndVectors = {
139 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
140 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
141
142 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,
143 p2, p3, p4, p5, p6, p7,
144 p8, p9, p10, p11, p12, p13};
145
146 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13};
147
148 auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
149
150 bool IsExtendedInts =
151 ST.canUseExtension(
152 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers) ||
153 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bit_instructions) ||
154 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_int4);
155 auto extendedScalarsAndVectors =
156 [IsExtendedInts](const LegalityQuery &Query) {
157 const LLT Ty = Query.Types[0];
158 return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
159 };
160 auto extendedScalarsAndVectorsProduct = [IsExtendedInts](
161 const LegalityQuery &Query) {
162 const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1];
163 return IsExtendedInts && Ty1.isValid() && Ty2.isValid() &&
164 !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector();
165 };
166 auto extendedPtrsScalarsAndVectors =
167 [IsExtendedInts](const LegalityQuery &Query) {
168 const LLT Ty = Query.Types[0];
169 return IsExtendedInts && Ty.isValid();
170 };
171
172 // The universal validation rules in the SPIR-V specification state that
173 // vector sizes are typically limited to 2, 3, or 4. However, larger vector
174 // sizes (8 and 16) are enabled when the Kernel capability is present. For
175 // shader execution models, vector sizes are strictly limited to 4. In
176 // non-shader contexts, vector sizes of 8 and 16 are also permitted, but
177 // arbitrary sizes (e.g., 6 or 11) are not.
178 uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
179 LLVM_DEBUG(dbgs() << "MaxVectorSize: " << MaxVectorSize << "\n");
180
181 for (auto Opc : getTypeFoldingSupportedOpcodes()) {
182 switch (Opc) {
183 case G_EXTRACT_VECTOR_ELT:
184 case G_UREM:
185 case G_SREM:
186 case G_UDIV:
187 case G_SDIV:
188 case G_FREM:
189 break;
190 default:
191 getActionDefinitionsBuilder(Opcode: Opc)
192 .customFor(Types: allScalars)
193 .customFor(Types: allowedVectorTypes)
194 .moreElementsToNextPow2(TypeIdx: 0)
195 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
196 Mutation: LegalizeMutations::changeElementCountTo(
197 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
198 .custom();
199 break;
200 }
201 }
202
203 getActionDefinitionsBuilder(Opcodes: {G_UREM, G_SREM, G_SDIV, G_UDIV, G_FREM})
204 .customFor(Types: allScalars)
205 .customFor(Types: allowedVectorTypes)
206 .scalarizeIf(Predicate: numElementsNotPow2(TypeIdx: 0), TypeIdx: 0)
207 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
208 Mutation: LegalizeMutations::changeElementCountTo(
209 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
210 .custom();
211
212 getActionDefinitionsBuilder(Opcodes: {G_FMA, G_STRICT_FMA})
213 .legalFor(Types: allScalars)
214 .legalFor(Types: allowedVectorTypes)
215 .moreElementsToNextPow2(TypeIdx: 0)
216 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
217 Mutation: LegalizeMutations::changeElementCountTo(
218 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
219 .alwaysLegal();
220
221 getActionDefinitionsBuilder(Opcode: G_INTRINSIC_W_SIDE_EFFECTS).custom();
222
223 getActionDefinitionsBuilder(Opcode: G_SHUFFLE_VECTOR)
224 .legalForCartesianProduct(Types0: allowedVectorTypes, Types1: allowedVectorTypes)
225 .moreElementsToNextPow2(TypeIdx: 0)
226 .lowerIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize))
227 .moreElementsToNextPow2(TypeIdx: 1)
228 .lowerIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 1, Size: MaxVectorSize));
229
230 getActionDefinitionsBuilder(Opcode: G_EXTRACT_VECTOR_ELT)
231 .moreElementsToNextPow2(TypeIdx: 1)
232 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 1, Size: MaxVectorSize),
233 Mutation: LegalizeMutations::changeElementCountTo(
234 TypeIdx: 1, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
235 .custom();
236
237 getActionDefinitionsBuilder(Opcode: G_INSERT_VECTOR_ELT)
238 .moreElementsToNextPow2(TypeIdx: 0)
239 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
240 Mutation: LegalizeMutations::changeElementCountTo(
241 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
242 .custom();
243
244 // Illegal G_UNMERGE_VALUES instructions should be handled
245 // during the combine phase.
246 getActionDefinitionsBuilder(Opcode: G_BUILD_VECTOR)
247 .legalIf(Predicate: vectorElementCountIsLessThanOrEqualTo(TypeIdx: 0, Size: MaxVectorSize));
248
249 // When entering the legalizer, there should be no G_BITCAST instructions.
250 // They should all be calls to the `spv_bitcast` intrinsic. The call to
251 // the intrinsic will be converted to a G_BITCAST during legalization if
252 // the vectors are not legal. After using the rules to legalize a G_BITCAST,
253 // we turn it back into a call to the intrinsic with a custom rule to avoid
254 // potential machine verifier failures.
255 getActionDefinitionsBuilder(Opcode: G_BITCAST)
256 .moreElementsToNextPow2(TypeIdx: 0)
257 .moreElementsToNextPow2(TypeIdx: 1)
258 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
259 Mutation: LegalizeMutations::changeElementCountTo(
260 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
261 .lowerIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 1, Size: MaxVectorSize))
262 .custom();
263
264 // If the result is still illegal, the combiner should be able to remove it.
265 getActionDefinitionsBuilder(Opcode: G_CONCAT_VECTORS)
266 .legalForCartesianProduct(Types0: allowedVectorTypes, Types1: allowedVectorTypes);
267
268 getActionDefinitionsBuilder(Opcode: G_SPLAT_VECTOR)
269 .legalFor(Types: allowedVectorTypes)
270 .moreElementsToNextPow2(TypeIdx: 0)
271 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
272 Mutation: LegalizeMutations::changeElementSizeTo(TypeIdx: 0, FromTypeIdx: MaxVectorSize))
273 .alwaysLegal();
274
275 // Vector Reduction Operations
276 getActionDefinitionsBuilder(
277 Opcodes: {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
278 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
279 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
280 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
281 .legalFor(Types: allowedVectorTypes)
282 .scalarize(TypeIdx: 1)
283 .lower();
284
285 getActionDefinitionsBuilder(Opcodes: {G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
286 .scalarize(TypeIdx: 2)
287 .lower();
288
289 // Illegal G_UNMERGE_VALUES instructions should be handled
290 // during the combine phase.
291 getActionDefinitionsBuilder(Opcode: G_UNMERGE_VALUES)
292 .legalIf(Predicate: vectorElementCountIsLessThanOrEqualTo(TypeIdx: 1, Size: MaxVectorSize));
293
294 getActionDefinitionsBuilder(Opcodes: {G_MEMCPY, G_MEMMOVE})
295 .unsupportedIf(Predicate: LegalityPredicates::any(P0: typeIs(TypeIdx: 0, TypesInit: p9), P1: typeIs(TypeIdx: 1, TypesInit: p9)))
296 .legalIf(Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs)));
297
298 getActionDefinitionsBuilder(Opcode: G_MEMSET)
299 .unsupportedIf(Predicate: typeIs(TypeIdx: 0, TypesInit: p9))
300 .legalIf(Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeInSet(TypeIdx: 1, TypesInit: allIntScalars)));
301
302 getActionDefinitionsBuilder(Opcode: G_ADDRSPACE_CAST)
303 .unsupportedIf(
304 Predicate: LegalityPredicates::any(P0: all(P0: typeIs(TypeIdx: 0, TypesInit: p9), P1: typeIsNot(TypeIdx: 1, Type: p9)),
305 P1: all(P0: typeIsNot(TypeIdx: 0, Type: p9), P1: typeIs(TypeIdx: 1, TypesInit: p9))))
306 .legalForCartesianProduct(Types0: allPtrs, Types1: allPtrs);
307
308 // Should we be legalizing bad scalar sizes like s5 here instead
309 // of handling them in the instruction selector?
310 getActionDefinitionsBuilder(Opcodes: {G_LOAD, G_STORE})
311 .unsupportedIf(Predicate: typeIs(TypeIdx: 1, TypesInit: p9))
312 .legalForCartesianProduct(Types0: allowedVectorTypes, Types1: allPtrs)
313 .legalForCartesianProduct(Types0: allPtrs, Types1: allPtrs)
314 .legalIf(Predicate: isScalar(TypeIdx: 0))
315 .custom();
316
317 getActionDefinitionsBuilder(Opcodes: {G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
318 G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
319 G_USUBSAT, G_SCMP, G_UCMP})
320 .legalFor(Types: allIntScalarsAndVectors)
321 .legalIf(Predicate: extendedScalarsAndVectors);
322
323 getActionDefinitionsBuilder(Opcode: G_STRICT_FLDEXP)
324 .legalForCartesianProduct(Types0: allFloatScalarsAndVectors, Types1: allIntScalars);
325
326 getActionDefinitionsBuilder(Opcodes: {G_FPTOSI, G_FPTOUI})
327 .legalForCartesianProduct(Types0: allIntScalarsAndVectors,
328 Types1: allFloatScalarsAndVectors);
329
330 getActionDefinitionsBuilder(Opcodes: {G_FPTOSI_SAT, G_FPTOUI_SAT})
331 .legalForCartesianProduct(Types0: allIntScalarsAndVectors,
332 Types1: allFloatScalarsAndVectors);
333
334 getActionDefinitionsBuilder(Opcodes: {G_SITOFP, G_UITOFP})
335 .legalForCartesianProduct(Types0: allFloatScalarsAndVectors,
336 Types1: allScalarsAndVectors);
337
338 getActionDefinitionsBuilder(Opcode: G_CTPOP)
339 .legalForCartesianProduct(Types: allIntScalarsAndVectors)
340 .legalIf(Predicate: extendedScalarsAndVectorsProduct);
341
342 // Extensions.
343 getActionDefinitionsBuilder(Opcodes: {G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
344 .legalForCartesianProduct(Types: allScalarsAndVectors)
345 .legalIf(Predicate: extendedScalarsAndVectorsProduct);
346
347 getActionDefinitionsBuilder(Opcode: G_PHI)
348 .legalFor(Types: allPtrsScalarsAndVectors)
349 .legalIf(Predicate: extendedPtrsScalarsAndVectors);
350
351 getActionDefinitionsBuilder(Opcode: G_BITCAST).legalIf(
352 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrsScalarsAndVectors),
353 P1: typeInSet(TypeIdx: 1, TypesInit: allPtrsScalarsAndVectors)));
354
355 getActionDefinitionsBuilder(Opcodes: {G_IMPLICIT_DEF, G_FREEZE})
356 .legalFor(Types: {s1, s128})
357 .legalFor(Types: allFloatAndIntScalarsAndPtrs)
358 .legalFor(Types: allowedVectorTypes)
359 .moreElementsToNextPow2(TypeIdx: 0)
360 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
361 Mutation: LegalizeMutations::changeElementCountTo(
362 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)));
363
364 getActionDefinitionsBuilder(Opcodes: {G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
365
366 getActionDefinitionsBuilder(Opcode: G_INTTOPTR)
367 .legalForCartesianProduct(Types0: allPtrs, Types1: allIntScalars)
368 .legalIf(
369 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeOfExtendedScalars(TypeIdx: 1, IsExtendedInts)));
370 getActionDefinitionsBuilder(Opcode: G_PTRTOINT)
371 .legalForCartesianProduct(Types0: allIntScalars, Types1: allPtrs)
372 .legalIf(
373 Predicate: all(P0: typeOfExtendedScalars(TypeIdx: 0, IsExtendedInts), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs)));
374 getActionDefinitionsBuilder(Opcode: G_PTR_ADD)
375 .legalForCartesianProduct(Types0: allPtrs, Types1: allIntScalars)
376 .legalIf(
377 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeOfExtendedScalars(TypeIdx: 1, IsExtendedInts)));
378
379 // ST.canDirectlyComparePointers() for pointer args is supported in
380 // legalizeCustom().
381 getActionDefinitionsBuilder(Opcode: G_ICMP)
382 .unsupportedIf(Predicate: LegalityPredicates::any(
383 P0: all(P0: typeIs(TypeIdx: 0, TypesInit: p9), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs), args: typeIsNot(TypeIdx: 1, Type: p9)),
384 P1: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeIsNot(TypeIdx: 0, Type: p9), args: typeIs(TypeIdx: 1, TypesInit: p9))))
385 .customIf(Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allBoolScalarsAndVectors),
386 P1: typeInSet(TypeIdx: 1, TypesInit: allPtrsScalarsAndVectors)));
387
388 getActionDefinitionsBuilder(Opcode: G_FCMP).legalIf(
389 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allBoolScalarsAndVectors),
390 P1: typeInSet(TypeIdx: 1, TypesInit: allFloatScalarsAndVectors)));
391
392 getActionDefinitionsBuilder(Opcodes: {G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
393 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
394 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
395 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
396 .legalForCartesianProduct(Types0: allIntScalars, Types1: allPtrs);
397
398 getActionDefinitionsBuilder(
399 Opcodes: {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
400 .legalForCartesianProduct(Types0: allFloatScalarsAndF16Vector2AndVector4s,
401 Types1: allPtrs);
402
403 getActionDefinitionsBuilder(Opcode: G_ATOMICRMW_XCHG)
404 .legalForCartesianProduct(Types0: allFloatAndIntScalarsAndPtrs, Types1: allPtrs);
405
406 getActionDefinitionsBuilder(Opcode: G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
407 // TODO: add proper legalization rules.
408 getActionDefinitionsBuilder(Opcode: G_ATOMIC_CMPXCHG).alwaysLegal();
409
410 getActionDefinitionsBuilder(
411 Opcodes: {G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO})
412 .alwaysLegal();
413
414 getActionDefinitionsBuilder(Opcodes: {G_LROUND, G_LLROUND})
415 .legalForCartesianProduct(Types0: allFloatScalarsAndVectors,
416 Types1: allIntScalarsAndVectors);
417
418 // FP conversions.
419 getActionDefinitionsBuilder(Opcodes: {G_FPTRUNC, G_FPEXT})
420 .legalForCartesianProduct(Types: allFloatScalarsAndVectors);
421
422 // Pointer-handling.
423 getActionDefinitionsBuilder(Opcode: G_FRAME_INDEX).legalFor(Types: {p0});
424
425 getActionDefinitionsBuilder(Opcode: G_GLOBAL_VALUE).legalFor(Types: allPtrs);
426
427 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
428 getActionDefinitionsBuilder(Opcode: G_BRCOND).legalFor(Types: {s1, s32});
429
430 getActionDefinitionsBuilder(Opcode: G_FFREXP).legalForCartesianProduct(
431 Types0: allFloatScalarsAndVectors, Types1: {s32, v2s32, v3s32, v4s32, v8s32, v16s32});
432
433 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
434 // tighten these requirements. Many of these math functions are only legal on
435 // specific bitwidths, so they are not selectable for
436 // allFloatScalarsAndVectors.
437 getActionDefinitionsBuilder(Opcodes: {G_STRICT_FSQRT,
438 G_FPOW,
439 G_FEXP,
440 G_FMODF,
441 G_FEXP2,
442 G_FLOG,
443 G_FLOG2,
444 G_FLOG10,
445 G_FABS,
446 G_FMINNUM,
447 G_FMAXNUM,
448 G_FCEIL,
449 G_FCOS,
450 G_FSIN,
451 G_FTAN,
452 G_FACOS,
453 G_FASIN,
454 G_FATAN,
455 G_FATAN2,
456 G_FCOSH,
457 G_FSINH,
458 G_FTANH,
459 G_FSQRT,
460 G_FFLOOR,
461 G_FRINT,
462 G_FNEARBYINT,
463 G_INTRINSIC_ROUND,
464 G_INTRINSIC_TRUNC,
465 G_FMINIMUM,
466 G_FMAXIMUM,
467 G_INTRINSIC_ROUNDEVEN})
468 .legalFor(Types: allFloatScalarsAndVectors);
469
470 getActionDefinitionsBuilder(Opcode: G_FCOPYSIGN)
471 .legalForCartesianProduct(Types0: allFloatScalarsAndVectors,
472 Types1: allFloatScalarsAndVectors);
473
474 getActionDefinitionsBuilder(Opcode: G_FPOWI).legalForCartesianProduct(
475 Types0: allFloatScalarsAndVectors, Types1: allIntScalarsAndVectors);
476
477 if (ST.canUseExtInstSet(E: SPIRV::InstructionSet::OpenCL_std)) {
478 getActionDefinitionsBuilder(
479 Opcodes: {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
480 .legalForCartesianProduct(Types0: allIntScalarsAndVectors,
481 Types1: allIntScalarsAndVectors);
482
483 // Struct return types become a single scalar, so cannot easily legalize.
484 getActionDefinitionsBuilder(Opcodes: {G_SMULH, G_UMULH}).alwaysLegal();
485 }
486
487 getActionDefinitionsBuilder(Opcode: G_IS_FPCLASS).custom();
488
489 getLegacyLegalizerInfo().computeTables();
490 verify(MII: *ST.getInstrInfo());
491}
492
493static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
494 SPIRVGlobalRegistry *GR) {
495 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
496 Register DstReg = MI.getOperand(i: 0).getReg();
497 Register SrcReg = MI.getOperand(i: 1).getReg();
498 Register IdxReg = MI.getOperand(i: 2).getReg();
499
500 MIRBuilder
501 .buildIntrinsic(ID: Intrinsic::spv_extractelt, Res: ArrayRef<Register>{DstReg})
502 .addUse(RegNo: SrcReg)
503 .addUse(RegNo: IdxReg);
504 MI.eraseFromParent();
505 return true;
506}
507
508static bool legalizeInsertVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
509 SPIRVGlobalRegistry *GR) {
510 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
511 Register DstReg = MI.getOperand(i: 0).getReg();
512 Register SrcReg = MI.getOperand(i: 1).getReg();
513 Register ValReg = MI.getOperand(i: 2).getReg();
514 Register IdxReg = MI.getOperand(i: 3).getReg();
515
516 MIRBuilder
517 .buildIntrinsic(ID: Intrinsic::spv_insertelt, Res: ArrayRef<Register>{DstReg})
518 .addUse(RegNo: SrcReg)
519 .addUse(RegNo: ValReg)
520 .addUse(RegNo: IdxReg);
521 MI.eraseFromParent();
522 return true;
523}
524
525static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
526 LegalizerHelper &Helper,
527 MachineRegisterInfo &MRI,
528 SPIRVGlobalRegistry *GR) {
529 Register ConvReg = MRI.createGenericVirtualRegister(Ty: ConvTy);
530 MRI.setRegClass(Reg: ConvReg, RC: GR->getRegClass(SpvType));
531 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: ConvReg, MF: Helper.MIRBuilder.getMF());
532 Helper.MIRBuilder.buildInstr(Opcode: TargetOpcode::G_PTRTOINT)
533 .addDef(RegNo: ConvReg)
534 .addUse(RegNo: Reg);
535 return ConvReg;
536}
537
538static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) {
539 if (!Ty.isVector())
540 return false;
541 unsigned NumElements = Ty.getNumElements();
542 unsigned MaxVectorSize = ST.isShader() ? 4 : 16;
543 return (NumElements > 4 && !isPowerOf2_32(Value: NumElements)) ||
544 NumElements > MaxVectorSize;
545}
546
547static bool legalizeLoad(LegalizerHelper &Helper, MachineInstr &MI,
548 SPIRVGlobalRegistry *GR) {
549 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
550 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
551 Register DstReg = MI.getOperand(i: 0).getReg();
552 Register PtrReg = MI.getOperand(i: 1).getReg();
553 LLT DstTy = MRI.getType(Reg: DstReg);
554
555 if (!DstTy.isVector())
556 return true;
557
558 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
559 if (!needsVectorLegalization(Ty: DstTy, ST))
560 return true;
561
562 SmallVector<Register, 8> SplitRegs;
563 LLT EltTy = DstTy.getElementType();
564 unsigned NumElts = DstTy.getNumElements();
565
566 LLT PtrTy = MRI.getType(Reg: PtrReg);
567 auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0);
568
569 for (unsigned i = 0; i < NumElts; ++i) {
570 auto Idx = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: i);
571 Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy);
572
573 MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr})
574 .addImm(Val: 1) // InBounds
575 .addUse(RegNo: PtrReg)
576 .addUse(RegNo: Zero.getReg(Idx: 0))
577 .addUse(RegNo: Idx.getReg(Idx: 0));
578
579 MachinePointerInfo EltPtrInfo;
580 Align EltAlign = Align(1);
581 if (!MI.memoperands_empty()) {
582 MachineMemOperand *MMO = *MI.memoperands_begin();
583 EltPtrInfo =
584 MMO->getPointerInfo().getWithOffset(O: i * EltTy.getSizeInBytes());
585 EltAlign = commonAlignment(A: MMO->getAlign(), Offset: i * EltTy.getSizeInBytes());
586 }
587
588 Register EltReg = MRI.createGenericVirtualRegister(Ty: EltTy);
589 MIRBuilder.buildLoad(Res: EltReg, Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign);
590 SplitRegs.push_back(Elt: EltReg);
591 }
592
593 MIRBuilder.buildBuildVector(Res: DstReg, Ops: SplitRegs);
594 MI.eraseFromParent();
595 return true;
596}
597
598static bool legalizeStore(LegalizerHelper &Helper, MachineInstr &MI,
599 SPIRVGlobalRegistry *GR) {
600 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
601 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
602 Register ValReg = MI.getOperand(i: 0).getReg();
603 Register PtrReg = MI.getOperand(i: 1).getReg();
604 LLT ValTy = MRI.getType(Reg: ValReg);
605
606 assert(ValTy.isVector() && "Expected vector store");
607
608 SmallVector<Register, 8> SplitRegs;
609 LLT EltTy = ValTy.getElementType();
610 unsigned NumElts = ValTy.getNumElements();
611
612 for (unsigned i = 0; i < NumElts; ++i)
613 SplitRegs.push_back(Elt: MRI.createGenericVirtualRegister(Ty: EltTy));
614
615 MIRBuilder.buildUnmerge(Res: SplitRegs, Op: ValReg);
616
617 LLT PtrTy = MRI.getType(Reg: PtrReg);
618 auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0);
619
620 for (unsigned i = 0; i < NumElts; ++i) {
621 auto Idx = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: i);
622 Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy);
623
624 MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr})
625 .addImm(Val: 1) // InBounds
626 .addUse(RegNo: PtrReg)
627 .addUse(RegNo: Zero.getReg(Idx: 0))
628 .addUse(RegNo: Idx.getReg(Idx: 0));
629
630 MachinePointerInfo EltPtrInfo;
631 Align EltAlign = Align(1);
632 if (!MI.memoperands_empty()) {
633 MachineMemOperand *MMO = *MI.memoperands_begin();
634 EltPtrInfo =
635 MMO->getPointerInfo().getWithOffset(O: i * EltTy.getSizeInBytes());
636 EltAlign = commonAlignment(A: MMO->getAlign(), Offset: i * EltTy.getSizeInBytes());
637 }
638
639 MIRBuilder.buildStore(Val: SplitRegs[i], Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign);
640 }
641
642 MI.eraseFromParent();
643 return true;
644}
645
646bool SPIRVLegalizerInfo::legalizeCustom(
647 LegalizerHelper &Helper, MachineInstr &MI,
648 LostDebugLocObserver &LocObserver) const {
649 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
650 switch (MI.getOpcode()) {
651 default:
652 // TODO: implement legalization for other opcodes.
653 return true;
654 case TargetOpcode::G_BITCAST:
655 return legalizeBitcast(Helper, MI);
656 case TargetOpcode::G_EXTRACT_VECTOR_ELT:
657 return legalizeExtractVectorElt(Helper, MI, GR);
658 case TargetOpcode::G_INSERT_VECTOR_ELT:
659 return legalizeInsertVectorElt(Helper, MI, GR);
660 case TargetOpcode::G_INTRINSIC:
661 case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
662 return legalizeIntrinsic(Helper, MI);
663 case TargetOpcode::G_IS_FPCLASS:
664 return legalizeIsFPClass(Helper, MI, LocObserver);
665 case TargetOpcode::G_ICMP: {
666 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
667 auto &Op0 = MI.getOperand(i: 2);
668 auto &Op1 = MI.getOperand(i: 3);
669 Register Reg0 = Op0.getReg();
670 Register Reg1 = Op1.getReg();
671 CmpInst::Predicate Cond =
672 static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
673 if ((!ST->canDirectlyComparePointers() ||
674 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
675 MRI.getType(Reg: Reg0).isPointer() && MRI.getType(Reg: Reg1).isPointer()) {
676 LLT ConvT = LLT::scalar(SizeInBits: ST->getPointerSize());
677 Type *LLVMTy = IntegerType::get(C&: MI.getMF()->getFunction().getContext(),
678 NumBits: ST->getPointerSize());
679 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(
680 Type: LLVMTy, MIRBuilder&: Helper.MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
681 Op0.setReg(convertPtrToInt(Reg: Reg0, ConvTy: ConvT, SpvType: SpirvTy, Helper, MRI, GR));
682 Op1.setReg(convertPtrToInt(Reg: Reg1, ConvTy: ConvT, SpvType: SpirvTy, Helper, MRI, GR));
683 }
684 return true;
685 }
686 case TargetOpcode::G_LOAD:
687 return legalizeLoad(Helper, MI, GR);
688 case TargetOpcode::G_STORE:
689 return legalizeStore(Helper, MI, GR);
690 }
691}
692
693static MachineInstrBuilder
694createStackTemporaryForVector(LegalizerHelper &Helper, SPIRVGlobalRegistry *GR,
695 Register SrcReg, LLT SrcTy,
696 MachinePointerInfo &PtrInfo, Align &VecAlign) {
697 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
698 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
699
700 VecAlign = Helper.getStackTemporaryAlignment(Type: SrcTy);
701 auto StackTemp = Helper.createStackTemporary(
702 Bytes: TypeSize::getFixed(ExactSize: SrcTy.getSizeInBytes()), Alignment: VecAlign, PtrInfo);
703
704 // Set the type of StackTemp to a pointer to an array of the element type.
705 SPIRVType *SpvSrcTy = GR->getSPIRVTypeForVReg(VReg: SrcReg);
706 SPIRVType *EltSpvTy = GR->getScalarOrVectorComponentType(Type: SpvSrcTy);
707 const Type *LLVMEltTy = GR->getTypeForSPIRVType(Ty: EltSpvTy);
708 const Type *LLVMArrTy =
709 ArrayType::get(ElementType: const_cast<Type *>(LLVMEltTy), NumElements: SrcTy.getNumElements());
710 SPIRVType *ArrSpvTy = GR->getOrCreateSPIRVType(
711 Type: LLVMArrTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
712 SPIRVType *PtrToArrSpvTy = GR->getOrCreateSPIRVPointerType(
713 BaseType: ArrSpvTy, MIRBuilder, SC: SPIRV::StorageClass::Function);
714
715 Register StackReg = StackTemp.getReg(Idx: 0);
716 MRI.setRegClass(Reg: StackReg, RC: GR->getRegClass(SpvType: PtrToArrSpvTy));
717 GR->assignSPIRVTypeToVReg(Type: PtrToArrSpvTy, VReg: StackReg, MF: MIRBuilder.getMF());
718
719 return StackTemp;
720}
721
722static bool legalizeSpvBitcast(LegalizerHelper &Helper, MachineInstr &MI,
723 SPIRVGlobalRegistry *GR) {
724 LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
725 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
726 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
727 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
728
729 Register DstReg = MI.getOperand(i: 0).getReg();
730 Register SrcReg = MI.getOperand(i: 2).getReg();
731 LLT DstTy = MRI.getType(Reg: DstReg);
732 LLT SrcTy = MRI.getType(Reg: SrcReg);
733
734 // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
735 // allow using the generic legalization rules.
736 if (needsVectorLegalization(Ty: DstTy, ST) ||
737 needsVectorLegalization(Ty: SrcTy, ST)) {
738 LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
739 MIRBuilder.buildBitcast(Dst: DstReg, Src: SrcReg);
740 MI.eraseFromParent();
741 }
742 return true;
743}
744
745static bool legalizeSpvInsertElt(LegalizerHelper &Helper, MachineInstr &MI,
746 SPIRVGlobalRegistry *GR) {
747 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
748 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
749 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
750
751 Register DstReg = MI.getOperand(i: 0).getReg();
752 LLT DstTy = MRI.getType(Reg: DstReg);
753
754 if (needsVectorLegalization(Ty: DstTy, ST)) {
755 Register SrcReg = MI.getOperand(i: 2).getReg();
756 Register ValReg = MI.getOperand(i: 3).getReg();
757 LLT SrcTy = MRI.getType(Reg: SrcReg);
758 MachineOperand &IdxOperand = MI.getOperand(i: 4);
759
760 if (getImm(MO: IdxOperand, MRI: &MRI)) {
761 uint64_t IdxVal = foldImm(MO: IdxOperand, MRI: &MRI);
762 if (IdxVal < SrcTy.getNumElements()) {
763 SmallVector<Register, 8> Regs;
764 SPIRVType *ElementType =
765 GR->getScalarOrVectorComponentType(Type: GR->getSPIRVTypeForVReg(VReg: DstReg));
766 LLT ElementLLTTy = GR->getRegType(SpvType: ElementType);
767 for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
768 Register Reg = MRI.createGenericVirtualRegister(Ty: ElementLLTTy);
769 MRI.setRegClass(Reg, RC: GR->getRegClass(SpvType: ElementType));
770 GR->assignSPIRVTypeToVReg(Type: ElementType, VReg: Reg, MF: *MI.getMF());
771 Regs.push_back(Elt: Reg);
772 }
773 MIRBuilder.buildUnmerge(Res: Regs, Op: SrcReg);
774 Regs[IdxVal] = ValReg;
775 MIRBuilder.buildBuildVector(Res: DstReg, Ops: Regs);
776 MI.eraseFromParent();
777 return true;
778 }
779 }
780
781 LLT EltTy = SrcTy.getElementType();
782 Align VecAlign;
783 MachinePointerInfo PtrInfo;
784 auto StackTemp = createStackTemporaryForVector(Helper, GR, SrcReg, SrcTy,
785 PtrInfo, VecAlign);
786
787 MIRBuilder.buildStore(Val: SrcReg, Addr: StackTemp, PtrInfo, Alignment: VecAlign);
788
789 Register IdxReg = IdxOperand.getReg();
790 LLT PtrTy = MRI.getType(Reg: StackTemp.getReg(Idx: 0));
791 Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy);
792 auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0);
793
794 MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr})
795 .addImm(Val: 1) // InBounds
796 .addUse(RegNo: StackTemp.getReg(Idx: 0))
797 .addUse(RegNo: Zero.getReg(Idx: 0))
798 .addUse(RegNo: IdxReg);
799
800 MachinePointerInfo EltPtrInfo = MachinePointerInfo(PtrTy.getAddressSpace());
801 Align EltAlign = Helper.getStackTemporaryAlignment(Type: EltTy);
802 MIRBuilder.buildStore(Val: ValReg, Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign);
803
804 MIRBuilder.buildLoad(Res: DstReg, Addr: StackTemp, PtrInfo, Alignment: VecAlign);
805 MI.eraseFromParent();
806 return true;
807 }
808 return true;
809}
810
811static bool legalizeSpvExtractElt(LegalizerHelper &Helper, MachineInstr &MI,
812 SPIRVGlobalRegistry *GR) {
813 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
814 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
815 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
816
817 Register SrcReg = MI.getOperand(i: 2).getReg();
818 LLT SrcTy = MRI.getType(Reg: SrcReg);
819
820 if (needsVectorLegalization(Ty: SrcTy, ST)) {
821 Register DstReg = MI.getOperand(i: 0).getReg();
822 MachineOperand &IdxOperand = MI.getOperand(i: 3);
823
824 if (getImm(MO: IdxOperand, MRI: &MRI)) {
825 uint64_t IdxVal = foldImm(MO: IdxOperand, MRI: &MRI);
826 if (IdxVal < SrcTy.getNumElements()) {
827 LLT DstTy = MRI.getType(Reg: DstReg);
828 SmallVector<Register, 8> Regs;
829 SPIRVType *DstSpvTy = GR->getSPIRVTypeForVReg(VReg: DstReg);
830 for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
831 if (I == IdxVal) {
832 Regs.push_back(Elt: DstReg);
833 } else {
834 Register Reg = MRI.createGenericVirtualRegister(Ty: DstTy);
835 MRI.setRegClass(Reg, RC: GR->getRegClass(SpvType: DstSpvTy));
836 GR->assignSPIRVTypeToVReg(Type: DstSpvTy, VReg: Reg, MF: *MI.getMF());
837 Regs.push_back(Elt: Reg);
838 }
839 }
840 MIRBuilder.buildUnmerge(Res: Regs, Op: SrcReg);
841 MI.eraseFromParent();
842 return true;
843 }
844 }
845
846 LLT EltTy = SrcTy.getElementType();
847 Align VecAlign;
848 MachinePointerInfo PtrInfo;
849 auto StackTemp = createStackTemporaryForVector(Helper, GR, SrcReg, SrcTy,
850 PtrInfo, VecAlign);
851
852 MIRBuilder.buildStore(Val: SrcReg, Addr: StackTemp, PtrInfo, Alignment: VecAlign);
853
854 Register IdxReg = IdxOperand.getReg();
855 LLT PtrTy = MRI.getType(Reg: StackTemp.getReg(Idx: 0));
856 Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy);
857 auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0);
858
859 MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr})
860 .addImm(Val: 1) // InBounds
861 .addUse(RegNo: StackTemp.getReg(Idx: 0))
862 .addUse(RegNo: Zero.getReg(Idx: 0))
863 .addUse(RegNo: IdxReg);
864
865 MachinePointerInfo EltPtrInfo = MachinePointerInfo(PtrTy.getAddressSpace());
866 Align EltAlign = Helper.getStackTemporaryAlignment(Type: EltTy);
867 MIRBuilder.buildLoad(Res: DstReg, Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign);
868
869 MI.eraseFromParent();
870 return true;
871 }
872 return true;
873}
874
875static bool legalizeSpvConstComposite(LegalizerHelper &Helper, MachineInstr &MI,
876 SPIRVGlobalRegistry *GR) {
877 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
878 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
879 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
880
881 Register DstReg = MI.getOperand(i: 0).getReg();
882 LLT DstTy = MRI.getType(Reg: DstReg);
883
884 if (!needsVectorLegalization(Ty: DstTy, ST))
885 return true;
886
887 SmallVector<Register, 8> SrcRegs;
888 if (MI.getNumOperands() == 2) {
889 // The "null" case: no values are attached.
890 LLT EltTy = DstTy.getElementType();
891 auto Zero = MIRBuilder.buildConstant(Res: EltTy, Val: 0);
892 SPIRVType *SpvDstTy = GR->getSPIRVTypeForVReg(VReg: DstReg);
893 SPIRVType *SpvEltTy = GR->getScalarOrVectorComponentType(Type: SpvDstTy);
894 GR->assignSPIRVTypeToVReg(Type: SpvEltTy, VReg: Zero.getReg(Idx: 0), MF: MIRBuilder.getMF());
895 for (unsigned i = 0; i < DstTy.getNumElements(); ++i)
896 SrcRegs.push_back(Elt: Zero.getReg(Idx: 0));
897 } else {
898 for (unsigned i = 2; i < MI.getNumOperands(); ++i) {
899 SrcRegs.push_back(Elt: MI.getOperand(i).getReg());
900 }
901 }
902 MIRBuilder.buildBuildVector(Res: DstReg, Ops: SrcRegs);
903 MI.eraseFromParent();
904 return true;
905}
906
907bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
908 MachineInstr &MI) const {
909 LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
910 auto IntrinsicID = cast<GIntrinsic>(Val&: MI).getIntrinsicID();
911 switch (IntrinsicID) {
912 case Intrinsic::spv_bitcast:
913 return legalizeSpvBitcast(Helper, MI, GR);
914 case Intrinsic::spv_insertelt:
915 return legalizeSpvInsertElt(Helper, MI, GR);
916 case Intrinsic::spv_extractelt:
917 return legalizeSpvExtractElt(Helper, MI, GR);
918 case Intrinsic::spv_const_composite:
919 return legalizeSpvConstComposite(Helper, MI, GR);
920 }
921 return true;
922}
923
924bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
925 MachineInstr &MI) const {
926 // Once the G_BITCAST is using vectors that are allowed, we turn it back into
927 // an spv_bitcast to avoid verifier problems when the register types are the
928 // same for the source and the result. Note that the SPIR-V types associated
929 // with the bitcast can be different even if the register types are the same.
930 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
931 Register DstReg = MI.getOperand(i: 0).getReg();
932 Register SrcReg = MI.getOperand(i: 1).getReg();
933 SmallVector<Register, 1> DstRegs = {DstReg};
934 MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_bitcast, Res: DstRegs).addUse(RegNo: SrcReg);
935 MI.eraseFromParent();
936 return true;
937}
938
939// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
940// to ensure that all instructions created during the lowering have SPIR-V types
941// assigned to them.
942bool SPIRVLegalizerInfo::legalizeIsFPClass(
943 LegalizerHelper &Helper, MachineInstr &MI,
944 LostDebugLocObserver &LocObserver) const {
945 auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
946 FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(i: 2).getImm());
947
948 auto &MIRBuilder = Helper.MIRBuilder;
949 auto &MF = MIRBuilder.getMF();
950 MachineRegisterInfo &MRI = MF.getRegInfo();
951
952 Type *LLVMDstTy =
953 IntegerType::get(C&: MIRBuilder.getContext(), NumBits: DstTy.getScalarSizeInBits());
954 if (DstTy.isVector())
955 LLVMDstTy = VectorType::get(ElementType: LLVMDstTy, EC: DstTy.getElementCount());
956 SPIRVType *SPIRVDstTy = GR->getOrCreateSPIRVType(
957 Type: LLVMDstTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite,
958 /*EmitIR*/ true);
959
960 unsigned BitSize = SrcTy.getScalarSizeInBits();
961 const fltSemantics &Semantics = getFltSemanticForLLT(Ty: SrcTy.getScalarType());
962
963 LLT IntTy = LLT::scalar(SizeInBits: BitSize);
964 Type *LLVMIntTy = IntegerType::get(C&: MIRBuilder.getContext(), NumBits: BitSize);
965 if (SrcTy.isVector()) {
966 IntTy = LLT::vector(EC: SrcTy.getElementCount(), ScalarTy: IntTy);
967 LLVMIntTy = VectorType::get(ElementType: LLVMIntTy, EC: SrcTy.getElementCount());
968 }
969 SPIRVType *SPIRVIntTy = GR->getOrCreateSPIRVType(
970 Type: LLVMIntTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite,
971 /*EmitIR*/ true);
972
973 // Clang doesn't support capture of structured bindings:
974 LLT DstTyCopy = DstTy;
975 const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) {
976 // Assign this MI's (assumed only) destination to one of the two types we
977 // expect: either the G_IS_FPCLASS's destination type, or the integer type
978 // bitcast from the source type.
979 LLT MITy = MRI.getType(Reg: MI.getReg(Idx: 0));
980 assert((MITy == IntTy || MITy == DstTyCopy) &&
981 "Unexpected LLT type while lowering G_IS_FPCLASS");
982 auto *SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy;
983 GR->assignSPIRVTypeToVReg(Type: SPVTy, VReg: MI.getReg(Idx: 0), MF);
984 return MI;
985 };
986
987 // Helper to build and assign a constant in one go
988 const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder {
989 if (!Ty.isFixedVector())
990 return assignSPIRVTy(MIRBuilder.buildConstant(Ty, C));
991 auto ScalarC = MIRBuilder.buildConstant(Ty.getScalarType(), C);
992 assert((Ty == IntTy || Ty == DstTyCopy) &&
993 "Unexpected LLT type while lowering constant for G_IS_FPCLASS");
994 SPIRVType *VecEltTy = GR->getOrCreateSPIRVType(
995 Type: (Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType(), MIRBuilder,
996 AQ: SPIRV::AccessQualifier::ReadWrite,
997 /*EmitIR*/ true);
998 GR->assignSPIRVTypeToVReg(Type: VecEltTy, VReg: ScalarC.getReg(0), MF);
999 return assignSPIRVTy(MIRBuilder.buildSplatBuildVector(Res: Ty, Src: ScalarC));
1000 };
1001
1002 if (Mask == fcNone) {
1003 MIRBuilder.buildCopy(Res: DstReg, Op: buildSPIRVConstant(DstTy, 0));
1004 MI.eraseFromParent();
1005 return true;
1006 }
1007 if (Mask == fcAllFlags) {
1008 MIRBuilder.buildCopy(Res: DstReg, Op: buildSPIRVConstant(DstTy, 1));
1009 MI.eraseFromParent();
1010 return true;
1011 }
1012
1013 // Note that rather than creating a COPY here (between a floating-point and
1014 // integer type of the same size) we create a SPIR-V bitcast immediately. We
1015 // can't create a G_BITCAST because the LLTs are the same, and we can't seem
1016 // to correctly lower COPYs to SPIR-V bitcasts at this moment.
1017 Register ResVReg = MRI.createGenericVirtualRegister(Ty: IntTy);
1018 MRI.setRegClass(Reg: ResVReg, RC: GR->getRegClass(SpvType: SPIRVIntTy));
1019 GR->assignSPIRVTypeToVReg(Type: SPIRVIntTy, VReg: ResVReg, MF: Helper.MIRBuilder.getMF());
1020 auto AsInt = MIRBuilder.buildInstr(Opcode: SPIRV::OpBitcast)
1021 .addDef(RegNo: ResVReg)
1022 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: SPIRVIntTy))
1023 .addUse(RegNo: SrcReg);
1024 AsInt = assignSPIRVTy(std::move(AsInt));
1025
1026 // Various masks.
1027 APInt SignBit = APInt::getSignMask(BitWidth: BitSize);
1028 APInt ValueMask = APInt::getSignedMaxValue(numBits: BitSize); // All bits but sign.
1029 APInt Inf = APFloat::getInf(Sem: Semantics).bitcastToAPInt(); // Exp and int bit.
1030 APInt ExpMask = Inf;
1031 APInt AllOneMantissa = APFloat::getLargest(Sem: Semantics).bitcastToAPInt() & ~Inf;
1032 APInt QNaNBitMask =
1033 APInt::getOneBitSet(numBits: BitSize, BitNo: AllOneMantissa.getActiveBits() - 1);
1034 APInt InversionMask = APInt::getAllOnes(numBits: DstTy.getScalarSizeInBits());
1035
1036 auto SignBitC = buildSPIRVConstant(IntTy, SignBit);
1037 auto ValueMaskC = buildSPIRVConstant(IntTy, ValueMask);
1038 auto InfC = buildSPIRVConstant(IntTy, Inf);
1039 auto ExpMaskC = buildSPIRVConstant(IntTy, ExpMask);
1040 auto ZeroC = buildSPIRVConstant(IntTy, 0);
1041
1042 auto Abs = assignSPIRVTy(MIRBuilder.buildAnd(Dst: IntTy, Src0: AsInt, Src1: ValueMaskC));
1043 auto Sign = assignSPIRVTy(
1044 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_NE, Res: DstTy, Op0: AsInt, Op1: Abs));
1045
1046 auto Res = buildSPIRVConstant(DstTy, 0);
1047
1048 const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) {
1049 Res = assignSPIRVTy(
1050 MIRBuilder.buildOr(Dst: DstTyCopy, Src0: Res, Src1: assignSPIRVTy(std::move(ToAppend))));
1051 };
1052
1053 // Tests that involve more than one class should be processed first.
1054 if ((Mask & fcFinite) == fcFinite) {
1055 // finite(V) ==> abs(V) u< exp_mask
1056 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: Abs,
1057 Op1: ExpMaskC));
1058 Mask &= ~fcFinite;
1059 } else if ((Mask & fcFinite) == fcPosFinite) {
1060 // finite(V) && V > 0 ==> V u< exp_mask
1061 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: AsInt,
1062 Op1: ExpMaskC));
1063 Mask &= ~fcPosFinite;
1064 } else if ((Mask & fcFinite) == fcNegFinite) {
1065 // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
1066 auto Cmp = assignSPIRVTy(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT,
1067 Res: DstTy, Op0: Abs, Op1: ExpMaskC));
1068 appendToRes(MIRBuilder.buildAnd(Dst: DstTy, Src0: Cmp, Src1: Sign));
1069 Mask &= ~fcNegFinite;
1070 }
1071
1072 if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
1073 // fcZero | fcSubnormal => test all exponent bits are 0
1074 // TODO: Handle sign bit specific cases
1075 // TODO: Handle inverted case
1076 if (PartialCheck == (fcZero | fcSubnormal)) {
1077 auto ExpBits = assignSPIRVTy(MIRBuilder.buildAnd(Dst: IntTy, Src0: AsInt, Src1: ExpMaskC));
1078 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy,
1079 Op0: ExpBits, Op1: ZeroC));
1080 Mask &= ~PartialCheck;
1081 }
1082 }
1083
1084 // Check for individual classes.
1085 if (FPClassTest PartialCheck = Mask & fcZero) {
1086 if (PartialCheck == fcPosZero)
1087 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy,
1088 Op0: AsInt, Op1: ZeroC));
1089 else if (PartialCheck == fcZero)
1090 appendToRes(
1091 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy, Op0: Abs, Op1: ZeroC));
1092 else // fcNegZero
1093 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy,
1094 Op0: AsInt, Op1: SignBitC));
1095 }
1096
1097 if (FPClassTest PartialCheck = Mask & fcSubnormal) {
1098 // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
1099 // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
1100 auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
1101 auto OneC = buildSPIRVConstant(IntTy, 1);
1102 auto VMinusOne = MIRBuilder.buildSub(Dst: IntTy, Src0: V, Src1: OneC);
1103 auto SubnormalRes = assignSPIRVTy(
1104 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: VMinusOne,
1105 Op1: buildSPIRVConstant(IntTy, AllOneMantissa)));
1106 if (PartialCheck == fcNegSubnormal)
1107 SubnormalRes = MIRBuilder.buildAnd(Dst: DstTy, Src0: SubnormalRes, Src1: Sign);
1108 appendToRes(std::move(SubnormalRes));
1109 }
1110
1111 if (FPClassTest PartialCheck = Mask & fcInf) {
1112 if (PartialCheck == fcPosInf)
1113 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy,
1114 Op0: AsInt, Op1: InfC));
1115 else if (PartialCheck == fcInf)
1116 appendToRes(
1117 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy, Op0: Abs, Op1: InfC));
1118 else { // fcNegInf
1119 APInt NegInf = APFloat::getInf(Sem: Semantics, Negative: true).bitcastToAPInt();
1120 auto NegInfC = buildSPIRVConstant(IntTy, NegInf);
1121 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy,
1122 Op0: AsInt, Op1: NegInfC));
1123 }
1124 }
1125
1126 if (FPClassTest PartialCheck = Mask & fcNan) {
1127 auto InfWithQnanBitC =
1128 buildSPIRVConstant(IntTy, std::move(Inf) | QNaNBitMask);
1129 if (PartialCheck == fcNan) {
1130 // isnan(V) ==> abs(V) u> int(inf)
1131 appendToRes(
1132 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_UGT, Res: DstTy, Op0: Abs, Op1: InfC));
1133 } else if (PartialCheck == fcQNan) {
1134 // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
1135 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_UGE, Res: DstTy, Op0: Abs,
1136 Op1: InfWithQnanBitC));
1137 } else { // fcSNan
1138 // issignaling(V) ==> abs(V) u> unsigned(Inf) &&
1139 // abs(V) u< (unsigned(Inf) | quiet_bit)
1140 auto IsNan = assignSPIRVTy(
1141 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_UGT, Res: DstTy, Op0: Abs, Op1: InfC));
1142 auto IsNotQnan = assignSPIRVTy(MIRBuilder.buildICmp(
1143 Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: Abs, Op1: InfWithQnanBitC));
1144 appendToRes(MIRBuilder.buildAnd(Dst: DstTy, Src0: IsNan, Src1: IsNotQnan));
1145 }
1146 }
1147
1148 if (FPClassTest PartialCheck = Mask & fcNormal) {
1149 // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
1150 // (max_exp-1))
1151 APInt ExpLSB = ExpMask & ~(ExpMask.shl(shiftAmt: 1));
1152 auto ExpMinusOne = assignSPIRVTy(
1153 MIRBuilder.buildSub(Dst: IntTy, Src0: Abs, Src1: buildSPIRVConstant(IntTy, ExpLSB)));
1154 APInt MaxExpMinusOne = std::move(ExpMask) - ExpLSB;
1155 auto NormalRes = assignSPIRVTy(
1156 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: ExpMinusOne,
1157 Op1: buildSPIRVConstant(IntTy, MaxExpMinusOne)));
1158 if (PartialCheck == fcNegNormal)
1159 NormalRes = MIRBuilder.buildAnd(Dst: DstTy, Src0: NormalRes, Src1: Sign);
1160 else if (PartialCheck == fcPosNormal) {
1161 auto PosSign = assignSPIRVTy(MIRBuilder.buildXor(
1162 Dst: DstTy, Src0: Sign, Src1: buildSPIRVConstant(DstTy, InversionMask)));
1163 NormalRes = MIRBuilder.buildAnd(Dst: DstTy, Src0: NormalRes, Src1: PosSign);
1164 }
1165 appendToRes(std::move(NormalRes));
1166 }
1167
1168 MIRBuilder.buildCopy(Res: DstReg, Op: Res);
1169 MI.eraseFromParent();
1170 return true;
1171}
1172