| 1 | //===-- SPIRVPostLegalizer.cpp - amend info after legalization -*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // The pass partially applies pre-legalization logic to new instructions |
| 10 | // inserted as a result of legalization: |
| 11 | // - assigns SPIR-V types to registers for new instructions. |
| 12 | // - inserts ASSIGN_TYPE pseudo-instructions required for type folding. |
| 13 | // |
| 14 | //===----------------------------------------------------------------------===// |
| 15 | |
| 16 | #include "SPIRV.h" |
| 17 | #include "SPIRVSubtarget.h" |
| 18 | #include "SPIRVUtils.h" |
| 19 | #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" |
| 20 | #include "llvm/CodeGen/MachineFrameInfo.h" |
| 21 | #include "llvm/IR/IntrinsicsSPIRV.h" |
| 22 | #include "llvm/Support/Debug.h" |
| 23 | #include <stack> |
| 24 | |
| 25 | #define DEBUG_TYPE "spirv-postlegalizer" |
| 26 | |
| 27 | using namespace llvm; |
| 28 | |
| 29 | namespace { |
| 30 | class SPIRVPostLegalizer : public MachineFunctionPass { |
| 31 | public: |
| 32 | static char ID; |
| 33 | SPIRVPostLegalizer() : MachineFunctionPass(ID) {} |
| 34 | bool runOnMachineFunction(MachineFunction &MF) override; |
| 35 | }; |
| 36 | } // namespace |
| 37 | |
| 38 | namespace llvm { |
| 39 | // Defined in SPIRVPreLegalizer.cpp. |
| 40 | extern void updateRegType(Register Reg, Type *Ty, SPIRVType *SpirvTy, |
| 41 | SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB, |
| 42 | MachineRegisterInfo &MRI); |
| 43 | extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, |
| 44 | MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR, |
| 45 | SPIRVType *KnownResType); |
| 46 | } // namespace llvm |
| 47 | |
| 48 | static SPIRVType *deduceIntTypeFromResult(Register ResVReg, |
| 49 | MachineIRBuilder &MIB, |
| 50 | SPIRVGlobalRegistry *GR) { |
| 51 | const LLT &Ty = MIB.getMRI()->getType(Reg: ResVReg); |
| 52 | return GR->getOrCreateSPIRVIntegerType(BitWidth: Ty.getScalarSizeInBits(), MIRBuilder&: MIB); |
| 53 | } |
| 54 | |
| 55 | static SPIRVType *deduceTypeFromSingleOperand(MachineInstr *I, |
| 56 | MachineIRBuilder &MIB, |
| 57 | SPIRVGlobalRegistry *GR, |
| 58 | unsigned OpIdx) { |
| 59 | Register OpReg = I->getOperand(i: OpIdx).getReg(); |
| 60 | if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(VReg: OpReg)) { |
| 61 | if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(Type: OpType)) { |
| 62 | Register ResVReg = I->getOperand(i: 0).getReg(); |
| 63 | const LLT &ResLLT = MIB.getMRI()->getType(Reg: ResVReg); |
| 64 | if (ResLLT.isVector()) |
| 65 | return GR->getOrCreateSPIRVVectorType(BaseType: CompType, NumElements: ResLLT.getNumElements(), |
| 66 | MIRBuilder&: MIB, EmitIR: false); |
| 67 | return CompType; |
| 68 | } |
| 69 | } |
| 70 | return nullptr; |
| 71 | } |
| 72 | |
| 73 | static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I, |
| 74 | MachineIRBuilder &MIB, |
| 75 | SPIRVGlobalRegistry *GR, |
| 76 | unsigned StartOp, unsigned EndOp) { |
| 77 | SPIRVType *ResType = nullptr; |
| 78 | for (unsigned i = StartOp; i < EndOp; ++i) { |
| 79 | if (SPIRVType *Type = deduceTypeFromSingleOperand(I, MIB, GR, OpIdx: i)) { |
| 80 | #ifdef EXPENSIVE_CHECKS |
| 81 | assert(!ResType || Type == ResType && "Conflicting type from operands." ); |
| 82 | ResType = Type; |
| 83 | #else |
| 84 | return Type; |
| 85 | #endif |
| 86 | } |
| 87 | } |
| 88 | return ResType; |
| 89 | } |
| 90 | |
| 91 | static SPIRVType *deduceTypeFromResultRegister(MachineInstr *Use, |
| 92 | Register UseRegister, |
| 93 | SPIRVGlobalRegistry *GR, |
| 94 | MachineIRBuilder &MIB) { |
| 95 | for (const MachineOperand &MO : Use->defs()) { |
| 96 | if (!MO.isReg()) |
| 97 | continue; |
| 98 | if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(VReg: MO.getReg())) { |
| 99 | if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(Type: OpType)) { |
| 100 | const LLT &ResLLT = MIB.getMRI()->getType(Reg: UseRegister); |
| 101 | if (ResLLT.isVector()) |
| 102 | return GR->getOrCreateSPIRVVectorType( |
| 103 | BaseType: CompType, NumElements: ResLLT.getNumElements(), MIRBuilder&: MIB, EmitIR: false); |
| 104 | return CompType; |
| 105 | } |
| 106 | } |
| 107 | } |
| 108 | return nullptr; |
| 109 | } |
| 110 | |
| 111 | static SPIRVType *deducePointerTypeFromResultRegister(MachineInstr *Use, |
| 112 | Register UseRegister, |
| 113 | SPIRVGlobalRegistry *GR, |
| 114 | MachineIRBuilder &MIB) { |
| 115 | assert(Use->getOpcode() == TargetOpcode::G_LOAD || |
| 116 | Use->getOpcode() == TargetOpcode::G_STORE); |
| 117 | |
| 118 | Register ValueReg = Use->getOperand(i: 0).getReg(); |
| 119 | SPIRVType *ValueType = GR->getSPIRVTypeForVReg(VReg: ValueReg); |
| 120 | if (!ValueType) |
| 121 | return nullptr; |
| 122 | |
| 123 | return GR->getOrCreateSPIRVPointerType(BaseType: ValueType, MIRBuilder&: MIB, |
| 124 | SC: SPIRV::StorageClass::Function); |
| 125 | } |
| 126 | |
| 127 | static SPIRVType *deduceTypeFromPointerOperand(MachineInstr *Use, |
| 128 | Register UseRegister, |
| 129 | SPIRVGlobalRegistry *GR, |
| 130 | MachineIRBuilder &MIB) { |
| 131 | assert(Use->getOpcode() == TargetOpcode::G_LOAD || |
| 132 | Use->getOpcode() == TargetOpcode::G_STORE); |
| 133 | |
| 134 | Register PtrReg = Use->getOperand(i: 1).getReg(); |
| 135 | SPIRVType *PtrType = GR->getSPIRVTypeForVReg(VReg: PtrReg); |
| 136 | if (!PtrType) |
| 137 | return nullptr; |
| 138 | |
| 139 | return GR->getPointeeType(PtrType); |
| 140 | } |
| 141 | |
| 142 | static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF, |
| 143 | SPIRVGlobalRegistry *GR, |
| 144 | MachineIRBuilder &MIB) { |
| 145 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 146 | for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) { |
| 147 | SPIRVType *ResType = nullptr; |
| 148 | LLVM_DEBUG(dbgs() << "Looking at use " << Use); |
| 149 | switch (Use.getOpcode()) { |
| 150 | case TargetOpcode::G_BUILD_VECTOR: |
| 151 | case TargetOpcode::G_EXTRACT_VECTOR_ELT: |
| 152 | case TargetOpcode::G_UNMERGE_VALUES: |
| 153 | case TargetOpcode::G_ADD: |
| 154 | case TargetOpcode::G_SUB: |
| 155 | case TargetOpcode::G_MUL: |
| 156 | case TargetOpcode::G_SDIV: |
| 157 | case TargetOpcode::G_UDIV: |
| 158 | case TargetOpcode::G_SREM: |
| 159 | case TargetOpcode::G_UREM: |
| 160 | case TargetOpcode::G_FADD: |
| 161 | case TargetOpcode::G_FSUB: |
| 162 | case TargetOpcode::G_FMUL: |
| 163 | case TargetOpcode::G_FDIV: |
| 164 | case TargetOpcode::G_FREM: |
| 165 | case TargetOpcode::G_FMA: |
| 166 | case TargetOpcode::COPY: |
| 167 | case TargetOpcode::G_STRICT_FMA: |
| 168 | ResType = deduceTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB); |
| 169 | break; |
| 170 | case TargetOpcode::G_LOAD: |
| 171 | case TargetOpcode::G_STORE: |
| 172 | if (Reg == Use.getOperand(i: 1).getReg()) |
| 173 | ResType = deducePointerTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB); |
| 174 | else |
| 175 | ResType = deduceTypeFromPointerOperand(Use: &Use, UseRegister: Reg, GR, MIB); |
| 176 | break; |
| 177 | case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: |
| 178 | case TargetOpcode::G_INTRINSIC: { |
| 179 | auto IntrinsicID = cast<GIntrinsic>(Val&: Use).getIntrinsicID(); |
| 180 | if (IntrinsicID == Intrinsic::spv_insertelt) { |
| 181 | if (Reg == Use.getOperand(i: 2).getReg()) |
| 182 | ResType = deduceTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB); |
| 183 | } else if (IntrinsicID == Intrinsic::spv_extractelt) { |
| 184 | if (Reg == Use.getOperand(i: 2).getReg()) |
| 185 | ResType = deduceTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB); |
| 186 | } |
| 187 | break; |
| 188 | } |
| 189 | } |
| 190 | if (ResType) { |
| 191 | LLVM_DEBUG(dbgs() << "Deduced type from use " << *ResType); |
| 192 | return ResType; |
| 193 | } |
| 194 | } |
| 195 | return nullptr; |
| 196 | } |
| 197 | |
| 198 | static SPIRVType *deduceGEPType(MachineInstr *I, SPIRVGlobalRegistry *GR, |
| 199 | MachineIRBuilder &MIB) { |
| 200 | LLVM_DEBUG(dbgs() << "Deducing GEP type for: " << *I); |
| 201 | Register PtrReg = I->getOperand(i: 3).getReg(); |
| 202 | SPIRVType *PtrType = GR->getSPIRVTypeForVReg(VReg: PtrReg); |
| 203 | if (!PtrType) { |
| 204 | LLVM_DEBUG(dbgs() << " Could not get type for pointer operand.\n" ); |
| 205 | return nullptr; |
| 206 | } |
| 207 | |
| 208 | SPIRVType *PointeeType = GR->getPointeeType(PtrType); |
| 209 | if (!PointeeType) { |
| 210 | LLVM_DEBUG(dbgs() << " Could not get pointee type from pointer type.\n" ); |
| 211 | return nullptr; |
| 212 | } |
| 213 | |
| 214 | MachineRegisterInfo *MRI = MIB.getMRI(); |
| 215 | |
| 216 | // The first index (operand 4) steps over the pointer, so the type doesn't |
| 217 | // change. |
| 218 | for (unsigned i = 5; i < I->getNumOperands(); ++i) { |
| 219 | LLVM_DEBUG(dbgs() << " Traversing index " << i |
| 220 | << ", current type: " << *PointeeType); |
| 221 | switch (PointeeType->getOpcode()) { |
| 222 | case SPIRV::OpTypeArray: |
| 223 | case SPIRV::OpTypeRuntimeArray: |
| 224 | case SPIRV::OpTypeVector: { |
| 225 | Register ElemTypeReg = PointeeType->getOperand(i: 1).getReg(); |
| 226 | PointeeType = GR->getSPIRVTypeForVReg(VReg: ElemTypeReg); |
| 227 | break; |
| 228 | } |
| 229 | case SPIRV::OpTypeStruct: { |
| 230 | MachineOperand &IdxOp = I->getOperand(i); |
| 231 | if (!IdxOp.isReg()) { |
| 232 | LLVM_DEBUG(dbgs() << " Index is not a register.\n" ); |
| 233 | return nullptr; |
| 234 | } |
| 235 | MachineInstr *Def = MRI->getVRegDef(Reg: IdxOp.getReg()); |
| 236 | if (!Def) { |
| 237 | LLVM_DEBUG( |
| 238 | dbgs() << " Could not find definition for index register.\n" ); |
| 239 | return nullptr; |
| 240 | } |
| 241 | |
| 242 | uint64_t IndexVal = foldImm(MO: IdxOp, MRI); |
| 243 | if (IndexVal >= PointeeType->getNumOperands() - 1) { |
| 244 | LLVM_DEBUG(dbgs() << " Struct index out of bounds.\n" ); |
| 245 | return nullptr; |
| 246 | } |
| 247 | |
| 248 | Register MemberTypeReg = PointeeType->getOperand(i: IndexVal + 1).getReg(); |
| 249 | PointeeType = GR->getSPIRVTypeForVReg(VReg: MemberTypeReg); |
| 250 | break; |
| 251 | } |
| 252 | default: |
| 253 | LLVM_DEBUG(dbgs() << " Unknown type opcode for GEP traversal.\n" ); |
| 254 | return nullptr; |
| 255 | } |
| 256 | |
| 257 | if (!PointeeType) { |
| 258 | LLVM_DEBUG(dbgs() << " Could not resolve next pointee type.\n" ); |
| 259 | return nullptr; |
| 260 | } |
| 261 | } |
| 262 | LLVM_DEBUG(dbgs() << " Final pointee type: " << *PointeeType); |
| 263 | |
| 264 | SPIRV::StorageClass::StorageClass SC = GR->getPointerStorageClass(Type: PtrType); |
| 265 | SPIRVType *Res = GR->getOrCreateSPIRVPointerType(BaseType: PointeeType, MIRBuilder&: MIB, SC); |
| 266 | LLVM_DEBUG(dbgs() << " Deduced GEP type: " << *Res); |
| 267 | return Res; |
| 268 | } |
| 269 | |
| 270 | static SPIRVType *deduceResultTypeFromOperands(MachineInstr *I, |
| 271 | SPIRVGlobalRegistry *GR, |
| 272 | MachineIRBuilder &MIB) { |
| 273 | Register ResVReg = I->getOperand(i: 0).getReg(); |
| 274 | switch (I->getOpcode()) { |
| 275 | case TargetOpcode::G_CONSTANT: |
| 276 | case TargetOpcode::G_ANYEXT: |
| 277 | case TargetOpcode::G_SEXT: |
| 278 | case TargetOpcode::G_ZEXT: |
| 279 | return deduceIntTypeFromResult(ResVReg, MIB, GR); |
| 280 | case TargetOpcode::G_BUILD_VECTOR: |
| 281 | return deduceTypeFromOperandRange(I, MIB, GR, StartOp: 1, EndOp: I->getNumOperands()); |
| 282 | case TargetOpcode::G_SHUFFLE_VECTOR: |
| 283 | return deduceTypeFromOperandRange(I, MIB, GR, StartOp: 1, EndOp: 3); |
| 284 | case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: |
| 285 | case TargetOpcode::G_INTRINSIC: { |
| 286 | auto IntrinsicID = cast<GIntrinsic>(Val: I)->getIntrinsicID(); |
| 287 | if (IntrinsicID == Intrinsic::spv_gep) |
| 288 | return deduceGEPType(I, GR, MIB); |
| 289 | break; |
| 290 | } |
| 291 | case TargetOpcode::G_LOAD: { |
| 292 | SPIRVType *PtrType = deduceTypeFromSingleOperand(I, MIB, GR, OpIdx: 1); |
| 293 | return PtrType ? GR->getPointeeType(PtrType) : nullptr; |
| 294 | } |
| 295 | default: |
| 296 | if (I->getNumDefs() == 1 && I->getNumOperands() > 1 && |
| 297 | I->getOperand(i: 1).isReg()) |
| 298 | return deduceTypeFromSingleOperand(I, MIB, GR, OpIdx: 1); |
| 299 | } |
| 300 | return nullptr; |
| 301 | } |
| 302 | |
| 303 | static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF, |
| 304 | SPIRVGlobalRegistry *GR, |
| 305 | MachineIRBuilder &MIB) { |
| 306 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 307 | Register SrcReg = I->getOperand(i: I->getNumOperands() - 1).getReg(); |
| 308 | SPIRVType *ScalarType = nullptr; |
| 309 | if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(VReg: SrcReg)) { |
| 310 | assert(DefType->getOpcode() == SPIRV::OpTypeVector); |
| 311 | ScalarType = GR->getSPIRVTypeForVReg(VReg: DefType->getOperand(i: 1).getReg()); |
| 312 | } |
| 313 | |
| 314 | if (!ScalarType) { |
| 315 | // If we could not deduce the type from the source, try to deduce it from |
| 316 | // the uses of the results. |
| 317 | for (unsigned i = 0; i < I->getNumDefs(); ++i) { |
| 318 | Register DefReg = I->getOperand(i).getReg(); |
| 319 | ScalarType = deduceTypeFromUses(Reg: DefReg, MF, GR, MIB); |
| 320 | if (ScalarType) { |
| 321 | ScalarType = GR->getScalarOrVectorComponentType(Type: ScalarType); |
| 322 | break; |
| 323 | } |
| 324 | } |
| 325 | } |
| 326 | |
| 327 | if (!ScalarType) |
| 328 | return false; |
| 329 | |
| 330 | for (unsigned i = 0; i < I->getNumOperands(); ++i) { |
| 331 | Register DefReg = I->getOperand(i).getReg(); |
| 332 | if (GR->getSPIRVTypeForVReg(VReg: DefReg)) |
| 333 | continue; |
| 334 | |
| 335 | LLT DefLLT = MRI.getType(Reg: DefReg); |
| 336 | SPIRVType *ResType = |
| 337 | DefLLT.isVector() |
| 338 | ? GR->getOrCreateSPIRVVectorType( |
| 339 | BaseType: ScalarType, NumElements: DefLLT.getNumElements(), I&: *I, |
| 340 | TII: *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo()) |
| 341 | : ScalarType; |
| 342 | setRegClassType(Reg: DefReg, SpvType: ResType, GR, MRI: &MRI, MF); |
| 343 | } |
| 344 | return true; |
| 345 | } |
| 346 | |
| 347 | static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF, |
| 348 | SPIRVGlobalRegistry *GR, |
| 349 | MachineIRBuilder &MIB) { |
| 350 | LLVM_DEBUG(dbgs() << "\nProcessing instruction: " << *I); |
| 351 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 352 | Register ResVReg = I->getOperand(i: 0).getReg(); |
| 353 | |
| 354 | // G_UNMERGE_VALUES is handled separately because it has multiple definitions, |
| 355 | // unlike the other instructions which have a single result register. The main |
| 356 | // deduction logic is designed for the single-definition case. |
| 357 | if (I->getOpcode() == TargetOpcode::G_UNMERGE_VALUES) |
| 358 | return deduceAndAssignTypeForGUnmerge(I, MF, GR, MIB); |
| 359 | |
| 360 | LLVM_DEBUG(dbgs() << "Inferring type from operands\n" ); |
| 361 | SPIRVType *ResType = deduceResultTypeFromOperands(I, GR, MIB); |
| 362 | if (!ResType) { |
| 363 | LLVM_DEBUG(dbgs() << "Inferring type from uses\n" ); |
| 364 | ResType = deduceTypeFromUses(Reg: ResVReg, MF, GR, MIB); |
| 365 | } |
| 366 | |
| 367 | if (!ResType) |
| 368 | return false; |
| 369 | |
| 370 | LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType); |
| 371 | GR->assignSPIRVTypeToVReg(Type: ResType, VReg: ResVReg, MF); |
| 372 | |
| 373 | if (!MRI.getRegClassOrNull(Reg: ResVReg)) { |
| 374 | LLVM_DEBUG(dbgs() << "Updating the register class.\n" ); |
| 375 | setRegClassType(Reg: ResVReg, SpvType: ResType, GR, MRI: &MRI, MF: *GR->CurMF, Force: true); |
| 376 | } |
| 377 | return true; |
| 378 | } |
| 379 | |
| 380 | static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR, |
| 381 | MachineRegisterInfo &MRI) { |
| 382 | LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: " |
| 383 | << I;); |
| 384 | if (I.getNumDefs() == 0) { |
| 385 | LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n" ); |
| 386 | return false; |
| 387 | } |
| 388 | |
| 389 | if (!I.isPreISelOpcode()) { |
| 390 | LLVM_DEBUG(dbgs() << "Instruction is not a generic instruction.\n" ); |
| 391 | return false; |
| 392 | } |
| 393 | |
| 394 | Register ResultRegister = I.defs().begin()->getReg(); |
| 395 | if (GR->getSPIRVTypeForVReg(VReg: ResultRegister)) { |
| 396 | LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n" ); |
| 397 | if (!MRI.getRegClassOrNull(Reg: ResultRegister)) { |
| 398 | LLVM_DEBUG(dbgs() << "Updating the register class.\n" ); |
| 399 | setRegClassType(Reg: ResultRegister, SpvType: GR->getSPIRVTypeForVReg(VReg: ResultRegister), |
| 400 | GR, MRI: &MRI, MF: *GR->CurMF, Force: true); |
| 401 | } |
| 402 | return false; |
| 403 | } |
| 404 | |
| 405 | return true; |
| 406 | } |
| 407 | |
| 408 | static void registerSpirvTypeForNewInstructions(MachineFunction &MF, |
| 409 | SPIRVGlobalRegistry *GR) { |
| 410 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 411 | SmallVector<MachineInstr *, 8> Worklist; |
| 412 | for (MachineBasicBlock &MBB : MF) { |
| 413 | for (MachineInstr &I : MBB) { |
| 414 | if (requiresSpirvType(I, GR, MRI)) { |
| 415 | Worklist.push_back(Elt: &I); |
| 416 | } |
| 417 | } |
| 418 | } |
| 419 | |
| 420 | if (Worklist.empty()) { |
| 421 | LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n" ); |
| 422 | return; |
| 423 | } |
| 424 | |
| 425 | LLVM_DEBUG(dbgs() << "Initial worklist:\n" ; |
| 426 | for (auto *I : Worklist) { I->dump(); }); |
| 427 | |
| 428 | bool Changed; |
| 429 | do { |
| 430 | Changed = false; |
| 431 | SmallVector<MachineInstr *, 8> NextWorklist; |
| 432 | |
| 433 | for (MachineInstr *I : Worklist) { |
| 434 | MachineIRBuilder MIB(*I); |
| 435 | if (deduceAndAssignSpirvType(I, MF, GR, MIB)) { |
| 436 | Changed = true; |
| 437 | } else { |
| 438 | NextWorklist.push_back(Elt: I); |
| 439 | } |
| 440 | } |
| 441 | Worklist = std::move(NextWorklist); |
| 442 | LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n" ); |
| 443 | } while (Changed); |
| 444 | |
| 445 | if (Worklist.empty()) |
| 446 | return; |
| 447 | |
| 448 | for (auto *I : Worklist) { |
| 449 | MachineIRBuilder MIB(*I); |
| 450 | LLVM_DEBUG(dbgs() << "Assigning default type to results in " << *I); |
| 451 | for (unsigned Idx = 0; Idx < I->getNumDefs(); ++Idx) { |
| 452 | Register ResVReg = I->getOperand(i: Idx).getReg(); |
| 453 | if (GR->getSPIRVTypeForVReg(VReg: ResVReg)) |
| 454 | continue; |
| 455 | const LLT &ResLLT = MRI.getType(Reg: ResVReg); |
| 456 | SPIRVType *ResType = nullptr; |
| 457 | if (ResLLT.isVector()) { |
| 458 | SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType( |
| 459 | BitWidth: ResLLT.getElementType().getSizeInBits(), MIRBuilder&: MIB); |
| 460 | ResType = GR->getOrCreateSPIRVVectorType( |
| 461 | BaseType: CompType, NumElements: ResLLT.getNumElements(), MIRBuilder&: MIB, EmitIR: false); |
| 462 | } else { |
| 463 | ResType = GR->getOrCreateSPIRVIntegerType(BitWidth: ResLLT.getSizeInBits(), MIRBuilder&: MIB); |
| 464 | } |
| 465 | setRegClassType(Reg: ResVReg, SpvType: ResType, GR, MRI: &MRI, MF, Force: true); |
| 466 | } |
| 467 | } |
| 468 | } |
| 469 | |
| 470 | static bool hasAssignType(Register Reg, MachineRegisterInfo &MRI) { |
| 471 | for (MachineInstr &UseInstr : MRI.use_nodbg_instructions(Reg)) { |
| 472 | if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) { |
| 473 | return true; |
| 474 | } |
| 475 | } |
| 476 | return false; |
| 477 | } |
| 478 | |
| 479 | static void generateAssignType(MachineInstr &MI, Register ResultRegister, |
| 480 | SPIRVType *ResultType, SPIRVGlobalRegistry *GR, |
| 481 | MachineRegisterInfo &MRI) { |
| 482 | LLVM_DEBUG(dbgs() << " Adding ASSIGN_TYPE for ResultRegister: " |
| 483 | << printReg(ResultRegister, MRI.getTargetRegisterInfo()) |
| 484 | << " with type: " << *ResultType); |
| 485 | MachineIRBuilder MIB(MI); |
| 486 | updateRegType(Reg: ResultRegister, Ty: nullptr, SpirvTy: ResultType, GR, MIB, MRI); |
| 487 | |
| 488 | // Tablegen definition assumes SPIRV::ASSIGN_TYPE pseudo-instruction is |
| 489 | // present after each auto-folded instruction to take a type reference |
| 490 | // from. |
| 491 | Register NewReg = |
| 492 | MRI.createGenericVirtualRegister(Ty: MRI.getType(Reg: ResultRegister)); |
| 493 | const auto *RegClass = GR->getRegClass(SpvType: ResultType); |
| 494 | MRI.setRegClass(Reg: NewReg, RC: RegClass); |
| 495 | MRI.setRegClass(Reg: ResultRegister, RC: RegClass); |
| 496 | |
| 497 | GR->assignSPIRVTypeToVReg(Type: ResultType, VReg: ResultRegister, MF: MIB.getMF()); |
| 498 | // This is to make it convenient for Legalizer to get the SPIRVType |
| 499 | // when processing the actual MI (i.e. not pseudo one). |
| 500 | GR->assignSPIRVTypeToVReg(Type: ResultType, VReg: NewReg, MF: MIB.getMF()); |
| 501 | // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to |
| 502 | // keep the flags after instruction selection. |
| 503 | const uint32_t Flags = MI.getFlags(); |
| 504 | MIB.buildInstr(Opcode: SPIRV::ASSIGN_TYPE) |
| 505 | .addDef(RegNo: ResultRegister) |
| 506 | .addUse(RegNo: NewReg) |
| 507 | .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: ResultType)) |
| 508 | .setMIFlags(Flags); |
| 509 | for (unsigned I = 0, E = MI.getNumDefs(); I != E; ++I) { |
| 510 | MachineOperand &MO = MI.getOperand(i: I); |
| 511 | if (MO.getReg() == ResultRegister) { |
| 512 | MO.setReg(NewReg); |
| 513 | break; |
| 514 | } |
| 515 | } |
| 516 | } |
| 517 | |
| 518 | static void ensureAssignTypeForTypeFolding(MachineFunction &MF, |
| 519 | SPIRVGlobalRegistry *GR) { |
| 520 | LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function " |
| 521 | << MF.getName() << "\n" ); |
| 522 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 523 | for (MachineBasicBlock &MBB : MF) { |
| 524 | for (MachineInstr &MI : MBB) { |
| 525 | if (!isTypeFoldingSupported(Opcode: MI.getOpcode())) |
| 526 | continue; |
| 527 | |
| 528 | LLVM_DEBUG(dbgs() << "Processing instruction: " << MI); |
| 529 | |
| 530 | Register ResultRegister = MI.defs().begin()->getReg(); |
| 531 | if (hasAssignType(Reg: ResultRegister, MRI)) { |
| 532 | LLVM_DEBUG(dbgs() << " Instruction already has ASSIGN_TYPE\n" ); |
| 533 | continue; |
| 534 | } |
| 535 | |
| 536 | SPIRVType *ResultType = GR->getSPIRVTypeForVReg(VReg: ResultRegister); |
| 537 | assert(ResultType); |
| 538 | generateAssignType(MI, ResultRegister, ResultType, GR, MRI); |
| 539 | } |
| 540 | } |
| 541 | } |
| 542 | |
| 543 | // Do a preorder traversal of the CFG starting from the BB |Start|. |
| 544 | // point. Calls |op| on each basic block encountered during the traversal. |
| 545 | void visit(MachineFunction &MF, MachineBasicBlock &Start, |
| 546 | std::function<void(MachineBasicBlock *)> op) { |
| 547 | std::stack<MachineBasicBlock *> ToVisit; |
| 548 | SmallPtrSet<MachineBasicBlock *, 8> Seen; |
| 549 | |
| 550 | ToVisit.push(x: &Start); |
| 551 | Seen.insert(Ptr: ToVisit.top()); |
| 552 | while (ToVisit.size() != 0) { |
| 553 | MachineBasicBlock *MBB = ToVisit.top(); |
| 554 | ToVisit.pop(); |
| 555 | |
| 556 | op(MBB); |
| 557 | |
| 558 | for (auto Succ : MBB->successors()) { |
| 559 | if (Seen.contains(Ptr: Succ)) |
| 560 | continue; |
| 561 | ToVisit.push(x: Succ); |
| 562 | Seen.insert(Ptr: Succ); |
| 563 | } |
| 564 | } |
| 565 | } |
| 566 | |
| 567 | // Do a preorder traversal of the CFG starting from the given function's entry |
| 568 | // point. Calls |op| on each basic block encountered during the traversal. |
| 569 | void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) { |
| 570 | visit(MF, Start&: *MF.begin(), op: std::move(op)); |
| 571 | } |
| 572 | |
| 573 | bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) { |
| 574 | // Initialize the type registry. |
| 575 | const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>(); |
| 576 | SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); |
| 577 | GR->setCurrentFunc(MF); |
| 578 | registerSpirvTypeForNewInstructions(MF, GR); |
| 579 | ensureAssignTypeForTypeFolding(MF, GR); |
| 580 | return true; |
| 581 | } |
| 582 | |
| 583 | INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer" , false, |
| 584 | false) |
| 585 | |
| 586 | char SPIRVPostLegalizer::ID = 0; |
| 587 | |
| 588 | FunctionPass *llvm::createSPIRVPostLegalizerPass() { |
| 589 | return new SPIRVPostLegalizer(); |
| 590 | } |
| 591 | |