| 1 | //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the targeting of the Machinelegalizer class for SPIR-V. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "SPIRVLegalizerInfo.h" |
| 14 | #include "SPIRV.h" |
| 15 | #include "SPIRVGlobalRegistry.h" |
| 16 | #include "SPIRVSubtarget.h" |
| 17 | #include "SPIRVUtils.h" |
| 18 | #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" |
| 19 | #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" |
| 20 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
| 21 | #include "llvm/CodeGen/MachineInstr.h" |
| 22 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
| 23 | #include "llvm/CodeGen/TargetOpcodes.h" |
| 24 | #include "llvm/IR/IntrinsicsSPIRV.h" |
| 25 | #include "llvm/Support/Debug.h" |
| 26 | #include "llvm/Support/MathExtras.h" |
| 27 | |
| 28 | using namespace llvm; |
| 29 | using namespace llvm::LegalizeActions; |
| 30 | using namespace llvm::LegalityPredicates; |
| 31 | |
| 32 | #define DEBUG_TYPE "spirv-legalizer" |
| 33 | |
| 34 | LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) { |
| 35 | return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) { |
| 36 | const LLT Ty = Query.Types[TypeIdx]; |
| 37 | return IsExtendedInts && Ty.isValid() && Ty.isScalar(); |
| 38 | }; |
| 39 | } |
| 40 | |
| 41 | SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { |
| 42 | using namespace TargetOpcode; |
| 43 | |
| 44 | this->ST = &ST; |
| 45 | GR = ST.getSPIRVGlobalRegistry(); |
| 46 | |
| 47 | const LLT s1 = LLT::scalar(SizeInBits: 1); |
| 48 | const LLT s8 = LLT::scalar(SizeInBits: 8); |
| 49 | const LLT s16 = LLT::scalar(SizeInBits: 16); |
| 50 | const LLT s32 = LLT::scalar(SizeInBits: 32); |
| 51 | const LLT s64 = LLT::scalar(SizeInBits: 64); |
| 52 | const LLT s128 = LLT::scalar(SizeInBits: 128); |
| 53 | |
| 54 | const LLT v16s64 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 64); |
| 55 | const LLT v16s32 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 32); |
| 56 | const LLT v16s16 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 16); |
| 57 | const LLT v16s8 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8); |
| 58 | const LLT v16s1 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 1); |
| 59 | |
| 60 | const LLT v8s64 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 64); |
| 61 | const LLT v8s32 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 32); |
| 62 | const LLT v8s16 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 16); |
| 63 | const LLT v8s8 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 8); |
| 64 | const LLT v8s1 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 1); |
| 65 | |
| 66 | const LLT v4s64 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 64); |
| 67 | const LLT v4s32 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32); |
| 68 | const LLT v4s16 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 16); |
| 69 | const LLT v4s8 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 8); |
| 70 | const LLT v4s1 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 1); |
| 71 | |
| 72 | const LLT v3s64 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 64); |
| 73 | const LLT v3s32 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 32); |
| 74 | const LLT v3s16 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 16); |
| 75 | const LLT v3s8 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 8); |
| 76 | const LLT v3s1 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 1); |
| 77 | |
| 78 | const LLT v2s64 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 64); |
| 79 | const LLT v2s32 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 32); |
| 80 | const LLT v2s16 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 16); |
| 81 | const LLT v2s8 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 8); |
| 82 | const LLT v2s1 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 1); |
| 83 | |
| 84 | const unsigned PSize = ST.getPointerSize(); |
| 85 | const LLT p0 = LLT::pointer(AddressSpace: 0, SizeInBits: PSize); // Function |
| 86 | const LLT p1 = LLT::pointer(AddressSpace: 1, SizeInBits: PSize); // CrossWorkgroup |
| 87 | const LLT p2 = LLT::pointer(AddressSpace: 2, SizeInBits: PSize); // UniformConstant |
| 88 | const LLT p3 = LLT::pointer(AddressSpace: 3, SizeInBits: PSize); // Workgroup |
| 89 | const LLT p4 = LLT::pointer(AddressSpace: 4, SizeInBits: PSize); // Generic |
| 90 | const LLT p5 = |
| 91 | LLT::pointer(AddressSpace: 5, SizeInBits: PSize); // Input, SPV_INTEL_usm_storage_classes (Device) |
| 92 | const LLT p6 = LLT::pointer(AddressSpace: 6, SizeInBits: PSize); // SPV_INTEL_usm_storage_classes (Host) |
| 93 | const LLT p7 = LLT::pointer(AddressSpace: 7, SizeInBits: PSize); // Input |
| 94 | const LLT p8 = LLT::pointer(AddressSpace: 8, SizeInBits: PSize); // Output |
| 95 | const LLT p9 = |
| 96 | LLT::pointer(AddressSpace: 9, SizeInBits: PSize); // CodeSectionINTEL, SPV_INTEL_function_pointers |
| 97 | const LLT p10 = LLT::pointer(AddressSpace: 10, SizeInBits: PSize); // Private |
| 98 | const LLT p11 = LLT::pointer(AddressSpace: 11, SizeInBits: PSize); // StorageBuffer |
| 99 | const LLT p12 = LLT::pointer(AddressSpace: 12, SizeInBits: PSize); // Uniform |
| 100 | const LLT p13 = LLT::pointer(AddressSpace: 13, SizeInBits: PSize); // PushConstant |
| 101 | |
| 102 | // TODO: remove copy-pasting here by using concatenation in some way. |
| 103 | auto allPtrsScalarsAndVectors = { |
| 104 | p0, p1, p2, p3, p4, p5, p6, p7, p8, |
| 105 | p9, p10, p11, p12, p13, s1, s8, s16, s32, |
| 106 | s64, s128, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, |
| 107 | v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, |
| 108 | v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; |
| 109 | |
| 110 | auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, |
| 111 | v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, |
| 112 | v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, |
| 113 | v16s8, v16s16, v16s32, v16s64}; |
| 114 | |
| 115 | auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, |
| 116 | v3s1, v3s8, v3s16, v3s32, v3s64, |
| 117 | v4s1, v4s8, v4s16, v4s32, v4s64}; |
| 118 | |
| 119 | auto allScalars = {s1, s8, s16, s32, s64}; |
| 120 | |
| 121 | auto allScalarsAndVectors = { |
| 122 | s1, s8, s16, s32, s64, s128, v2s1, v2s8, |
| 123 | v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, |
| 124 | v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, |
| 125 | v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; |
| 126 | |
| 127 | auto allIntScalarsAndVectors = { |
| 128 | s8, s16, s32, s64, s128, v2s8, v2s16, v2s32, v2s64, |
| 129 | v3s8, v3s16, v3s32, v3s64, v4s8, v4s16, v4s32, v4s64, v8s8, |
| 130 | v8s16, v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; |
| 131 | |
| 132 | auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; |
| 133 | |
| 134 | auto allIntScalars = {s8, s16, s32, s64, s128}; |
| 135 | |
| 136 | auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16}; |
| 137 | |
| 138 | auto allFloatScalars = {s16, s32, s64}; |
| 139 | |
| 140 | auto allFloatScalarsAndVectors = { |
| 141 | s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, |
| 142 | v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; |
| 143 | |
| 144 | auto allShaderFloatVectors = {v2s16, v2s32, v2s64, v3s16, v3s32, |
| 145 | v3s64, v4s16, v4s32, v4s64}; |
| 146 | |
| 147 | auto allFloatVectors = {v2s16, v2s32, v2s64, v3s16, v3s32, |
| 148 | v3s64, v4s16, v4s32, v4s64, v8s16, |
| 149 | v8s32, v8s64, v16s16, v16s32, v16s64}; |
| 150 | |
| 151 | auto &allowedFloatVectorTypes = |
| 152 | ST.isShader() ? allShaderFloatVectors : allFloatVectors; |
| 153 | |
| 154 | auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, |
| 155 | p2, p3, p4, p5, p6, p7, |
| 156 | p8, p9, p10, p11, p12, p13}; |
| 157 | |
| 158 | auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13}; |
| 159 | |
| 160 | auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors; |
| 161 | |
| 162 | bool IsExtendedInts = |
| 163 | ST.canUseExtension( |
| 164 | E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers) || |
| 165 | ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bit_instructions) || |
| 166 | ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_int4); |
| 167 | auto extendedScalarsAndVectors = |
| 168 | [IsExtendedInts](const LegalityQuery &Query) { |
| 169 | const LLT Ty = Query.Types[0]; |
| 170 | return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector(); |
| 171 | }; |
| 172 | auto extendedScalarsAndVectorsProduct = [IsExtendedInts]( |
| 173 | const LegalityQuery &Query) { |
| 174 | const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1]; |
| 175 | return IsExtendedInts && Ty1.isValid() && Ty2.isValid() && |
| 176 | !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector(); |
| 177 | }; |
| 178 | auto extendedPtrsScalarsAndVectors = |
| 179 | [IsExtendedInts](const LegalityQuery &Query) { |
| 180 | const LLT Ty = Query.Types[0]; |
| 181 | return IsExtendedInts && Ty.isValid(); |
| 182 | }; |
| 183 | |
| 184 | // The universal validation rules in the SPIR-V specification state that |
| 185 | // vector sizes are typically limited to 2, 3, or 4. However, larger vector |
| 186 | // sizes (8 and 16) are enabled when the Kernel capability is present. For |
| 187 | // shader execution models, vector sizes are strictly limited to 4. In |
| 188 | // non-shader contexts, vector sizes of 8 and 16 are also permitted, but |
| 189 | // arbitrary sizes (e.g., 6 or 11) are not. |
| 190 | uint32_t MaxVectorSize = ST.isShader() ? 4 : 16; |
| 191 | LLVM_DEBUG(dbgs() << "MaxVectorSize: " << MaxVectorSize << "\n" ); |
| 192 | |
| 193 | for (auto Opc : getTypeFoldingSupportedOpcodes()) { |
| 194 | switch (Opc) { |
| 195 | case G_EXTRACT_VECTOR_ELT: |
| 196 | case G_UREM: |
| 197 | case G_SREM: |
| 198 | case G_UDIV: |
| 199 | case G_SDIV: |
| 200 | case G_FREM: |
| 201 | break; |
| 202 | default: |
| 203 | getActionDefinitionsBuilder(Opcode: Opc) |
| 204 | .customFor(Types: allScalars) |
| 205 | .customFor(Types: allowedVectorTypes) |
| 206 | .moreElementsToNextPow2(TypeIdx: 0) |
| 207 | .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize), |
| 208 | Mutation: LegalizeMutations::changeElementCountTo( |
| 209 | TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize))) |
| 210 | .custom(); |
| 211 | break; |
| 212 | } |
| 213 | } |
| 214 | |
| 215 | getActionDefinitionsBuilder(Opcodes: {G_UREM, G_SREM, G_SDIV, G_UDIV, G_FREM}) |
| 216 | .customFor(Types: allScalars) |
| 217 | .customFor(Types: allowedVectorTypes) |
| 218 | .scalarizeIf(Predicate: numElementsNotPow2(TypeIdx: 0), TypeIdx: 0) |
| 219 | .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize), |
| 220 | Mutation: LegalizeMutations::changeElementCountTo( |
| 221 | TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize))) |
| 222 | .custom(); |
| 223 | |
| 224 | getActionDefinitionsBuilder(Opcodes: {G_FMA, G_STRICT_FMA}) |
| 225 | .legalFor(Types: allScalars) |
| 226 | .legalFor(Types: allowedVectorTypes) |
| 227 | .moreElementsToNextPow2(TypeIdx: 0) |
| 228 | .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize), |
| 229 | Mutation: LegalizeMutations::changeElementCountTo( |
| 230 | TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize))) |
| 231 | .alwaysLegal(); |
| 232 | |
| 233 | getActionDefinitionsBuilder(Opcode: G_INTRINSIC_W_SIDE_EFFECTS).custom(); |
| 234 | |
| 235 | getActionDefinitionsBuilder(Opcode: G_SHUFFLE_VECTOR) |
| 236 | .legalForCartesianProduct(Types0: allowedVectorTypes, Types1: allowedVectorTypes) |
| 237 | .moreElementsToNextPow2(TypeIdx: 0) |
| 238 | .lowerIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize)) |
| 239 | .moreElementsToNextPow2(TypeIdx: 1) |
| 240 | .lowerIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 1, Size: MaxVectorSize)); |
| 241 | |
| 242 | getActionDefinitionsBuilder(Opcode: G_EXTRACT_VECTOR_ELT) |
| 243 | .moreElementsToNextPow2(TypeIdx: 1) |
| 244 | .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 1, Size: MaxVectorSize), |
| 245 | Mutation: LegalizeMutations::changeElementCountTo( |
| 246 | TypeIdx: 1, EC: ElementCount::getFixed(MinVal: MaxVectorSize))) |
| 247 | .custom(); |
| 248 | |
| 249 | getActionDefinitionsBuilder(Opcode: G_INSERT_VECTOR_ELT) |
| 250 | .moreElementsToNextPow2(TypeIdx: 0) |
| 251 | .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize), |
| 252 | Mutation: LegalizeMutations::changeElementCountTo( |
| 253 | TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize))) |
| 254 | .custom(); |
| 255 | |
| 256 | // Illegal G_UNMERGE_VALUES instructions should be handled |
| 257 | // during the combine phase. |
| 258 | getActionDefinitionsBuilder(Opcode: G_BUILD_VECTOR) |
| 259 | .legalIf(Predicate: vectorElementCountIsLessThanOrEqualTo(TypeIdx: 0, Size: MaxVectorSize)); |
| 260 | |
| 261 | // When entering the legalizer, there should be no G_BITCAST instructions. |
| 262 | // They should all be calls to the `spv_bitcast` intrinsic. The call to |
| 263 | // the intrinsic will be converted to a G_BITCAST during legalization if |
| 264 | // the vectors are not legal. After using the rules to legalize a G_BITCAST, |
| 265 | // we turn it back into a call to the intrinsic with a custom rule to avoid |
| 266 | // potential machine verifier failures. |
| 267 | getActionDefinitionsBuilder(Opcode: G_BITCAST) |
| 268 | .moreElementsToNextPow2(TypeIdx: 0) |
| 269 | .moreElementsToNextPow2(TypeIdx: 1) |
| 270 | .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize), |
| 271 | Mutation: LegalizeMutations::changeElementCountTo( |
| 272 | TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize))) |
| 273 | .lowerIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 1, Size: MaxVectorSize)) |
| 274 | .custom(); |
| 275 | |
| 276 | // If the result is still illegal, the combiner should be able to remove it. |
| 277 | getActionDefinitionsBuilder(Opcode: G_CONCAT_VECTORS) |
| 278 | .legalForCartesianProduct(Types0: allowedVectorTypes, Types1: allowedVectorTypes); |
| 279 | |
| 280 | getActionDefinitionsBuilder(Opcode: G_SPLAT_VECTOR) |
| 281 | .legalFor(Types: allowedVectorTypes) |
| 282 | .moreElementsToNextPow2(TypeIdx: 0) |
| 283 | .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize), |
| 284 | Mutation: LegalizeMutations::changeElementSizeTo(TypeIdx: 0, FromTypeIdx: MaxVectorSize)) |
| 285 | .alwaysLegal(); |
| 286 | |
| 287 | // Vector Reduction Operations |
| 288 | getActionDefinitionsBuilder( |
| 289 | Opcodes: {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX, |
| 290 | G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN, |
| 291 | G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM, |
| 292 | G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR}) |
| 293 | .legalFor(Types: allowedVectorTypes) |
| 294 | .scalarize(TypeIdx: 1) |
| 295 | .lower(); |
| 296 | |
| 297 | getActionDefinitionsBuilder(Opcodes: {G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL}) |
| 298 | .scalarize(TypeIdx: 2) |
| 299 | .lower(); |
| 300 | |
| 301 | // Illegal G_UNMERGE_VALUES instructions should be handled |
| 302 | // during the combine phase. |
| 303 | getActionDefinitionsBuilder(Opcode: G_UNMERGE_VALUES) |
| 304 | .legalIf(Predicate: vectorElementCountIsLessThanOrEqualTo(TypeIdx: 1, Size: MaxVectorSize)); |
| 305 | |
| 306 | getActionDefinitionsBuilder(Opcodes: {G_MEMCPY, G_MEMCPY_INLINE, G_MEMMOVE}) |
| 307 | .unsupportedIf(Predicate: LegalityPredicates::any(P0: typeIs(TypeIdx: 0, TypesInit: p9), P1: typeIs(TypeIdx: 1, TypesInit: p9))) |
| 308 | .legalIf(Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs))); |
| 309 | |
| 310 | getActionDefinitionsBuilder(Opcodes: {G_MEMSET, G_MEMSET_INLINE}) |
| 311 | .unsupportedIf(Predicate: typeIs(TypeIdx: 0, TypesInit: p9)) |
| 312 | .legalIf(Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeInSet(TypeIdx: 1, TypesInit: allIntScalars))); |
| 313 | |
| 314 | getActionDefinitionsBuilder(Opcode: G_ADDRSPACE_CAST) |
| 315 | .legalForCartesianProduct(Types0: allPtrs, Types1: allPtrs); |
| 316 | |
| 317 | // Should we be legalizing bad scalar sizes like s5 here instead |
| 318 | // of handling them in the instruction selector? |
| 319 | getActionDefinitionsBuilder(Opcodes: {G_LOAD, G_STORE}) |
| 320 | .unsupportedIf(Predicate: typeIs(TypeIdx: 1, TypesInit: p9)) |
| 321 | .legalForCartesianProduct(Types0: allowedVectorTypes, Types1: allPtrs) |
| 322 | .legalForCartesianProduct(Types0: allPtrs, Types1: allPtrs) |
| 323 | .legalIf(Predicate: isScalar(TypeIdx: 0)) |
| 324 | .custom(); |
| 325 | |
| 326 | getActionDefinitionsBuilder(Opcodes: {G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS, |
| 327 | G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT, |
| 328 | G_USUBSAT, G_SCMP, G_UCMP}) |
| 329 | .legalFor(Types: allIntScalarsAndVectors) |
| 330 | .legalIf(Predicate: extendedScalarsAndVectors); |
| 331 | |
| 332 | getActionDefinitionsBuilder(Opcode: G_STRICT_FLDEXP) |
| 333 | .legalForCartesianProduct(Types0: allFloatScalarsAndVectors, Types1: allIntScalars); |
| 334 | |
| 335 | getActionDefinitionsBuilder(Opcodes: {G_FPTOSI, G_FPTOUI}) |
| 336 | .legalForCartesianProduct(Types0: allIntScalarsAndVectors, |
| 337 | Types1: allFloatScalarsAndVectors); |
| 338 | |
| 339 | getActionDefinitionsBuilder(Opcodes: {G_FPTOSI_SAT, G_FPTOUI_SAT}) |
| 340 | .legalForCartesianProduct(Types0: allIntScalarsAndVectors, |
| 341 | Types1: allFloatScalarsAndVectors); |
| 342 | |
| 343 | getActionDefinitionsBuilder(Opcodes: {G_SITOFP, G_UITOFP}) |
| 344 | .legalForCartesianProduct(Types0: allFloatScalarsAndVectors, |
| 345 | Types1: allScalarsAndVectors); |
| 346 | |
| 347 | getActionDefinitionsBuilder(Opcode: G_CTPOP) |
| 348 | .legalForCartesianProduct(Types: allIntScalarsAndVectors) |
| 349 | .legalIf(Predicate: extendedScalarsAndVectorsProduct); |
| 350 | |
| 351 | // Extensions. |
| 352 | getActionDefinitionsBuilder(Opcodes: {G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) |
| 353 | .legalForCartesianProduct(Types: allScalarsAndVectors) |
| 354 | .legalIf(Predicate: extendedScalarsAndVectorsProduct); |
| 355 | |
| 356 | getActionDefinitionsBuilder(Opcode: G_PHI) |
| 357 | .legalFor(Types: allPtrsScalarsAndVectors) |
| 358 | .legalIf(Predicate: extendedPtrsScalarsAndVectors) |
| 359 | .moreElementsToNextPow2(TypeIdx: 0) |
| 360 | .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize), |
| 361 | Mutation: LegalizeMutations::changeElementCountTo( |
| 362 | TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize))); |
| 363 | |
| 364 | getActionDefinitionsBuilder(Opcode: G_BITCAST).legalIf( |
| 365 | Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrsScalarsAndVectors), |
| 366 | P1: typeInSet(TypeIdx: 1, TypesInit: allPtrsScalarsAndVectors))); |
| 367 | |
| 368 | getActionDefinitionsBuilder(Opcodes: {G_IMPLICIT_DEF, G_FREEZE}) |
| 369 | .legalFor(Types: {s1, s128}) |
| 370 | .legalFor(Types: allFloatAndIntScalarsAndPtrs) |
| 371 | .legalFor(Types: allowedVectorTypes) |
| 372 | .legalIf(Predicate: [](const LegalityQuery &Query) { |
| 373 | return Query.Types[0].isPointerVector(); |
| 374 | }) |
| 375 | .moreElementsToNextPow2(TypeIdx: 0) |
| 376 | .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize), |
| 377 | Mutation: LegalizeMutations::changeElementCountTo( |
| 378 | TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize))); |
| 379 | |
| 380 | getActionDefinitionsBuilder(Opcodes: {G_STACKSAVE, G_STACKRESTORE}).alwaysLegal(); |
| 381 | |
| 382 | getActionDefinitionsBuilder(Opcode: G_INTTOPTR) |
| 383 | .legalForCartesianProduct(Types0: allPtrs, Types1: allIntScalars) |
| 384 | .legalIf( |
| 385 | Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeOfExtendedScalars(TypeIdx: 1, IsExtendedInts))) |
| 386 | .legalIf(Predicate: [](const LegalityQuery &Query) { |
| 387 | const LLT DstTy = Query.Types[0]; |
| 388 | const LLT SrcTy = Query.Types[1]; |
| 389 | return DstTy.isPointerVector() && SrcTy.isVector() && |
| 390 | !SrcTy.isPointer() && |
| 391 | DstTy.getNumElements() == SrcTy.getNumElements(); |
| 392 | }); |
| 393 | getActionDefinitionsBuilder(Opcode: G_PTRTOINT) |
| 394 | .legalForCartesianProduct(Types0: allIntScalars, Types1: allPtrs) |
| 395 | .legalIf( |
| 396 | Predicate: all(P0: typeOfExtendedScalars(TypeIdx: 0, IsExtendedInts), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs))) |
| 397 | .legalIf(Predicate: [](const LegalityQuery &Query) { |
| 398 | const LLT DstTy = Query.Types[0]; |
| 399 | const LLT SrcTy = Query.Types[1]; |
| 400 | return SrcTy.isPointerVector() && DstTy.isVector() && |
| 401 | !DstTy.isPointer() && |
| 402 | DstTy.getNumElements() == SrcTy.getNumElements(); |
| 403 | }); |
| 404 | getActionDefinitionsBuilder(Opcode: G_PTR_ADD) |
| 405 | .legalForCartesianProduct(Types0: allPtrs, Types1: allIntScalars) |
| 406 | .legalIf( |
| 407 | Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeOfExtendedScalars(TypeIdx: 1, IsExtendedInts))); |
| 408 | |
| 409 | getActionDefinitionsBuilder(Opcode: G_PTRMASK) |
| 410 | .legalForCartesianProduct(Types0: allPtrs, Types1: allIntScalars) |
| 411 | .legalIf( |
| 412 | Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeOfExtendedScalars(TypeIdx: 1, IsExtendedInts))) |
| 413 | .legalIf(Predicate: [](const LegalityQuery &Query) { |
| 414 | const LLT PtrTy = Query.Types[0]; |
| 415 | const LLT MaskTy = Query.Types[1]; |
| 416 | return PtrTy.isPointerVector() && MaskTy.isVector() && |
| 417 | !MaskTy.isPointer() && |
| 418 | PtrTy.getNumElements() == MaskTy.getNumElements(); |
| 419 | }); |
| 420 | |
| 421 | // ST.canDirectlyComparePointers() for pointer args is supported in |
| 422 | // legalizeCustom(). |
| 423 | getActionDefinitionsBuilder(Opcode: G_ICMP) |
| 424 | .unsupportedIf(Predicate: LegalityPredicates::any( |
| 425 | P0: all(P0: typeIs(TypeIdx: 0, TypesInit: p9), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs), args: typeIsNot(TypeIdx: 1, Type: p9)), |
| 426 | P1: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrs), P1: typeIsNot(TypeIdx: 0, Type: p9), args: typeIs(TypeIdx: 1, TypesInit: p9)))) |
| 427 | .legalIf(Predicate: [IsExtendedInts](const LegalityQuery &Query) { |
| 428 | const LLT Ty = Query.Types[1]; |
| 429 | return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector(); |
| 430 | }) |
| 431 | .customIf(Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allBoolScalarsAndVectors), |
| 432 | P1: typeInSet(TypeIdx: 1, TypesInit: allPtrsScalarsAndVectors))); |
| 433 | |
| 434 | getActionDefinitionsBuilder(Opcode: G_FCMP).legalIf( |
| 435 | Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allBoolScalarsAndVectors), |
| 436 | P1: typeInSet(TypeIdx: 1, TypesInit: allFloatScalarsAndVectors))); |
| 437 | |
| 438 | getActionDefinitionsBuilder(Opcodes: {G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, |
| 439 | G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, |
| 440 | G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, |
| 441 | G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) |
| 442 | .legalForCartesianProduct(Types0: allIntScalars, Types1: allPtrs); |
| 443 | |
| 444 | getActionDefinitionsBuilder( |
| 445 | Opcodes: {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX}) |
| 446 | .legalForCartesianProduct(Types0: allFloatScalarsAndF16Vector2AndVector4s, |
| 447 | Types1: allPtrs); |
| 448 | |
| 449 | getActionDefinitionsBuilder(Opcode: G_ATOMICRMW_XCHG) |
| 450 | .legalForCartesianProduct(Types0: allFloatAndIntScalarsAndPtrs, Types1: allPtrs); |
| 451 | |
| 452 | getActionDefinitionsBuilder(Opcode: G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); |
| 453 | // TODO: add proper legalization rules. |
| 454 | getActionDefinitionsBuilder(Opcode: G_ATOMIC_CMPXCHG).alwaysLegal(); |
| 455 | |
| 456 | getActionDefinitionsBuilder(Opcodes: {G_UADDO, G_USUBO, G_UMULO, G_SMULO}) |
| 457 | .alwaysLegal(); |
| 458 | |
| 459 | getActionDefinitionsBuilder(Opcodes: {G_SADDO, G_SSUBO}).lower(); |
| 460 | |
| 461 | getActionDefinitionsBuilder(Opcodes: {G_LROUND, G_LLROUND}) |
| 462 | .legalForCartesianProduct(Types0: allFloatScalarsAndVectors, |
| 463 | Types1: allIntScalarsAndVectors); |
| 464 | |
| 465 | // FP conversions. |
| 466 | getActionDefinitionsBuilder(Opcodes: {G_FPTRUNC, G_FPEXT}) |
| 467 | .legalForCartesianProduct(Types: allFloatScalarsAndVectors); |
| 468 | |
| 469 | // Pointer-handling. |
| 470 | getActionDefinitionsBuilder(Opcode: G_FRAME_INDEX).legalFor(Types: {p0}); |
| 471 | |
| 472 | getActionDefinitionsBuilder(Opcode: G_GLOBAL_VALUE).legalFor(Types: allPtrs); |
| 473 | |
| 474 | // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32. |
| 475 | getActionDefinitionsBuilder(Opcode: G_BR).alwaysLegal(); |
| 476 | getActionDefinitionsBuilder(Opcode: G_BRCOND).legalFor(Types: {s1, s32}); |
| 477 | |
| 478 | getActionDefinitionsBuilder(Opcode: G_FFREXP).legalForCartesianProduct( |
| 479 | Types0: allFloatScalarsAndVectors, Types1: {s32, v2s32, v3s32, v4s32, v8s32, v16s32}); |
| 480 | |
| 481 | // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to |
| 482 | // tighten these requirements. Many of these math functions are only legal on |
| 483 | // specific bitwidths, so they are not selectable for |
| 484 | // allFloatScalarsAndVectors. |
| 485 | // clang-format off |
| 486 | getActionDefinitionsBuilder(Opcodes: {G_STRICT_FSQRT, |
| 487 | G_FPOW, |
| 488 | G_FEXP, |
| 489 | G_FMODF, |
| 490 | G_FSINCOS, |
| 491 | G_FEXP2, |
| 492 | G_FEXP10, |
| 493 | G_FLOG, |
| 494 | G_FLOG2, |
| 495 | G_FLOG10, |
| 496 | G_FABS, |
| 497 | G_FMINNUM, |
| 498 | G_FMAXNUM, |
| 499 | G_FCEIL, |
| 500 | G_FCOS, |
| 501 | G_FSIN, |
| 502 | G_FTAN, |
| 503 | G_FACOS, |
| 504 | G_FASIN, |
| 505 | G_FATAN, |
| 506 | G_FATAN2, |
| 507 | G_FCOSH, |
| 508 | G_FSINH, |
| 509 | G_FTANH, |
| 510 | G_FSQRT, |
| 511 | G_FFLOOR, |
| 512 | G_FRINT, |
| 513 | G_FNEARBYINT, |
| 514 | G_INTRINSIC_ROUND, |
| 515 | G_INTRINSIC_TRUNC, |
| 516 | G_FMINIMUM, |
| 517 | G_FMAXIMUM, |
| 518 | G_INTRINSIC_ROUNDEVEN}) |
| 519 | .legalFor(Types: allFloatScalars) |
| 520 | .legalFor(Types: allowedFloatVectorTypes) |
| 521 | .moreElementsToNextPow2(TypeIdx: 0) |
| 522 | .fewerElementsIf(Predicate: vectorElementCountIsGreaterThan(TypeIdx: 0, Size: MaxVectorSize), |
| 523 | Mutation: LegalizeMutations::changeElementCountTo( |
| 524 | TypeIdx: 0, EC: ElementCount::getFixed(MinVal: MaxVectorSize))); |
| 525 | // clang-format on |
| 526 | |
| 527 | getActionDefinitionsBuilder(Opcode: G_FCOPYSIGN) |
| 528 | .legalForCartesianProduct(Types0: allFloatScalarsAndVectors, |
| 529 | Types1: allFloatScalarsAndVectors); |
| 530 | |
| 531 | getActionDefinitionsBuilder(Opcode: G_FPOWI).legalForCartesianProduct( |
| 532 | Types0: allFloatScalarsAndVectors, Types1: allIntScalarsAndVectors); |
| 533 | |
| 534 | if (ST.canUseExtInstSet(E: SPIRV::InstructionSet::OpenCL_std)) { |
| 535 | getActionDefinitionsBuilder( |
| 536 | Opcodes: {G_CTTZ, G_CTTZ_ZERO_POISON, G_CTLZ, G_CTLZ_ZERO_POISON}) |
| 537 | .legalForCartesianProduct(Types0: allIntScalarsAndVectors, |
| 538 | Types1: allIntScalarsAndVectors); |
| 539 | |
| 540 | // Struct return types become a single scalar, so cannot easily legalize. |
| 541 | getActionDefinitionsBuilder(Opcodes: {G_SMULH, G_UMULH}).alwaysLegal(); |
| 542 | } |
| 543 | |
| 544 | getActionDefinitionsBuilder(Opcode: G_IS_FPCLASS).custom(); |
| 545 | |
| 546 | getActionDefinitionsBuilder(Opcodes: {G_INTRINSIC, G_INTRINSIC_CONVERGENT, |
| 547 | G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS}) |
| 548 | .alwaysLegal(); |
| 549 | getActionDefinitionsBuilder(Opcode: G_FENCE).alwaysLegal(); |
| 550 | getActionDefinitionsBuilder(Opcodes: {G_TRAP, G_DEBUGTRAP, G_UBSANTRAP}).alwaysLegal(); |
| 551 | |
| 552 | getLegacyLegalizerInfo().computeTables(); |
| 553 | verify(MII: *ST.getInstrInfo()); |
| 554 | } |
| 555 | |
| 556 | static bool (LegalizerHelper &Helper, MachineInstr &MI, |
| 557 | SPIRVGlobalRegistry *GR) { |
| 558 | MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; |
| 559 | Register DstReg = MI.getOperand(i: 0).getReg(); |
| 560 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
| 561 | Register IdxReg = MI.getOperand(i: 2).getReg(); |
| 562 | |
| 563 | MIRBuilder |
| 564 | .buildIntrinsic(ID: Intrinsic::spv_extractelt, Res: ArrayRef<Register>{DstReg}) |
| 565 | .addUse(RegNo: SrcReg) |
| 566 | .addUse(RegNo: IdxReg); |
| 567 | MI.eraseFromParent(); |
| 568 | return true; |
| 569 | } |
| 570 | |
| 571 | static bool legalizeInsertVectorElt(LegalizerHelper &Helper, MachineInstr &MI, |
| 572 | SPIRVGlobalRegistry *GR) { |
| 573 | MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; |
| 574 | Register DstReg = MI.getOperand(i: 0).getReg(); |
| 575 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
| 576 | Register ValReg = MI.getOperand(i: 2).getReg(); |
| 577 | Register IdxReg = MI.getOperand(i: 3).getReg(); |
| 578 | |
| 579 | MIRBuilder |
| 580 | .buildIntrinsic(ID: Intrinsic::spv_insertelt, Res: ArrayRef<Register>{DstReg}) |
| 581 | .addUse(RegNo: SrcReg) |
| 582 | .addUse(RegNo: ValReg) |
| 583 | .addUse(RegNo: IdxReg); |
| 584 | MI.eraseFromParent(); |
| 585 | return true; |
| 586 | } |
| 587 | |
| 588 | static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVTypeInst SpvType, |
| 589 | LegalizerHelper &Helper, |
| 590 | MachineRegisterInfo &MRI, |
| 591 | SPIRVGlobalRegistry *GR) { |
| 592 | Register ConvReg = MRI.createGenericVirtualRegister(Ty: ConvTy); |
| 593 | MRI.setRegClass(Reg: ConvReg, RC: GR->getRegClass(SpvType)); |
| 594 | GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: ConvReg, MF: Helper.MIRBuilder.getMF()); |
| 595 | Helper.MIRBuilder.buildInstr(Opcode: TargetOpcode::G_PTRTOINT) |
| 596 | .addDef(RegNo: ConvReg) |
| 597 | .addUse(RegNo: Reg); |
| 598 | return ConvReg; |
| 599 | } |
| 600 | |
| 601 | static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) { |
| 602 | if (!Ty.isVector()) |
| 603 | return false; |
| 604 | unsigned NumElements = Ty.getNumElements(); |
| 605 | unsigned MaxVectorSize = ST.isShader() ? 4 : 16; |
| 606 | return (NumElements > 4 && !isPowerOf2_32(Value: NumElements)) || |
| 607 | NumElements > MaxVectorSize; |
| 608 | } |
| 609 | |
| 610 | static bool legalizeLoad(LegalizerHelper &Helper, MachineInstr &MI, |
| 611 | SPIRVGlobalRegistry *GR) { |
| 612 | MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); |
| 613 | MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; |
| 614 | Register DstReg = MI.getOperand(i: 0).getReg(); |
| 615 | Register PtrReg = MI.getOperand(i: 1).getReg(); |
| 616 | LLT DstTy = MRI.getType(Reg: DstReg); |
| 617 | |
| 618 | if (!DstTy.isVector()) |
| 619 | return true; |
| 620 | |
| 621 | const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>(); |
| 622 | if (!needsVectorLegalization(Ty: DstTy, ST)) |
| 623 | return true; |
| 624 | |
| 625 | SmallVector<Register, 8> SplitRegs; |
| 626 | LLT EltTy = DstTy.getElementType(); |
| 627 | unsigned NumElts = DstTy.getNumElements(); |
| 628 | |
| 629 | LLT PtrTy = MRI.getType(Reg: PtrReg); |
| 630 | auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0); |
| 631 | |
| 632 | for (unsigned i = 0; i < NumElts; ++i) { |
| 633 | auto Idx = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: i); |
| 634 | Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy); |
| 635 | |
| 636 | MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr}) |
| 637 | .addImm(Val: 1) // InBounds |
| 638 | .addUse(RegNo: PtrReg) |
| 639 | .addUse(RegNo: Zero.getReg(Idx: 0)) |
| 640 | .addUse(RegNo: Idx.getReg(Idx: 0)); |
| 641 | |
| 642 | MachinePointerInfo EltPtrInfo; |
| 643 | Align EltAlign = Align(1); |
| 644 | if (!MI.memoperands_empty()) { |
| 645 | MachineMemOperand *MMO = *MI.memoperands_begin(); |
| 646 | EltPtrInfo = |
| 647 | MMO->getPointerInfo().getWithOffset(O: i * EltTy.getSizeInBytes()); |
| 648 | EltAlign = commonAlignment(A: MMO->getAlign(), Offset: i * EltTy.getSizeInBytes()); |
| 649 | } |
| 650 | |
| 651 | Register EltReg = MRI.createGenericVirtualRegister(Ty: EltTy); |
| 652 | MIRBuilder.buildLoad(Res: EltReg, Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign); |
| 653 | SplitRegs.push_back(Elt: EltReg); |
| 654 | } |
| 655 | |
| 656 | MIRBuilder.buildBuildVector(Res: DstReg, Ops: SplitRegs); |
| 657 | MI.eraseFromParent(); |
| 658 | return true; |
| 659 | } |
| 660 | |
| 661 | static bool legalizeStore(LegalizerHelper &Helper, MachineInstr &MI, |
| 662 | SPIRVGlobalRegistry *GR) { |
| 663 | MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); |
| 664 | MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; |
| 665 | Register ValReg = MI.getOperand(i: 0).getReg(); |
| 666 | Register PtrReg = MI.getOperand(i: 1).getReg(); |
| 667 | LLT ValTy = MRI.getType(Reg: ValReg); |
| 668 | |
| 669 | assert(ValTy.isVector() && "Expected vector store" ); |
| 670 | |
| 671 | SmallVector<Register, 8> SplitRegs; |
| 672 | LLT EltTy = ValTy.getElementType(); |
| 673 | unsigned NumElts = ValTy.getNumElements(); |
| 674 | |
| 675 | for (unsigned i = 0; i < NumElts; ++i) |
| 676 | SplitRegs.push_back(Elt: MRI.createGenericVirtualRegister(Ty: EltTy)); |
| 677 | |
| 678 | MIRBuilder.buildUnmerge(Res: SplitRegs, Op: ValReg); |
| 679 | |
| 680 | LLT PtrTy = MRI.getType(Reg: PtrReg); |
| 681 | auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0); |
| 682 | |
| 683 | for (unsigned i = 0; i < NumElts; ++i) { |
| 684 | auto Idx = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: i); |
| 685 | Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy); |
| 686 | |
| 687 | MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr}) |
| 688 | .addImm(Val: 1) // InBounds |
| 689 | .addUse(RegNo: PtrReg) |
| 690 | .addUse(RegNo: Zero.getReg(Idx: 0)) |
| 691 | .addUse(RegNo: Idx.getReg(Idx: 0)); |
| 692 | |
| 693 | MachinePointerInfo EltPtrInfo; |
| 694 | Align EltAlign = Align(1); |
| 695 | if (!MI.memoperands_empty()) { |
| 696 | MachineMemOperand *MMO = *MI.memoperands_begin(); |
| 697 | EltPtrInfo = |
| 698 | MMO->getPointerInfo().getWithOffset(O: i * EltTy.getSizeInBytes()); |
| 699 | EltAlign = commonAlignment(A: MMO->getAlign(), Offset: i * EltTy.getSizeInBytes()); |
| 700 | } |
| 701 | |
| 702 | MIRBuilder.buildStore(Val: SplitRegs[i], Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign); |
| 703 | } |
| 704 | |
| 705 | MI.eraseFromParent(); |
| 706 | return true; |
| 707 | } |
| 708 | |
| 709 | bool SPIRVLegalizerInfo::legalizeCustom( |
| 710 | LegalizerHelper &Helper, MachineInstr &MI, |
| 711 | LostDebugLocObserver &LocObserver) const { |
| 712 | MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); |
| 713 | switch (MI.getOpcode()) { |
| 714 | default: |
| 715 | // TODO: implement legalization for other opcodes. |
| 716 | return true; |
| 717 | case TargetOpcode::G_BITCAST: |
| 718 | return legalizeBitcast(Helper, MI); |
| 719 | case TargetOpcode::G_EXTRACT_VECTOR_ELT: |
| 720 | return legalizeExtractVectorElt(Helper, MI, GR); |
| 721 | case TargetOpcode::G_INSERT_VECTOR_ELT: |
| 722 | return legalizeInsertVectorElt(Helper, MI, GR); |
| 723 | case TargetOpcode::G_INTRINSIC: |
| 724 | case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: |
| 725 | return legalizeIntrinsic(Helper, MI); |
| 726 | case TargetOpcode::G_IS_FPCLASS: |
| 727 | return legalizeIsFPClass(Helper, MI, LocObserver); |
| 728 | case TargetOpcode::G_ICMP: { |
| 729 | auto &Op0 = MI.getOperand(i: 2); |
| 730 | auto &Op1 = MI.getOperand(i: 3); |
| 731 | Register Reg0 = Op0.getReg(); |
| 732 | Register Reg1 = Op1.getReg(); |
| 733 | CmpInst::Predicate Cond = |
| 734 | static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate()); |
| 735 | if ((!ST->canDirectlyComparePointers() || |
| 736 | (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) && |
| 737 | MRI.getType(Reg: Reg0).isPointer() && MRI.getType(Reg: Reg1).isPointer()) { |
| 738 | LLT ConvT = LLT::scalar(SizeInBits: ST->getPointerSize()); |
| 739 | Type *LLVMTy = IntegerType::get(C&: MI.getMF()->getFunction().getContext(), |
| 740 | NumBits: ST->getPointerSize()); |
| 741 | SPIRVTypeInst SpirvTy = GR->getOrCreateSPIRVType( |
| 742 | Type: LLVMTy, MIRBuilder&: Helper.MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true); |
| 743 | Op0.setReg(convertPtrToInt(Reg: Reg0, ConvTy: ConvT, SpvType: SpirvTy, Helper, MRI, GR)); |
| 744 | Op1.setReg(convertPtrToInt(Reg: Reg1, ConvTy: ConvT, SpvType: SpirvTy, Helper, MRI, GR)); |
| 745 | } |
| 746 | return true; |
| 747 | } |
| 748 | case TargetOpcode::G_LOAD: |
| 749 | return legalizeLoad(Helper, MI, GR); |
| 750 | case TargetOpcode::G_STORE: |
| 751 | return legalizeStore(Helper, MI, GR); |
| 752 | } |
| 753 | } |
| 754 | |
| 755 | static MachineInstrBuilder |
| 756 | createStackTemporaryForVector(LegalizerHelper &Helper, SPIRVGlobalRegistry *GR, |
| 757 | Register SrcReg, LLT SrcTy, |
| 758 | MachinePointerInfo &PtrInfo, Align &VecAlign) { |
| 759 | MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; |
| 760 | MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); |
| 761 | |
| 762 | VecAlign = Helper.getStackTemporaryAlignment(Type: SrcTy); |
| 763 | auto StackTemp = Helper.createStackTemporary( |
| 764 | Bytes: TypeSize::getFixed(ExactSize: SrcTy.getSizeInBytes()), Alignment: VecAlign, PtrInfo); |
| 765 | |
| 766 | // Set the type of StackTemp to a pointer to an array of the element type. |
| 767 | SPIRVTypeInst SpvSrcTy = GR->getSPIRVTypeForVReg(VReg: SrcReg); |
| 768 | SPIRVTypeInst EltSpvTy = GR->getScalarOrVectorComponentType(Type: SpvSrcTy); |
| 769 | const Type *LLVMEltTy = GR->getTypeForSPIRVType(Ty: EltSpvTy); |
| 770 | const Type *LLVMArrTy = |
| 771 | ArrayType::get(ElementType: const_cast<Type *>(LLVMEltTy), NumElements: SrcTy.getNumElements()); |
| 772 | SPIRVTypeInst ArrSpvTy = GR->getOrCreateSPIRVType( |
| 773 | Type: LLVMArrTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true); |
| 774 | SPIRVTypeInst PtrToArrSpvTy = GR->getOrCreateSPIRVPointerType( |
| 775 | BaseType: ArrSpvTy, MIRBuilder, SC: SPIRV::StorageClass::Function); |
| 776 | |
| 777 | Register StackReg = StackTemp.getReg(Idx: 0); |
| 778 | MRI.setRegClass(Reg: StackReg, RC: GR->getRegClass(SpvType: PtrToArrSpvTy)); |
| 779 | GR->assignSPIRVTypeToVReg(Type: PtrToArrSpvTy, VReg: StackReg, MF: MIRBuilder.getMF()); |
| 780 | |
| 781 | return StackTemp; |
| 782 | } |
| 783 | |
| 784 | static bool legalizeSpvBitcast(LegalizerHelper &Helper, MachineInstr &MI, |
| 785 | SPIRVGlobalRegistry *GR) { |
| 786 | LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n" ); |
| 787 | MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; |
| 788 | MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); |
| 789 | const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>(); |
| 790 | |
| 791 | Register DstReg = MI.getOperand(i: 0).getReg(); |
| 792 | Register SrcReg = MI.getOperand(i: 2).getReg(); |
| 793 | LLT DstTy = MRI.getType(Reg: DstReg); |
| 794 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
| 795 | |
| 796 | // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to |
| 797 | // allow using the generic legalization rules. |
| 798 | if (needsVectorLegalization(Ty: DstTy, ST) || |
| 799 | needsVectorLegalization(Ty: SrcTy, ST)) { |
| 800 | LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n" ); |
| 801 | MIRBuilder.buildBitcast(Dst: DstReg, Src: SrcReg); |
| 802 | MI.eraseFromParent(); |
| 803 | } |
| 804 | return true; |
| 805 | } |
| 806 | |
| 807 | static bool legalizeSpvInsertElt(LegalizerHelper &Helper, MachineInstr &MI, |
| 808 | SPIRVGlobalRegistry *GR) { |
| 809 | MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; |
| 810 | MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); |
| 811 | const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>(); |
| 812 | |
| 813 | Register DstReg = MI.getOperand(i: 0).getReg(); |
| 814 | LLT DstTy = MRI.getType(Reg: DstReg); |
| 815 | |
| 816 | if (needsVectorLegalization(Ty: DstTy, ST)) { |
| 817 | Register SrcReg = MI.getOperand(i: 2).getReg(); |
| 818 | Register ValReg = MI.getOperand(i: 3).getReg(); |
| 819 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
| 820 | MachineOperand &IdxOperand = MI.getOperand(i: 4); |
| 821 | |
| 822 | if (getImm(MO: IdxOperand, MRI: &MRI)) { |
| 823 | uint64_t IdxVal = foldImm(MO: IdxOperand, MRI: &MRI); |
| 824 | if (IdxVal < SrcTy.getNumElements()) { |
| 825 | SmallVector<Register, 8> Regs; |
| 826 | SPIRVTypeInst ElementType = |
| 827 | GR->getScalarOrVectorComponentType(Type: GR->getSPIRVTypeForVReg(VReg: DstReg)); |
| 828 | LLT ElementLLTTy = GR->getRegType(SpvType: ElementType); |
| 829 | for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) { |
| 830 | Register Reg = MRI.createGenericVirtualRegister(Ty: ElementLLTTy); |
| 831 | MRI.setRegClass(Reg, RC: GR->getRegClass(SpvType: ElementType)); |
| 832 | GR->assignSPIRVTypeToVReg(Type: ElementType, VReg: Reg, MF: *MI.getMF()); |
| 833 | Regs.push_back(Elt: Reg); |
| 834 | } |
| 835 | MIRBuilder.buildUnmerge(Res: Regs, Op: SrcReg); |
| 836 | Regs[IdxVal] = ValReg; |
| 837 | MIRBuilder.buildBuildVector(Res: DstReg, Ops: Regs); |
| 838 | MI.eraseFromParent(); |
| 839 | return true; |
| 840 | } |
| 841 | } |
| 842 | |
| 843 | LLT EltTy = SrcTy.getElementType(); |
| 844 | Align VecAlign; |
| 845 | MachinePointerInfo PtrInfo; |
| 846 | auto StackTemp = createStackTemporaryForVector(Helper, GR, SrcReg, SrcTy, |
| 847 | PtrInfo, VecAlign); |
| 848 | |
| 849 | MIRBuilder.buildStore(Val: SrcReg, Addr: StackTemp, PtrInfo, Alignment: VecAlign); |
| 850 | |
| 851 | Register IdxReg = IdxOperand.getReg(); |
| 852 | LLT PtrTy = MRI.getType(Reg: StackTemp.getReg(Idx: 0)); |
| 853 | Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy); |
| 854 | auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0); |
| 855 | |
| 856 | MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr}) |
| 857 | .addImm(Val: 1) // InBounds |
| 858 | .addUse(RegNo: StackTemp.getReg(Idx: 0)) |
| 859 | .addUse(RegNo: Zero.getReg(Idx: 0)) |
| 860 | .addUse(RegNo: IdxReg); |
| 861 | |
| 862 | MachinePointerInfo EltPtrInfo = MachinePointerInfo(PtrTy.getAddressSpace()); |
| 863 | Align EltAlign = Helper.getStackTemporaryAlignment(Type: EltTy); |
| 864 | MIRBuilder.buildStore(Val: ValReg, Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign); |
| 865 | |
| 866 | MIRBuilder.buildLoad(Res: DstReg, Addr: StackTemp, PtrInfo, Alignment: VecAlign); |
| 867 | MI.eraseFromParent(); |
| 868 | return true; |
| 869 | } |
| 870 | return true; |
| 871 | } |
| 872 | |
| 873 | static bool (LegalizerHelper &Helper, MachineInstr &MI, |
| 874 | SPIRVGlobalRegistry *GR) { |
| 875 | MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; |
| 876 | MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); |
| 877 | const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>(); |
| 878 | |
| 879 | Register SrcReg = MI.getOperand(i: 2).getReg(); |
| 880 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
| 881 | |
| 882 | if (needsVectorLegalization(Ty: SrcTy, ST)) { |
| 883 | Register DstReg = MI.getOperand(i: 0).getReg(); |
| 884 | MachineOperand &IdxOperand = MI.getOperand(i: 3); |
| 885 | |
| 886 | if (getImm(MO: IdxOperand, MRI: &MRI)) { |
| 887 | uint64_t IdxVal = foldImm(MO: IdxOperand, MRI: &MRI); |
| 888 | if (IdxVal < SrcTy.getNumElements()) { |
| 889 | LLT DstTy = MRI.getType(Reg: DstReg); |
| 890 | SmallVector<Register, 8> Regs; |
| 891 | SPIRVTypeInst DstSpvTy = GR->getSPIRVTypeForVReg(VReg: DstReg); |
| 892 | for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) { |
| 893 | if (I == IdxVal) { |
| 894 | Regs.push_back(Elt: DstReg); |
| 895 | } else { |
| 896 | Register Reg = MRI.createGenericVirtualRegister(Ty: DstTy); |
| 897 | MRI.setRegClass(Reg, RC: GR->getRegClass(SpvType: DstSpvTy)); |
| 898 | GR->assignSPIRVTypeToVReg(Type: DstSpvTy, VReg: Reg, MF: *MI.getMF()); |
| 899 | Regs.push_back(Elt: Reg); |
| 900 | } |
| 901 | } |
| 902 | MIRBuilder.buildUnmerge(Res: Regs, Op: SrcReg); |
| 903 | MI.eraseFromParent(); |
| 904 | return true; |
| 905 | } |
| 906 | } |
| 907 | |
| 908 | LLT EltTy = SrcTy.getElementType(); |
| 909 | Align VecAlign; |
| 910 | MachinePointerInfo PtrInfo; |
| 911 | auto StackTemp = createStackTemporaryForVector(Helper, GR, SrcReg, SrcTy, |
| 912 | PtrInfo, VecAlign); |
| 913 | |
| 914 | MIRBuilder.buildStore(Val: SrcReg, Addr: StackTemp, PtrInfo, Alignment: VecAlign); |
| 915 | |
| 916 | Register IdxReg = IdxOperand.getReg(); |
| 917 | LLT PtrTy = MRI.getType(Reg: StackTemp.getReg(Idx: 0)); |
| 918 | Register EltPtr = MRI.createGenericVirtualRegister(Ty: PtrTy); |
| 919 | auto Zero = MIRBuilder.buildConstant(Res: LLT::scalar(SizeInBits: 32), Val: 0); |
| 920 | |
| 921 | MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_gep, Res: ArrayRef<Register>{EltPtr}) |
| 922 | .addImm(Val: 1) // InBounds |
| 923 | .addUse(RegNo: StackTemp.getReg(Idx: 0)) |
| 924 | .addUse(RegNo: Zero.getReg(Idx: 0)) |
| 925 | .addUse(RegNo: IdxReg); |
| 926 | |
| 927 | MachinePointerInfo EltPtrInfo = MachinePointerInfo(PtrTy.getAddressSpace()); |
| 928 | Align EltAlign = Helper.getStackTemporaryAlignment(Type: EltTy); |
| 929 | MIRBuilder.buildLoad(Res: DstReg, Addr: EltPtr, PtrInfo: EltPtrInfo, Alignment: EltAlign); |
| 930 | |
| 931 | MI.eraseFromParent(); |
| 932 | return true; |
| 933 | } |
| 934 | return true; |
| 935 | } |
| 936 | |
| 937 | static bool legalizeSpvConstComposite(LegalizerHelper &Helper, MachineInstr &MI, |
| 938 | SPIRVGlobalRegistry *GR) { |
| 939 | MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; |
| 940 | MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); |
| 941 | const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>(); |
| 942 | |
| 943 | Register DstReg = MI.getOperand(i: 0).getReg(); |
| 944 | LLT DstTy = MRI.getType(Reg: DstReg); |
| 945 | |
| 946 | if (!needsVectorLegalization(Ty: DstTy, ST)) |
| 947 | return true; |
| 948 | |
| 949 | SmallVector<Register, 8> SrcRegs; |
| 950 | if (MI.getNumOperands() == 2) { |
| 951 | // The "null" case: no values are attached. |
| 952 | LLT EltTy = DstTy.getElementType(); |
| 953 | auto Zero = MIRBuilder.buildConstant(Res: EltTy, Val: 0); |
| 954 | SPIRVTypeInst SpvDstTy = GR->getSPIRVTypeForVReg(VReg: DstReg); |
| 955 | SPIRVTypeInst SpvEltTy = GR->getScalarOrVectorComponentType(Type: SpvDstTy); |
| 956 | GR->assignSPIRVTypeToVReg(Type: SpvEltTy, VReg: Zero.getReg(Idx: 0), MF: MIRBuilder.getMF()); |
| 957 | for (unsigned i = 0; i < DstTy.getNumElements(); ++i) |
| 958 | SrcRegs.push_back(Elt: Zero.getReg(Idx: 0)); |
| 959 | } else { |
| 960 | for (unsigned i = 2; i < MI.getNumOperands(); ++i) { |
| 961 | SrcRegs.push_back(Elt: MI.getOperand(i).getReg()); |
| 962 | } |
| 963 | } |
| 964 | MIRBuilder.buildBuildVector(Res: DstReg, Ops: SrcRegs); |
| 965 | MI.eraseFromParent(); |
| 966 | return true; |
| 967 | } |
| 968 | |
| 969 | bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, |
| 970 | MachineInstr &MI) const { |
| 971 | LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI); |
| 972 | auto IntrinsicID = cast<GIntrinsic>(Val&: MI).getIntrinsicID(); |
| 973 | switch (IntrinsicID) { |
| 974 | case Intrinsic::spv_bitcast: |
| 975 | return legalizeSpvBitcast(Helper, MI, GR); |
| 976 | case Intrinsic::spv_insertelt: |
| 977 | return legalizeSpvInsertElt(Helper, MI, GR); |
| 978 | case Intrinsic::spv_extractelt: |
| 979 | return legalizeSpvExtractElt(Helper, MI, GR); |
| 980 | case Intrinsic::spv_const_composite: |
| 981 | return legalizeSpvConstComposite(Helper, MI, GR); |
| 982 | } |
| 983 | return true; |
| 984 | } |
| 985 | |
| 986 | bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper, |
| 987 | MachineInstr &MI) const { |
| 988 | // Once the G_BITCAST is using vectors that are allowed, we turn it back into |
| 989 | // an spv_bitcast to avoid verifier problems when the register types are the |
| 990 | // same for the source and the result. Note that the SPIR-V types associated |
| 991 | // with the bitcast can be different even if the register types are the same. |
| 992 | MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; |
| 993 | Register DstReg = MI.getOperand(i: 0).getReg(); |
| 994 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
| 995 | SmallVector<Register, 1> DstRegs = {DstReg}; |
| 996 | MIRBuilder.buildIntrinsic(ID: Intrinsic::spv_bitcast, Res: DstRegs).addUse(RegNo: SrcReg); |
| 997 | MI.eraseFromParent(); |
| 998 | return true; |
| 999 | } |
| 1000 | |
| 1001 | // Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted |
| 1002 | // to ensure that all instructions created during the lowering have SPIR-V types |
| 1003 | // assigned to them. |
| 1004 | bool SPIRVLegalizerInfo::legalizeIsFPClass( |
| 1005 | LegalizerHelper &Helper, MachineInstr &MI, |
| 1006 | LostDebugLocObserver &LocObserver) const { |
| 1007 | auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs(); |
| 1008 | FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(i: 2).getImm()); |
| 1009 | |
| 1010 | auto &MIRBuilder = Helper.MIRBuilder; |
| 1011 | auto &MF = MIRBuilder.getMF(); |
| 1012 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 1013 | |
| 1014 | Type *LLVMDstTy = |
| 1015 | IntegerType::get(C&: MIRBuilder.getContext(), NumBits: DstTy.getScalarSizeInBits()); |
| 1016 | if (DstTy.isVector()) |
| 1017 | LLVMDstTy = VectorType::get(ElementType: LLVMDstTy, EC: DstTy.getElementCount()); |
| 1018 | SPIRVTypeInst SPIRVDstTy = GR->getOrCreateSPIRVType( |
| 1019 | Type: LLVMDstTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, |
| 1020 | /*EmitIR*/ true); |
| 1021 | |
| 1022 | unsigned BitSize = SrcTy.getScalarSizeInBits(); |
| 1023 | const fltSemantics &Semantics = getFltSemanticForLLT(Ty: SrcTy.getScalarType()); |
| 1024 | |
| 1025 | LLT IntTy = LLT::scalar(SizeInBits: BitSize); |
| 1026 | Type *LLVMIntTy = IntegerType::get(C&: MIRBuilder.getContext(), NumBits: BitSize); |
| 1027 | if (SrcTy.isVector()) { |
| 1028 | IntTy = LLT::vector(EC: SrcTy.getElementCount(), ScalarTy: IntTy); |
| 1029 | LLVMIntTy = VectorType::get(ElementType: LLVMIntTy, EC: SrcTy.getElementCount()); |
| 1030 | } |
| 1031 | SPIRVTypeInst SPIRVIntTy = GR->getOrCreateSPIRVType( |
| 1032 | Type: LLVMIntTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, |
| 1033 | /*EmitIR*/ true); |
| 1034 | |
| 1035 | // Clang doesn't support capture of structured bindings: |
| 1036 | LLT DstTyCopy = DstTy; |
| 1037 | const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) { |
| 1038 | // Assign this MI's (assumed only) destination to one of the two types we |
| 1039 | // expect: either the G_IS_FPCLASS's destination type, or the integer type |
| 1040 | // bitcast from the source type. |
| 1041 | LLT MITy = MRI.getType(Reg: MI.getReg(Idx: 0)); |
| 1042 | assert((MITy == IntTy || MITy == DstTyCopy) && |
| 1043 | "Unexpected LLT type while lowering G_IS_FPCLASS" ); |
| 1044 | SPIRVTypeInst SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy; |
| 1045 | GR->assignSPIRVTypeToVReg(Type: SPVTy, VReg: MI.getReg(Idx: 0), MF); |
| 1046 | return MI; |
| 1047 | }; |
| 1048 | |
| 1049 | // Helper to build and assign a constant in one go |
| 1050 | const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder { |
| 1051 | if (!Ty.isFixedVector()) |
| 1052 | return assignSPIRVTy(MIRBuilder.buildConstant(Ty, C)); |
| 1053 | auto ScalarC = MIRBuilder.buildConstant(Ty.getScalarType(), C); |
| 1054 | assert((Ty == IntTy || Ty == DstTyCopy) && |
| 1055 | "Unexpected LLT type while lowering constant for G_IS_FPCLASS" ); |
| 1056 | SPIRVTypeInst VecEltTy = GR->getOrCreateSPIRVType( |
| 1057 | Type: (Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType(), MIRBuilder, |
| 1058 | AQ: SPIRV::AccessQualifier::ReadWrite, |
| 1059 | /*EmitIR*/ true); |
| 1060 | GR->assignSPIRVTypeToVReg(Type: VecEltTy, VReg: ScalarC.getReg(0), MF); |
| 1061 | return assignSPIRVTy(MIRBuilder.buildSplatBuildVector(Res: Ty, Src: ScalarC)); |
| 1062 | }; |
| 1063 | |
| 1064 | if (Mask == fcNone) { |
| 1065 | MIRBuilder.buildCopy(Res: DstReg, Op: buildSPIRVConstant(DstTy, 0)); |
| 1066 | MI.eraseFromParent(); |
| 1067 | return true; |
| 1068 | } |
| 1069 | if (Mask == fcAllFlags) { |
| 1070 | MIRBuilder.buildCopy(Res: DstReg, Op: buildSPIRVConstant(DstTy, 1)); |
| 1071 | MI.eraseFromParent(); |
| 1072 | return true; |
| 1073 | } |
| 1074 | |
| 1075 | // Note that rather than creating a COPY here (between a floating-point and |
| 1076 | // integer type of the same size) we create a SPIR-V bitcast immediately. We |
| 1077 | // can't create a G_BITCAST because the LLTs are the same, and we can't seem |
| 1078 | // to correctly lower COPYs to SPIR-V bitcasts at this moment. |
| 1079 | Register ResVReg = MRI.createGenericVirtualRegister(Ty: IntTy); |
| 1080 | MRI.setRegClass(Reg: ResVReg, RC: GR->getRegClass(SpvType: SPIRVIntTy)); |
| 1081 | GR->assignSPIRVTypeToVReg(Type: SPIRVIntTy, VReg: ResVReg, MF: Helper.MIRBuilder.getMF()); |
| 1082 | auto AsInt = MIRBuilder.buildInstr(Opcode: SPIRV::OpBitcast) |
| 1083 | .addDef(RegNo: ResVReg) |
| 1084 | .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: SPIRVIntTy)) |
| 1085 | .addUse(RegNo: SrcReg); |
| 1086 | AsInt = assignSPIRVTy(std::move(AsInt)); |
| 1087 | |
| 1088 | // Various masks. |
| 1089 | APInt SignBit = APInt::getSignMask(BitWidth: BitSize); |
| 1090 | APInt ValueMask = APInt::getSignedMaxValue(numBits: BitSize); // All bits but sign. |
| 1091 | APInt Inf = APFloat::getInf(Sem: Semantics).bitcastToAPInt(); // Exp and int bit. |
| 1092 | APInt ExpMask = Inf; |
| 1093 | APInt AllOneMantissa = APFloat::getLargest(Sem: Semantics).bitcastToAPInt() & ~Inf; |
| 1094 | APInt QNaNBitMask = |
| 1095 | APInt::getOneBitSet(numBits: BitSize, BitNo: AllOneMantissa.getActiveBits() - 1); |
| 1096 | APInt InversionMask = APInt::getAllOnes(numBits: DstTy.getScalarSizeInBits()); |
| 1097 | |
| 1098 | auto SignBitC = buildSPIRVConstant(IntTy, SignBit); |
| 1099 | auto ValueMaskC = buildSPIRVConstant(IntTy, ValueMask); |
| 1100 | auto InfC = buildSPIRVConstant(IntTy, Inf); |
| 1101 | auto ExpMaskC = buildSPIRVConstant(IntTy, ExpMask); |
| 1102 | auto ZeroC = buildSPIRVConstant(IntTy, 0); |
| 1103 | |
| 1104 | auto Abs = assignSPIRVTy(MIRBuilder.buildAnd(Dst: IntTy, Src0: AsInt, Src1: ValueMaskC)); |
| 1105 | auto Sign = assignSPIRVTy( |
| 1106 | MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_NE, Res: DstTy, Op0: AsInt, Op1: Abs)); |
| 1107 | |
| 1108 | auto Res = buildSPIRVConstant(DstTy, 0); |
| 1109 | |
| 1110 | const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) { |
| 1111 | Res = assignSPIRVTy( |
| 1112 | MIRBuilder.buildOr(Dst: DstTyCopy, Src0: Res, Src1: assignSPIRVTy(std::move(ToAppend)))); |
| 1113 | }; |
| 1114 | |
| 1115 | // Tests that involve more than one class should be processed first. |
| 1116 | if ((Mask & fcFinite) == fcFinite) { |
| 1117 | // finite(V) ==> abs(V) u< exp_mask |
| 1118 | appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: Abs, |
| 1119 | Op1: ExpMaskC)); |
| 1120 | Mask &= ~fcFinite; |
| 1121 | } else if ((Mask & fcFinite) == fcPosFinite) { |
| 1122 | // finite(V) && V > 0 ==> V u< exp_mask |
| 1123 | appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: AsInt, |
| 1124 | Op1: ExpMaskC)); |
| 1125 | Mask &= ~fcPosFinite; |
| 1126 | } else if ((Mask & fcFinite) == fcNegFinite) { |
| 1127 | // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1 |
| 1128 | auto Cmp = assignSPIRVTy(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, |
| 1129 | Res: DstTy, Op0: Abs, Op1: ExpMaskC)); |
| 1130 | appendToRes(MIRBuilder.buildAnd(Dst: DstTy, Src0: Cmp, Src1: Sign)); |
| 1131 | Mask &= ~fcNegFinite; |
| 1132 | } |
| 1133 | |
| 1134 | if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) { |
| 1135 | // fcZero | fcSubnormal => test all exponent bits are 0 |
| 1136 | // TODO: Handle sign bit specific cases |
| 1137 | // TODO: Handle inverted case |
| 1138 | if (PartialCheck == (fcZero | fcSubnormal)) { |
| 1139 | auto ExpBits = assignSPIRVTy(MIRBuilder.buildAnd(Dst: IntTy, Src0: AsInt, Src1: ExpMaskC)); |
| 1140 | appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy, |
| 1141 | Op0: ExpBits, Op1: ZeroC)); |
| 1142 | Mask &= ~PartialCheck; |
| 1143 | } |
| 1144 | } |
| 1145 | |
| 1146 | // Check for individual classes. |
| 1147 | if (FPClassTest PartialCheck = Mask & fcZero) { |
| 1148 | if (PartialCheck == fcPosZero) |
| 1149 | appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy, |
| 1150 | Op0: AsInt, Op1: ZeroC)); |
| 1151 | else if (PartialCheck == fcZero) |
| 1152 | appendToRes( |
| 1153 | MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy, Op0: Abs, Op1: ZeroC)); |
| 1154 | else // fcNegZero |
| 1155 | appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy, |
| 1156 | Op0: AsInt, Op1: SignBitC)); |
| 1157 | } |
| 1158 | |
| 1159 | if (FPClassTest PartialCheck = Mask & fcSubnormal) { |
| 1160 | // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set) |
| 1161 | // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set) |
| 1162 | auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs; |
| 1163 | auto OneC = buildSPIRVConstant(IntTy, 1); |
| 1164 | auto VMinusOne = MIRBuilder.buildSub(Dst: IntTy, Src0: V, Src1: OneC); |
| 1165 | auto SubnormalRes = assignSPIRVTy( |
| 1166 | MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: VMinusOne, |
| 1167 | Op1: buildSPIRVConstant(IntTy, AllOneMantissa))); |
| 1168 | if (PartialCheck == fcNegSubnormal) |
| 1169 | SubnormalRes = MIRBuilder.buildAnd(Dst: DstTy, Src0: SubnormalRes, Src1: Sign); |
| 1170 | appendToRes(std::move(SubnormalRes)); |
| 1171 | } |
| 1172 | |
| 1173 | if (FPClassTest PartialCheck = Mask & fcInf) { |
| 1174 | if (PartialCheck == fcPosInf) |
| 1175 | appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy, |
| 1176 | Op0: AsInt, Op1: InfC)); |
| 1177 | else if (PartialCheck == fcInf) |
| 1178 | appendToRes( |
| 1179 | MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy, Op0: Abs, Op1: InfC)); |
| 1180 | else { // fcNegInf |
| 1181 | APInt NegInf = APFloat::getInf(Sem: Semantics, Negative: true).bitcastToAPInt(); |
| 1182 | auto NegInfC = buildSPIRVConstant(IntTy, NegInf); |
| 1183 | appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: DstTy, |
| 1184 | Op0: AsInt, Op1: NegInfC)); |
| 1185 | } |
| 1186 | } |
| 1187 | |
| 1188 | if (FPClassTest PartialCheck = Mask & fcNan) { |
| 1189 | auto InfWithQnanBitC = |
| 1190 | buildSPIRVConstant(IntTy, std::move(Inf) | QNaNBitMask); |
| 1191 | if (PartialCheck == fcNan) { |
| 1192 | // isnan(V) ==> abs(V) u> int(inf) |
| 1193 | appendToRes( |
| 1194 | MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_UGT, Res: DstTy, Op0: Abs, Op1: InfC)); |
| 1195 | } else if (PartialCheck == fcQNan) { |
| 1196 | // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit) |
| 1197 | appendToRes(MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_UGE, Res: DstTy, Op0: Abs, |
| 1198 | Op1: InfWithQnanBitC)); |
| 1199 | } else { // fcSNan |
| 1200 | // issignaling(V) ==> abs(V) u> unsigned(Inf) && |
| 1201 | // abs(V) u< (unsigned(Inf) | quiet_bit) |
| 1202 | auto IsNan = assignSPIRVTy( |
| 1203 | MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_UGT, Res: DstTy, Op0: Abs, Op1: InfC)); |
| 1204 | auto IsNotQnan = assignSPIRVTy(MIRBuilder.buildICmp( |
| 1205 | Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: Abs, Op1: InfWithQnanBitC)); |
| 1206 | appendToRes(MIRBuilder.buildAnd(Dst: DstTy, Src0: IsNan, Src1: IsNotQnan)); |
| 1207 | } |
| 1208 | } |
| 1209 | |
| 1210 | if (FPClassTest PartialCheck = Mask & fcNormal) { |
| 1211 | // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u< |
| 1212 | // (max_exp-1)) |
| 1213 | APInt ExpLSB = ExpMask & ~(ExpMask.shl(shiftAmt: 1)); |
| 1214 | auto ExpMinusOne = assignSPIRVTy( |
| 1215 | MIRBuilder.buildSub(Dst: IntTy, Src0: Abs, Src1: buildSPIRVConstant(IntTy, ExpLSB))); |
| 1216 | APInt MaxExpMinusOne = std::move(ExpMask) - ExpLSB; |
| 1217 | auto NormalRes = assignSPIRVTy( |
| 1218 | MIRBuilder.buildICmp(Pred: CmpInst::Predicate::ICMP_ULT, Res: DstTy, Op0: ExpMinusOne, |
| 1219 | Op1: buildSPIRVConstant(IntTy, MaxExpMinusOne))); |
| 1220 | if (PartialCheck == fcNegNormal) |
| 1221 | NormalRes = MIRBuilder.buildAnd(Dst: DstTy, Src0: NormalRes, Src1: Sign); |
| 1222 | else if (PartialCheck == fcPosNormal) { |
| 1223 | auto PosSign = assignSPIRVTy(MIRBuilder.buildXor( |
| 1224 | Dst: DstTy, Src0: Sign, Src1: buildSPIRVConstant(DstTy, InversionMask))); |
| 1225 | NormalRes = MIRBuilder.buildAnd(Dst: DstTy, Src0: NormalRes, Src1: PosSign); |
| 1226 | } |
| 1227 | appendToRes(std::move(NormalRes)); |
| 1228 | } |
| 1229 | |
| 1230 | MIRBuilder.buildCopy(Res: DstReg, Op: Res); |
| 1231 | MI.eraseFromParent(); |
| 1232 | return true; |
| 1233 | } |
| 1234 | |