1//===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the lowering of LLVM calls to machine code calls for
10// GlobalISel.
11//
12//===----------------------------------------------------------------------===//
13
14#include "SPIRVCallLowering.h"
15#include "MCTargetDesc/SPIRVBaseInfo.h"
16#include "SPIRV.h"
17#include "SPIRVBuiltins.h"
18#include "SPIRVGlobalRegistry.h"
19#include "SPIRVISelLowering.h"
20#include "SPIRVMetadata.h"
21#include "SPIRVRegisterInfo.h"
22#include "SPIRVSubtarget.h"
23#include "SPIRVUtils.h"
24#include "llvm/CodeGen/FunctionLoweringInfo.h"
25#include "llvm/IR/IntrinsicInst.h"
26#include "llvm/IR/IntrinsicsSPIRV.h"
27#include "llvm/Support/ModRef.h"
28
29using namespace llvm;
30
31SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
32 SPIRVGlobalRegistry *GR)
33 : CallLowering(&TLI), GR(GR) {}
34
35bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
36 const Value *Val, ArrayRef<Register> VRegs,
37 FunctionLoweringInfo &FLI,
38 Register SwiftErrorVReg) const {
39 // Ignore if called from the internal service function
40 if (MIRBuilder.getMF()
41 .getFunction()
42 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
43 .isValid())
44 return true;
45
46 // Currently all return types should use a single register.
47 // TODO: handle the case of multiple registers.
48 if (VRegs.size() > 1)
49 return false;
50 if (Val) {
51 const auto &STI = MIRBuilder.getMF().getSubtarget();
52 return MIRBuilder.buildInstr(Opcode: SPIRV::OpReturnValue)
53 .addUse(RegNo: VRegs[0])
54 .constrainAllUses(TII: MIRBuilder.getTII(), TRI: *STI.getRegisterInfo(),
55 RBI: *STI.getRegBankInfo());
56 }
57 MIRBuilder.buildInstr(Opcode: SPIRV::OpReturn);
58 return true;
59}
60
61// Based on the LLVM function attributes, get a SPIR-V FunctionControl.
62static uint32_t getFunctionControl(const Function &F,
63 const SPIRVSubtarget *ST) {
64 MemoryEffects MemEffects = F.getMemoryEffects();
65
66 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
67
68 if (F.hasFnAttribute(Kind: Attribute::AttrKind::NoInline))
69 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
70 else if (F.hasFnAttribute(Kind: Attribute::AttrKind::AlwaysInline))
71 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
72
73 if (MemEffects.doesNotAccessMemory())
74 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
75 else if (MemEffects.onlyReadsMemory())
76 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
77
78 if (ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_optnone) ||
79 ST->canUseExtension(E: SPIRV::Extension::SPV_EXT_optnone))
80 if (F.hasFnAttribute(Kind: Attribute::OptimizeNone))
81 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::OptNoneEXT);
82
83 return FuncControl;
84}
85
86static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
87 if (MD->getNumOperands() > NumOp) {
88 auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MD->getOperand(I: NumOp));
89 if (CMeta)
90 return dyn_cast<ConstantInt>(Val: CMeta->getValue());
91 }
92 return nullptr;
93}
94
95// If the function has pointer arguments, we are forced to re-create this
96// function type from the very beginning, changing PointerType by
97// TypedPointerType for each pointer argument. Otherwise, the same `Type*`
98// potentially corresponds to different SPIR-V function type, effectively
99// invalidating logic behind global registry and duplicates tracker.
100static FunctionType *
101fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
102 FunctionType *FTy, const SPIRVType *SRetTy,
103 const SmallVector<SPIRVType *, 4> &SArgTys) {
104 bool hasArgPtrs = false;
105 for (auto &Arg : F.args()) {
106 // check if it's an instance of a non-typed PointerType
107 if (Arg.getType()->isPointerTy()) {
108 hasArgPtrs = true;
109 break;
110 }
111 }
112 if (!hasArgPtrs) {
113 Type *RetTy = FTy->getReturnType();
114 // check if it's an instance of a non-typed PointerType
115 if (!RetTy->isPointerTy())
116 return FTy;
117 }
118
119 // re-create function type, using TypedPointerType instead of PointerType to
120 // properly trace argument types
121 const Type *RetTy = GR->getTypeForSPIRVType(Ty: SRetTy);
122 SmallVector<Type *, 4> ArgTys;
123 for (auto SArgTy : SArgTys)
124 ArgTys.push_back(Elt: const_cast<Type *>(GR->getTypeForSPIRVType(Ty: SArgTy)));
125 return FunctionType::get(Result: const_cast<Type *>(RetTy), Params: ArgTys, isVarArg: false);
126}
127
128static SPIRV::AccessQualifier::AccessQualifier
129getArgAccessQual(const Function &F, unsigned ArgIdx) {
130 if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
131 return SPIRV::AccessQualifier::ReadWrite;
132
133 MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
134 if (!ArgAttribute)
135 return SPIRV::AccessQualifier::ReadWrite;
136
137 if (ArgAttribute->getString() == "read_only")
138 return SPIRV::AccessQualifier::ReadOnly;
139 if (ArgAttribute->getString() == "write_only")
140 return SPIRV::AccessQualifier::WriteOnly;
141 return SPIRV::AccessQualifier::ReadWrite;
142}
143
144static std::vector<SPIRV::Decoration::Decoration>
145getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
146 MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
147 if (ArgAttribute && ArgAttribute->getString() == "volatile")
148 return {SPIRV::Decoration::Volatile};
149 return {};
150}
151
152static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
153 SPIRVGlobalRegistry *GR,
154 MachineIRBuilder &MIRBuilder,
155 const SPIRVSubtarget &ST) {
156 // Read argument's access qualifier from metadata or default.
157 SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
158 getArgAccessQual(F, ArgIdx);
159
160 Type *OriginalArgType =
161 SPIRV::getOriginalFunctionType(F)->getParamType(i: ArgIdx);
162
163 // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
164 // be legally reassigned later).
165 if (!isPointerTy(T: OriginalArgType))
166 return GR->getOrCreateSPIRVType(Type: OriginalArgType, MIRBuilder, AQ: ArgAccessQual,
167 EmitIR: true);
168
169 Argument *Arg = F.getArg(i: ArgIdx);
170 Type *ArgType = Arg->getType();
171 if (isTypedPointerTy(T: ArgType)) {
172 return GR->getOrCreateSPIRVPointerType(
173 BaseType: cast<TypedPointerType>(Val: ArgType)->getElementType(), MIRBuilder,
174 SC: addressSpaceToStorageClass(AddrSpace: getPointerAddressSpace(T: ArgType), STI: ST));
175 }
176
177 // In case OriginalArgType is of untyped pointer type, there are three
178 // possibilities:
179 // 1) This is a pointer of an LLVM IR element type, passed byval/byref.
180 // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
181 // intrinsic assigning a TargetExtType.
182 // 3) This is a pointer, try to retrieve pointer element type from a
183 // spv_assign_ptr_type intrinsic or otherwise use default pointer element
184 // type.
185 if (hasPointeeTypeAttr(Arg)) {
186 return GR->getOrCreateSPIRVPointerType(
187 BaseType: getPointeeTypeByAttr(Arg), MIRBuilder,
188 SC: addressSpaceToStorageClass(AddrSpace: getPointerAddressSpace(T: ArgType), STI: ST));
189 }
190
191 for (auto User : Arg->users()) {
192 auto *II = dyn_cast<IntrinsicInst>(Val: User);
193 // Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
194 if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
195 MetadataAsValue *VMD = cast<MetadataAsValue>(Val: II->getOperand(i_nocapture: 1));
196 Type *BuiltinType =
197 cast<ConstantAsMetadata>(Val: VMD->getMetadata())->getType();
198 assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
199 return GR->getOrCreateSPIRVType(Type: BuiltinType, MIRBuilder, AQ: ArgAccessQual,
200 EmitIR: true);
201 }
202
203 // Check if this is spv_assign_ptr_type assigning pointer element type.
204 if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
205 continue;
206
207 MetadataAsValue *VMD = cast<MetadataAsValue>(Val: II->getOperand(i_nocapture: 1));
208 Type *ElementTy =
209 toTypedPointer(Ty: cast<ConstantAsMetadata>(Val: VMD->getMetadata())->getType());
210 return GR->getOrCreateSPIRVPointerType(
211 BaseType: ElementTy, MIRBuilder,
212 SC: addressSpaceToStorageClass(
213 AddrSpace: cast<ConstantInt>(Val: II->getOperand(i_nocapture: 2))->getZExtValue(), STI: ST));
214 }
215
216 // Replace PointerType with TypedPointerType to be able to map SPIR-V types to
217 // LLVM types in a consistent manner
218 return GR->getOrCreateSPIRVType(Type: toTypedPointer(Ty: OriginalArgType), MIRBuilder,
219 AQ: ArgAccessQual, EmitIR: true);
220}
221
222static SPIRV::ExecutionModel::ExecutionModel
223getExecutionModel(const SPIRVSubtarget &STI, const Function &F) {
224 if (STI.isKernel())
225 return SPIRV::ExecutionModel::Kernel;
226
227 if (STI.isShader()) {
228 auto attribute = F.getFnAttribute(Kind: "hlsl.shader");
229 if (!attribute.isValid()) {
230 report_fatal_error(
231 reason: "This entry point lacks mandatory hlsl.shader attribute.");
232 }
233
234 const auto value = attribute.getValueAsString();
235 if (value == "compute")
236 return SPIRV::ExecutionModel::GLCompute;
237 if (value == "vertex")
238 return SPIRV::ExecutionModel::Vertex;
239 if (value == "pixel")
240 return SPIRV::ExecutionModel::Fragment;
241
242 report_fatal_error(
243 reason: "This HLSL entry point is not supported by this backend.");
244 }
245
246 assert(STI.getEnv() == SPIRVSubtarget::Unknown);
247 // "hlsl.shader" attribute is mandatory for Vulkan, so we can set Env to
248 // Shader whenever we find it, and to Kernel otherwise.
249
250 // We will now change the Env based on the attribute, so we need to strip
251 // `const` out of the ref to STI.
252 SPIRVSubtarget *NonConstSTI = const_cast<SPIRVSubtarget *>(&STI);
253 auto attribute = F.getFnAttribute(Kind: "hlsl.shader");
254 if (!attribute.isValid()) {
255 NonConstSTI->setEnv(SPIRVSubtarget::Kernel);
256 return SPIRV::ExecutionModel::Kernel;
257 }
258 NonConstSTI->setEnv(SPIRVSubtarget::Shader);
259
260 const auto value = attribute.getValueAsString();
261 if (value == "compute")
262 return SPIRV::ExecutionModel::GLCompute;
263 if (value == "vertex")
264 return SPIRV::ExecutionModel::Vertex;
265 if (value == "pixel")
266 return SPIRV::ExecutionModel::Fragment;
267
268 report_fatal_error(reason: "This HLSL entry point is not supported by this backend.");
269}
270
271bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
272 const Function &F,
273 ArrayRef<ArrayRef<Register>> VRegs,
274 FunctionLoweringInfo &FLI) const {
275 // Discard the internal service function
276 if (F.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
277 return true;
278
279 assert(GR && "Must initialize the SPIRV type registry before lowering args.");
280 GR->setCurrentFunc(MIRBuilder.getMF());
281
282 // Get access to information about available extensions
283 const SPIRVSubtarget *ST =
284 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
285
286 // Assign types and names to all args, and store their types for later.
287 SmallVector<SPIRVType *, 4> ArgTypeVRegs;
288 if (VRegs.size() > 0) {
289 unsigned i = 0;
290 for (const auto &Arg : F.args()) {
291 // Currently formal args should use single registers.
292 // TODO: handle the case of multiple registers.
293 if (VRegs[i].size() > 1)
294 return false;
295 auto *SpirvTy = getArgSPIRVType(F, ArgIdx: i, GR, MIRBuilder, ST: *ST);
296 GR->assignSPIRVTypeToVReg(Type: SpirvTy, VReg: VRegs[i][0], MF: MIRBuilder.getMF());
297 ArgTypeVRegs.push_back(Elt: SpirvTy);
298
299 if (Arg.hasName())
300 buildOpName(Target: VRegs[i][0], Name: Arg.getName(), MIRBuilder);
301 if (isPointerTyOrWrapper(Ty: Arg.getType())) {
302 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
303 if (DerefBytes != 0)
304 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
305 Dec: SPIRV::Decoration::MaxByteOffset, DecArgs: {DerefBytes});
306 }
307 if (Arg.hasAttribute(Kind: Attribute::Alignment) && !ST->isShader()) {
308 auto Alignment = static_cast<unsigned>(
309 Arg.getAttribute(Kind: Attribute::Alignment).getValueAsInt());
310 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder, Dec: SPIRV::Decoration::Alignment,
311 DecArgs: {Alignment});
312 }
313 if (Arg.hasAttribute(Kind: Attribute::ReadOnly)) {
314 auto Attr =
315 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
316 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
317 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
318 }
319 if (Arg.hasAttribute(Kind: Attribute::ZExt)) {
320 auto Attr =
321 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
322 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
323 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
324 }
325 if (Arg.hasAttribute(Kind: Attribute::NoAlias)) {
326 auto Attr =
327 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
328 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
329 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
330 }
331 // TODO: the AMDGPU BE only supports ByRef argument passing, thus for
332 // AMDGCN flavoured SPIRV we CodeGen for ByRef, but lower it to
333 // ByVal, handling the impedance mismatch during reverse
334 // translation from SPIRV to LLVM IR; the vendor check should be
335 // removed once / if SPIRV adds ByRef support.
336 if (Arg.hasAttribute(Kind: Attribute::ByVal) ||
337 (Arg.hasAttribute(Kind: Attribute::ByRef) &&
338 F.getParent()->getTargetTriple().getVendor() ==
339 Triple::VendorType::AMD)) {
340 auto Attr =
341 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal);
342 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
343 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
344 }
345 if (Arg.hasAttribute(Kind: Attribute::StructRet)) {
346 auto Attr =
347 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Sret);
348 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
349 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
350 }
351
352 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
353 std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
354 getKernelArgTypeQual(F, ArgIdx: i);
355 for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
356 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder, Dec: Decoration, DecArgs: {});
357 }
358
359 MDNode *Node = F.getMetadata(Kind: "spirv.ParameterDecorations");
360 if (Node && i < Node->getNumOperands() &&
361 isa<MDNode>(Val: Node->getOperand(I: i))) {
362 MDNode *MD = cast<MDNode>(Val: Node->getOperand(I: i));
363 for (const MDOperand &MDOp : MD->operands()) {
364 MDNode *MD2 = dyn_cast<MDNode>(Val: MDOp);
365 assert(MD2 && "Metadata operand is expected");
366 ConstantInt *Const = getConstInt(MD: MD2, NumOp: 0);
367 assert(Const && "MDOperand should be ConstantInt");
368 auto Dec =
369 static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
370 std::vector<uint32_t> DecVec;
371 for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
372 ConstantInt *Const = getConstInt(MD: MD2, NumOp: j);
373 assert(Const && "MDOperand should be ConstantInt");
374 DecVec.push_back(x: static_cast<uint32_t>(Const->getZExtValue()));
375 }
376 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder, Dec, DecArgs: DecVec);
377 }
378 }
379 ++i;
380 }
381 }
382
383 auto MRI = MIRBuilder.getMRI();
384 Register FuncVReg = MRI->createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 64));
385 MRI->setRegClass(Reg: FuncVReg, RC: &SPIRV::iIDRegClass);
386 FunctionType *FTy = SPIRV::getOriginalFunctionType(F);
387 Type *FRetTy = FTy->getReturnType();
388 if (isUntypedPointerTy(T: FRetTy)) {
389 if (Type *FRetElemTy = GR->findDeducedElementType(Val: &F)) {
390 TypedPointerType *DerivedTy = TypedPointerType::get(
391 ElementType: toTypedPointer(Ty: FRetElemTy), AddressSpace: getPointerAddressSpace(T: FRetTy));
392 GR->addReturnType(ArgF: &F, DerivedTy);
393 FRetTy = DerivedTy;
394 }
395 }
396 SPIRVType *RetTy = GR->getOrCreateSPIRVType(
397 Type: FRetTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
398 FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, SRetTy: RetTy, SArgTys: ArgTypeVRegs);
399 SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
400 Ty: FTy, RetType: RetTy, ArgTypes: ArgTypeVRegs, MIRBuilder);
401 uint32_t FuncControl = getFunctionControl(F, ST);
402
403 // Add OpFunction instruction
404 MachineInstrBuilder MB = MIRBuilder.buildInstr(Opcode: SPIRV::OpFunction)
405 .addDef(RegNo: FuncVReg)
406 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: RetTy))
407 .addImm(Val: FuncControl)
408 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: FuncTy));
409 GR->recordFunctionDefinition(F: &F, MO: &MB.getInstr()->getOperand(i: 0));
410 GR->addGlobalObject(V: &F, MF: &MIRBuilder.getMF(), R: FuncVReg);
411 if (F.isDeclaration())
412 GR->add(V: &F, MI: MB);
413
414 // Add OpFunctionParameter instructions
415 int i = 0;
416 for (const auto &Arg : F.args()) {
417 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
418 Register ArgReg = VRegs[i][0];
419 MRI->setRegClass(Reg: ArgReg, RC: GR->getRegClass(SpvType: ArgTypeVRegs[i]));
420 MRI->setType(VReg: ArgReg, Ty: GR->getRegType(SpvType: ArgTypeVRegs[i]));
421 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpFunctionParameter)
422 .addDef(RegNo: ArgReg)
423 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: ArgTypeVRegs[i]));
424 if (F.isDeclaration())
425 GR->add(V: &Arg, MI: MIB);
426 GR->addGlobalObject(V: &Arg, MF: &MIRBuilder.getMF(), R: ArgReg);
427 i++;
428 }
429 // Name the function.
430 if (F.hasName())
431 buildOpName(Target: FuncVReg, Name: F.getName(), MIRBuilder);
432
433 // Handle entry points and function linkage.
434 if (isEntryPoint(F)) {
435 // EntryPoints can help us to determine the environment we're working on.
436 // Therefore, we need a non-const pointer to SPIRVSubtarget to update the
437 // environment if we need to.
438 const SPIRVSubtarget *ST =
439 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
440 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpEntryPoint)
441 .addImm(Val: static_cast<uint32_t>(getExecutionModel(STI: *ST, F)))
442 .addUse(RegNo: FuncVReg);
443 addStringImm(Str: F.getName(), MIB);
444 } else if (const auto LnkTy = getSpirvLinkageTypeFor(ST: *ST, GV: F)) {
445 buildOpDecorate(Reg: FuncVReg, MIRBuilder, Dec: SPIRV::Decoration::LinkageAttributes,
446 DecArgs: {static_cast<uint32_t>(*LnkTy)}, StrImm: F.getName());
447 }
448
449 // Handle function pointers decoration
450 bool hasFunctionPointers =
451 ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers);
452 if (hasFunctionPointers) {
453 if (F.hasFnAttribute(Kind: "referenced-indirectly")) {
454 assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
455 "Unexpected 'referenced-indirectly' attribute of the kernel "
456 "function");
457 buildOpDecorate(Reg: FuncVReg, MIRBuilder,
458 Dec: SPIRV::Decoration::ReferencedIndirectlyINTEL, DecArgs: {});
459 }
460 }
461
462 return true;
463}
464
465// TODO:
466// - add a topological sort of IndirectCalls to ensure the best types knowledge
467// - we may need to fix function formal parameter types if they are opaque
468// pointers used as function pointers in these indirect calls
469// - defaulting to StorageClass::Function in the absence of the
470// SPV_INTEL_function_pointers extension seems wrong, as that might not be
471// able to hold a full width pointer to function, and it also does not model
472// the semantics of a pointer to function in a generic fashion.
473void SPIRVCallLowering::produceIndirectPtrType(
474 MachineIRBuilder &MIRBuilder,
475 const SPIRVCallLowering::SPIRVIndirectCall &IC) const {
476 // Create indirect call data type if any
477 MachineFunction &MF = MIRBuilder.getMF();
478 const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
479 SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(
480 Type: IC.RetTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
481 SmallVector<SPIRVType *, 4> SpirvArgTypes;
482 for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
483 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(
484 Type: IC.ArgTys[i], MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
485 SpirvArgTypes.push_back(Elt: SPIRVTy);
486 if (!GR->getSPIRVTypeForVReg(VReg: IC.ArgRegs[i]))
487 GR->assignSPIRVTypeToVReg(Type: SPIRVTy, VReg: IC.ArgRegs[i], MF);
488 }
489 // SPIR-V function type:
490 FunctionType *FTy =
491 FunctionType::get(Result: const_cast<Type *>(IC.RetTy), Params: IC.ArgTys, isVarArg: false);
492 SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
493 Ty: FTy, RetType: SpirvRetTy, ArgTypes: SpirvArgTypes, MIRBuilder);
494 // SPIR-V pointer to function type:
495 auto SC = ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers)
496 ? SPIRV::StorageClass::CodeSectionINTEL
497 : SPIRV::StorageClass::Function;
498 SPIRVType *IndirectFuncPtrTy =
499 GR->getOrCreateSPIRVPointerType(BaseType: SpirvFuncTy, MIRBuilder, SC);
500 // Correct the Callee type
501 GR->assignSPIRVTypeToVReg(Type: IndirectFuncPtrTy, VReg: IC.Callee, MF);
502}
503
504bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
505 CallLoweringInfo &Info) const {
506 // Currently call returns should have single vregs.
507 // TODO: handle the case of multiple registers.
508 if (Info.OrigRet.Regs.size() > 1)
509 return false;
510 MachineFunction &MF = MIRBuilder.getMF();
511 GR->setCurrentFunc(MF);
512 const Function *CF = nullptr;
513 std::string DemangledName;
514 const Type *OrigRetTy = Info.OrigRet.Ty;
515
516 // Emit a regular OpFunctionCall. If it's an externally declared function,
517 // be sure to emit its type and function declaration here. It will be hoisted
518 // globally later.
519 if (Info.Callee.isGlobal()) {
520 std::string FuncName = Info.Callee.getGlobal()->getName().str();
521 DemangledName = getOclOrSpirvBuiltinDemangledName(Name: FuncName);
522 CF = dyn_cast_or_null<const Function>(Val: Info.Callee.getGlobal());
523 // TODO: support constexpr casts and indirect calls.
524 if (CF == nullptr)
525 return false;
526
527 FunctionType *FTy = SPIRV::getOriginalFunctionType(F: *CF);
528 OrigRetTy = FTy->getReturnType();
529 if (isUntypedPointerTy(T: OrigRetTy)) {
530 if (auto *DerivedRetTy = GR->findReturnType(ArgF: CF))
531 OrigRetTy = DerivedRetTy;
532 }
533 }
534
535 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
536 Register ResVReg =
537 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
538 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
539
540 bool isFunctionDecl = CF && CF->isDeclaration();
541 if (isFunctionDecl && !DemangledName.empty()) {
542 if (ResVReg.isValid()) {
543 if (!GR->getSPIRVTypeForVReg(VReg: ResVReg)) {
544 const Type *RetTy = OrigRetTy;
545 if (auto *PtrRetTy = dyn_cast<PointerType>(Val: OrigRetTy)) {
546 const Value *OrigValue = Info.OrigRet.OrigValue;
547 if (!OrigValue)
548 OrigValue = Info.CB;
549 if (OrigValue)
550 if (Type *ElemTy = GR->findDeducedElementType(Val: OrigValue))
551 RetTy =
552 TypedPointerType::get(ElementType: ElemTy, AddressSpace: PtrRetTy->getAddressSpace());
553 }
554 setRegClassType(Reg: ResVReg, Ty: RetTy, GR, MIRBuilder,
555 AccessQual: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
556 }
557 } else {
558 ResVReg = createVirtualRegister(Ty: OrigRetTy, GR, MIRBuilder,
559 AccessQual: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
560 }
561 SmallVector<Register, 8> ArgVRegs;
562 for (auto Arg : Info.OrigArgs) {
563 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
564 Register ArgReg = Arg.Regs[0];
565 ArgVRegs.push_back(Elt: ArgReg);
566 SPIRVType *SpvType = GR->getSPIRVTypeForVReg(VReg: ArgReg);
567 if (!SpvType) {
568 Type *ArgTy = nullptr;
569 if (auto *PtrArgTy = dyn_cast<PointerType>(Val: Arg.Ty)) {
570 // If Arg.Ty is an untyped pointer (i.e., ptr [addrspace(...)]) and we
571 // don't have access to original value in LLVM IR or info about
572 // deduced pointee type, then we should wait with setting the type for
573 // the virtual register until pre-legalizer step when we access
574 // @llvm.spv.assign.ptr.type.p...(...)'s info.
575 if (Arg.OrigValue)
576 if (Type *ElemTy = GR->findDeducedElementType(Val: Arg.OrigValue))
577 ArgTy =
578 TypedPointerType::get(ElementType: ElemTy, AddressSpace: PtrArgTy->getAddressSpace());
579 } else {
580 ArgTy = Arg.Ty;
581 }
582 if (ArgTy) {
583 SpvType = GR->getOrCreateSPIRVType(
584 Type: ArgTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
585 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: ArgReg, MF);
586 }
587 }
588 if (!MRI->getRegClassOrNull(Reg: ArgReg)) {
589 // Either we have SpvType created, or Arg.Ty is an untyped pointer and
590 // we know its virtual register's class and type even if we don't know
591 // pointee type.
592 MRI->setRegClass(Reg: ArgReg, RC: SpvType ? GR->getRegClass(SpvType)
593 : &SPIRV::pIDRegClass);
594 MRI->setType(
595 VReg: ArgReg,
596 Ty: SpvType ? GR->getRegType(SpvType)
597 : LLT::pointer(AddressSpace: cast<PointerType>(Val: Arg.Ty)->getAddressSpace(),
598 SizeInBits: GR->getPointerSize()));
599 }
600 }
601 if (auto Res = SPIRV::lowerBuiltin(
602 DemangledCall: DemangledName, Set: ST->getPreferredInstructionSet(), MIRBuilder,
603 OrigRet: ResVReg, OrigRetTy, Args: ArgVRegs, GR, CB: *Info.CB))
604 return *Res;
605 }
606
607 if (isFunctionDecl && !GR->find(V: CF, MF: &MF).isValid()) {
608 // Emit the type info and forward function declaration to the first MBB
609 // to ensure VReg definition dependencies are valid across all MBBs.
610 MachineIRBuilder FirstBlockBuilder;
611 FirstBlockBuilder.setMF(MF);
612 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(N: 0));
613
614 SmallVector<ArrayRef<Register>, 8> VRegArgs;
615 SmallVector<SmallVector<Register, 1>, 8> ToInsert;
616 for (const Argument &Arg : CF->args()) {
617 if (MIRBuilder.getDataLayout().getTypeStoreSize(Ty: Arg.getType()).isZero())
618 continue; // Don't handle zero sized types.
619 Register Reg = MRI->createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 64));
620 MRI->setRegClass(Reg, RC: &SPIRV::iIDRegClass);
621 ToInsert.push_back(Elt: {Reg});
622 VRegArgs.push_back(Elt: ToInsert.back());
623 }
624 // TODO: Reuse FunctionLoweringInfo
625 FunctionLoweringInfo FuncInfo;
626 lowerFormalArguments(MIRBuilder&: FirstBlockBuilder, F: *CF, VRegs: VRegArgs, FLI&: FuncInfo);
627 }
628
629 // Ignore the call if it's called from the internal service function
630 if (MIRBuilder.getMF()
631 .getFunction()
632 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
633 .isValid()) {
634 // insert a no-op
635 MIRBuilder.buildTrap();
636 return true;
637 }
638
639 unsigned CallOp;
640 if (Info.CB->isIndirectCall()) {
641 if (!ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers))
642 report_fatal_error(reason: "An indirect call is encountered but SPIR-V without "
643 "extensions does not support it",
644 gen_crash_diag: false);
645 // Set instruction operation according to SPV_INTEL_function_pointers
646 CallOp = SPIRV::OpFunctionPointerCallINTEL;
647 // Collect information about the indirect call to create correct types.
648 Register CalleeReg = Info.Callee.getReg();
649 if (CalleeReg.isValid()) {
650 SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
651 IndirectCall.Callee = CalleeReg;
652 FunctionType *FTy = SPIRV::getOriginalFunctionType(CB: *Info.CB);
653 IndirectCall.RetTy = OrigRetTy = FTy->getReturnType();
654 assert(FTy->getNumParams() == Info.OrigArgs.size() &&
655 "Function types mismatch");
656 for (unsigned I = 0; I != Info.OrigArgs.size(); ++I) {
657 assert(Info.OrigArgs[I].Regs.size() == 1 &&
658 "Call arg has multiple VRegs");
659 IndirectCall.ArgTys.push_back(Elt: FTy->getParamType(i: I));
660 IndirectCall.ArgRegs.push_back(Elt: Info.OrigArgs[I].Regs[0]);
661 }
662 produceIndirectPtrType(MIRBuilder, IC: IndirectCall);
663 }
664 } else {
665 // Emit a regular OpFunctionCall
666 CallOp = SPIRV::OpFunctionCall;
667 }
668
669 // Make sure there's a valid return reg, even for functions returning void.
670 if (!ResVReg.isValid())
671 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(RegClass: &SPIRV::iIDRegClass);
672 SPIRVType *RetType = GR->assignTypeToVReg(
673 Type: OrigRetTy, VReg: ResVReg, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
674
675 // Emit the call instruction and its args.
676 auto MIB = MIRBuilder.buildInstr(Opcode: CallOp)
677 .addDef(RegNo: ResVReg)
678 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: RetType))
679 .add(MO: Info.Callee);
680
681 for (const auto &Arg : Info.OrigArgs) {
682 // Currently call args should have single vregs.
683 if (Arg.Regs.size() > 1)
684 return false;
685 MIB.addUse(RegNo: Arg.Regs[0]);
686 }
687
688 if (ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_memory_access_aliasing)) {
689 // Process aliasing metadata.
690 const CallBase *CI = Info.CB;
691 if (CI && CI->hasMetadata()) {
692 if (MDNode *MD = CI->getMetadata(KindID: LLVMContext::MD_alias_scope))
693 GR->buildMemAliasingOpDecorate(Reg: ResVReg, MIRBuilder,
694 Dec: SPIRV::Decoration::AliasScopeINTEL, GVarMD: MD);
695 if (MDNode *MD = CI->getMetadata(KindID: LLVMContext::MD_noalias))
696 GR->buildMemAliasingOpDecorate(Reg: ResVReg, MIRBuilder,
697 Dec: SPIRV::Decoration::NoAliasINTEL, GVarMD: MD);
698 }
699 }
700
701 return MIB.constrainAllUses(TII: MIRBuilder.getTII(), TRI: *ST->getRegisterInfo(),
702 RBI: *ST->getRegBankInfo());
703}
704