1//===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the lowering of LLVM calls to machine code calls for
10// GlobalISel.
11//
12//===----------------------------------------------------------------------===//
13
14#include "SPIRVCallLowering.h"
15#include "MCTargetDesc/SPIRVBaseInfo.h"
16#include "SPIRV.h"
17#include "SPIRVBuiltins.h"
18#include "SPIRVGlobalRegistry.h"
19#include "SPIRVISelLowering.h"
20#include "SPIRVMetadata.h"
21#include "SPIRVRegisterInfo.h"
22#include "SPIRVSubtarget.h"
23#include "SPIRVUtils.h"
24#include "llvm/CodeGen/FunctionLoweringInfo.h"
25#include "llvm/IR/IntrinsicInst.h"
26#include "llvm/IR/IntrinsicsSPIRV.h"
27#include "llvm/Support/ModRef.h"
28
29using namespace llvm;
30
31SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
32 SPIRVGlobalRegistry *GR)
33 : CallLowering(&TLI), GR(GR) {}
34
35bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
36 const Value *Val, ArrayRef<Register> VRegs,
37 FunctionLoweringInfo &FLI,
38 Register SwiftErrorVReg) const {
39 // Ignore if called from the internal service function
40 if (MIRBuilder.getMF()
41 .getFunction()
42 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
43 .isValid())
44 return true;
45
46 // Maybe run postponed production of types for function pointers
47 if (IndirectCalls.size() > 0) {
48 produceIndirectPtrTypes(MIRBuilder);
49 IndirectCalls.clear();
50 }
51
52 // Currently all return types should use a single register.
53 // TODO: handle the case of multiple registers.
54 if (VRegs.size() > 1)
55 return false;
56 if (Val) {
57 const auto &STI = MIRBuilder.getMF().getSubtarget();
58 return MIRBuilder.buildInstr(Opcode: SPIRV::OpReturnValue)
59 .addUse(RegNo: VRegs[0])
60 .constrainAllUses(TII: MIRBuilder.getTII(), TRI: *STI.getRegisterInfo(),
61 RBI: *STI.getRegBankInfo());
62 }
63 MIRBuilder.buildInstr(Opcode: SPIRV::OpReturn);
64 return true;
65}
66
67// Based on the LLVM function attributes, get a SPIR-V FunctionControl.
68static uint32_t getFunctionControl(const Function &F,
69 const SPIRVSubtarget *ST) {
70 MemoryEffects MemEffects = F.getMemoryEffects();
71
72 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
73
74 if (F.hasFnAttribute(Kind: Attribute::AttrKind::NoInline))
75 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
76 else if (F.hasFnAttribute(Kind: Attribute::AttrKind::AlwaysInline))
77 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
78
79 if (MemEffects.doesNotAccessMemory())
80 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
81 else if (MemEffects.onlyReadsMemory())
82 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
83
84 if (ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_optnone) ||
85 ST->canUseExtension(E: SPIRV::Extension::SPV_EXT_optnone))
86 if (F.hasFnAttribute(Kind: Attribute::OptimizeNone))
87 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::OptNoneEXT);
88
89 return FuncControl;
90}
91
92static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
93 if (MD->getNumOperands() > NumOp) {
94 auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MD->getOperand(I: NumOp));
95 if (CMeta)
96 return dyn_cast<ConstantInt>(Val: CMeta->getValue());
97 }
98 return nullptr;
99}
100
101// If the function has pointer arguments, we are forced to re-create this
102// function type from the very beginning, changing PointerType by
103// TypedPointerType for each pointer argument. Otherwise, the same `Type*`
104// potentially corresponds to different SPIR-V function type, effectively
105// invalidating logic behind global registry and duplicates tracker.
106static FunctionType *
107fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
108 FunctionType *FTy, const SPIRVType *SRetTy,
109 const SmallVector<SPIRVType *, 4> &SArgTys) {
110 bool hasArgPtrs = false;
111 for (auto &Arg : F.args()) {
112 // check if it's an instance of a non-typed PointerType
113 if (Arg.getType()->isPointerTy()) {
114 hasArgPtrs = true;
115 break;
116 }
117 }
118 if (!hasArgPtrs) {
119 Type *RetTy = FTy->getReturnType();
120 // check if it's an instance of a non-typed PointerType
121 if (!RetTy->isPointerTy())
122 return FTy;
123 }
124
125 // re-create function type, using TypedPointerType instead of PointerType to
126 // properly trace argument types
127 const Type *RetTy = GR->getTypeForSPIRVType(Ty: SRetTy);
128 SmallVector<Type *, 4> ArgTys;
129 for (auto SArgTy : SArgTys)
130 ArgTys.push_back(Elt: const_cast<Type *>(GR->getTypeForSPIRVType(Ty: SArgTy)));
131 return FunctionType::get(Result: const_cast<Type *>(RetTy), Params: ArgTys, isVarArg: false);
132}
133
134// This code restores function args/retvalue types for composite cases
135// because the final types should still be aggregate whereas they're i32
136// during the translation to cope with aggregate flattening etc.
137static FunctionType *getOriginalFunctionType(const Function &F) {
138 auto *NamedMD = F.getParent()->getNamedMetadata(Name: "spv.cloned_funcs");
139 if (NamedMD == nullptr)
140 return F.getFunctionType();
141
142 Type *RetTy = F.getFunctionType()->getReturnType();
143 SmallVector<Type *, 4> ArgTypes;
144 for (auto &Arg : F.args())
145 ArgTypes.push_back(Elt: Arg.getType());
146
147 auto ThisFuncMDIt =
148 std::find_if(first: NamedMD->op_begin(), last: NamedMD->op_end(), pred: [&F](MDNode *N) {
149 return isa<MDString>(Val: N->getOperand(I: 0)) &&
150 cast<MDString>(Val: N->getOperand(I: 0))->getString() == F.getName();
151 });
152 // TODO: probably one function can have numerous type mutations,
153 // so we should support this.
154 if (ThisFuncMDIt != NamedMD->op_end()) {
155 auto *ThisFuncMD = *ThisFuncMDIt;
156 MDNode *MD = dyn_cast<MDNode>(Val: ThisFuncMD->getOperand(I: 1));
157 assert(MD && "MDNode operand is expected");
158 ConstantInt *Const = getConstInt(MD, NumOp: 0);
159 if (Const) {
160 auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MD->getOperand(I: 1));
161 assert(CMeta && "ConstantAsMetadata operand is expected");
162 assert(Const->getSExtValue() >= -1);
163 // Currently -1 indicates return value, greater values mean
164 // argument numbers.
165 if (Const->getSExtValue() == -1)
166 RetTy = CMeta->getType();
167 else
168 ArgTypes[Const->getSExtValue()] = CMeta->getType();
169 }
170 }
171
172 return FunctionType::get(Result: RetTy, Params: ArgTypes, isVarArg: F.isVarArg());
173}
174
175static SPIRV::AccessQualifier::AccessQualifier
176getArgAccessQual(const Function &F, unsigned ArgIdx) {
177 if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
178 return SPIRV::AccessQualifier::ReadWrite;
179
180 MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
181 if (!ArgAttribute)
182 return SPIRV::AccessQualifier::ReadWrite;
183
184 if (ArgAttribute->getString() == "read_only")
185 return SPIRV::AccessQualifier::ReadOnly;
186 if (ArgAttribute->getString() == "write_only")
187 return SPIRV::AccessQualifier::WriteOnly;
188 return SPIRV::AccessQualifier::ReadWrite;
189}
190
191static std::vector<SPIRV::Decoration::Decoration>
192getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
193 MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
194 if (ArgAttribute && ArgAttribute->getString() == "volatile")
195 return {SPIRV::Decoration::Volatile};
196 return {};
197}
198
199static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
200 SPIRVGlobalRegistry *GR,
201 MachineIRBuilder &MIRBuilder,
202 const SPIRVSubtarget &ST) {
203 // Read argument's access qualifier from metadata or default.
204 SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
205 getArgAccessQual(F, ArgIdx);
206
207 Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(i: ArgIdx);
208
209 // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
210 // be legally reassigned later).
211 if (!isPointerTy(T: OriginalArgType))
212 return GR->getOrCreateSPIRVType(Type: OriginalArgType, MIRBuilder, AQ: ArgAccessQual,
213 EmitIR: true);
214
215 Argument *Arg = F.getArg(i: ArgIdx);
216 Type *ArgType = Arg->getType();
217 if (isTypedPointerTy(T: ArgType)) {
218 return GR->getOrCreateSPIRVPointerType(
219 BaseType: cast<TypedPointerType>(Val: ArgType)->getElementType(), MIRBuilder,
220 SC: addressSpaceToStorageClass(AddrSpace: getPointerAddressSpace(T: ArgType), STI: ST));
221 }
222
223 // In case OriginalArgType is of untyped pointer type, there are three
224 // possibilities:
225 // 1) This is a pointer of an LLVM IR element type, passed byval/byref.
226 // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
227 // intrinsic assigning a TargetExtType.
228 // 3) This is a pointer, try to retrieve pointer element type from a
229 // spv_assign_ptr_type intrinsic or otherwise use default pointer element
230 // type.
231 if (hasPointeeTypeAttr(Arg)) {
232 return GR->getOrCreateSPIRVPointerType(
233 BaseType: getPointeeTypeByAttr(Arg), MIRBuilder,
234 SC: addressSpaceToStorageClass(AddrSpace: getPointerAddressSpace(T: ArgType), STI: ST));
235 }
236
237 for (auto User : Arg->users()) {
238 auto *II = dyn_cast<IntrinsicInst>(Val: User);
239 // Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
240 if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
241 MetadataAsValue *VMD = cast<MetadataAsValue>(Val: II->getOperand(i_nocapture: 1));
242 Type *BuiltinType =
243 cast<ConstantAsMetadata>(Val: VMD->getMetadata())->getType();
244 assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
245 return GR->getOrCreateSPIRVType(Type: BuiltinType, MIRBuilder, AQ: ArgAccessQual,
246 EmitIR: true);
247 }
248
249 // Check if this is spv_assign_ptr_type assigning pointer element type.
250 if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
251 continue;
252
253 MetadataAsValue *VMD = cast<MetadataAsValue>(Val: II->getOperand(i_nocapture: 1));
254 Type *ElementTy =
255 toTypedPointer(Ty: cast<ConstantAsMetadata>(Val: VMD->getMetadata())->getType());
256 return GR->getOrCreateSPIRVPointerType(
257 BaseType: ElementTy, MIRBuilder,
258 SC: addressSpaceToStorageClass(
259 AddrSpace: cast<ConstantInt>(Val: II->getOperand(i_nocapture: 2))->getZExtValue(), STI: ST));
260 }
261
262 // Replace PointerType with TypedPointerType to be able to map SPIR-V types to
263 // LLVM types in a consistent manner
264 return GR->getOrCreateSPIRVType(Type: toTypedPointer(Ty: OriginalArgType), MIRBuilder,
265 AQ: ArgAccessQual, EmitIR: true);
266}
267
268static SPIRV::ExecutionModel::ExecutionModel
269getExecutionModel(const SPIRVSubtarget &STI, const Function &F) {
270 if (STI.isKernel())
271 return SPIRV::ExecutionModel::Kernel;
272
273 if (STI.isShader()) {
274 auto attribute = F.getFnAttribute(Kind: "hlsl.shader");
275 if (!attribute.isValid()) {
276 report_fatal_error(
277 reason: "This entry point lacks mandatory hlsl.shader attribute.");
278 }
279
280 const auto value = attribute.getValueAsString();
281 if (value == "compute")
282 return SPIRV::ExecutionModel::GLCompute;
283 if (value == "vertex")
284 return SPIRV::ExecutionModel::Vertex;
285 if (value == "pixel")
286 return SPIRV::ExecutionModel::Fragment;
287
288 report_fatal_error(
289 reason: "This HLSL entry point is not supported by this backend.");
290 }
291
292 assert(STI.getEnv() == SPIRVSubtarget::Unknown);
293 // "hlsl.shader" attribute is mandatory for Vulkan, so we can set Env to
294 // Shader whenever we find it, and to Kernel otherwise.
295
296 // We will now change the Env based on the attribute, so we need to strip
297 // `const` out of the ref to STI.
298 SPIRVSubtarget *NonConstSTI = const_cast<SPIRVSubtarget *>(&STI);
299 auto attribute = F.getFnAttribute(Kind: "hlsl.shader");
300 if (!attribute.isValid()) {
301 NonConstSTI->setEnv(SPIRVSubtarget::Kernel);
302 return SPIRV::ExecutionModel::Kernel;
303 }
304 NonConstSTI->setEnv(SPIRVSubtarget::Shader);
305
306 const auto value = attribute.getValueAsString();
307 if (value == "compute")
308 return SPIRV::ExecutionModel::GLCompute;
309 if (value == "vertex")
310 return SPIRV::ExecutionModel::Vertex;
311 if (value == "pixel")
312 return SPIRV::ExecutionModel::Fragment;
313
314 report_fatal_error(reason: "This HLSL entry point is not supported by this backend.");
315}
316
317bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
318 const Function &F,
319 ArrayRef<ArrayRef<Register>> VRegs,
320 FunctionLoweringInfo &FLI) const {
321 // Discard the internal service function
322 if (F.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
323 return true;
324
325 assert(GR && "Must initialize the SPIRV type registry before lowering args.");
326 GR->setCurrentFunc(MIRBuilder.getMF());
327
328 // Get access to information about available extensions
329 const SPIRVSubtarget *ST =
330 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
331
332 // Assign types and names to all args, and store their types for later.
333 SmallVector<SPIRVType *, 4> ArgTypeVRegs;
334 if (VRegs.size() > 0) {
335 unsigned i = 0;
336 for (const auto &Arg : F.args()) {
337 // Currently formal args should use single registers.
338 // TODO: handle the case of multiple registers.
339 if (VRegs[i].size() > 1)
340 return false;
341 auto *SpirvTy = getArgSPIRVType(F, ArgIdx: i, GR, MIRBuilder, ST: *ST);
342 GR->assignSPIRVTypeToVReg(Type: SpirvTy, VReg: VRegs[i][0], MF: MIRBuilder.getMF());
343 ArgTypeVRegs.push_back(Elt: SpirvTy);
344
345 if (Arg.hasName())
346 buildOpName(Target: VRegs[i][0], Name: Arg.getName(), MIRBuilder);
347 if (isPointerTyOrWrapper(Ty: Arg.getType())) {
348 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
349 if (DerefBytes != 0)
350 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
351 Dec: SPIRV::Decoration::MaxByteOffset, DecArgs: {DerefBytes});
352 }
353 if (Arg.hasAttribute(Kind: Attribute::Alignment) && !ST->isShader()) {
354 auto Alignment = static_cast<unsigned>(
355 Arg.getAttribute(Kind: Attribute::Alignment).getValueAsInt());
356 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder, Dec: SPIRV::Decoration::Alignment,
357 DecArgs: {Alignment});
358 }
359 if (Arg.hasAttribute(Kind: Attribute::ReadOnly)) {
360 auto Attr =
361 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
362 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
363 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
364 }
365 if (Arg.hasAttribute(Kind: Attribute::ZExt)) {
366 auto Attr =
367 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
368 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
369 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
370 }
371 if (Arg.hasAttribute(Kind: Attribute::NoAlias)) {
372 auto Attr =
373 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
374 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
375 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
376 }
377 if (Arg.hasAttribute(Kind: Attribute::ByVal)) {
378 auto Attr =
379 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal);
380 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
381 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
382 }
383 if (Arg.hasAttribute(Kind: Attribute::StructRet)) {
384 auto Attr =
385 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Sret);
386 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder,
387 Dec: SPIRV::Decoration::FuncParamAttr, DecArgs: {Attr});
388 }
389
390 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
391 std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
392 getKernelArgTypeQual(F, ArgIdx: i);
393 for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
394 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder, Dec: Decoration, DecArgs: {});
395 }
396
397 MDNode *Node = F.getMetadata(Kind: "spirv.ParameterDecorations");
398 if (Node && i < Node->getNumOperands() &&
399 isa<MDNode>(Val: Node->getOperand(I: i))) {
400 MDNode *MD = cast<MDNode>(Val: Node->getOperand(I: i));
401 for (const MDOperand &MDOp : MD->operands()) {
402 MDNode *MD2 = dyn_cast<MDNode>(Val: MDOp);
403 assert(MD2 && "Metadata operand is expected");
404 ConstantInt *Const = getConstInt(MD: MD2, NumOp: 0);
405 assert(Const && "MDOperand should be ConstantInt");
406 auto Dec =
407 static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
408 std::vector<uint32_t> DecVec;
409 for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
410 ConstantInt *Const = getConstInt(MD: MD2, NumOp: j);
411 assert(Const && "MDOperand should be ConstantInt");
412 DecVec.push_back(x: static_cast<uint32_t>(Const->getZExtValue()));
413 }
414 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder, Dec, DecArgs: DecVec);
415 }
416 }
417 ++i;
418 }
419 }
420
421 auto MRI = MIRBuilder.getMRI();
422 Register FuncVReg = MRI->createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 64));
423 MRI->setRegClass(Reg: FuncVReg, RC: &SPIRV::iIDRegClass);
424 FunctionType *FTy = getOriginalFunctionType(F);
425 Type *FRetTy = FTy->getReturnType();
426 if (isUntypedPointerTy(T: FRetTy)) {
427 if (Type *FRetElemTy = GR->findDeducedElementType(Val: &F)) {
428 TypedPointerType *DerivedTy = TypedPointerType::get(
429 ElementType: toTypedPointer(Ty: FRetElemTy), AddressSpace: getPointerAddressSpace(T: FRetTy));
430 GR->addReturnType(ArgF: &F, DerivedTy);
431 FRetTy = DerivedTy;
432 }
433 }
434 SPIRVType *RetTy = GR->getOrCreateSPIRVType(
435 Type: FRetTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
436 FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, SRetTy: RetTy, SArgTys: ArgTypeVRegs);
437 SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
438 Ty: FTy, RetType: RetTy, ArgTypes: ArgTypeVRegs, MIRBuilder);
439 uint32_t FuncControl = getFunctionControl(F, ST);
440
441 // Add OpFunction instruction
442 MachineInstrBuilder MB = MIRBuilder.buildInstr(Opcode: SPIRV::OpFunction)
443 .addDef(RegNo: FuncVReg)
444 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: RetTy))
445 .addImm(Val: FuncControl)
446 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: FuncTy));
447 GR->recordFunctionDefinition(F: &F, MO: &MB.getInstr()->getOperand(i: 0));
448 GR->addGlobalObject(V: &F, MF: &MIRBuilder.getMF(), R: FuncVReg);
449 if (F.isDeclaration())
450 GR->add(V: &F, MI: MB);
451
452 // Add OpFunctionParameter instructions
453 int i = 0;
454 for (const auto &Arg : F.args()) {
455 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
456 Register ArgReg = VRegs[i][0];
457 MRI->setRegClass(Reg: ArgReg, RC: GR->getRegClass(SpvType: ArgTypeVRegs[i]));
458 MRI->setType(VReg: ArgReg, Ty: GR->getRegType(SpvType: ArgTypeVRegs[i]));
459 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpFunctionParameter)
460 .addDef(RegNo: ArgReg)
461 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: ArgTypeVRegs[i]));
462 if (F.isDeclaration())
463 GR->add(V: &Arg, MI: MIB);
464 GR->addGlobalObject(V: &Arg, MF: &MIRBuilder.getMF(), R: ArgReg);
465 i++;
466 }
467 // Name the function.
468 if (F.hasName())
469 buildOpName(Target: FuncVReg, Name: F.getName(), MIRBuilder);
470
471 // Handle entry points and function linkage.
472 if (isEntryPoint(F)) {
473 // EntryPoints can help us to determine the environment we're working on.
474 // Therefore, we need a non-const pointer to SPIRVSubtarget to update the
475 // environment if we need to.
476 const SPIRVSubtarget *ST =
477 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
478 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpEntryPoint)
479 .addImm(Val: static_cast<uint32_t>(getExecutionModel(STI: *ST, F)))
480 .addUse(RegNo: FuncVReg);
481 addStringImm(Str: F.getName(), MIB);
482 } else if (F.getLinkage() != GlobalValue::InternalLinkage &&
483 F.getLinkage() != GlobalValue::PrivateLinkage &&
484 F.getVisibility() != GlobalValue::HiddenVisibility) {
485 SPIRV::LinkageType::LinkageType LnkTy =
486 F.isDeclaration()
487 ? SPIRV::LinkageType::Import
488 : (F.getLinkage() == GlobalValue::LinkOnceODRLinkage &&
489 ST->canUseExtension(
490 E: SPIRV::Extension::SPV_KHR_linkonce_odr)
491 ? SPIRV::LinkageType::LinkOnceODR
492 : SPIRV::LinkageType::Export);
493 buildOpDecorate(Reg: FuncVReg, MIRBuilder, Dec: SPIRV::Decoration::LinkageAttributes,
494 DecArgs: {static_cast<uint32_t>(LnkTy)}, StrImm: F.getName());
495 }
496
497 // Handle function pointers decoration
498 bool hasFunctionPointers =
499 ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers);
500 if (hasFunctionPointers) {
501 if (F.hasFnAttribute(Kind: "referenced-indirectly")) {
502 assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
503 "Unexpected 'referenced-indirectly' attribute of the kernel "
504 "function");
505 buildOpDecorate(Reg: FuncVReg, MIRBuilder,
506 Dec: SPIRV::Decoration::ReferencedIndirectlyINTEL, DecArgs: {});
507 }
508 }
509
510 return true;
511}
512
513// Used to postpone producing of indirect function pointer types after all
514// indirect calls info is collected
515// TODO:
516// - add a topological sort of IndirectCalls to ensure the best types knowledge
517// - we may need to fix function formal parameter types if they are opaque
518// pointers used as function pointers in these indirect calls
519void SPIRVCallLowering::produceIndirectPtrTypes(
520 MachineIRBuilder &MIRBuilder) const {
521 // Create indirect call data types if any
522 MachineFunction &MF = MIRBuilder.getMF();
523 for (auto const &IC : IndirectCalls) {
524 SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(
525 Type: IC.RetTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
526 SmallVector<SPIRVType *, 4> SpirvArgTypes;
527 for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
528 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(
529 Type: IC.ArgTys[i], MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
530 SpirvArgTypes.push_back(Elt: SPIRVTy);
531 if (!GR->getSPIRVTypeForVReg(VReg: IC.ArgRegs[i]))
532 GR->assignSPIRVTypeToVReg(Type: SPIRVTy, VReg: IC.ArgRegs[i], MF);
533 }
534 // SPIR-V function type:
535 FunctionType *FTy =
536 FunctionType::get(Result: const_cast<Type *>(IC.RetTy), Params: IC.ArgTys, isVarArg: false);
537 SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
538 Ty: FTy, RetType: SpirvRetTy, ArgTypes: SpirvArgTypes, MIRBuilder);
539 // SPIR-V pointer to function type:
540 SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
541 BaseType: SpirvFuncTy, MIRBuilder, SC: SPIRV::StorageClass::Function);
542 // Correct the Callee type
543 GR->assignSPIRVTypeToVReg(Type: IndirectFuncPtrTy, VReg: IC.Callee, MF);
544 }
545}
546
547bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
548 CallLoweringInfo &Info) const {
549 // Currently call returns should have single vregs.
550 // TODO: handle the case of multiple registers.
551 if (Info.OrigRet.Regs.size() > 1)
552 return false;
553 MachineFunction &MF = MIRBuilder.getMF();
554 GR->setCurrentFunc(MF);
555 const Function *CF = nullptr;
556 std::string DemangledName;
557 const Type *OrigRetTy = Info.OrigRet.Ty;
558
559 // Emit a regular OpFunctionCall. If it's an externally declared function,
560 // be sure to emit its type and function declaration here. It will be hoisted
561 // globally later.
562 if (Info.Callee.isGlobal()) {
563 std::string FuncName = Info.Callee.getGlobal()->getName().str();
564 DemangledName = getOclOrSpirvBuiltinDemangledName(Name: FuncName);
565 CF = dyn_cast_or_null<const Function>(Val: Info.Callee.getGlobal());
566 // TODO: support constexpr casts and indirect calls.
567 if (CF == nullptr)
568 return false;
569 if (FunctionType *FTy = getOriginalFunctionType(F: *CF)) {
570 OrigRetTy = FTy->getReturnType();
571 if (isUntypedPointerTy(T: OrigRetTy)) {
572 if (auto *DerivedRetTy = GR->findReturnType(ArgF: CF))
573 OrigRetTy = DerivedRetTy;
574 }
575 }
576 }
577
578 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
579 Register ResVReg =
580 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
581 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
582
583 bool isFunctionDecl = CF && CF->isDeclaration();
584 if (isFunctionDecl && !DemangledName.empty()) {
585 if (ResVReg.isValid()) {
586 if (!GR->getSPIRVTypeForVReg(VReg: ResVReg)) {
587 const Type *RetTy = OrigRetTy;
588 if (auto *PtrRetTy = dyn_cast<PointerType>(Val: OrigRetTy)) {
589 const Value *OrigValue = Info.OrigRet.OrigValue;
590 if (!OrigValue)
591 OrigValue = Info.CB;
592 if (OrigValue)
593 if (Type *ElemTy = GR->findDeducedElementType(Val: OrigValue))
594 RetTy =
595 TypedPointerType::get(ElementType: ElemTy, AddressSpace: PtrRetTy->getAddressSpace());
596 }
597 setRegClassType(Reg: ResVReg, Ty: RetTy, GR, MIRBuilder,
598 AccessQual: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
599 }
600 } else {
601 ResVReg = createVirtualRegister(Ty: OrigRetTy, GR, MIRBuilder,
602 AccessQual: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
603 }
604 SmallVector<Register, 8> ArgVRegs;
605 for (auto Arg : Info.OrigArgs) {
606 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
607 Register ArgReg = Arg.Regs[0];
608 ArgVRegs.push_back(Elt: ArgReg);
609 SPIRVType *SpvType = GR->getSPIRVTypeForVReg(VReg: ArgReg);
610 if (!SpvType) {
611 Type *ArgTy = nullptr;
612 if (auto *PtrArgTy = dyn_cast<PointerType>(Val: Arg.Ty)) {
613 // If Arg.Ty is an untyped pointer (i.e., ptr [addrspace(...)]) and we
614 // don't have access to original value in LLVM IR or info about
615 // deduced pointee type, then we should wait with setting the type for
616 // the virtual register until pre-legalizer step when we access
617 // @llvm.spv.assign.ptr.type.p...(...)'s info.
618 if (Arg.OrigValue)
619 if (Type *ElemTy = GR->findDeducedElementType(Val: Arg.OrigValue))
620 ArgTy =
621 TypedPointerType::get(ElementType: ElemTy, AddressSpace: PtrArgTy->getAddressSpace());
622 } else {
623 ArgTy = Arg.Ty;
624 }
625 if (ArgTy) {
626 SpvType = GR->getOrCreateSPIRVType(
627 Type: ArgTy, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
628 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: ArgReg, MF);
629 }
630 }
631 if (!MRI->getRegClassOrNull(Reg: ArgReg)) {
632 // Either we have SpvType created, or Arg.Ty is an untyped pointer and
633 // we know its virtual register's class and type even if we don't know
634 // pointee type.
635 MRI->setRegClass(Reg: ArgReg, RC: SpvType ? GR->getRegClass(SpvType)
636 : &SPIRV::pIDRegClass);
637 MRI->setType(
638 VReg: ArgReg,
639 Ty: SpvType ? GR->getRegType(SpvType)
640 : LLT::pointer(AddressSpace: cast<PointerType>(Val: Arg.Ty)->getAddressSpace(),
641 SizeInBits: GR->getPointerSize()));
642 }
643 }
644 if (auto Res =
645 SPIRV::lowerBuiltin(DemangledCall: DemangledName, Set: ST->getPreferredInstructionSet(),
646 MIRBuilder, OrigRet: ResVReg, OrigRetTy, Args: ArgVRegs, GR))
647 return *Res;
648 }
649
650 if (isFunctionDecl && !GR->find(V: CF, MF: &MF).isValid()) {
651 // Emit the type info and forward function declaration to the first MBB
652 // to ensure VReg definition dependencies are valid across all MBBs.
653 MachineIRBuilder FirstBlockBuilder;
654 FirstBlockBuilder.setMF(MF);
655 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(N: 0));
656
657 SmallVector<ArrayRef<Register>, 8> VRegArgs;
658 SmallVector<SmallVector<Register, 1>, 8> ToInsert;
659 for (const Argument &Arg : CF->args()) {
660 if (MIRBuilder.getDataLayout().getTypeStoreSize(Ty: Arg.getType()).isZero())
661 continue; // Don't handle zero sized types.
662 Register Reg = MRI->createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 64));
663 MRI->setRegClass(Reg, RC: &SPIRV::iIDRegClass);
664 ToInsert.push_back(Elt: {Reg});
665 VRegArgs.push_back(Elt: ToInsert.back());
666 }
667 // TODO: Reuse FunctionLoweringInfo
668 FunctionLoweringInfo FuncInfo;
669 lowerFormalArguments(MIRBuilder&: FirstBlockBuilder, F: *CF, VRegs: VRegArgs, FLI&: FuncInfo);
670 }
671
672 // Ignore the call if it's called from the internal service function
673 if (MIRBuilder.getMF()
674 .getFunction()
675 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
676 .isValid()) {
677 // insert a no-op
678 MIRBuilder.buildTrap();
679 return true;
680 }
681
682 unsigned CallOp;
683 if (Info.CB->isIndirectCall()) {
684 if (!ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers))
685 report_fatal_error(reason: "An indirect call is encountered but SPIR-V without "
686 "extensions does not support it",
687 gen_crash_diag: false);
688 // Set instruction operation according to SPV_INTEL_function_pointers
689 CallOp = SPIRV::OpFunctionPointerCallINTEL;
690 // Collect information about the indirect call to support possible
691 // specification of opaque ptr types of parent function's parameters
692 Register CalleeReg = Info.Callee.getReg();
693 if (CalleeReg.isValid()) {
694 SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
695 IndirectCall.Callee = CalleeReg;
696 IndirectCall.RetTy = OrigRetTy;
697 for (const auto &Arg : Info.OrigArgs) {
698 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
699 IndirectCall.ArgTys.push_back(Elt: Arg.Ty);
700 IndirectCall.ArgRegs.push_back(Elt: Arg.Regs[0]);
701 }
702 IndirectCalls.push_back(Elt: IndirectCall);
703 }
704 } else {
705 // Emit a regular OpFunctionCall
706 CallOp = SPIRV::OpFunctionCall;
707 }
708
709 // Make sure there's a valid return reg, even for functions returning void.
710 if (!ResVReg.isValid())
711 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(RegClass: &SPIRV::iIDRegClass);
712 SPIRVType *RetType = GR->assignTypeToVReg(
713 Type: OrigRetTy, VReg: ResVReg, MIRBuilder, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: true);
714
715 // Emit the call instruction and its args.
716 auto MIB = MIRBuilder.buildInstr(Opcode: CallOp)
717 .addDef(RegNo: ResVReg)
718 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: RetType))
719 .add(MO: Info.Callee);
720
721 for (const auto &Arg : Info.OrigArgs) {
722 // Currently call args should have single vregs.
723 if (Arg.Regs.size() > 1)
724 return false;
725 MIB.addUse(RegNo: Arg.Regs[0]);
726 }
727
728 if (ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_memory_access_aliasing)) {
729 // Process aliasing metadata.
730 const CallBase *CI = Info.CB;
731 if (CI && CI->hasMetadata()) {
732 if (MDNode *MD = CI->getMetadata(KindID: LLVMContext::MD_alias_scope))
733 GR->buildMemAliasingOpDecorate(Reg: ResVReg, MIRBuilder,
734 Dec: SPIRV::Decoration::AliasScopeINTEL, GVarMD: MD);
735 if (MDNode *MD = CI->getMetadata(KindID: LLVMContext::MD_noalias))
736 GR->buildMemAliasingOpDecorate(Reg: ResVReg, MIRBuilder,
737 Dec: SPIRV::Decoration::NoAliasINTEL, GVarMD: MD);
738 }
739 }
740
741 return MIB.constrainAllUses(TII: MIRBuilder.getTII(), TRI: *ST->getRegisterInfo(),
742 RBI: *ST->getRegBankInfo());
743}
744