1//===-- SPIRVPostLegalizer.cpp - amend info after legalization -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The pass partially applies pre-legalization logic to new instructions
10// inserted as a result of legalization:
11// - assigns SPIR-V types to registers for new instructions.
12// - inserts ASSIGN_TYPE pseudo-instructions required for type folding.
13//
14//===----------------------------------------------------------------------===//
15
16#include "SPIRV.h"
17#include "SPIRVSubtarget.h"
18#include "SPIRVUtils.h"
19#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
20#include "llvm/CodeGen/MachineFrameInfo.h"
21#include "llvm/IR/IntrinsicsSPIRV.h"
22#include "llvm/Support/Debug.h"
23#include <stack>
24
25#define DEBUG_TYPE "spirv-postlegalizer"
26
27using namespace llvm;
28
29namespace {
30class SPIRVPostLegalizer : public MachineFunctionPass {
31public:
32 static char ID;
33 SPIRVPostLegalizer() : MachineFunctionPass(ID) {}
34 bool runOnMachineFunction(MachineFunction &MF) override;
35};
36} // namespace
37
38namespace llvm {
39// Defined in SPIRVPreLegalizer.cpp.
40extern void updateRegType(Register Reg, Type *Ty, SPIRVType *SpirvTy,
41 SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
42 MachineRegisterInfo &MRI);
43extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
44 MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR,
45 SPIRVType *KnownResType);
46} // namespace llvm
47
48static SPIRVType *deduceIntTypeFromResult(Register ResVReg,
49 MachineIRBuilder &MIB,
50 SPIRVGlobalRegistry *GR) {
51 const LLT &Ty = MIB.getMRI()->getType(Reg: ResVReg);
52 return GR->getOrCreateSPIRVIntegerType(BitWidth: Ty.getScalarSizeInBits(), MIRBuilder&: MIB);
53}
54
55static SPIRVType *deduceTypeFromSingleOperand(MachineInstr *I,
56 MachineIRBuilder &MIB,
57 SPIRVGlobalRegistry *GR,
58 unsigned OpIdx) {
59 Register OpReg = I->getOperand(i: OpIdx).getReg();
60 if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(VReg: OpReg)) {
61 if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(Type: OpType)) {
62 Register ResVReg = I->getOperand(i: 0).getReg();
63 const LLT &ResLLT = MIB.getMRI()->getType(Reg: ResVReg);
64 if (ResLLT.isVector())
65 return GR->getOrCreateSPIRVVectorType(BaseType: CompType, NumElements: ResLLT.getNumElements(),
66 MIRBuilder&: MIB, EmitIR: false);
67 return CompType;
68 }
69 }
70 return nullptr;
71}
72
73static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I,
74 MachineIRBuilder &MIB,
75 SPIRVGlobalRegistry *GR,
76 unsigned StartOp, unsigned EndOp) {
77 SPIRVType *ResType = nullptr;
78 for (unsigned i = StartOp; i < EndOp; ++i) {
79 if (SPIRVType *Type = deduceTypeFromSingleOperand(I, MIB, GR, OpIdx: i)) {
80#ifdef EXPENSIVE_CHECKS
81 assert(!ResType || Type == ResType && "Conflicting type from operands.");
82 ResType = Type;
83#else
84 return Type;
85#endif
86 }
87 }
88 return ResType;
89}
90
91static SPIRVType *deduceTypeFromResultRegister(MachineInstr *Use,
92 Register UseRegister,
93 SPIRVGlobalRegistry *GR,
94 MachineIRBuilder &MIB) {
95 for (const MachineOperand &MO : Use->defs()) {
96 if (!MO.isReg())
97 continue;
98 if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(VReg: MO.getReg())) {
99 if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(Type: OpType)) {
100 const LLT &ResLLT = MIB.getMRI()->getType(Reg: UseRegister);
101 if (ResLLT.isVector())
102 return GR->getOrCreateSPIRVVectorType(
103 BaseType: CompType, NumElements: ResLLT.getNumElements(), MIRBuilder&: MIB, EmitIR: false);
104 return CompType;
105 }
106 }
107 }
108 return nullptr;
109}
110
111static SPIRVType *deducePointerTypeFromResultRegister(MachineInstr *Use,
112 Register UseRegister,
113 SPIRVGlobalRegistry *GR,
114 MachineIRBuilder &MIB) {
115 assert(Use->getOpcode() == TargetOpcode::G_LOAD ||
116 Use->getOpcode() == TargetOpcode::G_STORE);
117
118 Register ValueReg = Use->getOperand(i: 0).getReg();
119 SPIRVType *ValueType = GR->getSPIRVTypeForVReg(VReg: ValueReg);
120 if (!ValueType)
121 return nullptr;
122
123 return GR->getOrCreateSPIRVPointerType(BaseType: ValueType, MIRBuilder&: MIB,
124 SC: SPIRV::StorageClass::Function);
125}
126
127static SPIRVType *deduceTypeFromPointerOperand(MachineInstr *Use,
128 Register UseRegister,
129 SPIRVGlobalRegistry *GR,
130 MachineIRBuilder &MIB) {
131 assert(Use->getOpcode() == TargetOpcode::G_LOAD ||
132 Use->getOpcode() == TargetOpcode::G_STORE);
133
134 Register PtrReg = Use->getOperand(i: 1).getReg();
135 SPIRVType *PtrType = GR->getSPIRVTypeForVReg(VReg: PtrReg);
136 if (!PtrType)
137 return nullptr;
138
139 return GR->getPointeeType(PtrType);
140}
141
142static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
143 SPIRVGlobalRegistry *GR,
144 MachineIRBuilder &MIB) {
145 MachineRegisterInfo &MRI = MF.getRegInfo();
146 for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
147 SPIRVType *ResType = nullptr;
148 LLVM_DEBUG(dbgs() << "Looking at use " << Use);
149 switch (Use.getOpcode()) {
150 case TargetOpcode::G_BUILD_VECTOR:
151 case TargetOpcode::G_EXTRACT_VECTOR_ELT:
152 case TargetOpcode::G_UNMERGE_VALUES:
153 case TargetOpcode::G_ADD:
154 case TargetOpcode::G_SUB:
155 case TargetOpcode::G_MUL:
156 case TargetOpcode::G_SDIV:
157 case TargetOpcode::G_UDIV:
158 case TargetOpcode::G_SREM:
159 case TargetOpcode::G_UREM:
160 case TargetOpcode::G_FADD:
161 case TargetOpcode::G_FSUB:
162 case TargetOpcode::G_FMUL:
163 case TargetOpcode::G_FDIV:
164 case TargetOpcode::G_FREM:
165 case TargetOpcode::G_FMA:
166 case TargetOpcode::COPY:
167 case TargetOpcode::G_STRICT_FMA:
168 ResType = deduceTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB);
169 break;
170 case TargetOpcode::G_LOAD:
171 case TargetOpcode::G_STORE:
172 if (Reg == Use.getOperand(i: 1).getReg())
173 ResType = deducePointerTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB);
174 else
175 ResType = deduceTypeFromPointerOperand(Use: &Use, UseRegister: Reg, GR, MIB);
176 break;
177 case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
178 case TargetOpcode::G_INTRINSIC: {
179 auto IntrinsicID = cast<GIntrinsic>(Val&: Use).getIntrinsicID();
180 if (IntrinsicID == Intrinsic::spv_insertelt) {
181 if (Reg == Use.getOperand(i: 2).getReg())
182 ResType = deduceTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB);
183 } else if (IntrinsicID == Intrinsic::spv_extractelt) {
184 if (Reg == Use.getOperand(i: 2).getReg())
185 ResType = deduceTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB);
186 }
187 break;
188 }
189 }
190 if (ResType) {
191 LLVM_DEBUG(dbgs() << "Deduced type from use " << *ResType);
192 return ResType;
193 }
194 }
195 return nullptr;
196}
197
198static SPIRVType *deduceGEPType(MachineInstr *I, SPIRVGlobalRegistry *GR,
199 MachineIRBuilder &MIB) {
200 LLVM_DEBUG(dbgs() << "Deducing GEP type for: " << *I);
201 Register PtrReg = I->getOperand(i: 3).getReg();
202 SPIRVType *PtrType = GR->getSPIRVTypeForVReg(VReg: PtrReg);
203 if (!PtrType) {
204 LLVM_DEBUG(dbgs() << " Could not get type for pointer operand.\n");
205 return nullptr;
206 }
207
208 SPIRVType *PointeeType = GR->getPointeeType(PtrType);
209 if (!PointeeType) {
210 LLVM_DEBUG(dbgs() << " Could not get pointee type from pointer type.\n");
211 return nullptr;
212 }
213
214 MachineRegisterInfo *MRI = MIB.getMRI();
215
216 // The first index (operand 4) steps over the pointer, so the type doesn't
217 // change.
218 for (unsigned i = 5; i < I->getNumOperands(); ++i) {
219 LLVM_DEBUG(dbgs() << " Traversing index " << i
220 << ", current type: " << *PointeeType);
221 switch (PointeeType->getOpcode()) {
222 case SPIRV::OpTypeArray:
223 case SPIRV::OpTypeRuntimeArray:
224 case SPIRV::OpTypeVector: {
225 Register ElemTypeReg = PointeeType->getOperand(i: 1).getReg();
226 PointeeType = GR->getSPIRVTypeForVReg(VReg: ElemTypeReg);
227 break;
228 }
229 case SPIRV::OpTypeStruct: {
230 MachineOperand &IdxOp = I->getOperand(i);
231 if (!IdxOp.isReg()) {
232 LLVM_DEBUG(dbgs() << " Index is not a register.\n");
233 return nullptr;
234 }
235 MachineInstr *Def = MRI->getVRegDef(Reg: IdxOp.getReg());
236 if (!Def) {
237 LLVM_DEBUG(
238 dbgs() << " Could not find definition for index register.\n");
239 return nullptr;
240 }
241
242 uint64_t IndexVal = foldImm(MO: IdxOp, MRI);
243 if (IndexVal >= PointeeType->getNumOperands() - 1) {
244 LLVM_DEBUG(dbgs() << " Struct index out of bounds.\n");
245 return nullptr;
246 }
247
248 Register MemberTypeReg = PointeeType->getOperand(i: IndexVal + 1).getReg();
249 PointeeType = GR->getSPIRVTypeForVReg(VReg: MemberTypeReg);
250 break;
251 }
252 default:
253 LLVM_DEBUG(dbgs() << " Unknown type opcode for GEP traversal.\n");
254 return nullptr;
255 }
256
257 if (!PointeeType) {
258 LLVM_DEBUG(dbgs() << " Could not resolve next pointee type.\n");
259 return nullptr;
260 }
261 }
262 LLVM_DEBUG(dbgs() << " Final pointee type: " << *PointeeType);
263
264 SPIRV::StorageClass::StorageClass SC = GR->getPointerStorageClass(Type: PtrType);
265 SPIRVType *Res = GR->getOrCreateSPIRVPointerType(BaseType: PointeeType, MIRBuilder&: MIB, SC);
266 LLVM_DEBUG(dbgs() << " Deduced GEP type: " << *Res);
267 return Res;
268}
269
270static SPIRVType *deduceResultTypeFromOperands(MachineInstr *I,
271 SPIRVGlobalRegistry *GR,
272 MachineIRBuilder &MIB) {
273 Register ResVReg = I->getOperand(i: 0).getReg();
274 switch (I->getOpcode()) {
275 case TargetOpcode::G_CONSTANT:
276 case TargetOpcode::G_ANYEXT:
277 case TargetOpcode::G_SEXT:
278 case TargetOpcode::G_ZEXT:
279 return deduceIntTypeFromResult(ResVReg, MIB, GR);
280 case TargetOpcode::G_BUILD_VECTOR:
281 return deduceTypeFromOperandRange(I, MIB, GR, StartOp: 1, EndOp: I->getNumOperands());
282 case TargetOpcode::G_SHUFFLE_VECTOR:
283 return deduceTypeFromOperandRange(I, MIB, GR, StartOp: 1, EndOp: 3);
284 case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
285 case TargetOpcode::G_INTRINSIC: {
286 auto IntrinsicID = cast<GIntrinsic>(Val: I)->getIntrinsicID();
287 if (IntrinsicID == Intrinsic::spv_gep)
288 return deduceGEPType(I, GR, MIB);
289 break;
290 }
291 case TargetOpcode::G_LOAD: {
292 SPIRVType *PtrType = deduceTypeFromSingleOperand(I, MIB, GR, OpIdx: 1);
293 return PtrType ? GR->getPointeeType(PtrType) : nullptr;
294 }
295 default:
296 if (I->getNumDefs() == 1 && I->getNumOperands() > 1 &&
297 I->getOperand(i: 1).isReg())
298 return deduceTypeFromSingleOperand(I, MIB, GR, OpIdx: 1);
299 }
300 return nullptr;
301}
302
303static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
304 SPIRVGlobalRegistry *GR,
305 MachineIRBuilder &MIB) {
306 MachineRegisterInfo &MRI = MF.getRegInfo();
307 Register SrcReg = I->getOperand(i: I->getNumOperands() - 1).getReg();
308 SPIRVType *ScalarType = nullptr;
309 if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(VReg: SrcReg)) {
310 assert(DefType->getOpcode() == SPIRV::OpTypeVector);
311 ScalarType = GR->getSPIRVTypeForVReg(VReg: DefType->getOperand(i: 1).getReg());
312 }
313
314 if (!ScalarType) {
315 // If we could not deduce the type from the source, try to deduce it from
316 // the uses of the results.
317 for (unsigned i = 0; i < I->getNumDefs(); ++i) {
318 Register DefReg = I->getOperand(i).getReg();
319 ScalarType = deduceTypeFromUses(Reg: DefReg, MF, GR, MIB);
320 if (ScalarType) {
321 ScalarType = GR->getScalarOrVectorComponentType(Type: ScalarType);
322 break;
323 }
324 }
325 }
326
327 if (!ScalarType)
328 return false;
329
330 for (unsigned i = 0; i < I->getNumOperands(); ++i) {
331 Register DefReg = I->getOperand(i).getReg();
332 if (GR->getSPIRVTypeForVReg(VReg: DefReg))
333 continue;
334
335 LLT DefLLT = MRI.getType(Reg: DefReg);
336 SPIRVType *ResType =
337 DefLLT.isVector()
338 ? GR->getOrCreateSPIRVVectorType(
339 BaseType: ScalarType, NumElements: DefLLT.getNumElements(), I&: *I,
340 TII: *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo())
341 : ScalarType;
342 setRegClassType(Reg: DefReg, SpvType: ResType, GR, MRI: &MRI, MF);
343 }
344 return true;
345}
346
347static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
348 SPIRVGlobalRegistry *GR,
349 MachineIRBuilder &MIB) {
350 LLVM_DEBUG(dbgs() << "\nProcessing instruction: " << *I);
351 MachineRegisterInfo &MRI = MF.getRegInfo();
352 Register ResVReg = I->getOperand(i: 0).getReg();
353
354 // G_UNMERGE_VALUES is handled separately because it has multiple definitions,
355 // unlike the other instructions which have a single result register. The main
356 // deduction logic is designed for the single-definition case.
357 if (I->getOpcode() == TargetOpcode::G_UNMERGE_VALUES)
358 return deduceAndAssignTypeForGUnmerge(I, MF, GR, MIB);
359
360 LLVM_DEBUG(dbgs() << "Inferring type from operands\n");
361 SPIRVType *ResType = deduceResultTypeFromOperands(I, GR, MIB);
362 if (!ResType) {
363 LLVM_DEBUG(dbgs() << "Inferring type from uses\n");
364 ResType = deduceTypeFromUses(Reg: ResVReg, MF, GR, MIB);
365 }
366
367 if (!ResType)
368 return false;
369
370 LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType);
371 GR->assignSPIRVTypeToVReg(Type: ResType, VReg: ResVReg, MF);
372
373 if (!MRI.getRegClassOrNull(Reg: ResVReg)) {
374 LLVM_DEBUG(dbgs() << "Updating the register class.\n");
375 setRegClassType(Reg: ResVReg, SpvType: ResType, GR, MRI: &MRI, MF: *GR->CurMF, Force: true);
376 }
377 return true;
378}
379
380static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
381 MachineRegisterInfo &MRI) {
382 LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: "
383 << I;);
384 if (I.getNumDefs() == 0) {
385 LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n");
386 return false;
387 }
388
389 if (!I.isPreISelOpcode()) {
390 LLVM_DEBUG(dbgs() << "Instruction is not a generic instruction.\n");
391 return false;
392 }
393
394 Register ResultRegister = I.defs().begin()->getReg();
395 if (GR->getSPIRVTypeForVReg(VReg: ResultRegister)) {
396 LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n");
397 if (!MRI.getRegClassOrNull(Reg: ResultRegister)) {
398 LLVM_DEBUG(dbgs() << "Updating the register class.\n");
399 setRegClassType(Reg: ResultRegister, SpvType: GR->getSPIRVTypeForVReg(VReg: ResultRegister),
400 GR, MRI: &MRI, MF: *GR->CurMF, Force: true);
401 }
402 return false;
403 }
404
405 return true;
406}
407
408static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
409 SPIRVGlobalRegistry *GR) {
410 MachineRegisterInfo &MRI = MF.getRegInfo();
411 SmallVector<MachineInstr *, 8> Worklist;
412 for (MachineBasicBlock &MBB : MF) {
413 for (MachineInstr &I : MBB) {
414 if (requiresSpirvType(I, GR, MRI)) {
415 Worklist.push_back(Elt: &I);
416 }
417 }
418 }
419
420 if (Worklist.empty()) {
421 LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n");
422 return;
423 }
424
425 LLVM_DEBUG(dbgs() << "Initial worklist:\n";
426 for (auto *I : Worklist) { I->dump(); });
427
428 bool Changed;
429 do {
430 Changed = false;
431 SmallVector<MachineInstr *, 8> NextWorklist;
432
433 for (MachineInstr *I : Worklist) {
434 MachineIRBuilder MIB(*I);
435 if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
436 Changed = true;
437 } else {
438 NextWorklist.push_back(Elt: I);
439 }
440 }
441 Worklist = std::move(NextWorklist);
442 LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
443 } while (Changed);
444
445 if (Worklist.empty())
446 return;
447
448 for (auto *I : Worklist) {
449 MachineIRBuilder MIB(*I);
450 LLVM_DEBUG(dbgs() << "Assigning default type to results in " << *I);
451 for (unsigned Idx = 0; Idx < I->getNumDefs(); ++Idx) {
452 Register ResVReg = I->getOperand(i: Idx).getReg();
453 if (GR->getSPIRVTypeForVReg(VReg: ResVReg))
454 continue;
455 const LLT &ResLLT = MRI.getType(Reg: ResVReg);
456 SPIRVType *ResType = nullptr;
457 if (ResLLT.isVector()) {
458 SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType(
459 BitWidth: ResLLT.getElementType().getSizeInBits(), MIRBuilder&: MIB);
460 ResType = GR->getOrCreateSPIRVVectorType(
461 BaseType: CompType, NumElements: ResLLT.getNumElements(), MIRBuilder&: MIB, EmitIR: false);
462 } else {
463 ResType = GR->getOrCreateSPIRVIntegerType(BitWidth: ResLLT.getSizeInBits(), MIRBuilder&: MIB);
464 }
465 setRegClassType(Reg: ResVReg, SpvType: ResType, GR, MRI: &MRI, MF, Force: true);
466 }
467 }
468}
469
470static bool hasAssignType(Register Reg, MachineRegisterInfo &MRI) {
471 for (MachineInstr &UseInstr : MRI.use_nodbg_instructions(Reg)) {
472 if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) {
473 return true;
474 }
475 }
476 return false;
477}
478
479static void generateAssignType(MachineInstr &MI, Register ResultRegister,
480 SPIRVType *ResultType, SPIRVGlobalRegistry *GR,
481 MachineRegisterInfo &MRI) {
482 LLVM_DEBUG(dbgs() << " Adding ASSIGN_TYPE for ResultRegister: "
483 << printReg(ResultRegister, MRI.getTargetRegisterInfo())
484 << " with type: " << *ResultType);
485 MachineIRBuilder MIB(MI);
486 updateRegType(Reg: ResultRegister, Ty: nullptr, SpirvTy: ResultType, GR, MIB, MRI);
487
488 // Tablegen definition assumes SPIRV::ASSIGN_TYPE pseudo-instruction is
489 // present after each auto-folded instruction to take a type reference
490 // from.
491 Register NewReg =
492 MRI.createGenericVirtualRegister(Ty: MRI.getType(Reg: ResultRegister));
493 const auto *RegClass = GR->getRegClass(SpvType: ResultType);
494 MRI.setRegClass(Reg: NewReg, RC: RegClass);
495 MRI.setRegClass(Reg: ResultRegister, RC: RegClass);
496
497 GR->assignSPIRVTypeToVReg(Type: ResultType, VReg: ResultRegister, MF: MIB.getMF());
498 // This is to make it convenient for Legalizer to get the SPIRVType
499 // when processing the actual MI (i.e. not pseudo one).
500 GR->assignSPIRVTypeToVReg(Type: ResultType, VReg: NewReg, MF: MIB.getMF());
501 // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to
502 // keep the flags after instruction selection.
503 const uint32_t Flags = MI.getFlags();
504 MIB.buildInstr(Opcode: SPIRV::ASSIGN_TYPE)
505 .addDef(RegNo: ResultRegister)
506 .addUse(RegNo: NewReg)
507 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: ResultType))
508 .setMIFlags(Flags);
509 for (unsigned I = 0, E = MI.getNumDefs(); I != E; ++I) {
510 MachineOperand &MO = MI.getOperand(i: I);
511 if (MO.getReg() == ResultRegister) {
512 MO.setReg(NewReg);
513 break;
514 }
515 }
516}
517
518static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
519 SPIRVGlobalRegistry *GR) {
520 LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
521 << MF.getName() << "\n");
522 MachineRegisterInfo &MRI = MF.getRegInfo();
523 for (MachineBasicBlock &MBB : MF) {
524 for (MachineInstr &MI : MBB) {
525 if (!isTypeFoldingSupported(Opcode: MI.getOpcode()))
526 continue;
527
528 LLVM_DEBUG(dbgs() << "Processing instruction: " << MI);
529
530 Register ResultRegister = MI.defs().begin()->getReg();
531 if (hasAssignType(Reg: ResultRegister, MRI)) {
532 LLVM_DEBUG(dbgs() << " Instruction already has ASSIGN_TYPE\n");
533 continue;
534 }
535
536 SPIRVType *ResultType = GR->getSPIRVTypeForVReg(VReg: ResultRegister);
537 assert(ResultType);
538 generateAssignType(MI, ResultRegister, ResultType, GR, MRI);
539 }
540 }
541}
542
543// Do a preorder traversal of the CFG starting from the BB |Start|.
544// point. Calls |op| on each basic block encountered during the traversal.
545void visit(MachineFunction &MF, MachineBasicBlock &Start,
546 std::function<void(MachineBasicBlock *)> op) {
547 std::stack<MachineBasicBlock *> ToVisit;
548 SmallPtrSet<MachineBasicBlock *, 8> Seen;
549
550 ToVisit.push(x: &Start);
551 Seen.insert(Ptr: ToVisit.top());
552 while (ToVisit.size() != 0) {
553 MachineBasicBlock *MBB = ToVisit.top();
554 ToVisit.pop();
555
556 op(MBB);
557
558 for (auto Succ : MBB->successors()) {
559 if (Seen.contains(Ptr: Succ))
560 continue;
561 ToVisit.push(x: Succ);
562 Seen.insert(Ptr: Succ);
563 }
564 }
565}
566
567// Do a preorder traversal of the CFG starting from the given function's entry
568// point. Calls |op| on each basic block encountered during the traversal.
569void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
570 visit(MF, Start&: *MF.begin(), op: std::move(op));
571}
572
573bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
574 // Initialize the type registry.
575 const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
576 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
577 GR->setCurrentFunc(MF);
578 registerSpirvTypeForNewInstructions(MF, GR);
579 ensureAssignTypeForTypeFolding(MF, GR);
580 return true;
581}
582
583INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer", false,
584 false)
585
586char SPIRVPostLegalizer::ID = 0;
587
588FunctionPass *llvm::createSPIRVPostLegalizerPass() {
589 return new SPIRVPostLegalizer();
590}
591