1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
17#include "SPIRVUtils.h"
18#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
19#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
20#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
21#include "llvm/CodeGen/MachineInstr.h"
22#include "llvm/CodeGen/MachineRegisterInfo.h"
23#include "llvm/CodeGen/TargetOpcodes.h"
24#include "llvm/IR/IntrinsicsSPIRV.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/MathExtras.h"
27
28using namespace llvm;
29using namespace llvm::LegalizeActions;
30using namespace llvm::LegalityPredicates;
31
32#define DEBUG_TYPE "spirv-legalizer"
33
34LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
35 return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
36 const LLT Ty = Query.Types[TypeIdx];
37 return IsExtendedInts && Ty.isValid() && Ty.isScalar();
38 };
39}
40
41SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
42 using namespace TargetOpcode;
43
44 this->ST = &ST;
45 GR = ST.getSPIRVGlobalRegistry();
46
47 const LLT s1 = LLT::scalar(SizeInBits: 1);
48 const LLT s8 = LLT::scalar(SizeInBits: 8);
49 const LLT s16 = LLT::scalar(SizeInBits: 16);
50 const LLT s32 = LLT::scalar(SizeInBits: 32);
51 const LLT s64 = LLT::scalar(SizeInBits: 64);
52 const LLT s128 = LLT::scalar(SizeInBits: 128);
53
54 const LLT v16s64 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 64);
55 const LLT v16s32 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 32);
56 const LLT v16s16 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 16);
57 const LLT v16s8 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8);
58 const LLT v16s1 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 1);
59
60 const LLT v8s64 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 64);
61 const LLT v8s32 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 32);
62 const LLT v8s16 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 16);
63 const LLT v8s8 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 8);
64 const LLT v8s1 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 1);
65
66 const LLT v4s64 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 64);
67 const LLT v4s32 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32);
68 const LLT v4s16 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 16);
69 const LLT v4s8 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 8);
70 const LLT v4s1 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 1);
71
72 const LLT v3s64 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 64);
73 const LLT v3s32 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 32);
74 const LLT v3s16 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 16);
75 const LLT v3s8 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 8);
76 const LLT v3s1 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 1);
77
78 const LLT v2s64 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 64);
79 const LLT v2s32 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 32);
80 const LLT v2s16 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 16);
81 const LLT v2s8 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 8);
82 const LLT v2s1 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 1);
83
84 const unsigned PSize = ST.getPointerSize();
85 const LLT p0 = LLT::pointer(AddressSpace: 0, SizeInBits: PSize); // Function
86 const LLT p1 = LLT::pointer(AddressSpace: 1, SizeInBits: PSize); // CrossWorkgroup
87 const LLT p2 = LLT::pointer(AddressSpace: 2, SizeInBits: PSize); // UniformConstant
88 const LLT p3 = LLT::pointer(AddressSpace: 3, SizeInBits: PSize); // Workgroup
89 const LLT p4 = LLT::pointer(AddressSpace: 4, SizeInBits: PSize); // Generic
90 const LLT p5 =
91 LLT::pointer(AddressSpace: 5, SizeInBits: PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
92 const LLT p6 = LLT::pointer(AddressSpace: 6, SizeInBits: PSize); // SPV_INTEL_usm_storage_classes (Host)
93 const LLT p7 = LLT::pointer(AddressSpace: 7, SizeInBits: PSize); // Input
94 const LLT p8 = LLT::pointer(AddressSpace: 8, SizeInBits: PSize); // Output
95 const LLT p9 =
96 LLT::pointer(AddressSpace: 9, SizeInBits: PSize); // CodeSectionINTEL, SPV_INTEL_function_pointers
97 const LLT p10 = LLT::pointer(AddressSpace: 10, SizeInBits: PSize); // Private
98 const LLT p11 = LLT::pointer(AddressSpace: 11, SizeInBits: PSize); // StorageBuffer
99 const LLT p12 = LLT::pointer(AddressSpace: 12, SizeInBits: PSize); // Uniform
100 const LLT p13 = LLT::pointer(AddressSpace: 13, SizeInBits: PSize); // PushConstant
101
102 // TODO: remove copy-pasting here by using concatenation in some way.
103 auto allPtrsScalarsAndVectors = {
104 p0, p1, p2, p3, p4, p5, p6, p7, p8,
105 p9, p10, p11, p12, p13, s1, s8, s16, s32,
106 s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
107 v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8,
108 v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
109
110 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
112 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
113 v16s8, v16s16, v16s32, v16s64};
114
115 auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
116 v3s1, v3s8, v3s16, v3s32, v3s64,
117 v4s1, v4s8, v4s16, v4s32, v4s64};
118
119 auto allScalars = {s1, s8, s16, s32, s64};
120
121 auto allScalarsAndVectors = {
122 s1, s8, s16, s32, s64, s128, v2s1, v2s8,
123 v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64,
124 v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
125 v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
126
127 auto allIntScalarsAndVectors = {
128 s8, s16, s32, s64, s128, v2s8, v2s16, v2s32, v2s64,
129 v3s8, v3s16, v3s32, v3s64, v4s8, v4s16, v4s32, v4s64, v8s8,
130 v8s16, v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
131
132 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
133
134 auto allIntScalars = {s8, s16, s32, s64, s128};
135
136 auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16};
137
138 auto allFloatScalarsAndVectors = {
139 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
140 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
141
142 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,
143 p2, p3, p4, p5, p6, p7,
144 p8, p9, p10, p11, p12, p13};
145
146 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13};
147
148 auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
149
150 bool IsExtendedInts =
151 ST.canUseExtension(
152 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers) ||
153 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bit_instructions) ||
154 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_int4);
155 auto extendedScalarsAndVectors =
156 [IsExtendedInts](const LegalityQuery &Query) {
157 const LLT Ty = Query.Types[0];
158 return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
159 };
160 auto extendedScalarsAndVectorsProduct = [IsExtendedInts](
161 const LegalityQuery &Query) {
162 const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1];
163 return IsExtendedInts && Ty1.isValid() && Ty2.isValid() &&
164 !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector();
165 };
166 auto extendedPtrsScalarsAndVectors =
167 [IsExtendedInts](const LegalityQuery &Query) {
168 const LLT Ty = Query.Types[0];
169 return IsExtendedInts && Ty.isValid();
170 };
171
172 // The universal validation rules in the SPIR-V specification state that
173 // vector sizes are typically limited to 2, 3, or 4. However, larger vector
174 // sizes (8 and 16) are enabled when the Kernel capability is present. For
175 // shader execution models, vector sizes are strictly limited to 4. In
176 // non-shader contexts, vector sizes of 8 and 16 are also permitted, but
177 // arbitrary sizes (e.g., 6 or 11) are not.
178 uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
179 LLVM_DEBUG(dbgs() << "MaxVectorSize: " << MaxVectorSize << "\n");
180
181 for (auto Opc : getTypeFoldingSupportedOpcodes()) {
182 switch (Opc) {
183 case G_EXTRACT_VECTOR_ELT:
184 case G_UREM:
185 case G_SREM:
186 case G_UDIV:
187 case G_SDIV:
188 case G_FREM:
189 break;
190 default:
191 getActionDefinitionsBuilder(Opcode: Opc)
192 .customFor(Types: allScalars)
193 .customFor(Types: allowedVectorTypes)
194 .moreElementsToNextPow2(TypeIdx: 0)
195 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
196 Mutation: LegalizeMutations::changeElementCountTo(
197 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
198 .custom();
199 break;
200 }
201 }
202
203 getActionDefinitionsBuilder(Opcodes: {G_UREM, G_SREM, G_SDIV, G_UDIV, G_FREM})
204 .customFor(Types: allScalars)
205 .customFor(Types: allowedVectorTypes)
206 .scalarizeIf(Predicate: numElementsNotPow2(TypeIdx: 0), TypeIdx: 0)
207 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
208 Mutation: LegalizeMutations::changeElementCountTo(
209 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
210 .custom();
211
212 getActionDefinitionsBuilder(Opcodes: {G_FMA, G_STRICT_FMA})
213 .legalFor(Types: allScalars)
214 .legalFor(Types: allowedVectorTypes)
215 .moreElementsToNextPow2(TypeIdx: 0)
216 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
217 Mutation: LegalizeMutations::changeElementCountTo(
218 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
219 .alwaysLegal();
220
221 getActionDefinitionsBuilder(Opcode: G_INTRINSIC_W_SIDE_EFFECTS).custom();
222
223 getActionDefinitionsBuilder(Opcode: G_SHUFFLE_VECTOR)
224 .legalForCartesianProduct(Types0: allowedVectorTypes, Types1: allowedVectorTypes)
225 .moreElementsToNextPow2(TypeIdx: 0)
226 .lowerIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize))
227 .moreElementsToNextPow2(TypeIdx: 1)
228 .lowerIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 1, Size: MaxVectorSize));
229
230 getActionDefinitionsBuilder(Opcode: G_EXTRACT_VECTOR_ELT)
231 .moreElementsToNextPow2(TypeIdx: 1)
232 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 1, Size: MaxVectorSize),
233 Mutation: LegalizeMutations::changeElementCountTo(
234 TypeIdx: 1, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
235 .custom();
236
237 getActionDefinitionsBuilder(Opcode: G_INSERT_VECTOR_ELT)
238 .moreElementsToNextPow2(TypeIdx: 0)
239 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
240 Mutation: LegalizeMutations::changeElementCountTo(
241 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
242 .custom();
243
244 // Illegal G_UNMERGE_VALUES instructions should be handled
245 // during the combine phase.
246 getActionDefinitionsBuilder(Opcode: G_BUILD_VECTOR)
247 .legalIf(Predicate: vectorElementCountIsLessThanOrEqualTo(TypeIdx: 0, Size: MaxVectorSize));
248
249 // When entering the legalizer, there should be no G_BITCAST instructions.
250 // They should all be calls to the `spv_bitcast` intrinsic. The call to
251 // the intrinsic will be converted to a G_BITCAST during legalization if
252 // the vectors are not legal. After using the rules to legalize a G_BITCAST,
253 // we turn it back into a call to the intrinsic with a custom rule to avoid
254 // potential machine verifier failures.
255 getActionDefinitionsBuilder(Opcode: G_BITCAST)
256 .moreElementsToNextPow2(TypeIdx: 0)
257 .moreElementsToNextPow2(TypeIdx: 1)
258 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
259 Mutation: LegalizeMutations::changeElementCountTo(
260 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)))
261 .lowerIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 1, Size: MaxVectorSize))
262 .custom();
263
264 // If the result is still illegal, the combiner should be able to remove it.
265 getActionDefinitionsBuilder(Opcode: G_CONCAT_VECTORS)
266 .legalForCartesianProduct(Types0: allowedVectorTypes, Types1: allowedVectorTypes);
267
268 getActionDefinitionsBuilder(Opcode: G_SPLAT_VECTOR)
269 .legalFor(Types: allowedVectorTypes)
270 .moreElementsToNextPow2(TypeIdx: 0)
271 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
272 Mutation: LegalizeMutations::changeElementSizeTo(TypeIdx: 0, FromTypeIdx: MaxVectorSize))
273 .alwaysLegal();
274
275 // Vector Reduction Operations
276 getActionDefinitionsBuilder(
277 Opcodes: {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
278 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
279 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
280 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
281 .legalFor(Types: allowedVectorTypes)
282 .scalarize(TypeIdx: 1)
283 .lower();
284
285 getActionDefinitionsBuilder(Opcodes: {G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
286 .scalarize(TypeIdx: 2)
287 .lower();
288
289 // Illegal G_UNMERGE_VALUES instructions should be handled
290 // during the combine phase.
291 getActionDefinitionsBuilder(Opcode: G_UNMERGE_VALUES)
292 .legalIf(Predicate: vectorElementCountIsLessThanOrEqualTo(TypeIdx: 1, Size: MaxVectorSize));
293
294 getActionDefinitionsBuilder(Opcodes: {G_MEMCPY, G_MEMMOVE})
295 .unsupportedIf(Predicate: LegalityPredicates::any(P0: typeIs(TypeIdx: 0, TypesInit: p9), P1: typeIs(TypeIdx: 1, TypesInit: p9)))
296 .legalIf(Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs)));
297
298 getActionDefinitionsBuilder(Opcode: G_MEMSET)
299 .unsupportedIf(Predicate: typeIs(TypeIdx: 0, TypesInit: p9))
300 .legalIf(Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeInSet(TypeIdx: 1, TypesInit: allIntScalars)));
301
302 getActionDefinitionsBuilder(Opcode: G_ADDRSPACE_CAST)
303 .unsupportedIf(
304 Predicate: LegalityPredicates::any(P0: all(P0: typeIs(TypeIdx: 0, TypesInit: p9), P1: typeIsNot(TypeIdx: 1, Type: p9)),
305 P1: all(P0: typeIsNot(TypeIdx: 0, Type: p9), P1: typeIs(TypeIdx: 1, TypesInit: p9))))
306 .legalForCartesianProduct(Types0: allPtrs, Types1: allPtrs);
307
308 // Should we be legalizing bad scalar sizes like s5 here instead
309 // of handling them in the instruction selector?
310 getActionDefinitionsBuilder(Opcodes: {G_LOAD, G_STORE})
311 .unsupportedIf(Predicate: typeIs(TypeIdx: 1, TypesInit: p9))
312 .legalForCartesianProduct(Types0: allowedVectorTypes, Types1: allPtrs)
313 .legalForCartesianProduct(Types0: allPtrs, Types1: allPtrs)
314 .legalIf(Predicate: isScalar(TypeIdx: 0))
315 .custom();
316
317 getActionDefinitionsBuilder(Opcodes: {G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
318 G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
319 G_USUBSAT, G_SCMP, G_UCMP})
320 .legalFor(Types: allIntScalarsAndVectors)
321 .legalIf(Predicate: extendedScalarsAndVectors);
322
323 getActionDefinitionsBuilder(Opcode: G_STRICT_FLDEXP)
324 .legalForCartesianProduct(Types0: allFloatScalarsAndVectors, Types1: allIntScalars);
325
326 getActionDefinitionsBuilder(Opcodes: {G_FPTOSI, G_FPTOUI})
327 .legalForCartesianProduct(Types0: allIntScalarsAndVectors,
328 Types1: allFloatScalarsAndVectors);
329
330 getActionDefinitionsBuilder(Opcodes: {G_FPTOSI_SAT, G_FPTOUI_SAT})
331 .legalForCartesianProduct(Types0: allIntScalarsAndVectors,
332 Types1: allFloatScalarsAndVectors);
333
334 getActionDefinitionsBuilder(Opcodes: {G_SITOFP, G_UITOFP})
335 .legalForCartesianProduct(Types0: allFloatScalarsAndVectors,
336 Types1: allScalarsAndVectors);
337
338 getActionDefinitionsBuilder(Opcode: G_CTPOP)
339 .legalForCartesianProduct(Types: allIntScalarsAndVectors)
340 .legalIf(Predicate: extendedScalarsAndVectorsProduct);
341
342 // Extensions.
343 getActionDefinitionsBuilder(Opcodes: {G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
344 .legalForCartesianProduct(Types: allScalarsAndVectors)
345 .legalIf(Predicate: extendedScalarsAndVectorsProduct);
346
347 getActionDefinitionsBuilder(Opcode: G_PHI)
348 .legalFor(Types: allPtrsScalarsAndVectors)
349 .legalIf(Predicate: extendedPtrsScalarsAndVectors);
350
351 getActionDefinitionsBuilder(Opcode: G_BITCAST).legalIf(
352 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrsScalarsAndVectors),
353 P1: typeInSet(TypeIdx: 1, TypesInit: allPtrsScalarsAndVectors)));
354
355 getActionDefinitionsBuilder(Opcodes: {G_IMPLICIT_DEF, G_FREEZE})
356 .legalFor(Types: {s1, s128})
357 .legalFor(Types: allFloatAndIntScalarsAndPtrs)
358 .legalFor(Types: allowedVectorTypes)
359 .moreElementsToNextPow2(TypeIdx: 0)
360 .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize),
361 Mutation: LegalizeMutations::changeElementCountTo(
362 TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize)));
363
364 getActionDefinitionsBuilder(Opcodes: {G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
365
366 getActionDefinitionsBuilder(Opcode: G_INTTOPTR)
367 .legalForCartesianProduct(Types0: allPtrs, Types1: allIntScalars)
368 .legalIf(
369 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeOfExtendedScalars(TypeIdx: 1, IsExtendedInts)));
370 getActionDefinitionsBuilder(Opcode: G_PTRTOINT)
371 .legalForCartesianProduct(Types0: allIntScalars, Types1: allPtrs)
372 .legalIf(
373 Predicate: all(P0: typeOfExtendedScalars(TypeIdx: 0, IsExtendedInts), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs)));
374 getActionDefinitionsBuilder(Opcode: G_PTR_ADD)
375 .legalForCartesianProduct(Types0: allPtrs, Types1: allIntScalars)
376 .legalIf(
377 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeOfExtendedScalars(TypeIdx: 1, IsExtendedInts)));
378
379 // ST.canDirectlyComparePointers() for pointer args is supported in
380 // legalizeCustom().
381 getActionDefinitionsBuilder(Opcode: G_ICMP)
382 .unsupportedIf(Predicate: LegalityPredicates::any(
383 P0: all(P0: typeIs(TypeIdx: 0, TypesInit: p9), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs), args: typeIsNot(TypeIdx: 1, Type: p9)),
384 P1: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeIsNot(TypeIdx: 0, Type: p9), args: typeIs(TypeIdx: 1, TypesInit: p9))))
385 .legalIf(Predicate: [IsExtendedInts](const LegalityQuery &Query) {
386 const LLT Ty = Query.Types[1];
387 return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
388 })
389 .customIf(Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allBoolScalarsAndVectors),
390 P1: typeInSet(TypeIdx: 1, TypesInit: allPtrsScalarsAndVectors)));
391
392 getActionDefinitionsBuilder(Opcode: G_FCMP).legalIf(
393 Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allBoolScalarsAndVectors),
394 P1: typeInSet(TypeIdx: 1, TypesInit: allFloatScalarsAndVectors)));
395
396 getActionDefinitionsBuilder(Opcodes: {G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
397 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
398 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
399 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
400 .legalForCartesianProduct(Types0: allIntScalars, Types1: allPtrs);
401
402 getActionDefinitionsBuilder(
403 Opcodes: {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
404 .legalForCartesianProduct(Types0: allFloatScalarsAndF16Vector2AndVector4s,
405 Types1: allPtrs);
406
407 getActionDefinitionsBuilder(Opcode: G_ATOMICRMW_XCHG)
408 .legalForCartesianProduct(Types0: allFloatAndIntScalarsAndPtrs, Types1: allPtrs);
409
410 getActionDefinitionsBuilder(Opcode: G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
411 // TODO: add proper legalization rules.
412 getActionDefinitionsBuilder(Opcode: G_ATOMIC_CMPXCHG).alwaysLegal();
413
414 getActionDefinitionsBuilder(
415 Opcodes: {G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO})
416 .alwaysLegal();
417
418 getActionDefinitionsBuilder(Opcodes: {G_LROUND, G_LLROUND})
419 .legalForCartesianProduct(Types0: allFloatScalarsAndVectors,
420 Types1: allIntScalarsAndVectors);
421
422 // FP conversions.
423 getActionDefinitionsBuilder(Opcodes: {G_FPTRUNC, G_FPEXT})
424 .legalForCartesianProduct(Types: allFloatScalarsAndVectors);
425
426 // Pointer-handling.
427 getActionDefinitionsBuilder(Opcode: G_FRAME_INDEX).legalFor(Types: {p0});
428
429 getActionDefinitionsBuilder(Opcode: G_GLOBAL_VALUE).legalFor(Types: allPtrs);
430
431 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
432 getActionDefinitionsBuilder(Opcode: G_BRCOND).legalFor(Types: {s1, s32});
433
434 getActionDefinitionsBuilder(Opcode: G_FFREXP).legalForCartesianProduct(
435 Types0: allFloatScalarsAndVectors, Types1: {s32, v2s32, v3s32, v4s32, v8s32, v16s32});
436
437 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
438 // tighten these requirements. Many of these math functions are only legal on
439 // specific bitwidths, so they are not selectable for
440 // allFloatScalarsAndVectors.
441 // clang-format off
442 getActionDefinitionsBuilder(Opcodes: {G_STRICT_FSQRT,
443 G_FPOW,
444 G_FEXP,
445 G_FMODF,
446 G_FSINCOS,
447 G_FEXP2,
448 G_FLOG,
449 G_FLOG2,
450 G_FLOG10,
451 G_FABS,
452 G_FMINNUM,
453 G_FMAXNUM,
454 G_FCEIL,
455 G_FCOS,
456 G_FSIN,
457 G_FTAN,
458 G_FACOS,
459 G_FASIN,
460 G_FATAN,
461 G_FATAN2,
462 G_FCOSH,
463 G_FSINH,
464 G_FTANH,
465 G_FSQRT,
466 G_FFLOOR,
467 G_FRINT,
468 G_FNEARBYINT,
469 G_INTRINSIC_ROUND,
470 G_INTRINSIC_TRUNC,
471 G_FMINIMUM,
472 G_FMAXIMUM,
473 G_INTRINSIC_ROUNDEVEN})
474 .legalFor(Types: allFloatScalarsAndVectors);
475 // clang-format on
476
477 getActionDefinitionsBuilder(Opcode: G_FCOPYSIGN)
478 .legalForCartesianProduct(Types0: allFloatScalarsAndVectors,
479 Types1: allFloatScalarsAndVectors);
480
481 getActionDefinitionsBuilder(Opcode: G_FPOWI).legalForCartesianProduct(
482 Types0: allFloatScalarsAndVectors, Types1: allIntScalarsAndVectors);
483
484 if (ST.canUseExtInstSet(E: SPIRV::InstructionSet::OpenCL_std)) {
485 getActionDefinitionsBuilder(
486 Opcodes: {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
487 .legalForCartesianProduct(Types0: allIntScalarsAndVectors,
488 Types1: allIntScalarsAndVectors);
489
490 // Struct return types become a single scalar, so cannot easily legalize.
491 getActionDefinitionsBuilder(Opcodes: {G_SMULH, G_UMULH}).alwaysLegal();
492 }
493
494 getActionDefinitionsBuilder(Opcode: G_IS_FPCLASS).custom();
495
496 getLegacyLegalizerInfo().computeTables();
497 verify(MII: *ST.getInstrInfo());
498}
499
500static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
501 SPIRVGlobalRegistry *GR) {
502 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
503 Register DstReg = MI.getOperand(i: 0).getReg();
504 Register SrcReg = MI.getOperand(i: 1).getReg();
505 Register IdxReg = MI.getOperand(i: 2).getReg();
506
507 MIRBuilder
508 .buildIntrinsic(ID: Intrinsic::spv_extractelt, Res: ArrayRef<Register>{DstReg})
509 .addUse(RegNo: SrcReg)
510 .addUse(RegNo: IdxReg);
511 MI.eraseFromParent();
512 return true;
513}
514
515static bool legalizeInsertVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
516 SPIRVGlobalRegistry *GR) {
517 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
518 Register DstReg = MI.getOperand(i: 0).getReg();
519 Register SrcReg = MI.getOperand(i: 1).getReg();
520 Register ValReg = MI.getOperand(i: 2).getReg();
521 Register IdxReg = MI.getOperand(i: 3).getReg();
522
523 MIRBuilder
524 .buildIntrinsic(ID: Intrinsic::spv_insertelt, Res: ArrayRef<Register>{DstReg})
525 .addUse(RegNo: SrcReg)
526 .addUse(RegNo: ValReg)
527 .addUse(RegNo: IdxReg);
528 MI.eraseFromParent();
529 return true;
530}
531
532static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVTypeInst SpvType,
533 LegalizerHelper &Helper,
534 MachineRegisterInfo &MRI,
535 SPIRVGlobalRegistry *GR) {
536 Register ConvReg = MRI.createGenericVirtualRegister(Ty: ConvTy);
537 MRI.setRegClass(Reg: ConvReg, RC: GR->getRegClass(SpvType));
538 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: ConvReg, MF: Helper.MIRBuilder.getMF());
539 Helper.MIRBuilder.buildInstr(Opcode: TargetOpcode::G_PTRTOINT)
540 .addDef(RegNo: ConvReg)
541 .addUse(RegNo: Reg);
542 return ConvReg;
543}
544
545static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) {
546 if (!Ty.isVector())
547 return false;
548 unsigned NumElements = Ty.getNumElements();
549 unsigned MaxVectorSize = ST.isShader() ? 4 : 16;
550 return (NumElements > 4 && !isPowerOf2_32(Value: NumElements)) ||
551 NumElements > MaxVectorSize;
552}
553
554static bool legalizeLoad(LegalizerHelper &Helper, MachineInstr &MI,
555 SPIRVGlobalRegistry *GR) {
556 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
557 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
558 Register DstReg = MI.getOperand(i: 0).getReg();
559 Register PtrReg = MI.getOperand(i: 1).getReg();
560 LLT DstTy = MRI.getType(Reg: DstReg);
561
562 if (!DstTy.isVector())
563 return true;
564
565 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
566 if (!needsVectorLegalization(Ty: DstTy, ST))
567 return true;
568
569 SmallVector<Register, 8> SplitRegs;
570 LLT EltTy = DstTy.getElementType();
571 unsigned NumElts = DstTy.getNumElements();
572
573 LLT PtrTy = MRI.getType(Reg: PtrReg);
574 auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0);
575
576 for (unsigned i = 0; i < NumElts; ++i) {
577 auto Idx = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: i);
578 Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy);
579
580 MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr})
581 .addImm(Val: 1) // InBounds
582 .addUse(RegNo: PtrReg)
583 .addUse(RegNo: Zero.getReg(Idx: 0))
584 .addUse(RegNo: Idx.getReg(Idx: 0));
585
586 MachinePointerInfo EltPtrInfo;
587 Align EltAlign = Align(1);
588 if (!MI.memoperands_empty()) {
589 MachineMemOperand *MMO = *MI.memoperands_begin();
590 EltPtrInfo =
591 MMO->getPointerInfo().getWithOffset(O: i * EltTy.getSizeInBytes());
592 EltAlign = commonAlignment(A: MMO->getAlign(), Offset: i * EltTy.getSizeInBytes());
593 }
594
595 Register EltReg = MRI.createGenericVirtualRegister(Ty: EltTy);
596 MIRBuilder.buildLoad(Res: EltReg, Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign);
597 SplitRegs.push_back(Elt: EltReg);
598 }
599
600 MIRBuilder.buildBuildVector(Res: DstReg, Ops: SplitRegs);
601 MI.eraseFromParent();
602 return true;
603}
604
605static bool legalizeStore(LegalizerHelper &Helper, MachineInstr &MI,
606 SPIRVGlobalRegistry *GR) {
607 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
608 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
609 Register ValReg = MI.getOperand(i: 0).getReg();
610 Register PtrReg = MI.getOperand(i: 1).getReg();
611 LLT ValTy = MRI.getType(Reg: ValReg);
612
613 assert(ValTy.isVector() && "Expected vector store");
614
615 SmallVector<Register, 8> SplitRegs;
616 LLT EltTy = ValTy.getElementType();
617 unsigned NumElts = ValTy.getNumElements();
618
619 for (unsigned i = 0; i < NumElts; ++i)
620 SplitRegs.push_back(Elt: MRI.createGenericVirtualRegister(Ty: EltTy));
621
622 MIRBuilder.buildUnmerge(Res: SplitRegs, Op: ValReg);
623
624 LLT PtrTy = MRI.getType(Reg: PtrReg);
625 auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0);
626
627 for (unsigned i = 0; i < NumElts; ++i) {
628 auto Idx = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: i);
629 Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy);
630
631 MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr})
632 .addImm(Val: 1) // InBounds
633 .addUse(RegNo: PtrReg)
634 .addUse(RegNo: Zero.getReg(Idx: 0))
635 .addUse(RegNo: Idx.getReg(Idx: 0));
636
637 MachinePointerInfo EltPtrInfo;
638 Align EltAlign = Align(1);
639 if (!MI.memoperands_empty()) {
640 MachineMemOperand *MMO = *MI.memoperands_begin();
641 EltPtrInfo =
642 MMO->getPointerInfo().getWithOffset(O: i * EltTy.getSizeInBytes());
643 EltAlign = commonAlignment(A: MMO->getAlign(), Offset: i * EltTy.getSizeInBytes());
644 }
645
646 MIRBuilder.buildStore(Val: SplitRegs[i], Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign);
647 }
648
649 MI.eraseFromParent();
650 return true;
651}
652
653bool SPIRVLegalizerInfo::legalizeCustom(
654 LegalizerHelper &Helper, MachineInstr &MI,
655 LostDebugLocObserver &LocObserver) const {
656 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
657 switch (MI.getOpcode()) {
658 default:
659 // TODO: implement legalization for other opcodes.
660 return true;
661 case TargetOpcode::G_BITCAST:
662 return legalizeBitcast(Helper, MI);
663 case TargetOpcode::G_EXTRACT_VECTOR_ELT:
664 return legalizeExtractVectorElt(Helper, MI, GR);
665 case TargetOpcode::G_INSERT_VECTOR_ELT:
666 return legalizeInsertVectorElt(Helper, MI, GR);
667 case TargetOpcode::G_INTRINSIC:
668 case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
669 return legalizeIntrinsic(Helper, MI);
670 case TargetOpcode::G_IS_FPCLASS:
671 return legalizeIsFPClass(Helper, MI, LocObserver);
672 case TargetOpcode::G_ICMP: {
673 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
674 auto &Op0 = MI.getOperand(i: 2);
675 auto &Op1 = MI.getOperand(i: 3);
676 Register Reg0 = Op0.getReg();
677 Register Reg1 = Op1.getReg();
678 CmpInst::Predicate Cond =
679 static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
680 if ((!ST->canDirectlyComparePointers() ||
681 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
682 MRI.getType(Reg: Reg0).isPointer() && MRI.getType(Reg: Reg1).isPointer()) {
683 LLT ConvT = LLT::scalar(SizeInBits: ST->getPointerSize());
684 Type *LLVMTy = IntegerType::get(C&: MI.getMF()->getFunction().getContext(),
685 NumBits: ST->getPointerSize());
686 SPIRVTypeInst SpirvTy = GR->getOrCreateSPIRVType(
687 Type: LLVMTy, MIRBuilder&: Helper.MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
688 Op0.setReg(convertPtrToInt(Reg: Reg0, ConvTy: ConvT, SpvType: SpirvTy, Helper, MRI, GR));
689 Op1.setReg(convertPtrToInt(Reg: Reg1, ConvTy: ConvT, SpvType: SpirvTy, Helper, MRI, GR));
690 }
691 return true;
692 }
693 case TargetOpcode::G_LOAD:
694 return legalizeLoad(Helper, MI, GR);
695 case TargetOpcode::G_STORE:
696 return legalizeStore(Helper, MI, GR);
697 }
698}
699
700static MachineInstrBuilder
701createStackTemporaryForVector(LegalizerHelper &Helper, SPIRVGlobalRegistry *GR,
702 Register SrcReg, LLT SrcTy,
703 MachinePointerInfo &PtrInfo, Align &VecAlign) {
704 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
705 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
706
707 VecAlign = Helper.getStackTemporaryAlignment(Type: SrcTy);
708 auto StackTemp = Helper.createStackTemporary(
709 Bytes: TypeSize::getFixed(ExactSize: SrcTy.getSizeInBytes()), Alignment: VecAlign, PtrInfo);
710
711 // Set the type of StackTemp to a pointer to an array of the element type.
712 SPIRVTypeInst SpvSrcTy = GR->getSPIRVTypeForVReg(VReg: SrcReg);
713 SPIRVTypeInst EltSpvTy = GR->getScalarOrVectorComponentType(Type: SpvSrcTy);
714 const Type *LLVMEltTy = GR->getTypeForSPIRVType(Ty: EltSpvTy);
715 const Type *LLVMArrTy =
716 ArrayType::get(ElementType: const_cast<Type *>(LLVMEltTy), NumElements: SrcTy.getNumElements());
717 SPIRVTypeInst ArrSpvTy = GR->getOrCreateSPIRVType(
718 Type: LLVMArrTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
719 SPIRVTypeInst PtrToArrSpvTy = GR->getOrCreateSPIRVPointerType(
720 BaseType: ArrSpvTy, MIRBuilder, SC: SPIRV::StorageClass::Function);
721
722 Register StackReg = StackTemp.getReg(Idx: 0);
723 MRI.setRegClass(Reg: StackReg, RC: GR->getRegClass(SpvType: PtrToArrSpvTy));
724 GR->assignSPIRVTypeToVReg(Type: PtrToArrSpvTy, VReg: StackReg, MF: MIRBuilder.getMF());
725
726 return StackTemp;
727}
728
729static bool legalizeSpvBitcast(LegalizerHelper &Helper, MachineInstr &MI,
730 SPIRVGlobalRegistry *GR) {
731 LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
732 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
733 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
734 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
735
736 Register DstReg = MI.getOperand(i: 0).getReg();
737 Register SrcReg = MI.getOperand(i: 2).getReg();
738 LLT DstTy = MRI.getType(Reg: DstReg);
739 LLT SrcTy = MRI.getType(Reg: SrcReg);
740
741 // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
742 // allow using the generic legalization rules.
743 if (needsVectorLegalization(Ty: DstTy, ST) ||
744 needsVectorLegalization(Ty: SrcTy, ST)) {
745 LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
746 MIRBuilder.buildBitcast(Dst: DstReg, Src: SrcReg);
747 MI.eraseFromParent();
748 }
749 return true;
750}
751
752static bool legalizeSpvInsertElt(LegalizerHelper &Helper, MachineInstr &MI,
753 SPIRVGlobalRegistry *GR) {
754 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
755 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
756 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
757
758 Register DstReg = MI.getOperand(i: 0).getReg();
759 LLT DstTy = MRI.getType(Reg: DstReg);
760
761 if (needsVectorLegalization(Ty: DstTy, ST)) {
762 Register SrcReg = MI.getOperand(i: 2).getReg();
763 Register ValReg = MI.getOperand(i: 3).getReg();
764 LLT SrcTy = MRI.getType(Reg: SrcReg);
765 MachineOperand &IdxOperand = MI.getOperand(i: 4);
766
767 if (getImm(MO: IdxOperand, MRI: &MRI)) {
768 uint64_t IdxVal = foldImm(MO: IdxOperand, MRI: &MRI);
769 if (IdxVal < SrcTy.getNumElements()) {
770 SmallVector<Register, 8> Regs;
771 SPIRVTypeInst ElementType =
772 GR->getScalarOrVectorComponentType(Type: GR->getSPIRVTypeForVReg(VReg: DstReg));
773 LLT ElementLLTTy = GR->getRegType(SpvType: ElementType);
774 for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
775 Register Reg = MRI.createGenericVirtualRegister(Ty: ElementLLTTy);
776 MRI.setRegClass(Reg, RC: GR->getRegClass(SpvType: ElementType));
777 GR->assignSPIRVTypeToVReg(Type: ElementType, VReg: Reg, MF: *MI.getMF());
778 Regs.push_back(Elt: Reg);
779 }
780 MIRBuilder.buildUnmerge(Res: Regs, Op: SrcReg);
781 Regs[IdxVal] = ValReg;
782 MIRBuilder.buildBuildVector(Res: DstReg, Ops: Regs);
783 MI.eraseFromParent();
784 return true;
785 }
786 }
787
788 LLT EltTy = SrcTy.getElementType();
789 Align VecAlign;
790 MachinePointerInfo PtrInfo;
791 auto StackTemp = createStackTemporaryForVector(Helper, GR, SrcReg, SrcTy,
792 PtrInfo, VecAlign);
793
794 MIRBuilder.buildStore(Val: SrcReg, Addr: StackTemp, PtrInfo, Alignment: VecAlign);
795
796 Register IdxReg = IdxOperand.getReg();
797 LLT PtrTy = MRI.getType(Reg: StackTemp.getReg(Idx: 0));
798 Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy);
799 auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0);
800
801 MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr})
802 .addImm(Val: 1) // InBounds
803 .addUse(RegNo: StackTemp.getReg(Idx: 0))
804 .addUse(RegNo: Zero.getReg(Idx: 0))
805 .addUse(RegNo: IdxReg);
806
807 MachinePointerInfo EltPtrInfo = MachinePointerInfo(PtrTy.getAddressSpace());
808 Align EltAlign = Helper.getStackTemporaryAlignment(Type: EltTy);
809 MIRBuilder.buildStore(Val: ValReg, Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign);
810
811 MIRBuilder.buildLoad(Res: DstReg, Addr: StackTemp, PtrInfo, Alignment: VecAlign);
812 MI.eraseFromParent();
813 return true;
814 }
815 return true;
816}
817
818static bool legalizeSpvExtractElt(LegalizerHelper &Helper, MachineInstr &MI,
819 SPIRVGlobalRegistry *GR) {
820 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
821 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
822 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
823
824 Register SrcReg = MI.getOperand(i: 2).getReg();
825 LLT SrcTy = MRI.getType(Reg: SrcReg);
826
827 if (needsVectorLegalization(Ty: SrcTy, ST)) {
828 Register DstReg = MI.getOperand(i: 0).getReg();
829 MachineOperand &IdxOperand = MI.getOperand(i: 3);
830
831 if (getImm(MO: IdxOperand, MRI: &MRI)) {
832 uint64_t IdxVal = foldImm(MO: IdxOperand, MRI: &MRI);
833 if (IdxVal < SrcTy.getNumElements()) {
834 LLT DstTy = MRI.getType(Reg: DstReg);
835 SmallVector<Register, 8> Regs;
836 SPIRVTypeInst DstSpvTy = GR->getSPIRVTypeForVReg(VReg: DstReg);
837 for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
838 if (I == IdxVal) {
839 Regs.push_back(Elt: DstReg);
840 } else {
841 Register Reg = MRI.createGenericVirtualRegister(Ty: DstTy);
842 MRI.setRegClass(Reg, RC: GR->getRegClass(SpvType: DstSpvTy));
843 GR->assignSPIRVTypeToVReg(Type: DstSpvTy, VReg: Reg, MF: *MI.getMF());
844 Regs.push_back(Elt: Reg);
845 }
846 }
847 MIRBuilder.buildUnmerge(Res: Regs, Op: SrcReg);
848 MI.eraseFromParent();
849 return true;
850 }
851 }
852
853 LLT EltTy = SrcTy.getElementType();
854 Align VecAlign;
855 MachinePointerInfo PtrInfo;
856 auto StackTemp = createStackTemporaryForVector(Helper, GR, SrcReg, SrcTy,
857 PtrInfo, VecAlign);
858
859 MIRBuilder.buildStore(Val: SrcReg, Addr: StackTemp, PtrInfo, Alignment: VecAlign);
860
861 Register IdxReg = IdxOperand.getReg();
862 LLT PtrTy = MRI.getType(Reg: StackTemp.getReg(Idx: 0));
863 Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy);
864 auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0);
865
866 MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr})
867 .addImm(Val: 1) // InBounds
868 .addUse(RegNo: StackTemp.getReg(Idx: 0))
869 .addUse(RegNo: Zero.getReg(Idx: 0))
870 .addUse(RegNo: IdxReg);
871
872 MachinePointerInfo EltPtrInfo = MachinePointerInfo(PtrTy.getAddressSpace());
873 Align EltAlign = Helper.getStackTemporaryAlignment(Type: EltTy);
874 MIRBuilder.buildLoad(Res: DstReg, Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign);
875
876 MI.eraseFromParent();
877 return true;
878 }
879 return true;
880}
881
882static bool legalizeSpvConstComposite(LegalizerHelper &Helper, MachineInstr &MI,
883 SPIRVGlobalRegistry *GR) {
884 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
885 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
886 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
887
888 Register DstReg = MI.getOperand(i: 0).getReg();
889 LLT DstTy = MRI.getType(Reg: DstReg);
890
891 if (!needsVectorLegalization(Ty: DstTy, ST))
892 return true;
893
894 SmallVector<Register, 8> SrcRegs;
895 if (MI.getNumOperands() == 2) {
896 // The "null" case: no values are attached.
897 LLT EltTy = DstTy.getElementType();
898 auto Zero = MIRBuilder.buildConstant(Res: EltTy, Val: 0);
899 SPIRVTypeInst SpvDstTy = GR->getSPIRVTypeForVReg(VReg: DstReg);
900 SPIRVTypeInst SpvEltTy = GR->getScalarOrVectorComponentType(Type: SpvDstTy);
901 GR->assignSPIRVTypeToVReg(Type: SpvEltTy, VReg: Zero.getReg(Idx: 0), MF: MIRBuilder.getMF());
902 for (unsigned i = 0; i < DstTy.getNumElements(); ++i)
903 SrcRegs.push_back(Elt: Zero.getReg(Idx: 0));
904 } else {
905 for (unsigned i = 2; i < MI.getNumOperands(); ++i) {
906 SrcRegs.push_back(Elt: MI.getOperand(i).getReg());
907 }
908 }
909 MIRBuilder.buildBuildVector(Res: DstReg, Ops: SrcRegs);
910 MI.eraseFromParent();
911 return true;
912}
913
914bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
915 MachineInstr &MI) const {
916 LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
917 auto IntrinsicID = cast<GIntrinsic>(Val&: MI).getIntrinsicID();
918 switch (IntrinsicID) {
919 case Intrinsic::spv_bitcast:
920 return legalizeSpvBitcast(Helper, MI, GR);
921 case Intrinsic::spv_insertelt:
922 return legalizeSpvInsertElt(Helper, MI, GR);
923 case Intrinsic::spv_extractelt:
924 return legalizeSpvExtractElt(Helper, MI, GR);
925 case Intrinsic::spv_const_composite:
926 return legalizeSpvConstComposite(Helper, MI, GR);
927 }
928 return true;
929}
930
931bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
932 MachineInstr &MI) const {
933 // Once the G_BITCAST is using vectors that are allowed, we turn it back into
934 // an spv_bitcast to avoid verifier problems when the register types are the
935 // same for the source and the result. Note that the SPIR-V types associated
936 // with the bitcast can be different even if the register types are the same.
937 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
938 Register DstReg = MI.getOperand(i: 0).getReg();
939 Register SrcReg = MI.getOperand(i: 1).getReg();
940 SmallVector<Register, 1> DstRegs = {DstReg};
941 MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_bitcast, Res: DstRegs).addUse(RegNo: SrcReg);
942 MI.eraseFromParent();
943 return true;
944}
945
946// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
947// to ensure that all instructions created during the lowering have SPIR-V types
948// assigned to them.
949bool SPIRVLegalizerInfo::legalizeIsFPClass(
950 LegalizerHelper &Helper, MachineInstr &MI,
951 LostDebugLocObserver &LocObserver) const {
952 auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
953 FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(i: 2).getImm());
954
955 auto &MIRBuilder = Helper.MIRBuilder;
956 auto &MF = MIRBuilder.getMF();
957 MachineRegisterInfo &MRI = MF.getRegInfo();
958
959 Type *LLVMDstTy =
960 IntegerType::get(C&: MIRBuilder.getContext(), NumBits: DstTy.getScalarSizeInBits());
961 if (DstTy.isVector())
962 LLVMDstTy = VectorType::get(ElementType: LLVMDstTy, EC: DstTy.getElementCount());
963 SPIRVTypeInst SPIRVDstTy = GR->getOrCreateSPIRVType(
964 Type: LLVMDstTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite,
965 /*EmitIR*/ true);
966
967 unsigned BitSize = SrcTy.getScalarSizeInBits();
968 const fltSemantics &Semantics = getFltSemanticForLLT(Ty: SrcTy.getScalarType());
969
970 LLT IntTy = LLT::scalar(SizeInBits: BitSize);
971 Type *LLVMIntTy = IntegerType::get(C&: MIRBuilder.getContext(), NumBits: BitSize);
972 if (SrcTy.isVector()) {
973 IntTy = LLT::vector(EC: SrcTy.getElementCount(), ScalarTy: IntTy);
974 LLVMIntTy = VectorType::get(ElementType: LLVMIntTy, EC: SrcTy.getElementCount());
975 }
976 SPIRVTypeInst SPIRVIntTy = GR->getOrCreateSPIRVType(
977 Type: LLVMIntTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite,
978 /*EmitIR*/ true);
979
980 // Clang doesn't support capture of structured bindings:
981 LLT DstTyCopy = DstTy;
982 const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) {
983 // Assign this MI's (assumed only) destination to one of the two types we
984 // expect: either the G_IS_FPCLASS's destination type, or the integer type
985 // bitcast from the source type.
986 LLT MITy = MRI.getType(Reg: MI.getReg(Idx: 0));
987 assert((MITy == IntTy || MITy == DstTyCopy) &&
988 "Unexpected LLT type while lowering G_IS_FPCLASS");
989 SPIRVTypeInst SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy;
990 GR->assignSPIRVTypeToVReg(Type: SPVTy, VReg: MI.getReg(Idx: 0), MF);
991 return MI;
992 };
993
994 // Helper to build and assign a constant in one go
995 const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder {
996 if (!Ty.isFixedVector())
997 return assignSPIRVTy(MIRBuilder.buildConstant(Ty, C));
998 auto ScalarC = MIRBuilder.buildConstant(Ty.getScalarType(), C);
999 assert((Ty == IntTy || Ty == DstTyCopy) &&
1000 "Unexpected LLT type while lowering constant for G_IS_FPCLASS");
1001 SPIRVTypeInst VecEltTy = GR->getOrCreateSPIRVType(
1002 Type: (Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType(), MIRBuilder,
1003 AQ: SPIRV::AccessQualifier::ReadWrite,
1004 /*EmitIR*/ true);
1005 GR->assignSPIRVTypeToVReg(Type: VecEltTy, VReg: ScalarC.getReg(0), MF);
1006 return assignSPIRVTy(MIRBuilder.buildSplatBuildVector(Res: Ty, Src: ScalarC));
1007 };
1008
1009 if (Mask == fcNone) {
1010 MIRBuilder.buildCopy(Res: DstReg, Op: buildSPIRVConstant(DstTy, 0));
1011 MI.eraseFromParent();
1012 return true;
1013 }
1014 if (Mask == fcAllFlags) {
1015 MIRBuilder.buildCopy(Res: DstReg, Op: buildSPIRVConstant(DstTy, 1));
1016 MI.eraseFromParent();
1017 return true;
1018 }
1019
1020 // Note that rather than creating a COPY here (between a floating-point and
1021 // integer type of the same size) we create a SPIR-V bitcast immediately. We
1022 // can't create a G_BITCAST because the LLTs are the same, and we can't seem
1023 // to correctly lower COPYs to SPIR-V bitcasts at this moment.
1024 Register ResVReg = MRI.createGenericVirtualRegister(Ty: IntTy);
1025 MRI.setRegClass(Reg: ResVReg, RC: GR->getRegClass(SpvType: SPIRVIntTy));
1026 GR->assignSPIRVTypeToVReg(Type: SPIRVIntTy, VReg: ResVReg, MF: Helper.MIRBuilder.getMF());
1027 auto AsInt = MIRBuilder.buildInstr(Opcode: SPIRV::OpBitcast)
1028 .addDef(RegNo: ResVReg)
1029 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: SPIRVIntTy))
1030 .addUse(RegNo: SrcReg);
1031 AsInt = assignSPIRVTy(std::move(AsInt));
1032
1033 // Various masks.
1034 APInt SignBit = APInt::getSignMask(BitWidth: BitSize);
1035 APInt ValueMask = APInt::getSignedMaxValue(numBits: BitSize); // All bits but sign.
1036 APInt Inf = APFloat::getInf(Sem: Semantics).bitcastToAPInt(); // Exp and int bit.
1037 APInt ExpMask = Inf;
1038 APInt AllOneMantissa = APFloat::getLargest(Sem: Semantics).bitcastToAPInt() & ~Inf;
1039 APInt QNaNBitMask =
1040 APInt::getOneBitSet(numBits: BitSize, BitNo: AllOneMantissa.getActiveBits() - 1);
1041 APInt InversionMask = APInt::getAllOnes(numBits: DstTy.getScalarSizeInBits());
1042
1043 auto SignBitC = buildSPIRVConstant(IntTy, SignBit);
1044 auto ValueMaskC = buildSPIRVConstant(IntTy, ValueMask);
1045 auto InfC = buildSPIRVConstant(IntTy, Inf);
1046 auto ExpMaskC = buildSPIRVConstant(IntTy, ExpMask);
1047 auto ZeroC = buildSPIRVConstant(IntTy, 0);
1048
1049 auto Abs = assignSPIRVTy(MIRBuilder.buildAnd(Dst: IntTy, Src0: AsInt, Src1: ValueMaskC));
1050 auto Sign = assignSPIRVTy(
1051 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_NE, Res: DstTy, Op0: AsInt, Op1: Abs));
1052
1053 auto Res = buildSPIRVConstant(DstTy, 0);
1054
1055 const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) {
1056 Res = assignSPIRVTy(
1057 MIRBuilder.buildOr(Dst: DstTyCopy, Src0: Res, Src1: assignSPIRVTy(std::move(ToAppend))));
1058 };
1059
1060 // Tests that involve more than one class should be processed first.
1061 if ((Mask & fcFinite) == fcFinite) {
1062 // finite(V) ==> abs(V) u< exp_mask
1063 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: Abs,
1064 Op1: ExpMaskC));
1065 Mask &= ~fcFinite;
1066 } else if ((Mask & fcFinite) == fcPosFinite) {
1067 // finite(V) && V > 0 ==> V u< exp_mask
1068 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: AsInt,
1069 Op1: ExpMaskC));
1070 Mask &= ~fcPosFinite;
1071 } else if ((Mask & fcFinite) == fcNegFinite) {
1072 // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
1073 auto Cmp = assignSPIRVTy(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT,
1074 Res: DstTy, Op0: Abs, Op1: ExpMaskC));
1075 appendToRes(MIRBuilder.buildAnd(Dst: DstTy, Src0: Cmp, Src1: Sign));
1076 Mask &= ~fcNegFinite;
1077 }
1078
1079 if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
1080 // fcZero | fcSubnormal => test all exponent bits are 0
1081 // TODO: Handle sign bit specific cases
1082 // TODO: Handle inverted case
1083 if (PartialCheck == (fcZero | fcSubnormal)) {
1084 auto ExpBits = assignSPIRVTy(MIRBuilder.buildAnd(Dst: IntTy, Src0: AsInt, Src1: ExpMaskC));
1085 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy,
1086 Op0: ExpBits, Op1: ZeroC));
1087 Mask &= ~PartialCheck;
1088 }
1089 }
1090
1091 // Check for individual classes.
1092 if (FPClassTest PartialCheck = Mask & fcZero) {
1093 if (PartialCheck == fcPosZero)
1094 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy,
1095 Op0: AsInt, Op1: ZeroC));
1096 else if (PartialCheck == fcZero)
1097 appendToRes(
1098 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy, Op0: Abs, Op1: ZeroC));
1099 else // fcNegZero
1100 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy,
1101 Op0: AsInt, Op1: SignBitC));
1102 }
1103
1104 if (FPClassTest PartialCheck = Mask & fcSubnormal) {
1105 // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
1106 // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
1107 auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
1108 auto OneC = buildSPIRVConstant(IntTy, 1);
1109 auto VMinusOne = MIRBuilder.buildSub(Dst: IntTy, Src0: V, Src1: OneC);
1110 auto SubnormalRes = assignSPIRVTy(
1111 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: VMinusOne,
1112 Op1: buildSPIRVConstant(IntTy, AllOneMantissa)));
1113 if (PartialCheck == fcNegSubnormal)
1114 SubnormalRes = MIRBuilder.buildAnd(Dst: DstTy, Src0: SubnormalRes, Src1: Sign);
1115 appendToRes(std::move(SubnormalRes));
1116 }
1117
1118 if (FPClassTest PartialCheck = Mask & fcInf) {
1119 if (PartialCheck == fcPosInf)
1120 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy,
1121 Op0: AsInt, Op1: InfC));
1122 else if (PartialCheck == fcInf)
1123 appendToRes(
1124 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy, Op0: Abs, Op1: InfC));
1125 else { // fcNegInf
1126 APInt NegInf = APFloat::getInf(Sem: Semantics, Negative: true).bitcastToAPInt();
1127 auto NegInfC = buildSPIRVConstant(IntTy, NegInf);
1128 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy,
1129 Op0: AsInt, Op1: NegInfC));
1130 }
1131 }
1132
1133 if (FPClassTest PartialCheck = Mask & fcNan) {
1134 auto InfWithQnanBitC =
1135 buildSPIRVConstant(IntTy, std::move(Inf) | QNaNBitMask);
1136 if (PartialCheck == fcNan) {
1137 // isnan(V) ==> abs(V) u> int(inf)
1138 appendToRes(
1139 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_UGT, Res: DstTy, Op0: Abs, Op1: InfC));
1140 } else if (PartialCheck == fcQNan) {
1141 // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
1142 appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_UGE, Res: DstTy, Op0: Abs,
1143 Op1: InfWithQnanBitC));
1144 } else { // fcSNan
1145 // issignaling(V) ==> abs(V) u> unsigned(Inf) &&
1146 // abs(V) u< (unsigned(Inf) | quiet_bit)
1147 auto IsNan = assignSPIRVTy(
1148 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_UGT, Res: DstTy, Op0: Abs, Op1: InfC));
1149 auto IsNotQnan = assignSPIRVTy(MIRBuilder.buildICmp(
1150 Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: Abs, Op1: InfWithQnanBitC));
1151 appendToRes(MIRBuilder.buildAnd(Dst: DstTy, Src0: IsNan, Src1: IsNotQnan));
1152 }
1153 }
1154
1155 if (FPClassTest PartialCheck = Mask & fcNormal) {
1156 // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
1157 // (max_exp-1))
1158 APInt ExpLSB = ExpMask & ~(ExpMask.shl(shiftAmt: 1));
1159 auto ExpMinusOne = assignSPIRVTy(
1160 MIRBuilder.buildSub(Dst: IntTy, Src0: Abs, Src1: buildSPIRVConstant(IntTy, ExpLSB)));
1161 APInt MaxExpMinusOne = std::move(ExpMask) - ExpLSB;
1162 auto NormalRes = assignSPIRVTy(
1163 MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: ExpMinusOne,
1164 Op1: buildSPIRVConstant(IntTy, MaxExpMinusOne)));
1165 if (PartialCheck == fcNegNormal)
1166 NormalRes = MIRBuilder.buildAnd(Dst: DstTy, Src0: NormalRes, Src1: Sign);
1167 else if (PartialCheck == fcPosNormal) {
1168 auto PosSign = assignSPIRVTy(MIRBuilder.buildXor(
1169 Dst: DstTy, Src0: Sign, Src1: buildSPIRVConstant(DstTy, InversionMask)));
1170 NormalRes = MIRBuilder.buildAnd(Dst: DstTy, Src0: NormalRes, Src1: PosSign);
1171 }
1172 appendToRes(std::move(NormalRes));
1173 }
1174
1175 MIRBuilder.buildCopy(Res: DstReg, Op: Res);
1176 MI.eraseFromParent();
1177 return true;
1178}
1179