1//===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the lowering of LLVM calls to machine code calls for
10// GlobalISel.
11//
12//===----------------------------------------------------------------------===//
13
14#include "SPIRVCallLowering.h"
15#include "MCTargetDesc/SPIRVBaseInfo.h"
16#include "SPIRV.h"
17#include "SPIRVBuiltins.h"
18#include "SPIRVGlobalRegistry.h"
19#include "SPIRVISelLowering.h"
20#include "SPIRVMetadata.h"
21#include "SPIRVRegisterInfo.h"
22#include "SPIRVSubtarget.h"
23#include "SPIRVUtils.h"
24#include "llvm/CodeGen/FunctionLoweringInfo.h"
25#include "llvm/IR/IntrinsicInst.h"
26#include "llvm/IR/IntrinsicsSPIRV.h"
27#include "llvm/Support/ModRef.h"
28
29using namespace llvm;
30
31SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
32 SPIRVGlobalRegistry *GR)
33 : CallLowering(&TLI), GR(GR) {}
34
35bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
36 const Value *Val, ArrayRef<Register> VRegs,
37 FunctionLoweringInfo &FLI,
38 Register SwiftErrorVReg) const {
39 // Ignore if called from the internal service function
40 if (MIRBuilder.getMF()
41 .getFunction()
42 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
43 .isValid())
44 return true;
45
46 // Currently all return types should use a single register.
47 // TODO: handle the case of multiple registers.
48 if (VRegs.size() > 1)
49 return false;
50
51 if (Val) {
52 const auto &STI = MIRBuilder.getMF().getSubtarget();
53 MIRBuilder.buildInstr(Opcode: SPIRV::OpReturnValue)
54 .addUse(RegNo: VRegs[0])
55 .constrainAllUses(TII: MIRBuilder.getTII(), TRI: *STI.getRegisterInfo(),
56 RBI: *STI.getRegBankInfo());
57 return true;
58 }
59 MIRBuilder.buildInstr(Opcode: SPIRV::OpReturn);
60 return true;
61}
62
63// Based on the LLVM function attributes, get a SPIR-V FunctionControl.
64static uint32_t getFunctionControl(const Function &F,
65 const SPIRVSubtarget *ST) {
66 MemoryEffects MemEffects = F.getMemoryEffects();
67
68 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
69
70 if (F.hasFnAttribute(Kind: Attribute::AttrKind::NoInline))
71 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
72 else if (F.hasFnAttribute(Kind: Attribute::AttrKind::AlwaysInline))
73 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
74
75 if (MemEffects.doesNotAccessMemory())
76 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
77 else if (MemEffects.onlyReadsMemory())
78 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
79
80 if (ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_optnone) ||
81 ST->canUseExtension(E: SPIRV::Extension::SPV_EXT_optnone))
82 if (F.hasFnAttribute(Kind: Attribute::OptimizeNone))
83 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::OptNoneEXT);
84
85 return FuncControl;
86}
87
88static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
89 if (MD->getNumOperands() > NumOp) {
90 auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MD->getOperand(I: NumOp));
91 if (CMeta)
92 return dyn_cast<ConstantInt>(Val: CMeta->getValue());
93 }
94 return nullptr;
95}
96
97// If the function has pointer arguments, we are forced to re-create this
98// function type from the very beginning, changing PointerType by
99// TypedPointerType for each pointer argument. Otherwise, the same `Type*`
100// potentially corresponds to different SPIR-V function type, effectively
101// invalidating logic behind global registry and duplicates tracker.
102static FunctionType *
103fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
104 FunctionType *FTy, SPIRVTypeInst SRetTy,
105 const SmallVector<SPIRVTypeInst, 4> &SArgTys) {
106 bool hasArgPtrs = false;
107 for (auto &Arg : F.args()) {
108 // check if it's an instance of a non-typed PointerType
109 if (Arg.getType()->isPointerTy()) {
110 hasArgPtrs = true;
111 break;
112 }
113 }
114 if (!hasArgPtrs) {
115 Type *RetTy = FTy->getReturnType();
116 // check if it's an instance of a non-typed PointerType
117 if (!RetTy->isPointerTy())
118 return FTy;
119 }
120
121 // re-create function type, using TypedPointerType instead of PointerType to
122 // properly trace argument types
123 const Type *RetTy = GR->getTypeForSPIRVType(Ty: SRetTy);
124 SmallVector<Type *, 4> ArgTys;
125 for (auto SArgTy : SArgTys)
126 ArgTys.push_back(Elt: const_cast<Type *>(GR->getTypeForSPIRVType(Ty: SArgTy)));
127 return FunctionType::get(Result: const_cast<Type *>(RetTy), Params: ArgTys, isVarArg: false);
128}
129
130static SPIRV::AccessQualifier::AccessQualifier
131getArgAccessQual(const Function &F, unsigned ArgIdx) {
132 if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
133 return SPIRV::AccessQualifier::ReadWrite;
134
135 MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
136 if (!ArgAttribute)
137 return SPIRV::AccessQualifier::ReadWrite;
138
139 if (ArgAttribute->getString() == "read_only")
140 return SPIRV::AccessQualifier::ReadOnly;
141 if (ArgAttribute->getString() == "write_only")
142 return SPIRV::AccessQualifier::WriteOnly;
143 return SPIRV::AccessQualifier::ReadWrite;
144}
145
146static std::vector<SPIRV::Decoration::Decoration>
147getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
148 MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
149 if (ArgAttribute && ArgAttribute->getString() == "volatile")
150 return {SPIRV::Decoration::Volatile};
151 return {};
152}
153
154static SPIRVTypeInst getArgSPIRVType(const Function &F, unsigned ArgIdx,
155 SPIRVGlobalRegistry *GR,
156 MachineIRBuilder &MIRBuilder,
157 const SPIRVSubtarget &ST) {
158 // Read argument's access qualifier from metadata or default.
159 SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
160 getArgAccessQual(F, ArgIdx);
161
162 Type *OriginalArgType =
163 SPIRV::getOriginalFunctionType(F)->getParamType(i: ArgIdx);
164
165 // Vector of untyped pointers: build with the deduced pointee instead of
166 // the default i8 (mismatches typed uses downstream).
167 Argument *Arg = F.getArg(i: ArgIdx);
168 if (auto *VTy = dyn_cast<FixedVectorType>(Val: OriginalArgType);
169 VTy && isUntypedPointerTy(T: VTy->getElementType()))
170 if (Type *ElemTy = GR->findDeducedElementType(Val: Arg))
171 return GR->getOrCreateSPIRVVectorType(
172 BaseType: GR->getOrCreateSPIRVPointerType(
173 BaseType: ElemTy, MIRBuilder,
174 SC: addressSpaceToStorageClass(
175 AddrSpace: getPointerAddressSpace(T: OriginalArgType), STI: ST)),
176 NumElements: VTy->getNumElements(), MIRBuilder, EmitIR: true);
177
178 // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
179 // be legally reassigned later).
180 if (!isPointerTy(T: OriginalArgType))
181 return GR->getOrCreateSPIRVType(Type: OriginalArgType, MIRBuilder, AQ: ArgAccessQual,
182 EmitIR: true);
183
184 Type *ArgType = Arg->getType();
185 if (isTypedPointerTy(T: ArgType)) {
186 return GR->getOrCreateSPIRVPointerType(
187 BaseType: cast<TypedPointerType>(Val: ArgType)->getElementType(), MIRBuilder,
188 SC: addressSpaceToStorageClass(AddrSpace: getPointerAddressSpace(T: ArgType), STI: ST));
189 }
190
191 // In case OriginalArgType is of untyped pointer type, there are three
192 // possibilities:
193 // 1) This is a pointer of an LLVM IR element type, passed byval/byref.
194 // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
195 // intrinsic assigning a TargetExtType.
196 // 3) This is a pointer, try to retrieve pointer element type from a
197 // spv_assign_ptr_type intrinsic or otherwise use default pointer element
198 // type.
199 if (hasPointeeTypeAttr(Arg)) {
200 return GR->getOrCreateSPIRVPointerType(
201 BaseType: getPointeeTypeByAttr(Arg), MIRBuilder,
202 SC: addressSpaceToStorageClass(AddrSpace: getPointerAddressSpace(T: ArgType), STI: ST));
203 }
204
205 for (auto User : Arg->users()) {
206 auto *II = dyn_cast<IntrinsicInst>(Val: User);
207 // Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
208 if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
209 MetadataAsValue *VMD = cast<MetadataAsValue>(Val: II->getOperand(i_nocapture: 1));
210 Type *BuiltinType =
211 cast<ConstantAsMetadata>(Val: VMD->getMetadata())->getType();
212 assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
213 return GR->getOrCreateSPIRVType(Type: BuiltinType, MIRBuilder, AQ: ArgAccessQual,
214 EmitIR: true);
215 }
216
217 // Check if this is spv_assign_ptr_type assigning pointer element type.
218 if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
219 continue;
220
221 MetadataAsValue *VMD = cast<MetadataAsValue>(Val: II->getOperand(i_nocapture: 1));
222 Type *ElementTy =
223 toTypedPointer(Ty: cast<ConstantAsMetadata>(Val: VMD->getMetadata())->getType());
224 return GR->getOrCreateSPIRVPointerType(
225 BaseType: ElementTy, MIRBuilder,
226 SC: addressSpaceToStorageClass(
227 AddrSpace: cast<ConstantInt>(Val: II->getOperand(i_nocapture: 2))->getZExtValue(), STI: ST));
228 }
229
230 // Replace PointerType with TypedPointerType to be able to map SPIR-V types to
231 // LLVM types in a consistent manner
232 return GR->getOrCreateSPIRVType(Type: toTypedPointer(Ty: OriginalArgType), MIRBuilder,
233 AQ: ArgAccessQual, EmitIR: true);
234}
235
236static SPIRV::ExecutionModel::ExecutionModel
237getExecutionModel(const SPIRVSubtarget &STI, const Function &F) {
238 assert(STI.getEnv() != SPIRVSubtarget::Unknown &&
239 "Environment must be resolved before lowering entry points.");
240
241 if (STI.isKernel())
242 return SPIRV::ExecutionModel::Kernel;
243
244 auto attribute = F.getFnAttribute(Kind: "hlsl.shader");
245 if (!attribute.isValid()) {
246 report_fatal_error(
247 reason: "This entry point lacks mandatory hlsl.shader attribute.");
248 }
249
250 const auto value = attribute.getValueAsString();
251 if (value == "compute")
252 return SPIRV::ExecutionModel::GLCompute;
253 if (value == "vertex")
254 return SPIRV::ExecutionModel::Vertex;
255 if (value == "pixel")
256 return SPIRV::ExecutionModel::Fragment;
257
258 report_fatal_error(reason: "This HLSL entry point is not supported by this backend.");
259}
260
261bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
262 const Function &F,
263 ArrayRef<ArrayRef<Register>> VRegs,
264 FunctionLoweringInfo &FLI) const {
265 // Discard the internal service function
266 if (F.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
267 return true;
268
269 assert(GR && "Must initialize the SPIRV type registry before lowering args.");
270 GR->setCurrentFunc(MIRBuilder.getMF());
271
272 // Get access to information about available extensions
273 const SPIRVSubtarget *ST =
274 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
275
276 // Assign types and names to all args, and store their types for later.
277 SmallVector<SPIRVTypeInst, 4> ArgTypeVRegs;
278 if (VRegs.size() > 0) {
279 unsigned i = 0;
280 for (const auto &Arg : F.args()) {
281 // Currently formal args should use single registers.
282 // TODO: handle the case of multiple registers.
283 if (VRegs[i].size() > 1)
284 return false;
285 SPIRVTypeInst SpirvTy = getArgSPIRVType(F, ArgIdx: i, GR, MIRBuilder, ST: *ST);
286 GR->assignSPIRVTypeToVReg(Type: SpirvTy, VReg: VRegs[i][0], MF: MIRBuilder.getMF());
287 ArgTypeVRegs.push_back(Elt: SpirvTy);
288
289 if (Arg.hasName())
290 buildOpName(Target: VRegs[i][0], Name: Arg.getName(), MIRBuilder);
291 if (isPointerTyOrWrapper(Ty: Arg.getType())) {
292 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
293 if (DerefBytes != 0)
294 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
295 Dec: SPIRV::Decoration::MaxByteOffset, DecArgs: {DerefBytes});
296 }
297 if (Arg.hasAttribute(Kind: Attribute::Alignment) && !ST->isShader()) {
298 auto Alignment = static_cast<unsigned>(
299 Arg.getAttribute(Kind: Attribute::Alignment).getValueAsInt());
300 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder, Dec: SPIRV::Decoration::Alignment,
301 DecArgs: {Alignment});
302 }
303 if (!ST->isShader()) {
304 if (Arg.hasAttribute(Kind: Attribute::ReadOnly)) {
305 auto Attr =
306 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
307 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
308 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
309 }
310 if (Arg.hasAttribute(Kind: Attribute::ZExt)) {
311 auto Attr =
312 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
313 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
314 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
315 }
316 if (Arg.hasAttribute(Kind: Attribute::SExt)) {
317 auto Attr =
318 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Sext);
319 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
320 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
321 }
322 if (Arg.hasAttribute(Kind: Attribute::NoAlias)) {
323 auto Attr =
324 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
325 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
326 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
327 }
328 // TODO: the AMDGPU BE only supports ByRef argument passing, thus for
329 // AMDGCN flavoured SPIRV we CodeGen for ByRef, but lower it to
330 // ByVal, handling the impedance mismatch during reverse
331 // translation from SPIRV to LLVM IR; the vendor check should be
332 // removed once / if SPIRV adds ByRef support.
333 if (Arg.hasAttribute(Kind: Attribute::ByVal) ||
334 (Arg.hasAttribute(Kind: Attribute::ByRef) &&
335 F.getParent()->getTargetTriple().getVendor() ==
336 Triple::VendorType::AMD)) {
337 auto Attr =
338 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal);
339 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
340 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
341 }
342 if (Arg.hasAttribute(Kind: Attribute::StructRet)) {
343 auto Attr =
344 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Sret);
345 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
346 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
347 }
348 }
349
350 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
351 std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
352 getKernelArgTypeQual(F, ArgIdx: i);
353 for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
354 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder, Dec: Decoration, DecArgs: {});
355 }
356
357 MDNode *Node = F.getMetadata(Kind: "spirv.ParameterDecorations");
358 if (Node && i < Node->getNumOperands() &&
359 isa<MDNode>(Val: Node->getOperand(I: i))) {
360 MDNode *MD = cast<MDNode>(Val: Node->getOperand(I: i));
361 for (const MDOperand &MDOp : MD->operands()) {
362 MDNode *MD2 = dyn_cast<MDNode>(Val: MDOp);
363 assert(MD2 && "Metadata operand is expected");
364 ConstantInt *Const = getConstInt(MD: MD2, NumOp: 0);
365 assert(Const && "MDOperand should be ConstantInt");
366 auto Dec =
367 static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
368 std::vector<uint32_t> DecVec;
369 for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
370 ConstantInt *Const = getConstInt(MD: MD2, NumOp: j);
371 assert(Const && "MDOperand should be ConstantInt");
372 DecVec.push_back(x: static_cast<uint32_t>(Const->getZExtValue()));
373 }
374 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder, Dec, DecArgs: DecVec);
375 }
376 }
377 ++i;
378 }
379 }
380
381 auto MRI = MIRBuilder.getMRI();
382 Register FuncVReg = MRI->createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 64));
383 MRI->setRegClass(Reg: FuncVReg, RC: &SPIRV::iIDRegClass);
384 FunctionType *FTy = SPIRV::getOriginalFunctionType(F);
385 Type *FRetTy = FTy->getReturnType();
386 if (isUntypedPointerTy(T: FRetTy)) {
387 if (Type *FRetElemTy = GR->findDeducedElementType(Val: &F)) {
388 TypedPointerType *DerivedTy = TypedPointerType::get(
389 ElementType: toTypedPointer(Ty: FRetElemTy), AddressSpace: getPointerAddressSpace(T: FRetTy));
390 GR->addReturnType(ArgF: &F, DerivedTy);
391 FRetTy = DerivedTy;
392 }
393 }
394 SPIRVTypeInst RetTy = GR->getOrCreateSPIRVType(
395 Type: FRetTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
396 FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, SRetTy: RetTy, SArgTys: ArgTypeVRegs);
397 SPIRVTypeInst FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
398 Ty: FTy, RetType: RetTy, ArgTypes: ArgTypeVRegs, MIRBuilder);
399 uint32_t FuncControl = getFunctionControl(F, ST);
400
401 // Add OpFunction instruction
402 MachineInstrBuilder MB = MIRBuilder.buildInstr(Opcode: SPIRV::OpFunction)
403 .addDef(RegNo: FuncVReg)
404 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: RetTy))
405 .addImm(Val: FuncControl)
406 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: FuncTy));
407 GR->recordFunctionDefinition(F: &F, MO: &MB.getInstr()->getOperand(i: 0));
408 GR->addGlobalObject(V: &F, MF: &MIRBuilder.getMF(), R: FuncVReg);
409 if (F.isDeclaration())
410 GR->add(V: &F, MI: MB);
411
412 // Add OpFunctionParameter instructions
413 int i = 0;
414 for (const auto &Arg : F.args()) {
415 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
416 Register ArgReg = VRegs[i][0];
417 MRI->setRegClass(Reg: ArgReg, RC: GR->getRegClass(SpvType: ArgTypeVRegs[i]));
418 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpFunctionParameter)
419 .addDef(RegNo: ArgReg)
420 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: ArgTypeVRegs[i]));
421 if (F.isDeclaration())
422 GR->add(V: &Arg, MI: MIB);
423 GR->addGlobalObject(V: &Arg, MF: &MIRBuilder.getMF(), R: ArgReg);
424 i++;
425 }
426 // Name the function.
427 if (F.hasName())
428 buildOpName(Target: FuncVReg, Name: F.getName(), MIRBuilder);
429
430 // Handle entry points and function linkage.
431 if (isEntryPoint(F)) {
432 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpEntryPoint)
433 .addImm(Val: static_cast<uint32_t>(getExecutionModel(STI: *ST, F)))
434 .addUse(RegNo: FuncVReg);
435 addStringImm(Str: F.getName(), MIB);
436 } else if (const auto LnkTy = getSpirvLinkageTypeFor(ST: *ST, GV: F)) {
437 buildOpDecorate(Reg: FuncVReg, MIRBuilder, Dec: SPIRV::Decoration::LinkageAttributes,
438 DecArgs: {static_cast<uint32_t>(*LnkTy)}, StrImm: F.getName());
439 }
440
441 // Handle function pointers decoration
442 bool hasFunctionPointers =
443 ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers);
444 if (hasFunctionPointers) {
445 if (F.hasFnAttribute(Kind: "referenced-indirectly")) {
446 assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
447 "Unexpected 'referenced-indirectly' attribute of the kernel "
448 "function");
449 buildOpDecorate(Reg: FuncVReg, MIRBuilder,
450 Dec: SPIRV::Decoration::ReferencedIndirectlyINTEL, DecArgs: {});
451 }
452 }
453
454 return true;
455}
456
457// TODO:
458// - add a topological sort of IndirectCalls to ensure the best types knowledge
459// - we may need to fix function formal parameter types if they are opaque
460// pointers used as function pointers in these indirect calls
461// - defaulting to StorageClass::Function in the absence of the
462// SPV_INTEL_function_pointers extension seems wrong, as that might not be
463// able to hold a full width pointer to function, and it also does not model
464// the semantics of a pointer to function in a generic fashion.
465void SPIRVCallLowering::produceIndirectPtrType(
466 MachineIRBuilder &MIRBuilder,
467 const SPIRVCallLowering::SPIRVIndirectCall &IC) const {
468 // Create indirect call data type if any
469 MachineFunction &MF = MIRBuilder.getMF();
470 const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
471 SPIRVTypeInst SpirvRetTy = GR->getOrCreateSPIRVType(
472 Type: IC.RetTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
473 SmallVector<SPIRVTypeInst, 4> SpirvArgTypes;
474 for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
475 SPIRVTypeInst SPIRVTy = GR->getOrCreateSPIRVType(
476 Type: IC.ArgTys[i], MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
477 SpirvArgTypes.push_back(Elt: SPIRVTy);
478 if (!GR->getSPIRVTypeForVReg(VReg: IC.ArgRegs[i]))
479 GR->assignSPIRVTypeToVReg(Type: SPIRVTy, VReg: IC.ArgRegs[i], MF);
480 }
481 // SPIR-V function type:
482 FunctionType *FTy =
483 FunctionType::get(Result: const_cast<Type *>(IC.RetTy), Params: IC.ArgTys, isVarArg: false);
484 SPIRVTypeInst SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
485 Ty: FTy, RetType: SpirvRetTy, ArgTypes: SpirvArgTypes, MIRBuilder);
486 // SPIR-V pointer to function type:
487 auto SC = ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers)
488 ? SPIRV::StorageClass::CodeSectionINTEL
489 : SPIRV::StorageClass::Function;
490 SPIRVTypeInst IndirectFuncPtrTy =
491 GR->getOrCreateSPIRVPointerType(BaseType: SpirvFuncTy, MIRBuilder, SC);
492 // Correct the Callee type
493 GR->assignSPIRVTypeToVReg(Type: IndirectFuncPtrTy, VReg: IC.Callee, MF);
494}
495
496bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
497 CallLoweringInfo &Info) const {
498 // Currently call returns should have single vregs.
499 // TODO: handle the case of multiple registers.
500 if (Info.OrigRet.Regs.size() > 1)
501 return false;
502 MachineFunction &MF = MIRBuilder.getMF();
503 GR->setCurrentFunc(MF);
504 const Function *CF = nullptr;
505 std::string DemangledName;
506 const Type *OrigRetTy = Info.OrigRet.Ty;
507
508 // Emit a regular OpFunctionCall. If it's an externally declared function,
509 // be sure to emit its type and function declaration here. It will be hoisted
510 // globally later.
511 if (Info.Callee.isGlobal()) {
512 std::string FuncName = Info.Callee.getGlobal()->getName().str();
513 DemangledName = getOclOrSpirvBuiltinDemangledName(Name: FuncName);
514 CF = dyn_cast_or_null<const Function>(Val: Info.Callee.getGlobal());
515 // TODO: support constexpr casts and indirect calls.
516 if (CF == nullptr)
517 return false;
518
519 FunctionType *FTy = SPIRV::getOriginalFunctionType(F: *CF);
520 OrigRetTy = FTy->getReturnType();
521 if (isUntypedPointerTy(T: OrigRetTy)) {
522 if (auto *DerivedRetTy = GR->findReturnType(ArgF: CF))
523 OrigRetTy = DerivedRetTy;
524 }
525 }
526
527 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
528 Register ResVReg =
529 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
530 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
531
532 bool isFunctionDecl = CF && CF->isDeclaration();
533 if (isFunctionDecl && !DemangledName.empty()) {
534 if (ResVReg.isValid()) {
535 if (!GR->getSPIRVTypeForVReg(VReg: ResVReg)) {
536 const Type *RetTy = OrigRetTy;
537 if (auto *PtrRetTy = dyn_cast<PointerType>(Val: OrigRetTy)) {
538 const Value *OrigValue = Info.OrigRet.OrigValue;
539 if (!OrigValue)
540 OrigValue = Info.CB;
541 if (OrigValue)
542 if (Type *ElemTy = GR->findDeducedElementType(Val: OrigValue))
543 RetTy =
544 TypedPointerType::get(ElementType: ElemTy, AddressSpace: PtrRetTy->getAddressSpace());
545 }
546 setRegClassType(Reg: ResVReg, Ty: RetTy, GR, MIRBuilder,
547 AccessQual: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
548 }
549 } else {
550 ResVReg = createVirtualRegister(Ty: OrigRetTy, GR, MIRBuilder,
551 AccessQual: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
552 }
553 SmallVector<Register, 8> ArgVRegs;
554 for (auto Arg : Info.OrigArgs) {
555 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
556 Register ArgReg = Arg.Regs[0];
557 ArgVRegs.push_back(Elt: ArgReg);
558 SPIRVTypeInst SpvType = GR->getSPIRVTypeForVReg(VReg: ArgReg);
559 if (!SpvType) {
560 Type *ArgTy = nullptr;
561 if (auto *PtrArgTy = dyn_cast<PointerType>(Val: Arg.Ty)) {
562 // If Arg.Ty is an untyped pointer (i.e., ptr [addrspace(...)]) and we
563 // don't have access to original value in LLVM IR or info about
564 // deduced pointee type, then we should wait with setting the type for
565 // the virtual register until pre-legalizer step when we access
566 // @llvm.spv.assign.ptr.type.p...(...)'s info.
567 if (Arg.OrigValue)
568 if (Type *ElemTy = GR->findDeducedElementType(Val: Arg.OrigValue))
569 ArgTy =
570 TypedPointerType::get(ElementType: ElemTy, AddressSpace: PtrArgTy->getAddressSpace());
571 } else {
572 ArgTy = Arg.Ty;
573 }
574 if (ArgTy) {
575 SpvType = GR->getOrCreateSPIRVType(
576 Type: ArgTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
577 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: ArgReg, MF);
578 }
579 }
580 if (!MRI->getRegClassOrNull(Reg: ArgReg)) {
581 // Either we have SpvType created, or Arg.Ty is an untyped pointer and
582 // we know its virtual register's class and type even if we don't know
583 // pointee type.
584 MRI->setRegClass(Reg: ArgReg, RC: SpvType ? GR->getRegClass(SpvType)
585 : &SPIRV::pIDRegClass);
586 MRI->setType(
587 VReg: ArgReg,
588 Ty: SpvType ? GR->getRegType(SpvType)
589 : LLT::pointer(AddressSpace: cast<PointerType>(Val: Arg.Ty)->getAddressSpace(),
590 SizeInBits: GR->getPointerSize()));
591 }
592 }
593 if (auto Res = SPIRV::lowerBuiltin(
594 DemangledCall: DemangledName, Set: ST->getPreferredInstructionSet(), MIRBuilder,
595 OrigRet: ResVReg, OrigRetTy, Args: ArgVRegs, GR, CB: *Info.CB))
596 return *Res;
597 }
598
599 if (isFunctionDecl && !GR->find(V: CF, MF: &MF).isValid()) {
600 // Emit the type info and forward function declaration to the first MBB
601 // to ensure VReg definition dependencies are valid across all MBBs.
602 MachineIRBuilder FirstBlockBuilder;
603 FirstBlockBuilder.setMF(MF);
604 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(N: 0));
605
606 SmallVector<ArrayRef<Register>, 8> VRegArgs;
607 SmallVector<SmallVector<Register, 1>, 8> ToInsert;
608 for (const Argument &Arg : CF->args()) {
609 if (MIRBuilder.getDataLayout().getTypeStoreSize(Ty: Arg.getType()).isZero())
610 continue; // Don't handle zero sized types.
611 Register Reg = MRI->createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 64));
612 MRI->setRegClass(Reg, RC: &SPIRV::iIDRegClass);
613 ToInsert.push_back(Elt: {Reg});
614 VRegArgs.push_back(Elt: ToInsert.back());
615 }
616 // TODO: Reuse FunctionLoweringInfo
617 FunctionLoweringInfo FuncInfo;
618 lowerFormalArguments(MIRBuilder&: FirstBlockBuilder, F: *CF, VRegs: VRegArgs, FLI&: FuncInfo);
619 }
620
621 // Ignore the call if it's called from the internal service function
622 if (MIRBuilder.getMF()
623 .getFunction()
624 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
625 .isValid()) {
626 // insert a no-op
627 MIRBuilder.buildTrap();
628 return true;
629 }
630
631 unsigned CallOp;
632 if (Info.CB->isIndirectCall()) {
633 if (!ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers))
634 report_fatal_error(reason: "An indirect call is encountered but SPIR-V without "
635 "extensions does not support it",
636 gen_crash_diag: false);
637 // Set instruction operation according to SPV_INTEL_function_pointers
638 CallOp = SPIRV::OpFunctionPointerCallINTEL;
639 // Collect information about the indirect call to create correct types.
640 Register CalleeReg = Info.Callee.getReg();
641 if (CalleeReg.isValid()) {
642 SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
643 IndirectCall.Callee = CalleeReg;
644 FunctionType *FTy = SPIRV::getOriginalFunctionType(CB: *Info.CB);
645 IndirectCall.RetTy = OrigRetTy = FTy->getReturnType();
646 assert(FTy->getNumParams() == Info.OrigArgs.size() &&
647 "Function types mismatch");
648 for (unsigned I = 0; I != Info.OrigArgs.size(); ++I) {
649 assert(Info.OrigArgs[I].Regs.size() == 1 &&
650 "Call arg has multiple VRegs");
651 IndirectCall.ArgTys.push_back(Elt: FTy->getParamType(i: I));
652 IndirectCall.ArgRegs.push_back(Elt: Info.OrigArgs[I].Regs[0]);
653 }
654 produceIndirectPtrType(MIRBuilder, IC: IndirectCall);
655 }
656 } else {
657 // Emit a regular OpFunctionCall
658 CallOp = SPIRV::OpFunctionCall;
659 }
660
661 // Make sure there's a valid return reg, even for functions returning void.
662 if (!ResVReg.isValid())
663 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(RegClass: &SPIRV::iIDRegClass);
664 SPIRVTypeInst RetType = GR->assignTypeToVReg(
665 Type: OrigRetTy, VReg: ResVReg, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
666
667 // Emit the call instruction and its args.
668 auto MIB = MIRBuilder.buildInstr(Opcode: CallOp)
669 .addDef(RegNo: ResVReg)
670 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: RetType))
671 .add(MO: Info.Callee);
672
673 for (const auto &Arg : Info.OrigArgs) {
674 // Currently call args should have single vregs.
675 if (Arg.Regs.size() > 1)
676 return false;
677 MIB.addUse(RegNo: Arg.Regs[0]);
678 }
679
680 if (ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_memory_access_aliasing)) {
681 // Process aliasing metadata.
682 const CallBase *CI = Info.CB;
683 if (CI && CI->hasMetadata()) {
684 if (MDNode *MD = CI->getMetadata(KindID: LLVMContext::MD_alias_scope))
685 GR->buildMemAliasingOpDecorate(Reg: ResVReg, MIRBuilder,
686 Dec: SPIRV::Decoration::AliasScopeINTEL, GVarMD: MD);
687 if (MDNode *MD = CI->getMetadata(KindID: LLVMContext::MD_noalias))
688 GR->buildMemAliasingOpDecorate(Reg: ResVReg, MIRBuilder,
689 Dec: SPIRV::Decoration::NoAliasINTEL, GVarMD: MD);
690 }
691 }
692
693 MIB.constrainAllUses(TII: MIRBuilder.getTII(), TRI: *ST->getRegisterInfo(),
694 RBI: *ST->getRegBankInfo());
695 return true;
696}
697