1//===----- CGCUDARuntime.cpp - Interface to CUDA Runtimes -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This provides an abstract class for CUDA code generation. Concrete
10// subclasses of this implement code generation for specific CUDA
11// runtime libraries.
12//
13//===----------------------------------------------------------------------===//
14
15#include "CGCUDARuntime.h"
16#include "CGCall.h"
17#include "CodeGenFunction.h"
18#include "clang/AST/ExprCXX.h"
19
20using namespace clang;
21using namespace CodeGen;
22
23CGCUDARuntime::~CGCUDARuntime() {}
24
25static llvm::Value *emitGetParamBuf(CodeGenFunction &CGF,
26 const CUDAKernelCallExpr *E) {
27 auto *GetParamBuf = CGF.getContext().getcudaGetParameterBufferDecl();
28 const FunctionProtoType *GetParamBufProto =
29 GetParamBuf->getType()->getAs<FunctionProtoType>();
30
31 DeclRefExpr *DRE = DeclRefExpr::Create(
32 Context: CGF.getContext(), QualifierLoc: {}, TemplateKWLoc: {}, D: GetParamBuf,
33 /*RefersToEnclosingVariableOrCapture=*/false, NameInfo: GetParamBuf->getNameInfo(),
34 T: GetParamBuf->getType(), VK: VK_PRValue);
35 auto *ImpCast = ImplicitCastExpr::Create(
36 Context: CGF.getContext(), T: CGF.getContext().getPointerType(T: GetParamBuf->getType()),
37 Kind: CK_FunctionToPointerDecay, Operand: DRE, BasePath: nullptr, Cat: VK_PRValue, FPO: FPOptionsOverride());
38
39 CGCallee Callee = CGF.EmitCallee(E: ImpCast);
40 CallArgList Args;
41 // Use 64B alignment.
42 Args.add(rvalue: RValue::get(V: CGF.CGM.getSize(numChars: CharUnits::fromQuantity(Quantity: 64))),
43 type: CGF.getContext().getSizeType());
44 // Calculate parameter sizes.
45 const PointerType *PT = E->getCallee()->getType()->getAs<PointerType>();
46 const FunctionProtoType *FTP =
47 PT->getPointeeType()->getAs<FunctionProtoType>();
48 CharUnits Offset = CharUnits::Zero();
49 for (auto ArgTy : FTP->getParamTypes()) {
50 auto TInfo = CGF.CGM.getContext().getTypeInfoInChars(T: ArgTy);
51 Offset = Offset.alignTo(Align: TInfo.Align) + TInfo.Width;
52 }
53 Args.add(rvalue: RValue::get(V: CGF.CGM.getSize(numChars: Offset)),
54 type: CGF.getContext().getSizeType());
55 const CGFunctionInfo &CallInfo = CGF.CGM.getTypes().arrangeFreeFunctionCall(
56 Args, Ty: GetParamBufProto, /*ChainCall=*/false);
57 auto Ret = CGF.EmitCall(CallInfo, Callee, /*ReturnValue=*/{}, Args);
58
59 return Ret.getScalarVal();
60}
61
62RValue CGCUDARuntime::EmitCUDADeviceKernelCallExpr(
63 CodeGenFunction &CGF, const CUDAKernelCallExpr *E,
64 ReturnValueSlot ReturnValue, llvm::CallBase **CallOrInvoke) {
65 assert(CGM.getContext().getcudaLaunchDeviceDecl() ==
66 E->getConfig()->getDirectCallee());
67
68 llvm::BasicBlock *ConfigOKBlock = CGF.createBasicBlock(name: "dkcall.configok");
69 llvm::BasicBlock *ContBlock = CGF.createBasicBlock(name: "dkcall.end");
70
71 llvm::Value *Config = emitGetParamBuf(CGF, E);
72 CGF.Builder.CreateCondBr(
73 Cond: CGF.Builder.CreateICmpNE(LHS: Config,
74 RHS: llvm::Constant::getNullValue(Ty: Config->getType())),
75 True: ConfigOKBlock, False: ContBlock);
76
77 CodeGenFunction::ConditionalEvaluation eval(CGF);
78
79 eval.begin(CGF);
80 CGF.EmitBlock(BB: ConfigOKBlock);
81
82 QualType KernelCalleeFuncTy =
83 E->getCallee()->getType()->getAs<PointerType>()->getPointeeType();
84 CGCallee KernelCallee = CGF.EmitCallee(E: E->getCallee());
85 // Emit kernel arguments.
86 CallArgList KernelCallArgs;
87 CGF.EmitCallArgs(Args&: KernelCallArgs,
88 Prototype: KernelCalleeFuncTy->getAs<FunctionProtoType>(),
89 ArgRange: E->arguments(), AC: E->getDirectCallee());
90 // Copy emitted kernel arguments into that parameter buffer.
91 RawAddress CfgBase(Config, CGM.Int8Ty,
92 /*Alignment=*/CharUnits::fromQuantity(Quantity: 64));
93 CharUnits Offset = CharUnits::Zero();
94 for (auto &Arg : KernelCallArgs) {
95 auto TInfo = CGM.getContext().getTypeInfoInChars(T: Arg.getType());
96 Offset = Offset.alignTo(Align: TInfo.Align);
97 Address Addr =
98 CGF.Builder.CreateConstInBoundsGEP(Addr: CfgBase, Index: Offset.getQuantity());
99 Arg.copyInto(CGF, A: Addr);
100 Offset += TInfo.Width;
101 }
102 // Make `cudaLaunchDevice` call, i.e. E->getConfig().
103 const CallExpr *LaunchCall = E->getConfig();
104 QualType LaunchCalleeFuncTy = LaunchCall->getCallee()
105 ->getType()
106 ->getAs<PointerType>()
107 ->getPointeeType();
108 CGCallee LaunchCallee = CGF.EmitCallee(E: LaunchCall->getCallee());
109 CallArgList LaunchCallArgs;
110 CGF.EmitCallArgs(Args&: LaunchCallArgs,
111 Prototype: LaunchCalleeFuncTy->getAs<FunctionProtoType>(),
112 ArgRange: LaunchCall->arguments(), AC: LaunchCall->getDirectCallee());
113 // Replace func and paramterbuffer arguments.
114 LaunchCallArgs[0] = CallArg(RValue::get(V: KernelCallee.getFunctionPointer()),
115 CGM.getContext().VoidPtrTy);
116 LaunchCallArgs[1] = CallArg(RValue::get(V: Config), CGM.getContext().VoidPtrTy);
117 const CGFunctionInfo &LaunchCallInfo = CGM.getTypes().arrangeFreeFunctionCall(
118 Args: LaunchCallArgs, Ty: LaunchCalleeFuncTy->getAs<FunctionProtoType>(),
119 /*ChainCall=*/false);
120 CGF.EmitCall(CallInfo: LaunchCallInfo, Callee: LaunchCallee, ReturnValue, Args: LaunchCallArgs,
121 CallOrInvoke,
122 /*IsMustTail=*/false, Loc: E->getExprLoc());
123 CGF.EmitBranch(Block: ContBlock);
124
125 CGF.EmitBlock(BB: ContBlock);
126 eval.end(CGF);
127
128 return RValue::get(V: nullptr);
129}
130
131RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
132 const CUDAKernelCallExpr *E,
133 ReturnValueSlot ReturnValue,
134 llvm::CallBase **CallOrInvoke) {
135 llvm::BasicBlock *ConfigOKBlock = CGF.createBasicBlock(name: "kcall.configok");
136 llvm::BasicBlock *ContBlock = CGF.createBasicBlock(name: "kcall.end");
137
138 CodeGenFunction::ConditionalEvaluation eval(CGF);
139 CGF.EmitBranchOnBoolExpr(Cond: E->getConfig(), TrueBlock: ContBlock, FalseBlock: ConfigOKBlock,
140 /*TrueCount=*/0);
141
142 eval.begin(CGF);
143 CGF.EmitBlock(BB: ConfigOKBlock);
144 CGF.EmitSimpleCallExpr(E, ReturnValue, CallOrInvoke);
145 CGF.EmitBranch(Block: ContBlock);
146
147 CGF.EmitBlock(BB: ContBlock);
148 eval.end(CGF);
149
150 return RValue::get(V: nullptr);
151}
152