1//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This provides an abstract class for HLSL code generation. Concrete
10// subclasses of this implement code generation for specific HLSL
11// runtime libraries.
12//
13//===----------------------------------------------------------------------===//
14
15#include "CGHLSLRuntime.h"
16#include "CGDebugInfo.h"
17#include "CodeGenModule.h"
18#include "clang/AST/Decl.h"
19#include "clang/Basic/TargetOptions.h"
20#include "llvm/IR/Metadata.h"
21#include "llvm/IR/Module.h"
22#include "llvm/Support/FormatVariadic.h"
23
24using namespace clang;
25using namespace CodeGen;
26using namespace clang::hlsl;
27using namespace llvm;
28
29namespace {
30
31void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
32 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
33 // Assume ValVersionStr is legal here.
34 VersionTuple Version;
35 if (Version.tryParse(string: ValVersionStr) || Version.getBuild() ||
36 Version.getSubminor() || !Version.getMinor()) {
37 return;
38 }
39
40 uint64_t Major = Version.getMajor();
41 uint64_t Minor = *Version.getMinor();
42
43 auto &Ctx = M.getContext();
44 IRBuilder<> B(M.getContext());
45 MDNode *Val = MDNode::get(Context&: Ctx, MDs: {ConstantAsMetadata::get(C: B.getInt32(C: Major)),
46 ConstantAsMetadata::get(C: B.getInt32(C: Minor))});
47 StringRef DXILValKey = "dx.valver";
48 auto *DXILValMD = M.getOrInsertNamedMetadata(Name: DXILValKey);
49 DXILValMD->addOperand(M: Val);
50}
51void addDisableOptimizations(llvm::Module &M) {
52 StringRef Key = "dx.disable_optimizations";
53 M.addModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Override, Key, Val: 1);
54}
55// cbuffer will be translated into global variable in special address space.
56// If translate into C,
57// cbuffer A {
58// float a;
59// float b;
60// }
61// float foo() { return a + b; }
62//
63// will be translated into
64//
65// struct A {
66// float a;
67// float b;
68// } cbuffer_A __attribute__((address_space(4)));
69// float foo() { return cbuffer_A.a + cbuffer_A.b; }
70//
71// layoutBuffer will create the struct A type.
72// replaceBuffer will replace use of global variable a and b with cbuffer_A.a
73// and cbuffer_A.b.
74//
75void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
76 if (Buf.Constants.empty())
77 return;
78
79 std::vector<llvm::Type *> EltTys;
80 for (auto &Const : Buf.Constants) {
81 GlobalVariable *GV = Const.first;
82 Const.second = EltTys.size();
83 llvm::Type *Ty = GV->getValueType();
84 EltTys.emplace_back(args&: Ty);
85 }
86 Buf.LayoutStruct = llvm::StructType::get(Context&: EltTys[0]->getContext(), Elements: EltTys);
87}
88
89GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
90 // Create global variable for CB.
91 GlobalVariable *CBGV = new GlobalVariable(
92 Buf.LayoutStruct, /*isConstant*/ true,
93 GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
94 llvm::formatv(Fmt: "{0}{1}", Vals&: Buf.Name, Vals: Buf.IsCBuffer ? ".cb." : ".tb."),
95 GlobalValue::NotThreadLocal);
96
97 IRBuilder<> B(CBGV->getContext());
98 Value *ZeroIdx = B.getInt32(C: 0);
99 // Replace Const use with CB use.
100 for (auto &[GV, Offset] : Buf.Constants) {
101 Value *GEP =
102 B.CreateGEP(Ty: Buf.LayoutStruct, Ptr: CBGV, IdxList: {ZeroIdx, B.getInt32(C: Offset)});
103
104 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
105 "constant type mismatch");
106
107 // Replace.
108 GV->replaceAllUsesWith(V: GEP);
109 // Erase GV.
110 GV->removeDeadConstantUsers();
111 GV->eraseFromParent();
112 }
113 return CBGV;
114}
115
116} // namespace
117
118llvm::Triple::ArchType CGHLSLRuntime::getArch() {
119 return CGM.getTarget().getTriple().getArch();
120}
121
122void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
123 if (D->getStorageClass() == SC_Static) {
124 // For static inside cbuffer, take as global static.
125 // Don't add to cbuffer.
126 CGM.EmitGlobal(D);
127 return;
128 }
129
130 auto *GV = cast<GlobalVariable>(Val: CGM.GetAddrOfGlobalVar(D));
131 // Add debug info for constVal.
132 if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
133 if (CGM.getCodeGenOpts().getDebugInfo() >=
134 codegenoptions::DebugInfoKind::LimitedDebugInfo)
135 DI->EmitGlobalVariable(GV: cast<GlobalVariable>(Val: GV), Decl: D);
136
137 // FIXME: support packoffset.
138 // See https://github.com/llvm/llvm-project/issues/57914.
139 uint32_t Offset = 0;
140 bool HasUserOffset = false;
141
142 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
143 CB.Constants.emplace_back(args: std::make_pair(x&: GV, y&: LowerBound));
144}
145
146void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
147 for (Decl *it : DC->decls()) {
148 if (auto *ConstDecl = dyn_cast<VarDecl>(Val: it)) {
149 addConstant(D: ConstDecl, CB);
150 } else if (isa<CXXRecordDecl, EmptyDecl>(Val: it)) {
151 // Nothing to do for this declaration.
152 } else if (isa<FunctionDecl>(Val: it)) {
153 // A function within an cbuffer is effectively a top-level function,
154 // as it only refers to globally scoped declarations.
155 CGM.EmitTopLevelDecl(D: it);
156 }
157 }
158}
159
160void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {
161 Buffers.emplace_back(Args: Buffer(D));
162 addBufferDecls(DC: D, CB&: Buffers.back());
163}
164
165void CGHLSLRuntime::finishCodeGen() {
166 auto &TargetOpts = CGM.getTarget().getTargetOpts();
167 llvm::Module &M = CGM.getModule();
168 Triple T(M.getTargetTriple());
169 if (T.getArch() == Triple::ArchType::dxil)
170 addDxilValVersion(ValVersionStr: TargetOpts.DxilValidatorVersion, M);
171
172 generateGlobalCtorDtorCalls();
173 if (CGM.getCodeGenOpts().OptimizationLevel == 0)
174 addDisableOptimizations(M);
175
176 const DataLayout &DL = M.getDataLayout();
177
178 for (auto &Buf : Buffers) {
179 layoutBuffer(Buf, DL);
180 GlobalVariable *GV = replaceBuffer(Buf);
181 M.insertGlobalVariable(GV);
182 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
183 ? llvm::hlsl::ResourceClass::CBuffer
184 : llvm::hlsl::ResourceClass::SRV;
185 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
186 ? llvm::hlsl::ResourceKind::CBuffer
187 : llvm::hlsl::ResourceKind::TBuffer;
188 addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false,
189 ET: llvm::hlsl::ElementType::Invalid, Binding&: Buf.Binding);
190 }
191}
192
193CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
194 : Name(D->getName()), IsCBuffer(D->isCBuffer()),
195 Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
196
197void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
198 llvm::hlsl::ResourceClass RC,
199 llvm::hlsl::ResourceKind RK,
200 bool IsROV,
201 llvm::hlsl::ElementType ET,
202 BufferResBinding &Binding) {
203 llvm::Module &M = CGM.getModule();
204
205 NamedMDNode *ResourceMD = nullptr;
206 switch (RC) {
207 case llvm::hlsl::ResourceClass::UAV:
208 ResourceMD = M.getOrInsertNamedMetadata(Name: "hlsl.uavs");
209 break;
210 case llvm::hlsl::ResourceClass::SRV:
211 ResourceMD = M.getOrInsertNamedMetadata(Name: "hlsl.srvs");
212 break;
213 case llvm::hlsl::ResourceClass::CBuffer:
214 ResourceMD = M.getOrInsertNamedMetadata(Name: "hlsl.cbufs");
215 break;
216 default:
217 assert(false && "Unsupported buffer type!");
218 return;
219 }
220 assert(ResourceMD != nullptr &&
221 "ResourceMD must have been set by the switch above.");
222
223 llvm::hlsl::FrontendResource Res(
224 GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
225 ResourceMD->addOperand(M: Res.getMetadata());
226}
227
228static llvm::hlsl::ElementType
229calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) {
230 using llvm::hlsl::ElementType;
231
232 // TODO: We may need to update this when we add things like ByteAddressBuffer
233 // that don't have a template parameter (or, indeed, an element type).
234 const auto *TST = ResourceTy->getAs<TemplateSpecializationType>();
235 assert(TST && "Resource types must be template specializations");
236 ArrayRef<TemplateArgument> Args = TST->template_arguments();
237 assert(!Args.empty() && "Resource has no element type");
238
239 // At this point we have a resource with an element type, so we can assume
240 // that it's valid or we would have diagnosed the error earlier.
241 QualType ElTy = Args[0].getAsType();
242
243 // We should either have a basic type or a vector of a basic type.
244 if (const auto *VecTy = ElTy->getAs<clang::VectorType>())
245 ElTy = VecTy->getElementType();
246
247 if (ElTy->isSignedIntegerType()) {
248 switch (Context.getTypeSize(T: ElTy)) {
249 case 16:
250 return ElementType::I16;
251 case 32:
252 return ElementType::I32;
253 case 64:
254 return ElementType::I64;
255 }
256 } else if (ElTy->isUnsignedIntegerType()) {
257 switch (Context.getTypeSize(T: ElTy)) {
258 case 16:
259 return ElementType::U16;
260 case 32:
261 return ElementType::U32;
262 case 64:
263 return ElementType::U64;
264 }
265 } else if (ElTy->isSpecificBuiltinType(K: BuiltinType::Half))
266 return ElementType::F16;
267 else if (ElTy->isSpecificBuiltinType(K: BuiltinType::Float))
268 return ElementType::F32;
269 else if (ElTy->isSpecificBuiltinType(K: BuiltinType::Double))
270 return ElementType::F64;
271
272 // TODO: We need to handle unorm/snorm float types here once we support them
273 llvm_unreachable("Invalid element type for resource");
274}
275
276void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
277 const Type *Ty = D->getType()->getPointeeOrArrayElementType();
278 if (!Ty)
279 return;
280 const auto *RD = Ty->getAsCXXRecordDecl();
281 if (!RD)
282 return;
283 const auto *HLSLResAttr = RD->getAttr<HLSLResourceAttr>();
284 const auto *HLSLResClassAttr = RD->getAttr<HLSLResourceClassAttr>();
285 if (!HLSLResAttr || !HLSLResClassAttr)
286 return;
287
288 llvm::hlsl::ResourceClass RC = HLSLResClassAttr->getResourceClass();
289 llvm::hlsl::ResourceKind RK = HLSLResAttr->getResourceKind();
290 bool IsROV = HLSLResAttr->getIsROV();
291 llvm::hlsl::ElementType ET = calculateElementType(Context: CGM.getContext(), ResourceTy: Ty);
292
293 BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
294 addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
295}
296
297CGHLSLRuntime::BufferResBinding::BufferResBinding(
298 HLSLResourceBindingAttr *Binding) {
299 if (Binding) {
300 llvm::APInt RegInt(64, 0);
301 Binding->getSlot().substr(Start: 1).getAsInteger(Radix: 10, Result&: RegInt);
302 Reg = RegInt.getLimitedValue();
303 llvm::APInt SpaceInt(64, 0);
304 Binding->getSpace().substr(Start: 5).getAsInteger(Radix: 10, Result&: SpaceInt);
305 Space = SpaceInt.getLimitedValue();
306 } else {
307 Space = 0;
308 }
309}
310
311void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
312 const FunctionDecl *FD, llvm::Function *Fn) {
313 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
314 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
315 const StringRef ShaderAttrKindStr = "hlsl.shader";
316 Fn->addFnAttr(Kind: ShaderAttrKindStr,
317 Val: llvm::Triple::getEnvironmentTypeName(Kind: ShaderAttr->getType()));
318 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
319 const StringRef NumThreadsKindStr = "hlsl.numthreads";
320 std::string NumThreadsStr =
321 formatv(Fmt: "{0},{1},{2}", Vals: NumThreadsAttr->getX(), Vals: NumThreadsAttr->getY(),
322 Vals: NumThreadsAttr->getZ());
323 Fn->addFnAttr(Kind: NumThreadsKindStr, Val: NumThreadsStr);
324 }
325}
326
327static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
328 if (const auto *VT = dyn_cast<FixedVectorType>(Val: Ty)) {
329 Value *Result = PoisonValue::get(T: Ty);
330 for (unsigned I = 0; I < VT->getNumElements(); ++I) {
331 Value *Elt = B.CreateCall(Callee: F, Args: {B.getInt32(C: I)});
332 Result = B.CreateInsertElement(Vec: Result, NewElt: Elt, Idx: I);
333 }
334 return Result;
335 }
336 return B.CreateCall(Callee: F, Args: {B.getInt32(C: 0)});
337}
338
339llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
340 const ParmVarDecl &D,
341 llvm::Type *Ty) {
342 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
343 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
344 llvm::Function *DxGroupIndex =
345 CGM.getIntrinsic(IID: Intrinsic::dx_flattened_thread_id_in_group);
346 return B.CreateCall(Callee: FunctionCallee(DxGroupIndex));
347 }
348 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
349 llvm::Function *ThreadIDIntrinsic =
350 CGM.getIntrinsic(IID: getThreadIdIntrinsic());
351 return buildVectorInput(B, F: ThreadIDIntrinsic, Ty);
352 }
353 assert(false && "Unhandled parameter attribute");
354 return nullptr;
355}
356
357void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
358 llvm::Function *Fn) {
359 llvm::Module &M = CGM.getModule();
360 llvm::LLVMContext &Ctx = M.getContext();
361 auto *EntryTy = llvm::FunctionType::get(Result: llvm::Type::getVoidTy(C&: Ctx), isVarArg: false);
362 Function *EntryFn =
363 Function::Create(Ty: EntryTy, Linkage: Function::ExternalLinkage, N: FD->getName(), M: &M);
364
365 // Copy function attributes over, we have no argument or return attributes
366 // that can be valid on the real entry.
367 AttributeList NewAttrs = AttributeList::get(C&: Ctx, Index: AttributeList::FunctionIndex,
368 Attrs: Fn->getAttributes().getFnAttrs());
369 EntryFn->setAttributes(NewAttrs);
370 setHLSLEntryAttributes(FD, Fn: EntryFn);
371
372 // Set the called function as internal linkage.
373 Fn->setLinkage(GlobalValue::InternalLinkage);
374
375 BasicBlock *BB = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: EntryFn);
376 IRBuilder<> B(BB);
377 llvm::SmallVector<Value *> Args;
378 // FIXME: support struct parameters where semantics are on members.
379 // See: https://github.com/llvm/llvm-project/issues/57874
380 unsigned SRetOffset = 0;
381 for (const auto &Param : Fn->args()) {
382 if (Param.hasStructRetAttr()) {
383 // FIXME: support output.
384 // See: https://github.com/llvm/llvm-project/issues/57874
385 SRetOffset = 1;
386 Args.emplace_back(Args: PoisonValue::get(T: Param.getType()));
387 continue;
388 }
389 const ParmVarDecl *PD = FD->getParamDecl(i: Param.getArgNo() - SRetOffset);
390 Args.push_back(Elt: emitInputSemantic(B, D: *PD, Ty: Param.getType()));
391 }
392
393 CallInst *CI = B.CreateCall(Callee: FunctionCallee(Fn), Args);
394 (void)CI;
395 // FIXME: Handle codegen for return type semantics.
396 // See: https://github.com/llvm/llvm-project/issues/57875
397 B.CreateRetVoid();
398}
399
400static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
401 bool CtorOrDtor) {
402 const auto *GV =
403 M.getNamedGlobal(Name: CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
404 if (!GV)
405 return;
406 const auto *CA = dyn_cast<ConstantArray>(Val: GV->getInitializer());
407 if (!CA)
408 return;
409 // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
410 // HLSL neither supports priorities or COMDat values, so we will check those
411 // in an assert but not handle them.
412
413 llvm::SmallVector<Function *> CtorFns;
414 for (const auto &Ctor : CA->operands()) {
415 if (isa<ConstantAggregateZero>(Val: Ctor))
416 continue;
417 ConstantStruct *CS = cast<ConstantStruct>(Val: Ctor);
418
419 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
420 "HLSL doesn't support setting priority for global ctors.");
421 assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
422 "HLSL doesn't support COMDat for global ctors.");
423 Fns.push_back(Elt: cast<Function>(Val: CS->getOperand(i_nocapture: 1)));
424 }
425}
426
427void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
428 llvm::Module &M = CGM.getModule();
429 SmallVector<Function *> CtorFns;
430 SmallVector<Function *> DtorFns;
431 gatherFunctions(Fns&: CtorFns, M, CtorOrDtor: true);
432 gatherFunctions(Fns&: DtorFns, M, CtorOrDtor: false);
433
434 // Insert a call to the global constructor at the beginning of the entry block
435 // to externally exported functions. This is a bit of a hack, but HLSL allows
436 // global constructors, but doesn't support driver initialization of globals.
437 for (auto &F : M.functions()) {
438 if (!F.hasFnAttribute(Kind: "hlsl.shader"))
439 continue;
440 IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
441 for (auto *Fn : CtorFns)
442 B.CreateCall(Callee: FunctionCallee(Fn));
443
444 // Insert global dtors before the terminator of the last instruction
445 B.SetInsertPoint(F.back().getTerminator());
446 for (auto *Fn : DtorFns)
447 B.CreateCall(Callee: FunctionCallee(Fn));
448 }
449
450 // No need to keep global ctors/dtors for non-lib profile after call to
451 // ctors/dtors added for entry.
452 Triple T(M.getTargetTriple());
453 if (T.getEnvironment() != Triple::EnvironmentType::Library) {
454 if (auto *GV = M.getNamedGlobal(Name: "llvm.global_ctors"))
455 GV->eraseFromParent();
456 if (auto *GV = M.getNamedGlobal(Name: "llvm.global_dtors"))
457 GV->eraseFromParent();
458 }
459}
460