1//===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the SPIRVGlobalRegistry class,
10// which is used to maintain rich type information required for SPIR-V even
11// after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12// an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13// and supports consistency of constants and global variables.
14//
15//===----------------------------------------------------------------------===//
16
17#include "SPIRVGlobalRegistry.h"
18#include "SPIRV.h"
19#include "SPIRVBuiltins.h"
20#include "SPIRVSubtarget.h"
21#include "SPIRVUtils.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/IR/Constants.h"
24#include "llvm/IR/IntrinsicInst.h"
25#include "llvm/IR/Intrinsics.h"
26#include "llvm/IR/IntrinsicsSPIRV.h"
27#include "llvm/IR/Type.h"
28#include "llvm/Support/Casting.h"
29#include "llvm/Support/MathExtras.h"
30#include <cassert>
31#include <functional>
32
33using namespace llvm;
34
35static bool allowEmitFakeUse(const Value *Arg) {
36 if (isSpvIntrinsic(Arg))
37 return false;
38 if (isa<AtomicCmpXchgInst, InsertValueInst, UndefValue>(Val: Arg))
39 return false;
40 if (const auto *LI = dyn_cast<LoadInst>(Val: Arg))
41 if (LI->getType()->isAggregateType())
42 return false;
43 return true;
44}
45
46static unsigned typeToAddressSpace(const Type *Ty) {
47 if (auto PType = dyn_cast<TypedPointerType>(Val: Ty))
48 return PType->getAddressSpace();
49 if (auto PType = dyn_cast<PointerType>(Val: Ty))
50 return PType->getAddressSpace();
51 if (auto *ExtTy = dyn_cast<TargetExtType>(Val: Ty);
52 ExtTy && isTypedPointerWrapper(ExtTy))
53 return ExtTy->getIntParameter(i: 0);
54 reportFatalInternalError(reason: "Unable to convert LLVM type to SPIRVType");
55}
56
57static bool
58storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC) {
59 switch (SC) {
60 case SPIRV::StorageClass::Uniform:
61 case SPIRV::StorageClass::PushConstant:
62 case SPIRV::StorageClass::StorageBuffer:
63 case SPIRV::StorageClass::PhysicalStorageBufferEXT:
64 return true;
65 case SPIRV::StorageClass::UniformConstant:
66 case SPIRV::StorageClass::Input:
67 case SPIRV::StorageClass::Output:
68 case SPIRV::StorageClass::Workgroup:
69 case SPIRV::StorageClass::CrossWorkgroup:
70 case SPIRV::StorageClass::Private:
71 case SPIRV::StorageClass::Function:
72 case SPIRV::StorageClass::Generic:
73 case SPIRV::StorageClass::AtomicCounter:
74 case SPIRV::StorageClass::Image:
75 case SPIRV::StorageClass::CallableDataNV:
76 case SPIRV::StorageClass::IncomingCallableDataNV:
77 case SPIRV::StorageClass::RayPayloadNV:
78 case SPIRV::StorageClass::HitAttributeNV:
79 case SPIRV::StorageClass::IncomingRayPayloadNV:
80 case SPIRV::StorageClass::ShaderRecordBufferNV:
81 case SPIRV::StorageClass::CodeSectionINTEL:
82 case SPIRV::StorageClass::DeviceOnlyINTEL:
83 case SPIRV::StorageClass::HostOnlyINTEL:
84 return false;
85 }
86 llvm_unreachable("Unknown SPIRV::StorageClass enum");
87}
88
89SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
90 : PointerSize(PointerSize), Bound(0) {}
91
92SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
93 Register VReg,
94 MachineInstr &I,
95 const SPIRVInstrInfo &TII) {
96 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
97 assignSPIRVTypeToVReg(Type: SpirvType, VReg, MF: *CurMF);
98 return SpirvType;
99}
100
101SPIRVType *
102SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
103 MachineInstr &I,
104 const SPIRVInstrInfo &TII) {
105 SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
106 assignSPIRVTypeToVReg(Type: SpirvType, VReg, MF: *CurMF);
107 return SpirvType;
108}
109
110SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
111 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
112 const SPIRVInstrInfo &TII) {
113 SPIRVType *SpirvType =
114 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
115 assignSPIRVTypeToVReg(Type: SpirvType, VReg, MF: *CurMF);
116 return SpirvType;
117}
118
119SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
120 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
121 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
122 SPIRVType *SpirvType =
123 getOrCreateSPIRVType(Type, MIRBuilder, AQ: AccessQual, EmitIR);
124 assignSPIRVTypeToVReg(Type: SpirvType, VReg, MF: MIRBuilder.getMF());
125 return SpirvType;
126}
127
128void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
129 Register VReg,
130 const MachineFunction &MF) {
131 VRegToTypeMap[&MF][VReg] = SpirvType;
132}
133
134static Register createTypeVReg(MachineRegisterInfo &MRI) {
135 auto Res = MRI.createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 64));
136 MRI.setRegClass(Reg: Res, RC: &SPIRV::TYPERegClass);
137 return Res;
138}
139
140inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
141 return createTypeVReg(MRI&: MIRBuilder.getMF().getRegInfo());
142}
143
144SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
145 return createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
146 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeBool)
147 .addDef(RegNo: createTypeVReg(MIRBuilder));
148 });
149}
150
151unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
152 if (Width > 64)
153 report_fatal_error(reason: "Unsupported integer width!");
154 const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(Val: CurMF->getSubtarget());
155 if (ST.canUseExtension(
156 E: SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
157 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_int4))
158 return Width;
159 if (Width <= 8)
160 Width = 8;
161 else if (Width <= 16)
162 Width = 16;
163 else if (Width <= 32)
164 Width = 32;
165 else
166 Width = 64;
167 return Width;
168}
169
170SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
171 MachineIRBuilder &MIRBuilder,
172 bool IsSigned) {
173 Width = adjustOpTypeIntWidth(Width);
174 const SPIRVSubtarget &ST =
175 cast<SPIRVSubtarget>(Val: MIRBuilder.getMF().getSubtarget());
176 return createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
177 if (Width == 4 && ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_int4)) {
178 MIRBuilder.buildInstr(Opcode: SPIRV::OpExtension)
179 .addImm(Val: SPIRV::Extension::SPV_INTEL_int4);
180 MIRBuilder.buildInstr(Opcode: SPIRV::OpCapability)
181 .addImm(Val: SPIRV::Capability::Int4TypeINTEL);
182 } else if ((!isPowerOf2_32(Value: Width) || Width < 8) &&
183 ST.canUseExtension(
184 E: SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
185 MIRBuilder.buildInstr(Opcode: SPIRV::OpExtension)
186 .addImm(Val: SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
187 MIRBuilder.buildInstr(Opcode: SPIRV::OpCapability)
188 .addImm(Val: SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
189 }
190 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeInt)
191 .addDef(RegNo: createTypeVReg(MIRBuilder))
192 .addImm(Val: Width)
193 .addImm(Val: IsSigned ? 1 : 0);
194 });
195}
196
197SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
198 MachineIRBuilder &MIRBuilder) {
199 return createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
200 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeFloat)
201 .addDef(RegNo: createTypeVReg(MIRBuilder))
202 .addImm(Val: Width);
203 });
204}
205
206SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
207 return createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
208 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeVoid)
209 .addDef(RegNo: createTypeVReg(MIRBuilder));
210 });
211}
212
213void SPIRVGlobalRegistry::invalidateMachineInstr(MachineInstr *MI) {
214 // TODO:
215 // - review other data structure wrt. possible issues related to removal
216 // of a machine instruction during instruction selection.
217 const MachineFunction *MF = MI->getMF();
218 auto It = LastInsertedTypeMap.find(Val: MF);
219 if (It == LastInsertedTypeMap.end())
220 return;
221 if (It->second == MI)
222 LastInsertedTypeMap.erase(Val: MF);
223 // remove from the duplicate tracker to avoid incorrect reuse
224 erase(MI);
225}
226
227SPIRVType *SPIRVGlobalRegistry::createOpType(
228 MachineIRBuilder &MIRBuilder,
229 std::function<MachineInstr *(MachineIRBuilder &)> Op) {
230 auto oldInsertPoint = MIRBuilder.getInsertPt();
231 MachineBasicBlock *OldMBB = &MIRBuilder.getMBB();
232 MachineBasicBlock *NewMBB = &*MIRBuilder.getMF().begin();
233
234 auto LastInsertedType = LastInsertedTypeMap.find(Val: CurMF);
235 if (LastInsertedType != LastInsertedTypeMap.end()) {
236 auto It = LastInsertedType->second->getIterator();
237 // It might happen that this instruction was removed from the first MBB,
238 // hence the Parent's check.
239 MachineBasicBlock::iterator InsertAt;
240 if (It->getParent() != NewMBB)
241 InsertAt = oldInsertPoint->getParent() == NewMBB
242 ? oldInsertPoint
243 : getInsertPtValidEnd(MBB: NewMBB);
244 else if (It->getNextNode())
245 InsertAt = It->getNextNode()->getIterator();
246 else
247 InsertAt = getInsertPtValidEnd(MBB: NewMBB);
248 MIRBuilder.setInsertPt(MBB&: *NewMBB, II: InsertAt);
249 } else {
250 MIRBuilder.setInsertPt(MBB&: *NewMBB, II: NewMBB->begin());
251 auto Result = LastInsertedTypeMap.try_emplace(Key: CurMF, Args: nullptr);
252 assert(Result.second);
253 LastInsertedType = Result.first;
254 }
255
256 MachineInstr *Type = Op(MIRBuilder);
257 // We expect all users of this function to insert definitions at the insertion
258 // point set above that is always the first MBB.
259 assert(Type->getParent() == NewMBB);
260 LastInsertedType->second = Type;
261
262 MIRBuilder.setInsertPt(MBB&: *OldMBB, II: oldInsertPoint);
263 return Type;
264}
265
266SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
267 SPIRVType *ElemType,
268 MachineIRBuilder &MIRBuilder) {
269 auto EleOpc = ElemType->getOpcode();
270 (void)EleOpc;
271 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
272 EleOpc == SPIRV::OpTypeBool) &&
273 "Invalid vector element type");
274
275 return createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
276 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeVector)
277 .addDef(RegNo: createTypeVReg(MIRBuilder))
278 .addUse(RegNo: getSPIRVTypeID(SpirvType: ElemType))
279 .addImm(Val: NumElems);
280 });
281}
282
283Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
284 SPIRVType *SpvType,
285 const SPIRVInstrInfo &TII,
286 bool ZeroAsNull) {
287 LLVMContext &Ctx = CurMF->getFunction().getContext();
288 auto *const CF = ConstantFP::get(Context&: Ctx, V: Val);
289 const MachineInstr *MI = findMI(Obj: CF, MF: CurMF);
290 if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
291 MI->getOpcode() == SPIRV::OpConstantF))
292 return MI->getOperand(i: 0).getReg();
293 return createConstFP(CF, I, SpvType, TII, ZeroAsNull);
294}
295
296Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
297 MachineInstr &I, SPIRVType *SpvType,
298 const SPIRVInstrInfo &TII,
299 bool ZeroAsNull) {
300 unsigned BitWidth = getScalarOrVectorBitWidth(Type: SpvType);
301 LLT LLTy = LLT::scalar(SizeInBits: BitWidth);
302 Register Res = CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
303 CurMF->getRegInfo().setRegClass(Reg: Res, RC: &SPIRV::fIDRegClass);
304 assignFloatTypeToVReg(BitWidth, VReg: Res, I, TII);
305
306 MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
307 MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
308 SPIRVType *NewType =
309 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
310 MachineInstrBuilder MIB;
311 // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
312 if (CF->getValue().isPosZero() && ZeroAsNull) {
313 MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantNull)
314 .addDef(RegNo: Res)
315 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
316 } else {
317 MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantF)
318 .addDef(RegNo: Res)
319 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
320 addNumImm(Imm: APInt(BitWidth,
321 CF->getValueAPF().bitcastToAPInt().getZExtValue()),
322 MIB);
323 }
324 const auto &ST = CurMF->getSubtarget();
325 constrainSelectedInstRegOperands(I&: *MIB, TII: *ST.getInstrInfo(),
326 TRI: *ST.getRegisterInfo(),
327 RBI: *ST.getRegBankInfo());
328 return MIB;
329 });
330 add(V: CF, MI: NewType);
331 return Res;
332}
333
334Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
335 SPIRVType *SpvType,
336 const SPIRVInstrInfo &TII,
337 bool ZeroAsNull) {
338 const IntegerType *Ty = cast<IntegerType>(Val: getTypeForSPIRVType(Ty: SpvType));
339 auto *const CI = ConstantInt::get(Ty: const_cast<IntegerType *>(Ty), V: Val);
340 const MachineInstr *MI = findMI(Obj: CI, MF: CurMF);
341 if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
342 MI->getOpcode() == SPIRV::OpConstantI))
343 return MI->getOperand(i: 0).getReg();
344 return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
345}
346
347Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
348 MachineInstr &I,
349 SPIRVType *SpvType,
350 const SPIRVInstrInfo &TII,
351 bool ZeroAsNull) {
352 unsigned BitWidth = getScalarOrVectorBitWidth(Type: SpvType);
353 LLT LLTy = LLT::scalar(SizeInBits: BitWidth);
354 Register Res = CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
355 CurMF->getRegInfo().setRegClass(Reg: Res, RC: &SPIRV::iIDRegClass);
356 assignIntTypeToVReg(BitWidth, VReg: Res, I, TII);
357
358 MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
359 MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
360 SPIRVType *NewType =
361 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
362 MachineInstrBuilder MIB;
363 if (BitWidth == 1) {
364 MIB = MIRBuilder
365 .buildInstr(Opcode: CI->isZero() ? SPIRV::OpConstantFalse
366 : SPIRV::OpConstantTrue)
367 .addDef(RegNo: Res)
368 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
369 } else if (!CI->isZero() || !ZeroAsNull) {
370 MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantI)
371 .addDef(RegNo: Res)
372 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
373 addNumImm(Imm: APInt(BitWidth, CI->getZExtValue()), MIB);
374 } else {
375 MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantNull)
376 .addDef(RegNo: Res)
377 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
378 }
379 const auto &ST = CurMF->getSubtarget();
380 constrainSelectedInstRegOperands(I&: *MIB, TII: *ST.getInstrInfo(),
381 TRI: *ST.getRegisterInfo(),
382 RBI: *ST.getRegBankInfo());
383 return MIB;
384 });
385 add(V: CI, MI: NewType);
386 return Res;
387}
388
389Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
390 MachineIRBuilder &MIRBuilder,
391 SPIRVType *SpvType, bool EmitIR,
392 bool ZeroAsNull) {
393 assert(SpvType);
394 auto &MF = MIRBuilder.getMF();
395 const IntegerType *Ty = cast<IntegerType>(Val: getTypeForSPIRVType(Ty: SpvType));
396 auto *const CI = ConstantInt::get(Ty: const_cast<IntegerType *>(Ty), V: Val);
397 Register Res = find(V: CI, MF: &MF);
398 if (Res.isValid())
399 return Res;
400
401 unsigned BitWidth = getScalarOrVectorBitWidth(Type: SpvType);
402 LLT LLTy = LLT::scalar(SizeInBits: BitWidth);
403 MachineRegisterInfo &MRI = MF.getRegInfo();
404 Res = MRI.createGenericVirtualRegister(Ty: LLTy);
405 MRI.setRegClass(Reg: Res, RC: &SPIRV::iIDRegClass);
406 assignTypeToVReg(Type: Ty, VReg: Res, MIRBuilder, AccessQual: SPIRV::AccessQualifier::ReadWrite,
407 EmitIR);
408
409 SPIRVType *NewType =
410 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
411 if (EmitIR)
412 return MIRBuilder.buildConstant(Res, Val: *CI);
413 Register SpvTypeReg = getSPIRVTypeID(SpirvType: SpvType);
414 MachineInstrBuilder MIB;
415 if (Val || !ZeroAsNull) {
416 MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantI)
417 .addDef(RegNo: Res)
418 .addUse(RegNo: SpvTypeReg);
419 addNumImm(Imm: APInt(BitWidth, Val), MIB);
420 } else {
421 MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantNull)
422 .addDef(RegNo: Res)
423 .addUse(RegNo: SpvTypeReg);
424 }
425 const auto &Subtarget = CurMF->getSubtarget();
426 constrainSelectedInstRegOperands(I&: *MIB, TII: *Subtarget.getInstrInfo(),
427 TRI: *Subtarget.getRegisterInfo(),
428 RBI: *Subtarget.getRegBankInfo());
429 return MIB;
430 });
431 add(V: CI, MI: NewType);
432 return Res;
433}
434
435Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
436 MachineIRBuilder &MIRBuilder,
437 SPIRVType *SpvType) {
438 auto &MF = MIRBuilder.getMF();
439 LLVMContext &Ctx = MF.getFunction().getContext();
440 if (!SpvType)
441 SpvType = getOrCreateSPIRVType(Type: Type::getFloatTy(C&: Ctx), MIRBuilder,
442 AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
443 auto *const CF = ConstantFP::get(Context&: Ctx, V: Val);
444 Register Res = find(V: CF, MF: &MF);
445 if (Res.isValid())
446 return Res;
447
448 LLT LLTy = LLT::scalar(SizeInBits: getScalarOrVectorBitWidth(Type: SpvType));
449 Res = MF.getRegInfo().createGenericVirtualRegister(Ty: LLTy);
450 MF.getRegInfo().setRegClass(Reg: Res, RC: &SPIRV::fIDRegClass);
451 assignSPIRVTypeToVReg(SpirvType: SpvType, VReg: Res, MF);
452
453 SPIRVType *NewType =
454 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
455 MachineInstrBuilder MIB;
456 MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantF)
457 .addDef(RegNo: Res)
458 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
459 addNumImm(Imm: CF->getValueAPF().bitcastToAPInt(), MIB);
460 return MIB;
461 });
462 add(V: CF, MI: NewType);
463 return Res;
464}
465
466Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
467 Constant *Val, MachineInstr &I, SPIRVType *SpvType,
468 const SPIRVInstrInfo &TII, unsigned BitWidth, bool ZeroAsNull) {
469 SPIRVType *Type = SpvType;
470 if (SpvType->getOpcode() == SPIRV::OpTypeVector ||
471 SpvType->getOpcode() == SPIRV::OpTypeArray) {
472 auto EleTypeReg = SpvType->getOperand(i: 1).getReg();
473 Type = getSPIRVTypeForVReg(VReg: EleTypeReg);
474 }
475 if (Type->getOpcode() == SPIRV::OpTypeFloat) {
476 SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
477 return getOrCreateConstFP(Val: dyn_cast<ConstantFP>(Val)->getValue(), I,
478 SpvType: SpvBaseType, TII, ZeroAsNull);
479 }
480 assert(Type->getOpcode() == SPIRV::OpTypeInt);
481 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
482 return getOrCreateConstInt(Val: Val->getUniqueInteger().getZExtValue(), I,
483 SpvType: SpvBaseType, TII, ZeroAsNull);
484}
485
486Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
487 Constant *Val, MachineInstr &I, SPIRVType *SpvType,
488 const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
489 unsigned ElemCnt, bool ZeroAsNull) {
490 if (Register R = find(V: CA, MF: CurMF); R.isValid())
491 return R;
492
493 bool IsNull = Val->isNullValue() && ZeroAsNull;
494 Register ElemReg;
495 if (!IsNull)
496 ElemReg =
497 getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth, ZeroAsNull);
498
499 LLT LLTy = LLT::scalar(SizeInBits: 64);
500 Register Res = CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
501 CurMF->getRegInfo().setRegClass(Reg: Res, RC: getRegClass(SpvType));
502 assignSPIRVTypeToVReg(SpirvType: SpvType, VReg: Res, MF: *CurMF);
503
504 MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
505 MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
506 const MachineInstr *NewMI =
507 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
508 MachineInstrBuilder MIB;
509 if (!IsNull) {
510 MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantComposite)
511 .addDef(RegNo: Res)
512 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
513 for (unsigned i = 0; i < ElemCnt; ++i)
514 MIB.addUse(RegNo: ElemReg);
515 } else {
516 MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantNull)
517 .addDef(RegNo: Res)
518 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
519 }
520 const auto &Subtarget = CurMF->getSubtarget();
521 constrainSelectedInstRegOperands(I&: *MIB, TII: *Subtarget.getInstrInfo(),
522 TRI: *Subtarget.getRegisterInfo(),
523 RBI: *Subtarget.getRegBankInfo());
524 return MIB;
525 });
526 add(V: CA, MI: NewMI);
527 return Res;
528}
529
530Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
531 MachineInstr &I,
532 SPIRVType *SpvType,
533 const SPIRVInstrInfo &TII,
534 bool ZeroAsNull) {
535 const Type *LLVMTy = getTypeForSPIRVType(Ty: SpvType);
536 assert(LLVMTy->isVectorTy());
537 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(Val: LLVMTy);
538 Type *LLVMBaseTy = LLVMVecTy->getElementType();
539 assert(LLVMBaseTy->isIntegerTy());
540 auto *ConstVal = ConstantInt::get(Ty: LLVMBaseTy, V: Val);
541 auto *ConstVec =
542 ConstantVector::getSplat(EC: LLVMVecTy->getElementCount(), Elt: ConstVal);
543 unsigned BW = getScalarOrVectorBitWidth(Type: SpvType);
544 return getOrCreateCompositeOrNull(Val: ConstVal, I, SpvType, TII, CA: ConstVec, BitWidth: BW,
545 ElemCnt: SpvType->getOperand(i: 2).getImm(),
546 ZeroAsNull);
547}
548
549Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
550 MachineInstr &I,
551 SPIRVType *SpvType,
552 const SPIRVInstrInfo &TII,
553 bool ZeroAsNull) {
554 const Type *LLVMTy = getTypeForSPIRVType(Ty: SpvType);
555 assert(LLVMTy->isVectorTy());
556 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(Val: LLVMTy);
557 Type *LLVMBaseTy = LLVMVecTy->getElementType();
558 assert(LLVMBaseTy->isFloatingPointTy());
559 auto *ConstVal = ConstantFP::get(Ty: LLVMBaseTy, V: Val);
560 auto *ConstVec =
561 ConstantVector::getSplat(EC: LLVMVecTy->getElementCount(), Elt: ConstVal);
562 unsigned BW = getScalarOrVectorBitWidth(Type: SpvType);
563 return getOrCreateCompositeOrNull(Val: ConstVal, I, SpvType, TII, CA: ConstVec, BitWidth: BW,
564 ElemCnt: SpvType->getOperand(i: 2).getImm(),
565 ZeroAsNull);
566}
567
568Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
569 uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
570 const SPIRVInstrInfo &TII) {
571 const Type *LLVMTy = getTypeForSPIRVType(Ty: SpvType);
572 assert(LLVMTy->isArrayTy());
573 const ArrayType *LLVMArrTy = cast<ArrayType>(Val: LLVMTy);
574 Type *LLVMBaseTy = LLVMArrTy->getElementType();
575 Constant *CI = ConstantInt::get(Ty: LLVMBaseTy, V: Val);
576 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(VReg: SpvType->getOperand(i: 1).getReg());
577 unsigned BW = getScalarOrVectorBitWidth(Type: SpvBaseTy);
578 // The following is reasonably unique key that is better that [Val]. The naive
579 // alternative would be something along the lines of:
580 // SmallVector<Constant *> NumCI(Num, CI);
581 // Constant *UniqueKey =
582 // ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
583 // that would be a truly unique but dangerous key, because it could lead to
584 // the creation of constants of arbitrary length (that is, the parameter of
585 // memset) which were missing in the original module.
586 Constant *UniqueKey = ConstantStruct::getAnon(
587 V: {PoisonValue::get(T: const_cast<ArrayType *>(LLVMArrTy)),
588 ConstantInt::get(Ty: LLVMBaseTy, V: Val), ConstantInt::get(Ty: LLVMBaseTy, V: Num)});
589 return getOrCreateCompositeOrNull(Val: CI, I, SpvType, TII, CA: UniqueKey, BitWidth: BW,
590 ElemCnt: LLVMArrTy->getNumElements());
591}
592
593Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
594 uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
595 Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
596 if (Register R = find(V: CA, MF: CurMF); R.isValid())
597 return R;
598
599 Register ElemReg;
600 if (Val || EmitIR) {
601 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
602 ElemReg = buildConstantInt(Val, MIRBuilder, SpvType: SpvBaseType, EmitIR);
603 }
604 LLT LLTy = EmitIR ? LLT::fixed_vector(NumElements: ElemCnt, ScalarSizeInBits: BitWidth) : LLT::scalar(SizeInBits: 64);
605 Register Res = CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
606 CurMF->getRegInfo().setRegClass(Reg: Res, RC: &SPIRV::iIDRegClass);
607 assignSPIRVTypeToVReg(SpirvType: SpvType, VReg: Res, MF: *CurMF);
608
609 const MachineInstr *NewMI =
610 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
611 if (EmitIR)
612 return MIRBuilder.buildSplatBuildVector(Res, Src: ElemReg);
613
614 if (Val) {
615 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantComposite)
616 .addDef(RegNo: Res)
617 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
618 for (unsigned i = 0; i < ElemCnt; ++i)
619 MIB.addUse(RegNo: ElemReg);
620 return MIB;
621 }
622
623 return MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantNull)
624 .addDef(RegNo: Res)
625 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
626 });
627 add(V: CA, MI: NewMI);
628 return Res;
629}
630
631Register
632SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
633 MachineIRBuilder &MIRBuilder,
634 SPIRVType *SpvType, bool EmitIR) {
635 const Type *LLVMTy = getTypeForSPIRVType(Ty: SpvType);
636 assert(LLVMTy->isVectorTy());
637 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(Val: LLVMTy);
638 Type *LLVMBaseTy = LLVMVecTy->getElementType();
639 const auto ConstInt = ConstantInt::get(Ty: LLVMBaseTy, V: Val);
640 auto ConstVec =
641 ConstantVector::getSplat(EC: LLVMVecTy->getElementCount(), Elt: ConstInt);
642 unsigned BW = getScalarOrVectorBitWidth(Type: SpvType);
643 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
644 CA: ConstVec, BitWidth: BW,
645 ElemCnt: SpvType->getOperand(i: 2).getImm());
646}
647
648Register
649SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
650 SPIRVType *SpvType) {
651 const Type *Ty = getTypeForSPIRVType(Ty: SpvType);
652 unsigned AddressSpace = typeToAddressSpace(Ty);
653 Type *ElemTy = ::getPointeeType(Ty);
654 assert(ElemTy);
655 const Constant *CP = ConstantTargetNone::get(
656 T: dyn_cast<TargetExtType>(Val: getTypedPointerWrapper(ElemTy, AS: AddressSpace)));
657 Register Res = find(V: CP, MF: CurMF);
658 if (Res.isValid())
659 return Res;
660
661 LLT LLTy = LLT::pointer(AddressSpace, SizeInBits: PointerSize);
662 Res = CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
663 CurMF->getRegInfo().setRegClass(Reg: Res, RC: &SPIRV::pIDRegClass);
664 assignSPIRVTypeToVReg(SpirvType: SpvType, VReg: Res, MF: *CurMF);
665
666 const MachineInstr *NewMI =
667 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
668 return MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantNull)
669 .addDef(RegNo: Res)
670 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
671 });
672 add(V: CP, MI: NewMI);
673 return Res;
674}
675
676Register
677SPIRVGlobalRegistry::buildConstantSampler(Register ResReg, unsigned AddrMode,
678 unsigned Param, unsigned FilerMode,
679 MachineIRBuilder &MIRBuilder) {
680 auto Sampler =
681 ResReg.isValid()
682 ? ResReg
683 : MIRBuilder.getMRI()->createVirtualRegister(RegClass: &SPIRV::iIDRegClass);
684 SPIRVType *TypeSampler = getOrCreateOpTypeSampler(MIRBuilder);
685 Register TypeSamplerReg = getSPIRVTypeID(SpirvType: TypeSampler);
686 // We cannot use createOpType() logic here, because of the
687 // GlobalISel/IRTranslator.cpp check for a tail call that expects that
688 // MIRBuilder.getInsertPt() has a previous instruction. If this constant is
689 // inserted as a result of "__translate_sampler_initializer()" this would
690 // break this IRTranslator assumption.
691 MIRBuilder.buildInstr(Opcode: SPIRV::OpConstantSampler)
692 .addDef(RegNo: Sampler)
693 .addUse(RegNo: TypeSamplerReg)
694 .addImm(Val: AddrMode)
695 .addImm(Val: Param)
696 .addImm(Val: FilerMode);
697 return Sampler;
698}
699
700Register SPIRVGlobalRegistry::buildGlobalVariable(
701 Register ResVReg, SPIRVType *BaseType, StringRef Name,
702 const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
703 const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
704 SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
705 bool IsInstSelector) {
706 const GlobalVariable *GVar = nullptr;
707 if (GV) {
708 GVar = cast<const GlobalVariable>(Val: GV);
709 } else {
710 // If GV is not passed explicitly, use the name to find or construct
711 // the global variable.
712 Module *M = MIRBuilder.getMF().getFunction().getParent();
713 GVar = M->getGlobalVariable(Name);
714 if (GVar == nullptr) {
715 const Type *Ty = getTypeForSPIRVType(Ty: BaseType); // TODO: check type.
716 // Module takes ownership of the global var.
717 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
718 GlobalValue::ExternalLinkage, nullptr,
719 Twine(Name));
720 }
721 GV = GVar;
722 }
723
724 const MachineFunction *MF = &MIRBuilder.getMF();
725 Register Reg = find(V: GVar, MF);
726 if (Reg.isValid()) {
727 if (Reg != ResVReg)
728 MIRBuilder.buildCopy(Res: ResVReg, Op: Reg);
729 return ResVReg;
730 }
731
732 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpVariable)
733 .addDef(RegNo: ResVReg)
734 .addUse(RegNo: getSPIRVTypeID(SpirvType: BaseType))
735 .addImm(Val: static_cast<uint32_t>(Storage));
736 if (Init != 0)
737 MIB.addUse(RegNo: Init->getOperand(i: 0).getReg());
738 // ISel may introduce a new register on this step, so we need to add it to
739 // DT and correct its type avoiding fails on the next stage.
740 if (IsInstSelector) {
741 const auto &Subtarget = CurMF->getSubtarget();
742 constrainSelectedInstRegOperands(I&: *MIB, TII: *Subtarget.getInstrInfo(),
743 TRI: *Subtarget.getRegisterInfo(),
744 RBI: *Subtarget.getRegBankInfo());
745 }
746 add(V: GVar, MI: MIB);
747
748 Reg = MIB->getOperand(i: 0).getReg();
749 addGlobalObject(V: GVar, MF, R: Reg);
750
751 // Set to Reg the same type as ResVReg has.
752 auto MRI = MIRBuilder.getMRI();
753 if (Reg != ResVReg) {
754 LLT RegLLTy =
755 LLT::pointer(AddressSpace: MRI->getType(Reg: ResVReg).getAddressSpace(), SizeInBits: getPointerSize());
756 MRI->setType(VReg: Reg, Ty: RegLLTy);
757 assignSPIRVTypeToVReg(SpirvType: BaseType, VReg: Reg, MF: MIRBuilder.getMF());
758 } else {
759 // Our knowledge about the type may be updated.
760 // If that's the case, we need to update a type
761 // associated with the register.
762 SPIRVType *DefType = getSPIRVTypeForVReg(VReg: ResVReg);
763 if (!DefType || DefType != BaseType)
764 assignSPIRVTypeToVReg(SpirvType: BaseType, VReg: Reg, MF: MIRBuilder.getMF());
765 }
766
767 // If it's a global variable with name, output OpName for it.
768 if (GVar && GVar->hasName())
769 buildOpName(Target: Reg, Name: GVar->getName(), MIRBuilder);
770
771 // Output decorations for the GV.
772 // TODO: maybe move to GenerateDecorations pass.
773 const SPIRVSubtarget &ST =
774 cast<SPIRVSubtarget>(Val: MIRBuilder.getMF().getSubtarget());
775 if (IsConst && !ST.isShader())
776 buildOpDecorate(Reg, MIRBuilder, Dec: SPIRV::Decoration::Constant, DecArgs: {});
777
778 if (GVar && GVar->getAlign().valueOrOne().value() != 1 && !ST.isShader()) {
779 unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
780 buildOpDecorate(Reg, MIRBuilder, Dec: SPIRV::Decoration::Alignment, DecArgs: {Alignment});
781 }
782
783 if (HasLinkageTy)
784 buildOpDecorate(Reg, MIRBuilder, Dec: SPIRV::Decoration::LinkageAttributes,
785 DecArgs: {static_cast<uint32_t>(LinkageType)}, StrImm: Name);
786
787 SPIRV::BuiltIn::BuiltIn BuiltInId;
788 if (getSpirvBuiltInIdByName(Name, BI&: BuiltInId))
789 buildOpDecorate(Reg, MIRBuilder, Dec: SPIRV::Decoration::BuiltIn,
790 DecArgs: {static_cast<uint32_t>(BuiltInId)});
791
792 // If it's a global variable with "spirv.Decorations" metadata node
793 // recognize it as a SPIR-V friendly LLVM IR and parse "spirv.Decorations"
794 // arguments.
795 MDNode *GVarMD = nullptr;
796 if (GVar && (GVarMD = GVar->getMetadata(Kind: "spirv.Decorations")) != nullptr)
797 buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD);
798
799 return Reg;
800}
801
802// Returns a name based on the Type. Notes that this does not look at
803// decorations, and will return the same string for two types that are the same
804// except for decorations.
805Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding(
806 const SPIRVType *VarType, uint32_t Set, uint32_t Binding, StringRef Name,
807 MachineIRBuilder &MIRBuilder) {
808 Register VarReg =
809 MIRBuilder.getMRI()->createVirtualRegister(RegClass: &SPIRV::iIDRegClass);
810
811 buildGlobalVariable(ResVReg: VarReg, BaseType: VarType, Name, GV: nullptr,
812 Storage: getPointerStorageClass(Type: VarType), Init: nullptr, IsConst: false, HasLinkageTy: false,
813 LinkageType: SPIRV::LinkageType::Import, MIRBuilder, IsInstSelector: false);
814
815 buildOpDecorate(Reg: VarReg, MIRBuilder, Dec: SPIRV::Decoration::DescriptorSet, DecArgs: {Set});
816 buildOpDecorate(Reg: VarReg, MIRBuilder, Dec: SPIRV::Decoration::Binding, DecArgs: {Binding});
817 return VarReg;
818}
819
820// TODO: Double check the calls to getOpTypeArray to make sure that `ElemType`
821// is explicitly laid out when required.
822SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
823 SPIRVType *ElemType,
824 MachineIRBuilder &MIRBuilder,
825 bool ExplicitLayoutRequired,
826 bool EmitIR) {
827 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
828 "Invalid array element type");
829 SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(BitWidth: 32, MIRBuilder);
830 SPIRVType *ArrayType = nullptr;
831 if (NumElems != 0) {
832 Register NumElementsVReg =
833 buildConstantInt(Val: NumElems, MIRBuilder, SpvType: SpvTypeInt32, EmitIR);
834 ArrayType = createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
835 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeArray)
836 .addDef(RegNo: createTypeVReg(MIRBuilder))
837 .addUse(RegNo: getSPIRVTypeID(SpirvType: ElemType))
838 .addUse(RegNo: NumElementsVReg);
839 });
840 } else {
841 ArrayType = createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
842 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeRuntimeArray)
843 .addDef(RegNo: createTypeVReg(MIRBuilder))
844 .addUse(RegNo: getSPIRVTypeID(SpirvType: ElemType));
845 });
846 }
847
848 if (ExplicitLayoutRequired && !isResourceType(Type: ElemType)) {
849 Type *ET = const_cast<Type *>(getTypeForSPIRVType(Ty: ElemType));
850 addArrayStrideDecorations(Reg: ArrayType->defs().begin()->getReg(), ElementType: ET,
851 MIRBuilder);
852 }
853
854 return ArrayType;
855}
856
857SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
858 MachineIRBuilder &MIRBuilder) {
859 assert(Ty->hasName());
860 const StringRef Name = Ty->hasName() ? Ty->getName() : "";
861 Register ResVReg = createTypeVReg(MIRBuilder);
862 return createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
863 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeOpaque).addDef(RegNo: ResVReg);
864 addStringImm(Str: Name, MIB);
865 buildOpName(Target: ResVReg, Name, MIRBuilder);
866 return MIB;
867 });
868}
869
870SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(
871 const StructType *Ty, MachineIRBuilder &MIRBuilder,
872 SPIRV::AccessQualifier::AccessQualifier AccQual,
873 StructOffsetDecorator Decorator, bool EmitIR) {
874 const SPIRVSubtarget &ST =
875 cast<SPIRVSubtarget>(Val: MIRBuilder.getMF().getSubtarget());
876 SmallVector<Register, 4> FieldTypes;
877 constexpr unsigned MaxWordCount = UINT16_MAX;
878 const size_t NumElements = Ty->getNumElements();
879
880 size_t MaxNumElements = MaxWordCount - 2;
881 size_t SPIRVStructNumElements = NumElements;
882 if (NumElements > MaxNumElements) {
883 // Do adjustments for continued instructions.
884 SPIRVStructNumElements = MaxNumElements;
885 MaxNumElements = MaxWordCount - 1;
886 }
887
888 for (const auto &Elem : Ty->elements()) {
889 SPIRVType *ElemTy = findSPIRVType(
890 Ty: toTypedPointer(Ty: Elem), MIRBuilder, accessQual: AccQual,
891 /* ExplicitLayoutRequired= */ Decorator != nullptr, EmitIR);
892 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
893 "Invalid struct element type");
894 FieldTypes.push_back(Elt: getSPIRVTypeID(SpirvType: ElemTy));
895 }
896 Register ResVReg = createTypeVReg(MIRBuilder);
897 if (Ty->hasName())
898 buildOpName(Target: ResVReg, Name: Ty->getName(), MIRBuilder);
899 if (Ty->isPacked() && !ST.isShader())
900 buildOpDecorate(Reg: ResVReg, MIRBuilder, Dec: SPIRV::Decoration::CPacked, DecArgs: {});
901
902 SPIRVType *SPVType =
903 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
904 auto MIBStruct =
905 MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeStruct).addDef(RegNo: ResVReg);
906 for (size_t I = 0; I < SPIRVStructNumElements; ++I)
907 MIBStruct.addUse(RegNo: FieldTypes[I]);
908 for (size_t I = SPIRVStructNumElements; I < NumElements;
909 I += MaxNumElements) {
910 auto MIBCont =
911 MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeStructContinuedINTEL);
912 for (size_t J = I; J < std::min(a: I + MaxNumElements, b: NumElements); ++J)
913 MIBCont.addUse(RegNo: FieldTypes[I]);
914 }
915 return MIBStruct;
916 });
917
918 if (Decorator)
919 Decorator(SPVType->defs().begin()->getReg());
920
921 return SPVType;
922}
923
924SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
925 const Type *Ty, MachineIRBuilder &MIRBuilder,
926 SPIRV::AccessQualifier::AccessQualifier AccQual) {
927 assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
928 return SPIRV::lowerBuiltinType(Type: Ty, AccessQual: AccQual, MIRBuilder, GR: this);
929}
930
931SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
932 SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
933 MachineIRBuilder &MIRBuilder, Register Reg) {
934 if (!Reg.isValid())
935 Reg = createTypeVReg(MIRBuilder);
936
937 return createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
938 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypePointer)
939 .addDef(RegNo: Reg)
940 .addImm(Val: static_cast<uint32_t>(SC))
941 .addUse(RegNo: getSPIRVTypeID(SpirvType: ElemType));
942 });
943}
944
945SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
946 SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
947 return createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
948 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeForwardPointer)
949 .addUse(RegNo: createTypeVReg(MIRBuilder))
950 .addImm(Val: static_cast<uint32_t>(SC));
951 });
952}
953
954SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
955 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
956 MachineIRBuilder &MIRBuilder) {
957 return createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
958 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeFunction)
959 .addDef(RegNo: createTypeVReg(MIRBuilder))
960 .addUse(RegNo: getSPIRVTypeID(SpirvType: RetType));
961 for (const SPIRVType *ArgType : ArgTypes)
962 MIB.addUse(RegNo: getSPIRVTypeID(SpirvType: ArgType));
963 return MIB;
964 });
965}
966
967SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
968 const Type *Ty, SPIRVType *RetType,
969 const SmallVectorImpl<SPIRVType *> &ArgTypes,
970 MachineIRBuilder &MIRBuilder) {
971 if (const MachineInstr *MI = findMI(T: Ty, RequiresExplicitLayout: false, MF: &MIRBuilder.getMF()))
972 return MI;
973 const MachineInstr *NewMI = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
974 add(T: Ty, RequiresExplicitLayout: false, MI: NewMI);
975 return finishCreatingSPIRVType(LLVMTy: Ty, SpirvType: NewMI);
976}
977
978SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
979 const Type *Ty, MachineIRBuilder &MIRBuilder,
980 SPIRV::AccessQualifier::AccessQualifier AccQual,
981 bool ExplicitLayoutRequired, bool EmitIR) {
982 Ty = adjustIntTypeByWidth(Ty);
983 // TODO: findMI needs to know if a layout is required.
984 if (const MachineInstr *MI =
985 findMI(T: Ty, RequiresExplicitLayout: ExplicitLayoutRequired, MF: &MIRBuilder.getMF()))
986 return MI;
987 if (auto It = ForwardPointerTypes.find(Val: Ty); It != ForwardPointerTypes.end())
988 return It->second;
989 return restOfCreateSPIRVType(Type: Ty, MIRBuilder, AccessQual: AccQual, ExplicitLayoutRequired,
990 EmitIR);
991}
992
993Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
994 assert(SpirvType && "Attempting to get type id for nullptr type.");
995 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer ||
996 SpirvType->getOpcode() == SPIRV::OpTypeStructContinuedINTEL)
997 return SpirvType->uses().begin()->getReg();
998 return SpirvType->defs().begin()->getReg();
999}
1000
1001// We need to use a new LLVM integer type if there is a mismatch between
1002// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
1003// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
1004// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
1005// same "OpTypeInt 8" type for a series of LLVM integer types with number of
1006// bits less than 8. This would lead to duplicate type definitions
1007// eventually due to the method that DuplicateTracker utilizes to reason
1008// about uniqueness of type records.
1009const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
1010 if (auto IType = dyn_cast<IntegerType>(Val: Ty)) {
1011 unsigned SrcBitWidth = IType->getBitWidth();
1012 if (SrcBitWidth > 1) {
1013 unsigned BitWidth = adjustOpTypeIntWidth(Width: SrcBitWidth);
1014 // Maybe change source LLVM type to keep DuplicateTracker consistent.
1015 if (SrcBitWidth != BitWidth)
1016 Ty = IntegerType::get(C&: Ty->getContext(), NumBits: BitWidth);
1017 }
1018 }
1019 return Ty;
1020}
1021
1022SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
1023 const Type *Ty, MachineIRBuilder &MIRBuilder,
1024 SPIRV::AccessQualifier::AccessQualifier AccQual,
1025 bool ExplicitLayoutRequired, bool EmitIR) {
1026 if (isSpecialOpaqueType(Ty))
1027 return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
1028
1029 if (const MachineInstr *MI =
1030 findMI(T: Ty, RequiresExplicitLayout: ExplicitLayoutRequired, MF: &MIRBuilder.getMF()))
1031 return MI;
1032
1033 if (auto IType = dyn_cast<IntegerType>(Val: Ty)) {
1034 const unsigned Width = IType->getBitWidth();
1035 return Width == 1 ? getOpTypeBool(MIRBuilder)
1036 : getOpTypeInt(Width, MIRBuilder, IsSigned: false);
1037 }
1038 if (Ty->isFloatingPointTy())
1039 return getOpTypeFloat(Width: Ty->getPrimitiveSizeInBits(), MIRBuilder);
1040 if (Ty->isVoidTy())
1041 return getOpTypeVoid(MIRBuilder);
1042 if (Ty->isVectorTy()) {
1043 SPIRVType *El =
1044 findSPIRVType(Ty: cast<FixedVectorType>(Val: Ty)->getElementType(), MIRBuilder,
1045 AccQual, ExplicitLayoutRequired, EmitIR);
1046 return getOpTypeVector(NumElems: cast<FixedVectorType>(Val: Ty)->getNumElements(), ElemType: El,
1047 MIRBuilder);
1048 }
1049 if (Ty->isArrayTy()) {
1050 SPIRVType *El = findSPIRVType(Ty: Ty->getArrayElementType(), MIRBuilder,
1051 AccQual, ExplicitLayoutRequired, EmitIR);
1052 return getOpTypeArray(NumElems: Ty->getArrayNumElements(), ElemType: El, MIRBuilder,
1053 ExplicitLayoutRequired, EmitIR);
1054 }
1055 if (auto SType = dyn_cast<StructType>(Val: Ty)) {
1056 if (SType->isOpaque())
1057 return getOpTypeOpaque(Ty: SType, MIRBuilder);
1058
1059 StructOffsetDecorator Decorator = nullptr;
1060 if (ExplicitLayoutRequired) {
1061 Decorator = [&MIRBuilder, SType, this](Register Reg) {
1062 addStructOffsetDecorations(Reg, Ty: const_cast<StructType *>(SType),
1063 MIRBuilder);
1064 };
1065 }
1066 return getOpTypeStruct(Ty: SType, MIRBuilder, AccQual, Decorator, EmitIR);
1067 }
1068 if (auto FType = dyn_cast<FunctionType>(Val: Ty)) {
1069 SPIRVType *RetTy = findSPIRVType(Ty: FType->getReturnType(), MIRBuilder,
1070 AccQual, ExplicitLayoutRequired, EmitIR);
1071 SmallVector<SPIRVType *, 4> ParamTypes;
1072 for (const auto &ParamTy : FType->params())
1073 ParamTypes.push_back(Elt: findSPIRVType(Ty: ParamTy, MIRBuilder, AccQual,
1074 ExplicitLayoutRequired, EmitIR));
1075 return getOpTypeFunction(RetType: RetTy, ArgTypes: ParamTypes, MIRBuilder);
1076 }
1077
1078 unsigned AddrSpace = typeToAddressSpace(Ty);
1079 SPIRVType *SpvElementType = nullptr;
1080 if (Type *ElemTy = ::getPointeeType(Ty))
1081 SpvElementType = getOrCreateSPIRVType(Type: ElemTy, MIRBuilder, AQ: AccQual, EmitIR);
1082 else
1083 SpvElementType = getOrCreateSPIRVIntegerType(BitWidth: 8, MIRBuilder);
1084
1085 // Get access to information about available extensions
1086 const SPIRVSubtarget *ST =
1087 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
1088 auto SC = addressSpaceToStorageClass(AddrSpace, STI: *ST);
1089
1090 Type *ElemTy = ::getPointeeType(Ty);
1091 if (!ElemTy) {
1092 ElemTy = Type::getInt8Ty(C&: MIRBuilder.getContext());
1093 }
1094
1095 // If we have forward pointer associated with this type, use its register
1096 // operand to create OpTypePointer.
1097 if (auto It = ForwardPointerTypes.find(Val: Ty); It != ForwardPointerTypes.end()) {
1098 Register Reg = getSPIRVTypeID(SpirvType: It->second);
1099 // TODO: what does getOpTypePointer do?
1100 return getOpTypePointer(SC, ElemType: SpvElementType, MIRBuilder, Reg);
1101 }
1102
1103 return getOrCreateSPIRVPointerType(BaseType: ElemTy, MIRBuilder, SC);
1104}
1105
1106SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
1107 const Type *Ty, MachineIRBuilder &MIRBuilder,
1108 SPIRV::AccessQualifier::AccessQualifier AccessQual,
1109 bool ExplicitLayoutRequired, bool EmitIR) {
1110 // TODO: Could this create a problem if one requires an explicit layout, and
1111 // the next time it does not?
1112 if (TypesInProcessing.count(Ptr: Ty) && !isPointerTyOrWrapper(Ty))
1113 return nullptr;
1114 TypesInProcessing.insert(Ptr: Ty);
1115 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccQual: AccessQual,
1116 ExplicitLayoutRequired, EmitIR);
1117 TypesInProcessing.erase(Ptr: Ty);
1118 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
1119
1120 // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
1121 // Is that a problem?
1122 SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
1123
1124 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer ||
1125 findMI(T: Ty, RequiresExplicitLayout: false, MF: &MIRBuilder.getMF()) || isSpecialOpaqueType(Ty))
1126 return SpirvType;
1127
1128 if (auto *ExtTy = dyn_cast<TargetExtType>(Val: Ty);
1129 ExtTy && isTypedPointerWrapper(ExtTy))
1130 add(PointeeTy: ExtTy->getTypeParameter(i: 0), AddressSpace: ExtTy->getIntParameter(i: 0), MI: SpirvType);
1131 else if (!isPointerTy(T: Ty))
1132 add(T: Ty, RequiresExplicitLayout: ExplicitLayoutRequired, MI: SpirvType);
1133 else if (isTypedPointerTy(T: Ty))
1134 add(PointeeTy: cast<TypedPointerType>(Val: Ty)->getElementType(),
1135 AddressSpace: getPointerAddressSpace(T: Ty), MI: SpirvType);
1136 else
1137 add(PointeeTy: Type::getInt8Ty(C&: MIRBuilder.getMF().getFunction().getContext()),
1138 AddressSpace: getPointerAddressSpace(T: Ty), MI: SpirvType);
1139 return SpirvType;
1140}
1141
1142SPIRVType *
1143SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
1144 const MachineFunction *MF) const {
1145 auto t = VRegToTypeMap.find(Val: MF ? MF : CurMF);
1146 if (t != VRegToTypeMap.end()) {
1147 auto tt = t->second.find(Val: VReg);
1148 if (tt != t->second.end())
1149 return tt->second;
1150 }
1151 return nullptr;
1152}
1153
1154SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg,
1155 MachineFunction *MF) {
1156 if (!MF)
1157 MF = CurMF;
1158 MachineInstr *Instr = getVRegDef(MRI&: MF->getRegInfo(), Reg: VReg);
1159 return getSPIRVTypeForVReg(VReg: Instr->getOperand(i: 1).getReg(), MF);
1160}
1161
1162SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
1163 const Type *Ty, MachineIRBuilder &MIRBuilder,
1164 SPIRV::AccessQualifier::AccessQualifier AccessQual,
1165 bool ExplicitLayoutRequired, bool EmitIR) {
1166 const MachineFunction *MF = &MIRBuilder.getMF();
1167 Register Reg;
1168 if (auto *ExtTy = dyn_cast<TargetExtType>(Val: Ty);
1169 ExtTy && isTypedPointerWrapper(ExtTy))
1170 Reg = find(PointeeTy: ExtTy->getTypeParameter(i: 0), AddressSpace: ExtTy->getIntParameter(i: 0), MF);
1171 else if (!isPointerTy(T: Ty))
1172 Reg = find(T: Ty = adjustIntTypeByWidth(Ty), RequiresExplicitLayout: ExplicitLayoutRequired, MF);
1173 else if (isTypedPointerTy(T: Ty))
1174 Reg = find(PointeeTy: cast<TypedPointerType>(Val: Ty)->getElementType(),
1175 AddressSpace: getPointerAddressSpace(T: Ty), MF);
1176 else
1177 Reg = find(PointeeTy: Type::getInt8Ty(C&: MIRBuilder.getMF().getFunction().getContext()),
1178 AddressSpace: getPointerAddressSpace(T: Ty), MF);
1179 if (Reg.isValid() && !isSpecialOpaqueType(Ty))
1180 return getSPIRVTypeForVReg(VReg: Reg);
1181
1182 TypesInProcessing.clear();
1183 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual,
1184 ExplicitLayoutRequired, EmitIR);
1185 // Create normal pointer types for the corresponding OpTypeForwardPointers.
1186 for (auto &CU : ForwardPointerTypes) {
1187 // Pointer type themselves do not require an explicit layout. The types
1188 // they pointer to might, but that is taken care of when creating the type.
1189 bool PtrNeedsLayout = false;
1190 const Type *Ty2 = CU.first;
1191 SPIRVType *STy2 = CU.second;
1192 if ((Reg = find(T: Ty2, RequiresExplicitLayout: PtrNeedsLayout, MF)).isValid())
1193 STy2 = getSPIRVTypeForVReg(VReg: Reg);
1194 else
1195 STy2 = restOfCreateSPIRVType(Ty: Ty2, MIRBuilder, AccessQual, ExplicitLayoutRequired: PtrNeedsLayout,
1196 EmitIR);
1197 if (Ty == Ty2)
1198 STy = STy2;
1199 }
1200 ForwardPointerTypes.clear();
1201 return STy;
1202}
1203
1204bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
1205 unsigned TypeOpcode) const {
1206 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1207 assert(Type && "isScalarOfType VReg has no type assigned");
1208 return Type->getOpcode() == TypeOpcode;
1209}
1210
1211bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
1212 unsigned TypeOpcode) const {
1213 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1214 assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
1215 if (Type->getOpcode() == TypeOpcode)
1216 return true;
1217 if (Type->getOpcode() == SPIRV::OpTypeVector) {
1218 Register ScalarTypeVReg = Type->getOperand(i: 1).getReg();
1219 SPIRVType *ScalarType = getSPIRVTypeForVReg(VReg: ScalarTypeVReg);
1220 return ScalarType->getOpcode() == TypeOpcode;
1221 }
1222 return false;
1223}
1224
1225bool SPIRVGlobalRegistry::isResourceType(SPIRVType *Type) const {
1226 switch (Type->getOpcode()) {
1227 case SPIRV::OpTypeImage:
1228 case SPIRV::OpTypeSampler:
1229 case SPIRV::OpTypeSampledImage:
1230 return true;
1231 case SPIRV::OpTypeStruct:
1232 return hasBlockDecoration(Type);
1233 default:
1234 return false;
1235 }
1236 return false;
1237}
1238unsigned
1239SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
1240 return getScalarOrVectorComponentCount(Type: getSPIRVTypeForVReg(VReg));
1241}
1242
1243unsigned
1244SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {
1245 if (!Type)
1246 return 0;
1247 return Type->getOpcode() == SPIRV::OpTypeVector
1248 ? static_cast<unsigned>(Type->getOperand(i: 2).getImm())
1249 : 1;
1250}
1251
1252SPIRVType *
1253SPIRVGlobalRegistry::getScalarOrVectorComponentType(Register VReg) const {
1254 return getScalarOrVectorComponentType(Type: getSPIRVTypeForVReg(VReg));
1255}
1256
1257SPIRVType *
1258SPIRVGlobalRegistry::getScalarOrVectorComponentType(SPIRVType *Type) const {
1259 if (!Type)
1260 return nullptr;
1261 Register ScalarReg = Type->getOpcode() == SPIRV::OpTypeVector
1262 ? Type->getOperand(i: 1).getReg()
1263 : Type->getOperand(i: 0).getReg();
1264 SPIRVType *ScalarType = getSPIRVTypeForVReg(VReg: ScalarReg);
1265 assert(isScalarOrVectorOfType(Type->getOperand(0).getReg(),
1266 ScalarType->getOpcode()));
1267 return ScalarType;
1268}
1269
1270unsigned
1271SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
1272 assert(Type && "Invalid Type pointer");
1273 if (Type->getOpcode() == SPIRV::OpTypeVector) {
1274 auto EleTypeReg = Type->getOperand(i: 1).getReg();
1275 Type = getSPIRVTypeForVReg(VReg: EleTypeReg);
1276 }
1277 if (Type->getOpcode() == SPIRV::OpTypeInt ||
1278 Type->getOpcode() == SPIRV::OpTypeFloat)
1279 return Type->getOperand(i: 1).getImm();
1280 if (Type->getOpcode() == SPIRV::OpTypeBool)
1281 return 1;
1282 llvm_unreachable("Attempting to get bit width of non-integer/float type.");
1283}
1284
1285unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
1286 const SPIRVType *Type) const {
1287 assert(Type && "Invalid Type pointer");
1288 unsigned NumElements = 1;
1289 if (Type->getOpcode() == SPIRV::OpTypeVector) {
1290 NumElements = static_cast<unsigned>(Type->getOperand(i: 2).getImm());
1291 Type = getSPIRVTypeForVReg(VReg: Type->getOperand(i: 1).getReg());
1292 }
1293 return Type->getOpcode() == SPIRV::OpTypeInt ||
1294 Type->getOpcode() == SPIRV::OpTypeFloat
1295 ? NumElements * Type->getOperand(i: 1).getImm()
1296 : 0;
1297}
1298
1299const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
1300 const SPIRVType *Type) const {
1301 if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
1302 Type = getSPIRVTypeForVReg(VReg: Type->getOperand(i: 1).getReg());
1303 return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
1304}
1305
1306bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
1307 const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
1308 return IntType && IntType->getOperand(i: 2).getImm() != 0;
1309}
1310
1311SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
1312 return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
1313 ? getSPIRVTypeForVReg(VReg: PtrType->getOperand(i: 2).getReg())
1314 : nullptr;
1315}
1316
1317unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
1318 SPIRVType *ElemType = getPointeeType(PtrType: getSPIRVTypeForVReg(VReg: PtrReg));
1319 return ElemType ? ElemType->getOpcode() : 0;
1320}
1321
1322bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
1323 const SPIRVType *Type2) const {
1324 if (!Type1 || !Type2)
1325 return false;
1326 auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
1327 // Ignore difference between <1.5 and >=1.5 protocol versions:
1328 // it's valid if either Result Type or Operand is a pointer, and the other
1329 // is a pointer, an integer scalar, or an integer vector.
1330 if (Op1 == SPIRV::OpTypePointer &&
1331 (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type: Type2)))
1332 return true;
1333 if (Op2 == SPIRV::OpTypePointer &&
1334 (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type: Type1)))
1335 return true;
1336 unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type: Type1),
1337 Bits2 = getNumScalarOrVectorTotalBitWidth(Type: Type2);
1338 return Bits1 > 0 && Bits1 == Bits2;
1339}
1340
1341SPIRV::StorageClass::StorageClass
1342SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
1343 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1344 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
1345 Type->getOperand(1).isImm() && "Pointer type is expected");
1346 return getPointerStorageClass(Type);
1347}
1348
1349SPIRV::StorageClass::StorageClass
1350SPIRVGlobalRegistry::getPointerStorageClass(const SPIRVType *Type) const {
1351 return static_cast<SPIRV::StorageClass::StorageClass>(
1352 Type->getOperand(i: 1).getImm());
1353}
1354
1355SPIRVType *SPIRVGlobalRegistry::getOrCreateVulkanBufferType(
1356 MachineIRBuilder &MIRBuilder, Type *ElemType,
1357 SPIRV::StorageClass::StorageClass SC, bool IsWritable, bool EmitIr) {
1358 auto Key = SPIRV::irhandle_vkbuffer(ElementType: ElemType, SC, IsWriteable: IsWritable);
1359 if (const MachineInstr *MI = findMI(Handle: Key, MF: &MIRBuilder.getMF()))
1360 return MI;
1361
1362 bool ExplicitLayoutRequired = storageClassRequiresExplictLayout(SC);
1363 // We need to get the SPIR-V type for the element here, so we can add the
1364 // decoration to it.
1365 auto *T = StructType::create(Elements: ElemType);
1366 auto *BlockType =
1367 getOrCreateSPIRVType(Ty: T, MIRBuilder, AccessQual: SPIRV::AccessQualifier::None,
1368 ExplicitLayoutRequired, EmitIR: EmitIr);
1369
1370 buildOpDecorate(Reg: BlockType->defs().begin()->getReg(), MIRBuilder,
1371 Dec: SPIRV::Decoration::Block, DecArgs: {});
1372
1373 if (!IsWritable) {
1374 buildOpMemberDecorate(Reg: BlockType->defs().begin()->getReg(), MIRBuilder,
1375 Dec: SPIRV::Decoration::NonWritable, Member: 0, DecArgs: {});
1376 }
1377
1378 SPIRVType *R = getOrCreateSPIRVPointerTypeInternal(BaseType: BlockType, MIRBuilder, SC);
1379 add(Handle: Key, MI: R);
1380 return R;
1381}
1382
1383SPIRVType *SPIRVGlobalRegistry::getOrCreateLayoutType(
1384 MachineIRBuilder &MIRBuilder, const TargetExtType *T, bool EmitIr) {
1385 auto Key = SPIRV::handle(Ty: T);
1386 if (const MachineInstr *MI = findMI(Handle: Key, MF: &MIRBuilder.getMF()))
1387 return MI;
1388
1389 StructType *ST = cast<StructType>(Val: T->getTypeParameter(i: 0));
1390 ArrayRef<uint32_t> Offsets = T->int_params().slice(N: 1);
1391 assert(ST->getNumElements() == Offsets.size());
1392
1393 StructOffsetDecorator Decorator = [&MIRBuilder, &Offsets](Register Reg) {
1394 for (uint32_t I = 0; I < Offsets.size(); ++I) {
1395 buildOpMemberDecorate(Reg, MIRBuilder, Dec: SPIRV::Decoration::Offset, Member: I,
1396 DecArgs: {Offsets[I]});
1397 }
1398 };
1399
1400 // We need a new OpTypeStruct instruction because decorations will be
1401 // different from a struct with an explicit layout created from a different
1402 // entry point.
1403 SPIRVType *SPIRVStructType = getOpTypeStruct(
1404 Ty: ST, MIRBuilder, AccQual: SPIRV::AccessQualifier::None, Decorator, EmitIR: EmitIr);
1405 add(Handle: Key, MI: SPIRVStructType);
1406 return SPIRVStructType;
1407}
1408
1409SPIRVType *SPIRVGlobalRegistry::getImageType(
1410 const TargetExtType *ExtensionType,
1411 const SPIRV::AccessQualifier::AccessQualifier Qualifier,
1412 MachineIRBuilder &MIRBuilder) {
1413 assert(ExtensionType->getNumTypeParameters() == 1 &&
1414 "SPIR-V image builtin type must have sampled type parameter!");
1415 const SPIRVType *SampledType =
1416 getOrCreateSPIRVType(Type: ExtensionType->getTypeParameter(i: 0), MIRBuilder,
1417 AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
1418 assert((ExtensionType->getNumIntParameters() == 7 ||
1419 ExtensionType->getNumIntParameters() == 6) &&
1420 "Invalid number of parameters for SPIR-V image builtin!");
1421
1422 SPIRV::AccessQualifier::AccessQualifier accessQualifier =
1423 SPIRV::AccessQualifier::None;
1424 if (ExtensionType->getNumIntParameters() == 7) {
1425 accessQualifier = Qualifier == SPIRV::AccessQualifier::WriteOnly
1426 ? SPIRV::AccessQualifier::WriteOnly
1427 : SPIRV::AccessQualifier::AccessQualifier(
1428 ExtensionType->getIntParameter(i: 6));
1429 }
1430
1431 // Create or get an existing type from GlobalRegistry.
1432 SPIRVType *R = getOrCreateOpTypeImage(
1433 MIRBuilder, SampledType,
1434 Dim: SPIRV::Dim::Dim(ExtensionType->getIntParameter(i: 0)),
1435 Depth: ExtensionType->getIntParameter(i: 1), Arrayed: ExtensionType->getIntParameter(i: 2),
1436 Multisampled: ExtensionType->getIntParameter(i: 3), Sampled: ExtensionType->getIntParameter(i: 4),
1437 ImageFormat: SPIRV::ImageFormat::ImageFormat(ExtensionType->getIntParameter(i: 5)),
1438 AccQual: accessQualifier);
1439 SPIRVToLLVMType[R] = ExtensionType;
1440 return R;
1441}
1442
1443SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
1444 MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
1445 uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
1446 SPIRV::ImageFormat::ImageFormat ImageFormat,
1447 SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1448 auto Key = SPIRV::irhandle_image(SampledTy: SPIRVToLLVMType.lookup(Val: SampledType), Dim,
1449 Depth, Arrayed, MS: Multisampled, Sampled,
1450 ImageFormat, AQ: AccessQual);
1451 if (const MachineInstr *MI = findMI(Handle: Key, MF: &MIRBuilder.getMF()))
1452 return MI;
1453 const MachineInstr *NewMI =
1454 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1455 auto MIB =
1456 MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeImage)
1457 .addDef(RegNo: createTypeVReg(MIRBuilder))
1458 .addUse(RegNo: getSPIRVTypeID(SpirvType: SampledType))
1459 .addImm(Val: Dim)
1460 .addImm(Val: Depth) // Depth (whether or not it is a Depth image).
1461 .addImm(Val: Arrayed) // Arrayed.
1462 .addImm(Val: Multisampled) // Multisampled (0 = only single-sample).
1463 .addImm(Val: Sampled) // Sampled (0 = usage known at runtime).
1464 .addImm(Val: ImageFormat);
1465 if (AccessQual != SPIRV::AccessQualifier::None)
1466 MIB.addImm(Val: AccessQual);
1467 return MIB;
1468 });
1469 add(Handle: Key, MI: NewMI);
1470 return NewMI;
1471}
1472
1473SPIRVType *
1474SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
1475 auto Key = SPIRV::irhandle_sampler();
1476 const MachineFunction *MF = &MIRBuilder.getMF();
1477 if (const MachineInstr *MI = findMI(Handle: Key, MF))
1478 return MI;
1479 const MachineInstr *NewMI =
1480 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1481 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeSampler)
1482 .addDef(RegNo: createTypeVReg(MIRBuilder));
1483 });
1484 add(Handle: Key, MI: NewMI);
1485 return NewMI;
1486}
1487
1488SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
1489 MachineIRBuilder &MIRBuilder,
1490 SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1491 auto Key = SPIRV::irhandle_pipe(AQ: AccessQual);
1492 if (const MachineInstr *MI = findMI(Handle: Key, MF: &MIRBuilder.getMF()))
1493 return MI;
1494 const MachineInstr *NewMI =
1495 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1496 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypePipe)
1497 .addDef(RegNo: createTypeVReg(MIRBuilder))
1498 .addImm(Val: AccessQual);
1499 });
1500 add(Handle: Key, MI: NewMI);
1501 return NewMI;
1502}
1503
1504SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
1505 MachineIRBuilder &MIRBuilder) {
1506 auto Key = SPIRV::irhandle_event();
1507 if (const MachineInstr *MI = findMI(Handle: Key, MF: &MIRBuilder.getMF()))
1508 return MI;
1509 const MachineInstr *NewMI =
1510 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1511 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeDeviceEvent)
1512 .addDef(RegNo: createTypeVReg(MIRBuilder));
1513 });
1514 add(Handle: Key, MI: NewMI);
1515 return NewMI;
1516}
1517
1518SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
1519 SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
1520 auto Key = SPIRV::irhandle_sampled_image(
1521 SampledTy: SPIRVToLLVMType.lookup(Val: MIRBuilder.getMF().getRegInfo().getVRegDef(
1522 Reg: ImageType->getOperand(i: 1).getReg())),
1523 ImageTy: ImageType);
1524 if (const MachineInstr *MI = findMI(Handle: Key, MF: &MIRBuilder.getMF()))
1525 return MI;
1526 const MachineInstr *NewMI =
1527 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1528 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeSampledImage)
1529 .addDef(RegNo: createTypeVReg(MIRBuilder))
1530 .addUse(RegNo: getSPIRVTypeID(SpirvType: ImageType));
1531 });
1532 add(Handle: Key, MI: NewMI);
1533 return NewMI;
1534}
1535
1536SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
1537 MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
1538 const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
1539 uint32_t Use, bool EmitIR) {
1540 if (const MachineInstr *MI =
1541 findMI(T: ExtensionType, RequiresExplicitLayout: false, MF: &MIRBuilder.getMF()))
1542 return MI;
1543 const MachineInstr *NewMI =
1544 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1545 SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(BitWidth: 32, MIRBuilder);
1546 const Type *ET = getTypeForSPIRVType(Ty: ElemType);
1547 if (ET->isIntegerTy() && ET->getIntegerBitWidth() == 4 &&
1548 cast<SPIRVSubtarget>(Val: MIRBuilder.getMF().getSubtarget())
1549 .canUseExtension(E: SPIRV::Extension::SPV_INTEL_int4)) {
1550 MIRBuilder.buildInstr(Opcode: SPIRV::OpCapability)
1551 .addImm(Val: SPIRV::Capability::Int4CooperativeMatrixINTEL);
1552 }
1553 return MIRBuilder.buildInstr(Opcode: SPIRV::OpTypeCooperativeMatrixKHR)
1554 .addDef(RegNo: createTypeVReg(MIRBuilder))
1555 .addUse(RegNo: getSPIRVTypeID(SpirvType: ElemType))
1556 .addUse(RegNo: buildConstantInt(Val: Scope, MIRBuilder, SpvType: SpvTypeInt32, EmitIR))
1557 .addUse(RegNo: buildConstantInt(Val: Rows, MIRBuilder, SpvType: SpvTypeInt32, EmitIR))
1558 .addUse(RegNo: buildConstantInt(Val: Columns, MIRBuilder, SpvType: SpvTypeInt32, EmitIR))
1559 .addUse(RegNo: buildConstantInt(Val: Use, MIRBuilder, SpvType: SpvTypeInt32, EmitIR));
1560 });
1561 add(T: ExtensionType, RequiresExplicitLayout: false, MI: NewMI);
1562 return NewMI;
1563}
1564
1565SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
1566 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
1567 if (const MachineInstr *MI = findMI(T: Ty, RequiresExplicitLayout: false, MF: &MIRBuilder.getMF()))
1568 return MI;
1569 const MachineInstr *NewMI =
1570 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1571 return MIRBuilder.buildInstr(Opcode).addDef(RegNo: createTypeVReg(MIRBuilder));
1572 });
1573 add(T: Ty, RequiresExplicitLayout: false, MI: NewMI);
1574 return NewMI;
1575}
1576
1577SPIRVType *SPIRVGlobalRegistry::getOrCreateUnknownType(
1578 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode,
1579 const ArrayRef<MCOperand> Operands) {
1580 if (const MachineInstr *MI = findMI(T: Ty, RequiresExplicitLayout: false, MF: &MIRBuilder.getMF()))
1581 return MI;
1582 Register ResVReg = createTypeVReg(MIRBuilder);
1583 const MachineInstr *NewMI =
1584 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1585 MachineInstrBuilder MIB = MIRBuilder.buildInstr(Opcode: SPIRV::UNKNOWN_type)
1586 .addDef(RegNo: ResVReg)
1587 .addImm(Val: Opcode);
1588 for (MCOperand Operand : Operands) {
1589 if (Operand.isReg()) {
1590 MIB.addUse(RegNo: Operand.getReg());
1591 } else if (Operand.isImm()) {
1592 MIB.addImm(Val: Operand.getImm());
1593 }
1594 }
1595 return MIB;
1596 });
1597 add(T: Ty, RequiresExplicitLayout: false, MI: NewMI);
1598 return NewMI;
1599}
1600
1601// Returns nullptr if unable to recognize SPIRV type name
1602SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
1603 StringRef TypeStr, MachineIRBuilder &MIRBuilder, bool EmitIR,
1604 SPIRV::StorageClass::StorageClass SC,
1605 SPIRV::AccessQualifier::AccessQualifier AQ) {
1606 unsigned VecElts = 0;
1607 auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
1608
1609 // Parse strings representing either a SPIR-V or OpenCL builtin type.
1610 if (hasBuiltinTypePrefix(Name: TypeStr))
1611 return getOrCreateSPIRVType(Ty: SPIRV::parseBuiltinTypeNameToTargetExtType(
1612 TypeName: TypeStr.str(), Context&: MIRBuilder.getContext()),
1613 MIRBuilder, AccessQual: AQ, ExplicitLayoutRequired: false, EmitIR: true);
1614
1615 // Parse type name in either "typeN" or "type vector[N]" format, where
1616 // N is the number of elements of the vector.
1617 Type *Ty;
1618
1619 Ty = parseBasicTypeName(TypeName&: TypeStr, Ctx);
1620 if (!Ty)
1621 // Unable to recognize SPIRV type name
1622 return nullptr;
1623
1624 const SPIRVType *SpirvTy =
1625 getOrCreateSPIRVType(Ty, MIRBuilder, AccessQual: AQ, ExplicitLayoutRequired: false, EmitIR: true);
1626
1627 // Handle "type*" or "type* vector[N]".
1628 if (TypeStr.consume_front(Prefix: "*"))
1629 SpirvTy = getOrCreateSPIRVPointerType(BaseType: Ty, MIRBuilder, SC);
1630
1631 // Handle "typeN*" or "type vector[N]*".
1632 bool IsPtrToVec = TypeStr.consume_back(Suffix: "*");
1633
1634 if (TypeStr.consume_front(Prefix: " vector[")) {
1635 TypeStr = TypeStr.substr(Start: 0, N: TypeStr.find(C: ']'));
1636 }
1637 TypeStr.getAsInteger(Radix: 10, Result&: VecElts);
1638 if (VecElts > 0)
1639 SpirvTy = getOrCreateSPIRVVectorType(BaseType: SpirvTy, NumElements: VecElts, MIRBuilder, EmitIR);
1640
1641 if (IsPtrToVec)
1642 SpirvTy = getOrCreateSPIRVPointerType(BaseType: SpirvTy, MIRBuilder, SC);
1643
1644 return SpirvTy;
1645}
1646
1647SPIRVType *
1648SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
1649 MachineIRBuilder &MIRBuilder) {
1650 return getOrCreateSPIRVType(
1651 Ty: IntegerType::get(C&: MIRBuilder.getMF().getFunction().getContext(), NumBits: BitWidth),
1652 MIRBuilder, AccessQual: SPIRV::AccessQualifier::ReadWrite, ExplicitLayoutRequired: false, EmitIR: true);
1653}
1654
1655SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
1656 SPIRVType *SpirvType) {
1657 assert(CurMF == SpirvType->getMF());
1658 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
1659 SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty: LLVMTy);
1660 return SpirvType;
1661}
1662
1663SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
1664 MachineInstr &I,
1665 const SPIRVInstrInfo &TII,
1666 unsigned SPIRVOPcode,
1667 Type *Ty) {
1668 if (const MachineInstr *MI = findMI(T: Ty, RequiresExplicitLayout: false, MF: CurMF))
1669 return MI;
1670 MachineBasicBlock &DepMBB = I.getMF()->front();
1671 MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
1672 const MachineInstr *NewMI =
1673 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1674 return BuildMI(BB&: MIRBuilder.getMBB(), I&: *MIRBuilder.getInsertPt(),
1675 MIMD: MIRBuilder.getDL(), MCID: TII.get(Opcode: SPIRVOPcode))
1676 .addDef(RegNo: createTypeVReg(MRI&: CurMF->getRegInfo()))
1677 .addImm(Val: BitWidth)
1678 .addImm(Val: 0);
1679 });
1680 add(T: Ty, RequiresExplicitLayout: false, MI: NewMI);
1681 return finishCreatingSPIRVType(LLVMTy: Ty, SpirvType: NewMI);
1682}
1683
1684SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1685 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1686 // Maybe adjust bit width to keep DuplicateTracker consistent. Without
1687 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
1688 // example, the same "OpTypeInt 8" type for a series of LLVM integer types
1689 // with number of bits less than 8, causing duplicate type definitions.
1690 if (BitWidth > 1)
1691 BitWidth = adjustOpTypeIntWidth(Width: BitWidth);
1692 Type *LLVMTy = IntegerType::get(C&: CurMF->getFunction().getContext(), NumBits: BitWidth);
1693 return getOrCreateSPIRVType(BitWidth, I, TII, SPIRVOPcode: SPIRV::OpTypeInt, Ty: LLVMTy);
1694}
1695
1696SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
1697 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1698 LLVMContext &Ctx = CurMF->getFunction().getContext();
1699 Type *LLVMTy;
1700 switch (BitWidth) {
1701 case 16:
1702 LLVMTy = Type::getHalfTy(C&: Ctx);
1703 break;
1704 case 32:
1705 LLVMTy = Type::getFloatTy(C&: Ctx);
1706 break;
1707 case 64:
1708 LLVMTy = Type::getDoubleTy(C&: Ctx);
1709 break;
1710 default:
1711 llvm_unreachable("Bit width is of unexpected size.");
1712 }
1713 return getOrCreateSPIRVType(BitWidth, I, TII, SPIRVOPcode: SPIRV::OpTypeFloat, Ty: LLVMTy);
1714}
1715
1716SPIRVType *
1717SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder,
1718 bool EmitIR) {
1719 return getOrCreateSPIRVType(
1720 Ty: IntegerType::get(C&: MIRBuilder.getMF().getFunction().getContext(), NumBits: 1),
1721 MIRBuilder, AccessQual: SPIRV::AccessQualifier::ReadWrite, ExplicitLayoutRequired: false, EmitIR);
1722}
1723
1724SPIRVType *
1725SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
1726 const SPIRVInstrInfo &TII) {
1727 Type *Ty = IntegerType::get(C&: CurMF->getFunction().getContext(), NumBits: 1);
1728 if (const MachineInstr *MI = findMI(T: Ty, RequiresExplicitLayout: false, MF: CurMF))
1729 return MI;
1730 MachineBasicBlock &DepMBB = I.getMF()->front();
1731 MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
1732 const MachineInstr *NewMI =
1733 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1734 return BuildMI(BB&: MIRBuilder.getMBB(), I&: *MIRBuilder.getInsertPt(),
1735 MIMD: MIRBuilder.getDL(), MCID: TII.get(Opcode: SPIRV::OpTypeBool))
1736 .addDef(RegNo: createTypeVReg(MRI&: CurMF->getRegInfo()));
1737 });
1738 add(T: Ty, RequiresExplicitLayout: false, MI: NewMI);
1739 return finishCreatingSPIRVType(LLVMTy: Ty, SpirvType: NewMI);
1740}
1741
1742SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1743 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder,
1744 bool EmitIR) {
1745 return getOrCreateSPIRVType(
1746 Ty: FixedVectorType::get(ElementType: const_cast<Type *>(getTypeForSPIRVType(Ty: BaseType)),
1747 NumElts: NumElements),
1748 MIRBuilder, AccessQual: SPIRV::AccessQualifier::ReadWrite, ExplicitLayoutRequired: false, EmitIR);
1749}
1750
1751SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1752 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1753 const SPIRVInstrInfo &TII) {
1754 Type *Ty = FixedVectorType::get(
1755 ElementType: const_cast<Type *>(getTypeForSPIRVType(Ty: BaseType)), NumElts: NumElements);
1756 if (const MachineInstr *MI = findMI(T: Ty, RequiresExplicitLayout: false, MF: CurMF))
1757 return MI;
1758 MachineInstr *DepMI = const_cast<MachineInstr *>(BaseType);
1759 MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
1760 const MachineInstr *NewMI =
1761 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1762 return BuildMI(BB&: MIRBuilder.getMBB(), I&: *MIRBuilder.getInsertPt(),
1763 MIMD: MIRBuilder.getDL(), MCID: TII.get(Opcode: SPIRV::OpTypeVector))
1764 .addDef(RegNo: createTypeVReg(MRI&: CurMF->getRegInfo()))
1765 .addUse(RegNo: getSPIRVTypeID(SpirvType: BaseType))
1766 .addImm(Val: NumElements);
1767 });
1768 add(T: Ty, RequiresExplicitLayout: false, MI: NewMI);
1769 return finishCreatingSPIRVType(LLVMTy: Ty, SpirvType: NewMI);
1770}
1771
1772SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1773 const Type *BaseType, MachineInstr &I,
1774 SPIRV::StorageClass::StorageClass SC) {
1775 MachineIRBuilder MIRBuilder(I);
1776 return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
1777}
1778
1779SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1780 const Type *BaseType, MachineIRBuilder &MIRBuilder,
1781 SPIRV::StorageClass::StorageClass SC) {
1782 // TODO: Need to check if EmitIr should always be true.
1783 SPIRVType *SpirvBaseType = getOrCreateSPIRVType(
1784 Ty: BaseType, MIRBuilder, AccessQual: SPIRV::AccessQualifier::ReadWrite,
1785 ExplicitLayoutRequired: storageClassRequiresExplictLayout(SC), EmitIR: true);
1786 assert(SpirvBaseType);
1787 return getOrCreateSPIRVPointerTypeInternal(BaseType: SpirvBaseType, MIRBuilder, SC);
1788}
1789
1790SPIRVType *SPIRVGlobalRegistry::changePointerStorageClass(
1791 SPIRVType *PtrType, SPIRV::StorageClass::StorageClass SC, MachineInstr &I) {
1792 [[maybe_unused]] SPIRV::StorageClass::StorageClass OldSC =
1793 getPointerStorageClass(Type: PtrType);
1794 assert(storageClassRequiresExplictLayout(OldSC) ==
1795 storageClassRequiresExplictLayout(SC));
1796
1797 SPIRVType *PointeeType = getPointeeType(PtrType);
1798 MachineIRBuilder MIRBuilder(I);
1799 return getOrCreateSPIRVPointerTypeInternal(BaseType: PointeeType, MIRBuilder, SC);
1800}
1801
1802SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1803 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1804 SPIRV::StorageClass::StorageClass SC) {
1805 const Type *LLVMType = getTypeForSPIRVType(Ty: BaseType);
1806 assert(!storageClassRequiresExplictLayout(SC));
1807 SPIRVType *R = getOrCreateSPIRVPointerType(BaseType: LLVMType, MIRBuilder, SC);
1808 assert(
1809 getPointeeType(R) == BaseType &&
1810 "The base type was not correctly laid out for the given storage class.");
1811 return R;
1812}
1813
1814SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerTypeInternal(
1815 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1816 SPIRV::StorageClass::StorageClass SC) {
1817 const Type *PointerElementType = getTypeForSPIRVType(Ty: BaseType);
1818 unsigned AddressSpace = storageClassToAddressSpace(SC);
1819 if (const MachineInstr *MI = findMI(PointeeTy: PointerElementType, AddressSpace, MF: CurMF))
1820 return MI;
1821 Type *Ty = TypedPointerType::get(ElementType: const_cast<Type *>(PointerElementType),
1822 AddressSpace);
1823 const MachineInstr *NewMI =
1824 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1825 return BuildMI(BB&: MIRBuilder.getMBB(), I: MIRBuilder.getInsertPt(),
1826 MIMD: MIRBuilder.getDebugLoc(),
1827 MCID: MIRBuilder.getTII().get(Opcode: SPIRV::OpTypePointer))
1828 .addDef(RegNo: createTypeVReg(MRI&: CurMF->getRegInfo()))
1829 .addImm(Val: static_cast<uint32_t>(SC))
1830 .addUse(RegNo: getSPIRVTypeID(SpirvType: BaseType));
1831 });
1832 add(PointeeTy: PointerElementType, AddressSpace, MI: NewMI);
1833 return finishCreatingSPIRVType(LLVMTy: Ty, SpirvType: NewMI);
1834}
1835
1836Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1837 SPIRVType *SpvType,
1838 const SPIRVInstrInfo &TII) {
1839 UndefValue *UV =
1840 UndefValue::get(T: const_cast<Type *>(getTypeForSPIRVType(Ty: SpvType)));
1841 Register Res = find(V: UV, MF: CurMF);
1842 if (Res.isValid())
1843 return Res;
1844
1845 LLT LLTy = LLT::scalar(SizeInBits: 64);
1846 Res = CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
1847 CurMF->getRegInfo().setRegClass(Reg: Res, RC: &SPIRV::iIDRegClass);
1848 assignSPIRVTypeToVReg(SpirvType: SpvType, VReg: Res, MF: *CurMF);
1849
1850 MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
1851 MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
1852 const MachineInstr *NewMI =
1853 createOpType(MIRBuilder, Op: [&](MachineIRBuilder &MIRBuilder) {
1854 auto MIB = BuildMI(BB&: MIRBuilder.getMBB(), I&: *MIRBuilder.getInsertPt(),
1855 MIMD: MIRBuilder.getDL(), MCID: TII.get(Opcode: SPIRV::OpUndef))
1856 .addDef(RegNo: Res)
1857 .addUse(RegNo: getSPIRVTypeID(SpirvType: SpvType));
1858 const auto &ST = CurMF->getSubtarget();
1859 constrainSelectedInstRegOperands(I&: *MIB, TII: *ST.getInstrInfo(),
1860 TRI: *ST.getRegisterInfo(),
1861 RBI: *ST.getRegBankInfo());
1862 return MIB;
1863 });
1864 add(V: UV, MI: NewMI);
1865 return Res;
1866}
1867
1868const TargetRegisterClass *
1869SPIRVGlobalRegistry::getRegClass(SPIRVType *SpvType) const {
1870 unsigned Opcode = SpvType->getOpcode();
1871 switch (Opcode) {
1872 case SPIRV::OpTypeFloat:
1873 return &SPIRV::fIDRegClass;
1874 case SPIRV::OpTypePointer:
1875 return &SPIRV::pIDRegClass;
1876 case SPIRV::OpTypeVector: {
1877 SPIRVType *ElemType = getSPIRVTypeForVReg(VReg: SpvType->getOperand(i: 1).getReg());
1878 unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0;
1879 if (ElemOpcode == SPIRV::OpTypeFloat)
1880 return &SPIRV::vfIDRegClass;
1881 if (ElemOpcode == SPIRV::OpTypePointer)
1882 return &SPIRV::vpIDRegClass;
1883 return &SPIRV::vIDRegClass;
1884 }
1885 }
1886 return &SPIRV::iIDRegClass;
1887}
1888
1889inline unsigned getAS(SPIRVType *SpvType) {
1890 return storageClassToAddressSpace(
1891 SC: static_cast<SPIRV::StorageClass::StorageClass>(
1892 SpvType->getOperand(i: 1).getImm()));
1893}
1894
1895LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const {
1896 unsigned Opcode = SpvType ? SpvType->getOpcode() : 0;
1897 switch (Opcode) {
1898 case SPIRV::OpTypeInt:
1899 case SPIRV::OpTypeFloat:
1900 case SPIRV::OpTypeBool:
1901 return LLT::scalar(SizeInBits: getScalarOrVectorBitWidth(Type: SpvType));
1902 case SPIRV::OpTypePointer:
1903 return LLT::pointer(AddressSpace: getAS(SpvType), SizeInBits: getPointerSize());
1904 case SPIRV::OpTypeVector: {
1905 SPIRVType *ElemType = getSPIRVTypeForVReg(VReg: SpvType->getOperand(i: 1).getReg());
1906 LLT ET;
1907 switch (ElemType ? ElemType->getOpcode() : 0) {
1908 case SPIRV::OpTypePointer:
1909 ET = LLT::pointer(AddressSpace: getAS(SpvType: ElemType), SizeInBits: getPointerSize());
1910 break;
1911 case SPIRV::OpTypeInt:
1912 case SPIRV::OpTypeFloat:
1913 case SPIRV::OpTypeBool:
1914 ET = LLT::scalar(SizeInBits: getScalarOrVectorBitWidth(Type: ElemType));
1915 break;
1916 default:
1917 ET = LLT::scalar(SizeInBits: 64);
1918 }
1919 return LLT::fixed_vector(
1920 NumElements: static_cast<unsigned>(SpvType->getOperand(i: 2).getImm()), ScalarTy: ET);
1921 }
1922 }
1923 return LLT::scalar(SizeInBits: 64);
1924}
1925
1926// Aliasing list MD contains several scope MD nodes whithin it. Each scope MD
1927// has a selfreference and an extra MD node for aliasing domain and also it
1928// can contain an optional string operand. Domain MD contains a self-reference
1929// with an optional string operand. Here we unfold the list, creating SPIR-V
1930// aliasing instructions.
1931// TODO: add support for an optional string operand.
1932MachineInstr *SPIRVGlobalRegistry::getOrAddMemAliasingINTELInst(
1933 MachineIRBuilder &MIRBuilder, const MDNode *AliasingListMD) {
1934 if (AliasingListMD->getNumOperands() == 0)
1935 return nullptr;
1936 if (auto L = AliasInstMDMap.find(x: AliasingListMD); L != AliasInstMDMap.end())
1937 return L->second;
1938
1939 SmallVector<MachineInstr *> ScopeList;
1940 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1941 for (const MDOperand &MDListOp : AliasingListMD->operands()) {
1942 if (MDNode *ScopeMD = dyn_cast<MDNode>(Val: MDListOp)) {
1943 if (ScopeMD->getNumOperands() < 2)
1944 return nullptr;
1945 MDNode *DomainMD = dyn_cast<MDNode>(Val: ScopeMD->getOperand(I: 1));
1946 if (!DomainMD)
1947 return nullptr;
1948 auto *Domain = [&] {
1949 auto D = AliasInstMDMap.find(x: DomainMD);
1950 if (D != AliasInstMDMap.end())
1951 return D->second;
1952 const Register Ret = MRI->createVirtualRegister(RegClass: &SPIRV::IDRegClass);
1953 auto MIB =
1954 MIRBuilder.buildInstr(Opcode: SPIRV::OpAliasDomainDeclINTEL).addDef(RegNo: Ret);
1955 return MIB.getInstr();
1956 }();
1957 AliasInstMDMap.insert(x: std::make_pair(x&: DomainMD, y&: Domain));
1958 auto *Scope = [&] {
1959 auto S = AliasInstMDMap.find(x: ScopeMD);
1960 if (S != AliasInstMDMap.end())
1961 return S->second;
1962 const Register Ret = MRI->createVirtualRegister(RegClass: &SPIRV::IDRegClass);
1963 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpAliasScopeDeclINTEL)
1964 .addDef(RegNo: Ret)
1965 .addUse(RegNo: Domain->getOperand(i: 0).getReg());
1966 return MIB.getInstr();
1967 }();
1968 AliasInstMDMap.insert(x: std::make_pair(x&: ScopeMD, y&: Scope));
1969 ScopeList.push_back(Elt: Scope);
1970 }
1971 }
1972
1973 const Register Ret = MRI->createVirtualRegister(RegClass: &SPIRV::IDRegClass);
1974 auto MIB =
1975 MIRBuilder.buildInstr(Opcode: SPIRV::OpAliasScopeListDeclINTEL).addDef(RegNo: Ret);
1976 for (auto *Scope : ScopeList)
1977 MIB.addUse(RegNo: Scope->getOperand(i: 0).getReg());
1978 auto List = MIB.getInstr();
1979 AliasInstMDMap.insert(x: std::make_pair(x&: AliasingListMD, y&: List));
1980 return List;
1981}
1982
1983void SPIRVGlobalRegistry::buildMemAliasingOpDecorate(
1984 Register Reg, MachineIRBuilder &MIRBuilder, uint32_t Dec,
1985 const MDNode *AliasingListMD) {
1986 MachineInstr *AliasList =
1987 getOrAddMemAliasingINTELInst(MIRBuilder, AliasingListMD);
1988 if (!AliasList)
1989 return;
1990 MIRBuilder.buildInstr(Opcode: SPIRV::OpDecorate)
1991 .addUse(RegNo: Reg)
1992 .addImm(Val: Dec)
1993 .addUse(RegNo: AliasList->getOperand(i: 0).getReg());
1994}
1995void SPIRVGlobalRegistry::replaceAllUsesWith(Value *Old, Value *New,
1996 bool DeleteOld) {
1997 Old->replaceAllUsesWith(V: New);
1998 updateIfExistDeducedElementType(OldVal: Old, NewVal: New, DeleteOld);
1999 updateIfExistAssignPtrTypeInstr(OldVal: Old, NewVal: New, DeleteOld);
2000}
2001
2002void SPIRVGlobalRegistry::buildAssignType(IRBuilder<> &B, Type *Ty,
2003 Value *Arg) {
2004 Value *OfType = getNormalizedPoisonValue(Ty);
2005 CallInst *AssignCI = nullptr;
2006 if (Arg->getType()->isAggregateType() && Ty->isAggregateType() &&
2007 allowEmitFakeUse(Arg)) {
2008 LLVMContext &Ctx = Arg->getContext();
2009 SmallVector<Metadata *, 2> ArgMDs{
2010 MDNode::get(Context&: Ctx, MDs: ValueAsMetadata::getConstant(C: OfType)),
2011 MDString::get(Context&: Ctx, Str: Arg->getName())};
2012 B.CreateIntrinsic(ID: Intrinsic::spv_value_md,
2013 Args: {MetadataAsValue::get(Context&: Ctx, MD: MDTuple::get(Context&: Ctx, MDs: ArgMDs))});
2014 AssignCI = B.CreateIntrinsic(ID: Intrinsic::fake_use, Args: {Arg});
2015 } else {
2016 AssignCI = buildIntrWithMD(IntrID: Intrinsic::spv_assign_type, Types: {Arg->getType()},
2017 Arg: OfType, Arg2: Arg, Imms: {}, B);
2018 }
2019 addAssignPtrTypeInstr(Val: Arg, AssignPtrTyCI: AssignCI);
2020}
2021
2022void SPIRVGlobalRegistry::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
2023 Value *Arg) {
2024 Value *OfType = PoisonValue::get(T: ElemTy);
2025 CallInst *AssignPtrTyCI = findAssignPtrTypeInstr(Val: Arg);
2026 Function *CurrF =
2027 B.GetInsertBlock() ? B.GetInsertBlock()->getParent() : nullptr;
2028 if (AssignPtrTyCI == nullptr ||
2029 AssignPtrTyCI->getParent()->getParent() != CurrF) {
2030 AssignPtrTyCI = buildIntrWithMD(
2031 IntrID: Intrinsic::spv_assign_ptr_type, Types: {Arg->getType()}, Arg: OfType, Arg2: Arg,
2032 Imms: {B.getInt32(C: getPointerAddressSpace(T: Arg->getType()))}, B);
2033 addDeducedElementType(Val: AssignPtrTyCI, Ty: ElemTy);
2034 addDeducedElementType(Val: Arg, Ty: ElemTy);
2035 addAssignPtrTypeInstr(Val: Arg, AssignPtrTyCI);
2036 } else {
2037 updateAssignType(AssignCI: AssignPtrTyCI, Arg, OfType);
2038 }
2039}
2040
2041void SPIRVGlobalRegistry::updateAssignType(CallInst *AssignCI, Value *Arg,
2042 Value *OfType) {
2043 AssignCI->setArgOperand(i: 1, v: buildMD(Arg: OfType));
2044 if (cast<IntrinsicInst>(Val: AssignCI)->getIntrinsicID() !=
2045 Intrinsic::spv_assign_ptr_type)
2046 return;
2047
2048 // update association with the pointee type
2049 Type *ElemTy = OfType->getType();
2050 addDeducedElementType(Val: AssignCI, Ty: ElemTy);
2051 addDeducedElementType(Val: Arg, Ty: ElemTy);
2052}
2053
2054void SPIRVGlobalRegistry::addStructOffsetDecorations(
2055 Register Reg, StructType *Ty, MachineIRBuilder &MIRBuilder) {
2056 DataLayout DL;
2057 ArrayRef<TypeSize> Offsets = DL.getStructLayout(Ty)->getMemberOffsets();
2058 for (uint32_t I = 0; I < Ty->getNumElements(); ++I) {
2059 buildOpMemberDecorate(Reg, MIRBuilder, Dec: SPIRV::Decoration::Offset, Member: I,
2060 DecArgs: {static_cast<uint32_t>(Offsets[I])});
2061 }
2062}
2063
2064void SPIRVGlobalRegistry::addArrayStrideDecorations(
2065 Register Reg, Type *ElementType, MachineIRBuilder &MIRBuilder) {
2066 uint32_t SizeInBytes = DataLayout().getTypeSizeInBits(Ty: ElementType) / 8;
2067 buildOpDecorate(Reg, MIRBuilder, Dec: SPIRV::Decoration::ArrayStride,
2068 DecArgs: {SizeInBytes});
2069}
2070
2071bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
2072 Register Def = getSPIRVTypeID(SpirvType: Type);
2073 for (const MachineInstr &Use :
2074 Type->getMF()->getRegInfo().use_instructions(Reg: Def)) {
2075 if (Use.getOpcode() != SPIRV::OpDecorate)
2076 continue;
2077
2078 if (Use.getOperand(i: 1).getImm() == SPIRV::Decoration::Block)
2079 return true;
2080 }
2081 return false;
2082}
2083