1//===-- SPIRVPostLegalizer.cpp - amend info after legalization -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The pass partially applies pre-legalization logic to new instructions
10// inserted as a result of legalization:
11// - assigns SPIR-V types to registers for new instructions.
12// - inserts ASSIGN_TYPE pseudo-instructions required for type folding.
13//
14//===----------------------------------------------------------------------===//
15
16#include "SPIRV.h"
17#include "SPIRVSubtarget.h"
18#include "SPIRVUtils.h"
19#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
20#include "llvm/CodeGen/MachineFrameInfo.h"
21#include "llvm/IR/IntrinsicsSPIRV.h"
22#include "llvm/Support/Debug.h"
23#include <stack>
24
25#define DEBUG_TYPE "spirv-postlegalizer"
26
27using namespace llvm;
28
29namespace {
30class SPIRVPostLegalizer : public MachineFunctionPass {
31public:
32 static char ID;
33 SPIRVPostLegalizer() : MachineFunctionPass(ID) {}
34 bool runOnMachineFunction(MachineFunction &MF) override;
35};
36} // namespace
37
38namespace llvm {
39// Defined in SPIRVPreLegalizer.cpp.
40extern void updateRegType(Register Reg, Type *Ty, SPIRVTypeInst SpirvTy,
41 SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
42 MachineRegisterInfo &MRI);
43extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
44 MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR,
45 SPIRVTypeInst KnownResType);
46} // namespace llvm
47
48static SPIRVTypeInst deduceIntTypeFromResult(Register ResVReg,
49 MachineIRBuilder &MIB,
50 SPIRVGlobalRegistry *GR) {
51 const LLT &Ty = MIB.getMRI()->getType(Reg: ResVReg);
52 return GR->getOrCreateSPIRVIntegerType(BitWidth: Ty.getScalarSizeInBits(), MIRBuilder&: MIB);
53}
54
55static SPIRVTypeInst deduceTypeFromSingleOperand(MachineInstr *I,
56 MachineIRBuilder &MIB,
57 SPIRVGlobalRegistry *GR,
58 unsigned OpIdx) {
59 Register OpReg = I->getOperand(i: OpIdx).getReg();
60 if (SPIRVTypeInst OpType = GR->getSPIRVTypeForVReg(VReg: OpReg)) {
61 if (SPIRVTypeInst CompType = GR->getScalarOrVectorComponentType(Type: OpType)) {
62 Register ResVReg = I->getOperand(i: 0).getReg();
63 const LLT &ResLLT = MIB.getMRI()->getType(Reg: ResVReg);
64 if (ResLLT.isVector())
65 return GR->getOrCreateSPIRVVectorType(BaseType: CompType, NumElements: ResLLT.getNumElements(),
66 MIRBuilder&: MIB, EmitIR: false);
67 return CompType;
68 }
69 }
70 return nullptr;
71}
72
73static SPIRVTypeInst deduceTypeFromOperandRange(MachineInstr *I,
74 MachineIRBuilder &MIB,
75 SPIRVGlobalRegistry *GR,
76 unsigned StartOp,
77 unsigned EndOp) {
78 SPIRVTypeInst ResType = nullptr;
79 for (unsigned i = StartOp; i < EndOp; ++i) {
80 if (SPIRVTypeInst Type = deduceTypeFromSingleOperand(I, MIB, GR, OpIdx: i)) {
81#ifdef EXPENSIVE_CHECKS
82 assert(!ResType || Type == ResType && "Conflicting type from operands.");
83 ResType = Type;
84#else
85 return Type;
86#endif
87 }
88 }
89 return ResType;
90}
91
92static SPIRVTypeInst deduceTypeFromResultRegister(MachineInstr *Use,
93 Register UseRegister,
94 SPIRVGlobalRegistry *GR,
95 MachineIRBuilder &MIB) {
96 for (const MachineOperand &MO : Use->defs()) {
97 if (!MO.isReg())
98 continue;
99 if (SPIRVTypeInst OpType = GR->getSPIRVTypeForVReg(VReg: MO.getReg())) {
100 if (SPIRVTypeInst CompType = GR->getScalarOrVectorComponentType(Type: OpType)) {
101 const LLT &ResLLT = MIB.getMRI()->getType(Reg: UseRegister);
102 if (ResLLT.isVector())
103 return GR->getOrCreateSPIRVVectorType(
104 BaseType: CompType, NumElements: ResLLT.getNumElements(), MIRBuilder&: MIB, EmitIR: false);
105 return CompType;
106 }
107 }
108 }
109 return nullptr;
110}
111
112static SPIRVTypeInst
113deducePointerTypeFromResultRegister(MachineInstr *Use, Register UseRegister,
114 SPIRVGlobalRegistry *GR,
115 MachineIRBuilder &MIB) {
116 assert(Use->getOpcode() == TargetOpcode::G_LOAD ||
117 Use->getOpcode() == TargetOpcode::G_STORE);
118
119 Register ValueReg = Use->getOperand(i: 0).getReg();
120 SPIRVTypeInst ValueType = GR->getSPIRVTypeForVReg(VReg: ValueReg);
121 if (!ValueType)
122 return nullptr;
123
124 return GR->getOrCreateSPIRVPointerType(BaseType: ValueType, MIRBuilder&: MIB,
125 SC: SPIRV::StorageClass::Function);
126}
127
128static SPIRVTypeInst deduceTypeFromPointerOperand(MachineInstr *Use,
129 Register UseRegister,
130 SPIRVGlobalRegistry *GR,
131 MachineIRBuilder &MIB) {
132 assert(Use->getOpcode() == TargetOpcode::G_LOAD ||
133 Use->getOpcode() == TargetOpcode::G_STORE);
134
135 Register PtrReg = Use->getOperand(i: 1).getReg();
136 SPIRVTypeInst PtrType = GR->getSPIRVTypeForVReg(VReg: PtrReg);
137 if (!PtrType)
138 return nullptr;
139
140 return GR->getPointeeType(PtrType);
141}
142
143static SPIRVTypeInst deduceTypeFromUses(Register Reg, MachineFunction &MF,
144 SPIRVGlobalRegistry *GR,
145 MachineIRBuilder &MIB) {
146 MachineRegisterInfo &MRI = MF.getRegInfo();
147 for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
148 SPIRVTypeInst ResType = nullptr;
149 LLVM_DEBUG(dbgs() << "Looking at use " << Use);
150 switch (Use.getOpcode()) {
151 case TargetOpcode::G_BUILD_VECTOR:
152 case TargetOpcode::G_EXTRACT_VECTOR_ELT:
153 case TargetOpcode::G_UNMERGE_VALUES:
154 case TargetOpcode::G_ADD:
155 case TargetOpcode::G_SUB:
156 case TargetOpcode::G_MUL:
157 case TargetOpcode::G_SDIV:
158 case TargetOpcode::G_UDIV:
159 case TargetOpcode::G_SREM:
160 case TargetOpcode::G_UREM:
161 case TargetOpcode::G_FADD:
162 case TargetOpcode::G_FSUB:
163 case TargetOpcode::G_FMUL:
164 case TargetOpcode::G_FDIV:
165 case TargetOpcode::G_FREM:
166 case TargetOpcode::G_FMA:
167 case TargetOpcode::COPY:
168 case TargetOpcode::G_STRICT_FMA:
169 ResType = deduceTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB);
170 break;
171 case TargetOpcode::G_LOAD:
172 case TargetOpcode::G_STORE:
173 if (Reg == Use.getOperand(i: 1).getReg())
174 ResType = deducePointerTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB);
175 else
176 ResType = deduceTypeFromPointerOperand(Use: &Use, UseRegister: Reg, GR, MIB);
177 break;
178 case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
179 case TargetOpcode::G_INTRINSIC: {
180 auto IntrinsicID = cast<GIntrinsic>(Val&: Use).getIntrinsicID();
181 if (IntrinsicID == Intrinsic::spv_insertelt) {
182 if (Reg == Use.getOperand(i: 2).getReg())
183 ResType = deduceTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB);
184 } else if (IntrinsicID == Intrinsic::spv_extractelt) {
185 if (Reg == Use.getOperand(i: 2).getReg())
186 ResType = deduceTypeFromResultRegister(Use: &Use, UseRegister: Reg, GR, MIB);
187 }
188 break;
189 }
190 }
191 if (ResType) {
192 LLVM_DEBUG(dbgs() << "Deduced type from use " << *ResType);
193 return ResType;
194 }
195 }
196 return nullptr;
197}
198
199static SPIRVTypeInst deduceGEPType(MachineInstr *I, SPIRVGlobalRegistry *GR,
200 MachineIRBuilder &MIB) {
201 LLVM_DEBUG(dbgs() << "Deducing GEP type for: " << *I);
202 Register PtrReg = I->getOperand(i: 3).getReg();
203 SPIRVTypeInst PtrType = GR->getSPIRVTypeForVReg(VReg: PtrReg);
204 if (!PtrType) {
205 LLVM_DEBUG(dbgs() << " Could not get type for pointer operand.\n");
206 return nullptr;
207 }
208
209 SPIRVTypeInst PointeeType = GR->getPointeeType(PtrType);
210 if (!PointeeType) {
211 LLVM_DEBUG(dbgs() << " Could not get pointee type from pointer type.\n");
212 return nullptr;
213 }
214
215 MachineRegisterInfo *MRI = MIB.getMRI();
216
217 // The first index (operand 4) steps over the pointer, so the type doesn't
218 // change.
219 for (unsigned i = 5; i < I->getNumOperands(); ++i) {
220 LLVM_DEBUG(dbgs() << " Traversing index " << i
221 << ", current type: " << *PointeeType);
222 switch (PointeeType->getOpcode()) {
223 case SPIRV::OpTypeArray:
224 case SPIRV::OpTypeRuntimeArray:
225 case SPIRV::OpTypeVector: {
226 Register ElemTypeReg = PointeeType->getOperand(i: 1).getReg();
227 PointeeType = GR->getSPIRVTypeForVReg(VReg: ElemTypeReg);
228 break;
229 }
230 case SPIRV::OpTypeStruct: {
231 MachineOperand &IdxOp = I->getOperand(i);
232 if (!IdxOp.isReg()) {
233 LLVM_DEBUG(dbgs() << " Index is not a register.\n");
234 return nullptr;
235 }
236 MachineInstr *Def = MRI->getVRegDef(Reg: IdxOp.getReg());
237 if (!Def) {
238 LLVM_DEBUG(
239 dbgs() << " Could not find definition for index register.\n");
240 return nullptr;
241 }
242
243 uint64_t IndexVal = foldImm(MO: IdxOp, MRI);
244 if (IndexVal >= PointeeType->getNumOperands() - 1) {
245 LLVM_DEBUG(dbgs() << " Struct index out of bounds.\n");
246 return nullptr;
247 }
248
249 Register MemberTypeReg = PointeeType->getOperand(i: IndexVal + 1).getReg();
250 PointeeType = GR->getSPIRVTypeForVReg(VReg: MemberTypeReg);
251 break;
252 }
253 default:
254 LLVM_DEBUG(dbgs() << " Unknown type opcode for GEP traversal.\n");
255 return nullptr;
256 }
257
258 if (!PointeeType) {
259 LLVM_DEBUG(dbgs() << " Could not resolve next pointee type.\n");
260 return nullptr;
261 }
262 }
263 LLVM_DEBUG(dbgs() << " Final pointee type: " << *PointeeType);
264
265 SPIRV::StorageClass::StorageClass SC = GR->getPointerStorageClass(Type: PtrType);
266 SPIRVTypeInst Res = GR->getOrCreateSPIRVPointerType(BaseType: PointeeType, MIRBuilder&: MIB, SC);
267 LLVM_DEBUG(dbgs() << " Deduced GEP type: " << *Res);
268 return Res;
269}
270
271static SPIRVTypeInst deduceResultTypeFromOperands(MachineInstr *I,
272 SPIRVGlobalRegistry *GR,
273 MachineIRBuilder &MIB) {
274 Register ResVReg = I->getOperand(i: 0).getReg();
275 switch (I->getOpcode()) {
276 case TargetOpcode::G_CONSTANT:
277 case TargetOpcode::G_ANYEXT:
278 case TargetOpcode::G_SEXT:
279 case TargetOpcode::G_ZEXT:
280 return deduceIntTypeFromResult(ResVReg, MIB, GR);
281 case TargetOpcode::G_BUILD_VECTOR:
282 return deduceTypeFromOperandRange(I, MIB, GR, StartOp: 1, EndOp: I->getNumOperands());
283 case TargetOpcode::G_SHUFFLE_VECTOR:
284 return deduceTypeFromOperandRange(I, MIB, GR, StartOp: 1, EndOp: 3);
285 case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
286 case TargetOpcode::G_INTRINSIC: {
287 auto IntrinsicID = cast<GIntrinsic>(Val: I)->getIntrinsicID();
288 if (IntrinsicID == Intrinsic::spv_gep)
289 return deduceGEPType(I, GR, MIB);
290 break;
291 }
292 case TargetOpcode::G_LOAD: {
293 SPIRVTypeInst PtrType = deduceTypeFromSingleOperand(I, MIB, GR, OpIdx: 1);
294 return PtrType ? GR->getPointeeType(PtrType) : nullptr;
295 }
296 default:
297 if (I->getNumDefs() == 1 && I->getNumOperands() > 1 &&
298 I->getOperand(i: 1).isReg())
299 return deduceTypeFromSingleOperand(I, MIB, GR, OpIdx: 1);
300 }
301 return nullptr;
302}
303
304static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
305 SPIRVGlobalRegistry *GR,
306 MachineIRBuilder &MIB) {
307 MachineRegisterInfo &MRI = MF.getRegInfo();
308 Register SrcReg = I->getOperand(i: I->getNumOperands() - 1).getReg();
309 SPIRVTypeInst ScalarType = nullptr;
310 if (SPIRVTypeInst DefType = GR->getSPIRVTypeForVReg(VReg: SrcReg)) {
311 assert(DefType->getOpcode() == SPIRV::OpTypeVector);
312 ScalarType = GR->getSPIRVTypeForVReg(VReg: DefType->getOperand(i: 1).getReg());
313 }
314
315 if (!ScalarType) {
316 // If we could not deduce the type from the source, try to deduce it from
317 // the uses of the results.
318 for (unsigned i = 0; i < I->getNumDefs(); ++i) {
319 Register DefReg = I->getOperand(i).getReg();
320 ScalarType = deduceTypeFromUses(Reg: DefReg, MF, GR, MIB);
321 if (ScalarType) {
322 ScalarType = GR->getScalarOrVectorComponentType(Type: ScalarType);
323 break;
324 }
325 }
326 }
327
328 if (!ScalarType)
329 return false;
330
331 for (unsigned i = 0; i < I->getNumOperands(); ++i) {
332 Register DefReg = I->getOperand(i).getReg();
333 if (GR->getSPIRVTypeForVReg(VReg: DefReg))
334 continue;
335
336 LLT DefLLT = MRI.getType(Reg: DefReg);
337 SPIRVTypeInst ResType =
338 DefLLT.isVector()
339 ? GR->getOrCreateSPIRVVectorType(
340 BaseType: ScalarType, NumElements: DefLLT.getNumElements(), I&: *I,
341 TII: *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo())
342 : ScalarType;
343 setRegClassType(Reg: DefReg, SpvType: ResType, GR, MRI: &MRI, MF);
344 }
345 return true;
346}
347
348static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
349 SPIRVGlobalRegistry *GR,
350 MachineIRBuilder &MIB) {
351 LLVM_DEBUG(dbgs() << "\nProcessing instruction: " << *I);
352 MachineRegisterInfo &MRI = MF.getRegInfo();
353 Register ResVReg = I->getOperand(i: 0).getReg();
354
355 // G_UNMERGE_VALUES is handled separately because it has multiple definitions,
356 // unlike the other instructions which have a single result register. The main
357 // deduction logic is designed for the single-definition case.
358 if (I->getOpcode() == TargetOpcode::G_UNMERGE_VALUES)
359 return deduceAndAssignTypeForGUnmerge(I, MF, GR, MIB);
360
361 LLVM_DEBUG(dbgs() << "Inferring type from operands\n");
362 SPIRVTypeInst ResType = deduceResultTypeFromOperands(I, GR, MIB);
363 if (!ResType) {
364 LLVM_DEBUG(dbgs() << "Inferring type from uses\n");
365 ResType = deduceTypeFromUses(Reg: ResVReg, MF, GR, MIB);
366 }
367
368 if (!ResType)
369 return false;
370
371 LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType);
372 GR->assignSPIRVTypeToVReg(Type: ResType, VReg: ResVReg, MF);
373
374 if (!MRI.getRegClassOrNull(Reg: ResVReg)) {
375 LLVM_DEBUG(dbgs() << "Updating the register class.\n");
376 setRegClassType(Reg: ResVReg, SpvType: ResType, GR, MRI: &MRI, MF: *GR->CurMF, Force: true);
377 }
378 return true;
379}
380
381static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
382 MachineRegisterInfo &MRI) {
383 LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: "
384 << I;);
385 if (I.getNumDefs() == 0) {
386 LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n");
387 return false;
388 }
389
390 if (!I.isPreISelOpcode()) {
391 LLVM_DEBUG(dbgs() << "Instruction is not a generic instruction.\n");
392 return false;
393 }
394
395 Register ResultRegister = I.defs().begin()->getReg();
396 if (GR->getSPIRVTypeForVReg(VReg: ResultRegister)) {
397 LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n");
398 if (!MRI.getRegClassOrNull(Reg: ResultRegister)) {
399 LLVM_DEBUG(dbgs() << "Updating the register class.\n");
400 setRegClassType(Reg: ResultRegister, SpvType: GR->getSPIRVTypeForVReg(VReg: ResultRegister),
401 GR, MRI: &MRI, MF: *GR->CurMF, Force: true);
402 }
403 return false;
404 }
405
406 return true;
407}
408
409static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
410 SPIRVGlobalRegistry *GR) {
411 MachineRegisterInfo &MRI = MF.getRegInfo();
412 SmallVector<MachineInstr *, 8> Worklist;
413 for (MachineBasicBlock &MBB : MF) {
414 for (MachineInstr &I : MBB) {
415 if (requiresSpirvType(I, GR, MRI)) {
416 Worklist.push_back(Elt: &I);
417 }
418 }
419 }
420
421 if (Worklist.empty()) {
422 LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n");
423 return;
424 }
425
426 LLVM_DEBUG(dbgs() << "Initial worklist:\n";
427 for (auto *I : Worklist) { I->dump(); });
428
429 bool Changed;
430 do {
431 Changed = false;
432 SmallVector<MachineInstr *, 8> NextWorklist;
433
434 for (MachineInstr *I : Worklist) {
435 MachineIRBuilder MIB(*I);
436 if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
437 Changed = true;
438 } else {
439 NextWorklist.push_back(Elt: I);
440 }
441 }
442 Worklist = std::move(NextWorklist);
443 LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
444 } while (Changed);
445
446 if (Worklist.empty())
447 return;
448
449 for (auto *I : Worklist) {
450 MachineIRBuilder MIB(*I);
451 LLVM_DEBUG(dbgs() << "Assigning default type to results in " << *I);
452 for (unsigned Idx = 0; Idx < I->getNumDefs(); ++Idx) {
453 Register ResVReg = I->getOperand(i: Idx).getReg();
454 if (GR->getSPIRVTypeForVReg(VReg: ResVReg))
455 continue;
456 const LLT &ResLLT = MRI.getType(Reg: ResVReg);
457 SPIRVTypeInst ResType = nullptr;
458 if (ResLLT.isVector()) {
459 SPIRVTypeInst CompType = GR->getOrCreateSPIRVIntegerType(
460 BitWidth: ResLLT.getElementType().getSizeInBits(), MIRBuilder&: MIB);
461 ResType = GR->getOrCreateSPIRVVectorType(
462 BaseType: CompType, NumElements: ResLLT.getNumElements(), MIRBuilder&: MIB, EmitIR: false);
463 } else {
464 ResType = GR->getOrCreateSPIRVIntegerType(BitWidth: ResLLT.getSizeInBits(), MIRBuilder&: MIB);
465 }
466 setRegClassType(Reg: ResVReg, SpvType: ResType, GR, MRI: &MRI, MF, Force: true);
467 }
468 }
469}
470
471static bool hasAssignType(Register Reg, MachineRegisterInfo &MRI) {
472 for (MachineInstr &UseInstr : MRI.use_nodbg_instructions(Reg)) {
473 if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) {
474 return true;
475 }
476 }
477 return false;
478}
479
480static void generateAssignType(MachineInstr &MI, Register ResultRegister,
481 SPIRVTypeInst ResultType,
482 SPIRVGlobalRegistry *GR,
483 MachineRegisterInfo &MRI) {
484 LLVM_DEBUG(dbgs() << " Adding ASSIGN_TYPE for ResultRegister: "
485 << printReg(ResultRegister, MRI.getTargetRegisterInfo())
486 << " with type: " << *ResultType);
487 MachineIRBuilder MIB(MI);
488 updateRegType(Reg: ResultRegister, Ty: nullptr, SpirvTy: ResultType, GR, MIB, MRI);
489
490 // Tablegen definition assumes SPIRV::ASSIGN_TYPE pseudo-instruction is
491 // present after each auto-folded instruction to take a type reference
492 // from.
493 Register NewReg =
494 MRI.createGenericVirtualRegister(Ty: MRI.getType(Reg: ResultRegister));
495 const auto *RegClass = GR->getRegClass(SpvType: ResultType);
496 MRI.setRegClass(Reg: NewReg, RC: RegClass);
497 MRI.setRegClass(Reg: ResultRegister, RC: RegClass);
498
499 GR->assignSPIRVTypeToVReg(Type: ResultType, VReg: ResultRegister, MF: MIB.getMF());
500 // This is to make it convenient for Legalizer to get the SPIRVType
501 // when processing the actual MI (i.e. not pseudo one).
502 GR->assignSPIRVTypeToVReg(Type: ResultType, VReg: NewReg, MF: MIB.getMF());
503 // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to
504 // keep the flags after instruction selection.
505 const uint32_t Flags = MI.getFlags();
506 MIB.buildInstr(Opcode: SPIRV::ASSIGN_TYPE)
507 .addDef(RegNo: ResultRegister)
508 .addUse(RegNo: NewReg)
509 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: ResultType))
510 .setMIFlags(Flags);
511 for (unsigned I = 0, E = MI.getNumDefs(); I != E; ++I) {
512 MachineOperand &MO = MI.getOperand(i: I);
513 if (MO.getReg() == ResultRegister) {
514 MO.setReg(NewReg);
515 break;
516 }
517 }
518}
519
520static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
521 SPIRVGlobalRegistry *GR) {
522 LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
523 << MF.getName() << "\n");
524 MachineRegisterInfo &MRI = MF.getRegInfo();
525 for (MachineBasicBlock &MBB : MF) {
526 for (MachineInstr &MI : MBB) {
527 if (!isTypeFoldingSupported(Opcode: MI.getOpcode()))
528 continue;
529
530 LLVM_DEBUG(dbgs() << "Processing instruction: " << MI);
531
532 Register ResultRegister = MI.defs().begin()->getReg();
533 if (hasAssignType(Reg: ResultRegister, MRI)) {
534 LLVM_DEBUG(dbgs() << " Instruction already has ASSIGN_TYPE\n");
535 continue;
536 }
537
538 SPIRVTypeInst ResultType = GR->getSPIRVTypeForVReg(VReg: ResultRegister);
539 generateAssignType(MI, ResultRegister, ResultType, GR, MRI);
540 }
541 }
542}
543
544// Do a preorder traversal of the CFG starting from the BB |Start|.
545// point. Calls |op| on each basic block encountered during the traversal.
546void visit(MachineFunction &MF, MachineBasicBlock &Start,
547 std::function<void(MachineBasicBlock *)> op) {
548 std::stack<MachineBasicBlock *> ToVisit;
549 SmallPtrSet<MachineBasicBlock *, 8> Seen;
550
551 ToVisit.push(x: &Start);
552 Seen.insert(Ptr: ToVisit.top());
553 while (ToVisit.size() != 0) {
554 MachineBasicBlock *MBB = ToVisit.top();
555 ToVisit.pop();
556
557 op(MBB);
558
559 for (auto Succ : MBB->successors()) {
560 if (Seen.contains(Ptr: Succ))
561 continue;
562 ToVisit.push(x: Succ);
563 Seen.insert(Ptr: Succ);
564 }
565 }
566}
567
568// Do a preorder traversal of the CFG starting from the given function's entry
569// point. Calls |op| on each basic block encountered during the traversal.
570void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
571 visit(MF, Start&: *MF.begin(), op: std::move(op));
572}
573
574bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
575 // Initialize the type registry.
576 const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
577 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
578 GR->setCurrentFunc(MF);
579 registerSpirvTypeForNewInstructions(MF, GR);
580 ensureAssignTypeForTypeFolding(MF, GR);
581 return true;
582}
583
584INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer", false,
585 false)
586
587char SPIRVPostLegalizer::ID = 0;
588
589FunctionPass *llvm::createSPIRVPostLegalizerPass() {
590 return new SPIRVPostLegalizer();
591}
592