1//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This provides an abstract class for HLSL code generation. Concrete
10// subclasses of this implement code generation for specific HLSL
11// runtime libraries.
12//
13//===----------------------------------------------------------------------===//
14
15#include "CGHLSLRuntime.h"
16#include "CGDebugInfo.h"
17#include "CGRecordLayout.h"
18#include "CodeGenFunction.h"
19#include "CodeGenModule.h"
20#include "HLSLBufferLayoutBuilder.h"
21#include "TargetInfo.h"
22#include "clang/AST/ASTContext.h"
23#include "clang/AST/Attrs.inc"
24#include "clang/AST/Decl.h"
25#include "clang/AST/Expr.h"
26#include "clang/AST/HLSLResource.h"
27#include "clang/AST/RecursiveASTVisitor.h"
28#include "clang/AST/Type.h"
29#include "clang/Basic/DiagnosticFrontend.h"
30#include "clang/Basic/TargetOptions.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/ScopeExit.h"
33#include "llvm/ADT/SmallString.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
36#include "llvm/IR/Constants.h"
37#include "llvm/IR/DerivedTypes.h"
38#include "llvm/IR/GlobalVariable.h"
39#include "llvm/IR/LLVMContext.h"
40#include "llvm/IR/Metadata.h"
41#include "llvm/IR/Module.h"
42#include "llvm/IR/Type.h"
43#include "llvm/IR/Value.h"
44#include "llvm/Support/Alignment.h"
45#include "llvm/Support/ErrorHandling.h"
46#include "llvm/Support/FormatVariadic.h"
47#include <cstdint>
48#include <optional>
49
50using namespace clang;
51using namespace CodeGen;
52using namespace clang::hlsl;
53using namespace llvm;
54
55using llvm::hlsl::CBufferRowSizeInBytes;
56
57namespace {
58
59void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
60 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
61 // Assume ValVersionStr is legal here.
62 VersionTuple Version;
63 if (Version.tryParse(string: ValVersionStr) || Version.getBuild() ||
64 Version.getSubminor() || !Version.getMinor()) {
65 return;
66 }
67
68 uint64_t Major = Version.getMajor();
69 uint64_t Minor = *Version.getMinor();
70
71 auto &Ctx = M.getContext();
72 IRBuilder<> B(M.getContext());
73 MDNode *Val = MDNode::get(Context&: Ctx, MDs: {ConstantAsMetadata::get(C: B.getInt32(C: Major)),
74 ConstantAsMetadata::get(C: B.getInt32(C: Minor))});
75 StringRef DXILValKey = "dx.valver";
76 auto *DXILValMD = M.getOrInsertNamedMetadata(Name: DXILValKey);
77 DXILValMD->addOperand(M: Val);
78}
79
80void addRootSignatureMD(llvm::dxbc::RootSignatureVersion RootSigVer,
81 ArrayRef<llvm::hlsl::rootsig::RootElement> Elements,
82 llvm::Function *Fn, llvm::Module &M) {
83 auto &Ctx = M.getContext();
84
85 llvm::hlsl::rootsig::MetadataBuilder RSBuilder(Ctx, Elements);
86 MDNode *RootSignature = RSBuilder.BuildRootSignature();
87
88 ConstantAsMetadata *Version = ConstantAsMetadata::get(C: ConstantInt::get(
89 Ty: llvm::Type::getInt32Ty(C&: Ctx), V: llvm::to_underlying(E: RootSigVer)));
90 ValueAsMetadata *EntryFunc = Fn ? ValueAsMetadata::get(V: Fn) : nullptr;
91 MDNode *MDVals = MDNode::get(Context&: Ctx, MDs: {EntryFunc, RootSignature, Version});
92
93 StringRef RootSignatureValKey = "dx.rootsignatures";
94 auto *RootSignatureValMD = M.getOrInsertNamedMetadata(Name: RootSignatureValKey);
95 RootSignatureValMD->addOperand(M: MDVals);
96}
97
98// Find array variable declaration from DeclRef expression
99static const ValueDecl *getArrayDecl(const Expr *E) {
100 if (const DeclRefExpr *DRE =
101 dyn_cast_or_null<DeclRefExpr>(Val: E->IgnoreImpCasts()))
102 return DRE->getDecl();
103 return nullptr;
104}
105
106// Find array variable declaration from nested array subscript AST nodes
107static const ValueDecl *getArrayDecl(const ArraySubscriptExpr *ASE) {
108 const Expr *E = nullptr;
109 while (ASE != nullptr) {
110 E = ASE->getBase()->IgnoreImpCasts();
111 if (!E)
112 return nullptr;
113 ASE = dyn_cast<ArraySubscriptExpr>(Val: E);
114 }
115 return getArrayDecl(E);
116}
117
118// Get the total size of the array, or -1 if the array is unbounded.
119static int getTotalArraySize(ASTContext &AST, const clang::Type *Ty) {
120 Ty = Ty->getUnqualifiedDesugaredType();
121 assert(Ty->isArrayType() && "expected array type");
122 if (Ty->isIncompleteArrayType())
123 return -1;
124 return AST.getConstantArrayElementCount(CA: cast<ConstantArrayType>(Val: Ty));
125}
126
127static Value *buildNameForResource(llvm::StringRef BaseName,
128 CodeGenModule &CGM) {
129 llvm::SmallString<64> GlobalName = {BaseName, ".str"};
130 return CGM.GetAddrOfConstantCString(Str: BaseName.str(), GlobalName: GlobalName.c_str())
131 .getPointer();
132}
133
134static CXXMethodDecl *lookupMethod(CXXRecordDecl *Record, StringRef Name,
135 StorageClass SC = SC_None) {
136 for (auto *Method : Record->methods()) {
137 if (Method->getStorageClass() == SC && Method->getName() == Name)
138 return Method;
139 }
140 return nullptr;
141}
142
143static CXXMethodDecl *lookupResourceInitMethodAndSetupArgs(
144 CodeGenModule &CGM, CXXRecordDecl *ResourceDecl, llvm::Value *Range,
145 llvm::Value *Index, StringRef Name, ResourceBindingAttrs &Binding,
146 CallArgList &Args) {
147 assert(Binding.hasBinding() && "at least one binding attribute expected");
148
149 ASTContext &AST = CGM.getContext();
150 CXXMethodDecl *CreateMethod = nullptr;
151 Value *NameStr = buildNameForResource(BaseName: Name, CGM);
152 Value *Space = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSpace());
153
154 if (Binding.isExplicit()) {
155 // explicit binding
156 auto *RegSlot = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSlot());
157 Args.add(rvalue: RValue::get(V: RegSlot), type: AST.UnsignedIntTy);
158 const char *Name = Binding.hasCounterImplicitOrderID()
159 ? "__createFromBindingWithImplicitCounter"
160 : "__createFromBinding";
161 CreateMethod = lookupMethod(Record: ResourceDecl, Name, SC: SC_Static);
162 } else {
163 // implicit binding
164 auto *OrderID =
165 llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getImplicitOrderID());
166 Args.add(rvalue: RValue::get(V: OrderID), type: AST.UnsignedIntTy);
167 const char *Name = Binding.hasCounterImplicitOrderID()
168 ? "__createFromImplicitBindingWithImplicitCounter"
169 : "__createFromImplicitBinding";
170 CreateMethod = lookupMethod(Record: ResourceDecl, Name, SC: SC_Static);
171 }
172 Args.add(rvalue: RValue::get(V: Space), type: AST.UnsignedIntTy);
173 Args.add(rvalue: RValue::get(V: Range), type: AST.IntTy);
174 Args.add(rvalue: RValue::get(V: Index), type: AST.UnsignedIntTy);
175 Args.add(rvalue: RValue::get(V: NameStr), type: AST.getPointerType(T: AST.CharTy.withConst()));
176 if (Binding.hasCounterImplicitOrderID()) {
177 uint32_t CounterBinding = Binding.getCounterImplicitOrderID();
178 auto *CounterOrderID = llvm::ConstantInt::get(Ty: CGM.IntTy, V: CounterBinding);
179 Args.add(rvalue: RValue::get(V: CounterOrderID), type: AST.UnsignedIntTy);
180 }
181
182 return CreateMethod;
183}
184
185static void callResourceInitMethod(CodeGenFunction &CGF,
186 CXXMethodDecl *CreateMethod,
187 CallArgList &Args, Address ReturnAddress) {
188 llvm::Constant *CalleeFn = CGF.CGM.GetAddrOfFunction(GD: CreateMethod);
189 const FunctionProtoType *Proto =
190 CreateMethod->getType()->getAs<FunctionProtoType>();
191 const CGFunctionInfo &FnInfo =
192 CGF.CGM.getTypes().arrangeFreeFunctionCall(Args, Ty: Proto, ChainCall: false);
193 ReturnValueSlot ReturnValue(ReturnAddress, false);
194 CGCallee Callee(CGCalleeInfo(Proto), CalleeFn);
195 CGF.EmitCall(CallInfo: FnInfo, Callee, ReturnValue, Args, CallOrInvoke: nullptr);
196}
197
198// Initializes local resource array variable. For multi-dimensional arrays it
199// calls itself recursively to initialize its sub-arrays. The Index used in the
200// resource constructor calls will begin at StartIndex and will be incremented
201// for each array element. The last used resource Index is returned to the
202// caller. If the function returns std::nullopt, it indicates an error.
203static std::optional<llvm::Value *> initializeLocalResourceArray(
204 CodeGenFunction &CGF, CXXRecordDecl *ResourceDecl,
205 const ConstantArrayType *ArrayTy, AggValueSlot &ValueSlot,
206 llvm::Value *Range, llvm::Value *StartIndex, StringRef ResourceName,
207 ResourceBindingAttrs &Binding, ArrayRef<llvm::Value *> PrevGEPIndices,
208 SourceLocation ArraySubsExprLoc) {
209
210 ASTContext &AST = CGF.getContext();
211 llvm::IntegerType *IntTy = CGF.CGM.IntTy;
212 llvm::Value *Index = StartIndex;
213 llvm::Value *One = llvm::ConstantInt::get(Ty: IntTy, V: 1);
214 const uint64_t ArraySize = ArrayTy->getSExtSize();
215 QualType ElemType = ArrayTy->getElementType();
216 Address TmpArrayAddr = ValueSlot.getAddress();
217
218 // Add additional index to the getelementptr call indices.
219 // This index will be updated for each array element in the loops below.
220 SmallVector<llvm::Value *> GEPIndices(PrevGEPIndices);
221 GEPIndices.push_back(Elt: llvm::ConstantInt::get(Ty: IntTy, V: 0));
222
223 // For array of arrays, recursively initialize the sub-arrays.
224 if (ElemType->isArrayType()) {
225 const ConstantArrayType *SubArrayTy = cast<ConstantArrayType>(Val&: ElemType);
226 for (uint64_t I = 0; I < ArraySize; I++) {
227 if (I > 0) {
228 Index = CGF.Builder.CreateAdd(LHS: Index, RHS: One);
229 GEPIndices.back() = llvm::ConstantInt::get(Ty: IntTy, V: I);
230 }
231 std::optional<llvm::Value *> MaybeIndex = initializeLocalResourceArray(
232 CGF, ResourceDecl, ArrayTy: SubArrayTy, ValueSlot, Range, StartIndex: Index, ResourceName,
233 Binding, PrevGEPIndices: GEPIndices, ArraySubsExprLoc);
234 if (!MaybeIndex)
235 return std::nullopt;
236 Index = *MaybeIndex;
237 }
238 return Index;
239 }
240
241 // For array of resources, initialize each resource in the array.
242 llvm::Type *Ty = CGF.ConvertTypeForMem(T: ElemType);
243 CharUnits ElemSize = AST.getTypeSizeInChars(T: ElemType);
244 CharUnits Align =
245 TmpArrayAddr.getAlignment().alignmentOfArrayElement(elementSize: ElemSize);
246
247 for (uint64_t I = 0; I < ArraySize; I++) {
248 if (I > 0) {
249 Index = CGF.Builder.CreateAdd(LHS: Index, RHS: One);
250 GEPIndices.back() = llvm::ConstantInt::get(Ty: IntTy, V: I);
251 }
252 Address ReturnAddress =
253 CGF.Builder.CreateGEP(Addr: TmpArrayAddr, IdxList: GEPIndices, ElementType: Ty, Align);
254
255 CallArgList Args;
256 CXXMethodDecl *CreateMethod = lookupResourceInitMethodAndSetupArgs(
257 CGM&: CGF.CGM, ResourceDecl, Range, Index, Name: ResourceName, Binding, Args);
258
259 if (!CreateMethod)
260 // This can happen if someone creates an array of structs that looks like
261 // an HLSL resource record array but it does not have the required static
262 // create method. No binding will be generated for it.
263 return std::nullopt;
264
265 callResourceInitMethod(CGF, CreateMethod, Args, ReturnAddress);
266 }
267 return Index;
268}
269
270} // namespace
271
272llvm::Type *
273CGHLSLRuntime::convertHLSLSpecificType(const Type *T,
274 const CGHLSLOffsetInfo &OffsetInfo) {
275 assert(T->isHLSLSpecificType() && "Not an HLSL specific type!");
276
277 // Check if the target has a specific translation for this type first.
278 if (llvm::Type *TargetTy =
279 CGM.getTargetCodeGenInfo().getHLSLType(CGM, T, OffsetInfo))
280 return TargetTy;
281
282 llvm_unreachable("Generic handling of HLSL types is not supported.");
283}
284
285llvm::Triple::ArchType CGHLSLRuntime::getArch() {
286 return CGM.getTarget().getTriple().getArch();
287}
288
289// Emits constant global variables for buffer constants declarations
290// and creates metadata linking the constant globals with the buffer global.
291void CGHLSLRuntime::emitBufferGlobalsAndMetadata(
292 const HLSLBufferDecl *BufDecl, llvm::GlobalVariable *BufGV,
293 const CGHLSLOffsetInfo &OffsetInfo) {
294 LLVMContext &Ctx = CGM.getLLVMContext();
295
296 // get the layout struct from constant buffer target type
297 llvm::Type *BufType = BufGV->getValueType();
298 llvm::StructType *LayoutStruct = cast<llvm::StructType>(
299 Val: cast<llvm::TargetExtType>(Val: BufType)->getTypeParameter(i: 0));
300
301 SmallVector<std::pair<VarDecl *, uint32_t>> DeclsWithOffset;
302 size_t OffsetIdx = 0;
303 for (Decl *D : BufDecl->buffer_decls()) {
304 if (isa<CXXRecordDecl, EmptyDecl>(Val: D))
305 // Nothing to do for this declaration.
306 continue;
307 if (isa<FunctionDecl>(Val: D)) {
308 // A function within an cbuffer is effectively a top-level function.
309 CGM.EmitTopLevelDecl(D);
310 continue;
311 }
312 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
313 if (!VD)
314 continue;
315
316 QualType VDTy = VD->getType();
317 if (VDTy.getAddressSpace() != LangAS::hlsl_constant) {
318 if (VD->getStorageClass() == SC_Static ||
319 VDTy.getAddressSpace() == LangAS::hlsl_groupshared ||
320 VDTy->isHLSLResourceRecord() || VDTy->isHLSLResourceRecordArray()) {
321 // Emit static and groupshared variables and resource classes inside
322 // cbuffer as regular globals
323 CGM.EmitGlobal(D: VD);
324 } else {
325 // Anything else that is not in the hlsl_constant address space must be
326 // an empty struct or a zero-sized array and can be ignored
327 assert(BufDecl->getASTContext().getTypeSize(VDTy) == 0 &&
328 "constant buffer decl with non-zero sized type outside of "
329 "hlsl_constant address space");
330 }
331 continue;
332 }
333
334 DeclsWithOffset.emplace_back(Args&: VD, Args: OffsetInfo[OffsetIdx++]);
335 }
336
337 if (!OffsetInfo.empty())
338 llvm::stable_sort(Range&: DeclsWithOffset, C: [](const auto &LHS, const auto &RHS) {
339 return CGHLSLOffsetInfo::compareOffsets(LHS: LHS.second, RHS: RHS.second);
340 });
341
342 // Associate the buffer global variable with its constants
343 SmallVector<llvm::Metadata *> BufGlobals;
344 BufGlobals.reserve(N: DeclsWithOffset.size() + 1);
345 BufGlobals.push_back(Elt: ValueAsMetadata::get(V: BufGV));
346
347 auto ElemIt = LayoutStruct->element_begin();
348 for (auto &[VD, _] : DeclsWithOffset) {
349 if (CGM.getTargetCodeGenInfo().isHLSLPadding(Ty: *ElemIt))
350 ++ElemIt;
351
352 assert(ElemIt != LayoutStruct->element_end() &&
353 "number of elements in layout struct does not match");
354 llvm::Type *LayoutType = *ElemIt++;
355
356 GlobalVariable *ElemGV =
357 cast<GlobalVariable>(Val: CGM.GetAddrOfGlobalVar(D: VD, Ty: LayoutType));
358 BufGlobals.push_back(Elt: ValueAsMetadata::get(V: ElemGV));
359 }
360 assert(ElemIt == LayoutStruct->element_end() &&
361 "number of elements in layout struct does not match");
362
363 // add buffer metadata to the module
364 CGM.getModule()
365 .getOrInsertNamedMetadata(Name: "hlsl.cbs")
366 ->addOperand(M: MDNode::get(Context&: Ctx, MDs: BufGlobals));
367}
368
369// Creates resource handle type for the HLSL buffer declaration
370static const clang::HLSLAttributedResourceType *
371createBufferHandleType(const HLSLBufferDecl *BufDecl) {
372 ASTContext &AST = BufDecl->getASTContext();
373 QualType QT = AST.getHLSLAttributedResourceType(
374 Wrapped: AST.HLSLResourceTy, Contained: AST.getCanonicalTagType(TD: BufDecl->getLayoutStruct()),
375 Attrs: HLSLAttributedResourceType::Attributes(ResourceClass::CBuffer));
376 return cast<HLSLAttributedResourceType>(Val: QT.getTypePtr());
377}
378
379CGHLSLOffsetInfo CGHLSLOffsetInfo::fromDecl(const HLSLBufferDecl &BufDecl) {
380 CGHLSLOffsetInfo Result;
381
382 // If we don't have packoffset info, just return an empty result.
383 if (!BufDecl.hasValidPackoffset())
384 return Result;
385
386 for (Decl *D : BufDecl.buffer_decls()) {
387 if (isa<CXXRecordDecl, EmptyDecl>(Val: D) || isa<FunctionDecl>(Val: D)) {
388 continue;
389 }
390 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
391 if (!VD || VD->getType().getAddressSpace() != LangAS::hlsl_constant)
392 continue;
393
394 if (!VD->hasAttrs()) {
395 Result.Offsets.push_back(Elt: Unspecified);
396 continue;
397 }
398
399 uint32_t Offset = Unspecified;
400 for (auto *Attr : VD->getAttrs()) {
401 if (auto *POA = dyn_cast<HLSLPackOffsetAttr>(Val: Attr)) {
402 Offset = POA->getOffsetInBytes();
403 break;
404 }
405 auto *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: Attr);
406 if (RBA &&
407 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
408 Offset = RBA->getSlotNumber() * CBufferRowSizeInBytes;
409 break;
410 }
411 }
412 Result.Offsets.push_back(Elt: Offset);
413 }
414 return Result;
415}
416
417// Codegen for HLSLBufferDecl
418void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *BufDecl) {
419
420 assert(BufDecl->isCBuffer() && "tbuffer codegen is not supported yet");
421
422 // create resource handle type for the buffer
423 const clang::HLSLAttributedResourceType *ResHandleTy =
424 createBufferHandleType(BufDecl);
425
426 // empty constant buffer is ignored
427 if (ResHandleTy->getContainedType()->getAsCXXRecordDecl()->isEmpty())
428 return;
429
430 // create global variable for the constant buffer
431 CGHLSLOffsetInfo OffsetInfo = CGHLSLOffsetInfo::fromDecl(BufDecl: *BufDecl);
432 llvm::Type *LayoutTy = convertHLSLSpecificType(T: ResHandleTy, OffsetInfo);
433 llvm::GlobalVariable *BufGV = new GlobalVariable(
434 LayoutTy, /*isConstant*/ false,
435 GlobalValue::LinkageTypes::ExternalLinkage, PoisonValue::get(T: LayoutTy),
436 llvm::formatv(Fmt: "{0}{1}", Vals: BufDecl->getName(),
437 Vals: BufDecl->isCBuffer() ? ".cb" : ".tb"),
438 GlobalValue::NotThreadLocal);
439 CGM.getModule().insertGlobalVariable(GV: BufGV);
440
441 // Add globals for constant buffer elements and create metadata nodes
442 emitBufferGlobalsAndMetadata(BufDecl, BufGV, OffsetInfo);
443
444 // Initialize cbuffer from binding (implicit or explicit)
445 initializeBufferFromBinding(BufDecl, GV: BufGV);
446}
447
448void CGHLSLRuntime::addRootSignature(
449 const HLSLRootSignatureDecl *SignatureDecl) {
450 llvm::Module &M = CGM.getModule();
451 Triple T(M.getTargetTriple());
452
453 // Generated later with the function decl if not targeting root signature
454 if (T.getEnvironment() != Triple::EnvironmentType::RootSignature)
455 return;
456
457 addRootSignatureMD(RootSigVer: SignatureDecl->getVersion(),
458 Elements: SignatureDecl->getRootElements(), Fn: nullptr, M);
459}
460
461llvm::StructType *
462CGHLSLRuntime::getHLSLBufferLayoutType(const RecordType *StructType) {
463 const auto Entry = LayoutTypes.find(Val: StructType);
464 if (Entry != LayoutTypes.end())
465 return Entry->getSecond();
466 return nullptr;
467}
468
469void CGHLSLRuntime::addHLSLBufferLayoutType(const RecordType *StructType,
470 llvm::StructType *LayoutTy) {
471 assert(getHLSLBufferLayoutType(StructType) == nullptr &&
472 "layout type for this struct already exist");
473 LayoutTypes[StructType] = LayoutTy;
474}
475
476void CGHLSLRuntime::finishCodeGen() {
477 auto &TargetOpts = CGM.getTarget().getTargetOpts();
478 auto &CodeGenOpts = CGM.getCodeGenOpts();
479 auto &LangOpts = CGM.getLangOpts();
480 llvm::Module &M = CGM.getModule();
481 Triple T(M.getTargetTriple());
482 if (T.getArch() == Triple::ArchType::dxil)
483 addDxilValVersion(ValVersionStr: TargetOpts.DxilValidatorVersion, M);
484 if (CodeGenOpts.ResMayAlias)
485 M.setModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Error, Key: "dx.resmayalias", Val: 1);
486 if (CodeGenOpts.AllResourcesBound)
487 M.setModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Error,
488 Key: "dx.allresourcesbound", Val: 1);
489 // NativeHalfType corresponds to the -fnative-half-type clang option which is
490 // aliased by clang-dxc's -enable-16bit-types option. This option is used to
491 // set the UseNativeLowPrecision DXIL module flag in the DirectX backend
492 if (LangOpts.NativeHalfType)
493 M.setModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Error, Key: "dx.nativelowprec",
494 Val: 1);
495
496 generateGlobalCtorDtorCalls();
497}
498
499void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
500 const FunctionDecl *FD, llvm::Function *Fn) {
501 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
502 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
503 const StringRef ShaderAttrKindStr = "hlsl.shader";
504 Fn->addFnAttr(Kind: ShaderAttrKindStr,
505 Val: llvm::Triple::getEnvironmentTypeName(Kind: ShaderAttr->getType()));
506 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
507 const StringRef NumThreadsKindStr = "hlsl.numthreads";
508 std::string NumThreadsStr =
509 formatv(Fmt: "{0},{1},{2}", Vals: NumThreadsAttr->getX(), Vals: NumThreadsAttr->getY(),
510 Vals: NumThreadsAttr->getZ());
511 Fn->addFnAttr(Kind: NumThreadsKindStr, Val: NumThreadsStr);
512 }
513 if (HLSLWaveSizeAttr *WaveSizeAttr = FD->getAttr<HLSLWaveSizeAttr>()) {
514 const StringRef WaveSizeKindStr = "hlsl.wavesize";
515 std::string WaveSizeStr =
516 formatv(Fmt: "{0},{1},{2}", Vals: WaveSizeAttr->getMin(), Vals: WaveSizeAttr->getMax(),
517 Vals: WaveSizeAttr->getPreferred());
518 Fn->addFnAttr(Kind: WaveSizeKindStr, Val: WaveSizeStr);
519 }
520 // HLSL entry functions are materialized for module functions with
521 // HLSLShaderAttr attribute. SetLLVMFunctionAttributesForDefinition called
522 // later in the compiler-flow for such module functions is not aware of and
523 // hence not able to set attributes of the newly materialized entry functions.
524 // So, set attributes of entry function here, as appropriate.
525 if (CGM.getCodeGenOpts().OptimizationLevel == 0)
526 Fn->addFnAttr(Kind: llvm::Attribute::OptimizeNone);
527 Fn->addFnAttr(Kind: llvm::Attribute::NoInline);
528
529 if (CGM.getLangOpts().HLSLSpvEnableMaximalReconvergence) {
530 Fn->addFnAttr(Kind: "enable-maximal-reconvergence", Val: "true");
531 }
532}
533
534static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
535 if (const auto *VT = dyn_cast<FixedVectorType>(Val: Ty)) {
536 Value *Result = PoisonValue::get(T: Ty);
537 for (unsigned I = 0; I < VT->getNumElements(); ++I) {
538 Value *Elt = B.CreateCall(Callee: F, Args: {B.getInt32(C: I)});
539 Result = B.CreateInsertElement(Vec: Result, NewElt: Elt, Idx: I);
540 }
541 return Result;
542 }
543 return B.CreateCall(Callee: F, Args: {B.getInt32(C: 0)});
544}
545
546static void addSPIRVBuiltinDecoration(llvm::GlobalVariable *GV,
547 unsigned BuiltIn) {
548 LLVMContext &Ctx = GV->getContext();
549 IRBuilder<> B(GV->getContext());
550 MDNode *Operands = MDNode::get(
551 Context&: Ctx,
552 MDs: {ConstantAsMetadata::get(C: B.getInt32(/* Spirv::Decoration::BuiltIn */ C: 11)),
553 ConstantAsMetadata::get(C: B.getInt32(C: BuiltIn))});
554 MDNode *Decoration = MDNode::get(Context&: Ctx, MDs: {Operands});
555 GV->addMetadata(Kind: "spirv.Decorations", MD&: *Decoration);
556}
557
558static void addLocationDecoration(llvm::GlobalVariable *GV, unsigned Location) {
559 LLVMContext &Ctx = GV->getContext();
560 IRBuilder<> B(GV->getContext());
561 MDNode *Operands =
562 MDNode::get(Context&: Ctx, MDs: {ConstantAsMetadata::get(C: B.getInt32(/* Location */ C: 30)),
563 ConstantAsMetadata::get(C: B.getInt32(C: Location))});
564 MDNode *Decoration = MDNode::get(Context&: Ctx, MDs: {Operands});
565 GV->addMetadata(Kind: "spirv.Decorations", MD&: *Decoration);
566}
567
568static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
569 llvm::Type *Ty, const Twine &Name,
570 unsigned BuiltInID) {
571 auto *GV = new llvm::GlobalVariable(
572 M, Ty, /* isConstant= */ true, llvm::GlobalValue::ExternalLinkage,
573 /* Initializer= */ nullptr, Name, /* insertBefore= */ nullptr,
574 llvm::GlobalVariable::GeneralDynamicTLSModel,
575 /* AddressSpace */ 7, /* isExternallyInitialized= */ true);
576 addSPIRVBuiltinDecoration(GV, BuiltIn: BuiltInID);
577 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
578 return B.CreateLoad(Ty, Ptr: GV);
579}
580
581static llvm::Value *createSPIRVLocationLoad(IRBuilder<> &B, llvm::Module &M,
582 llvm::Type *Ty, unsigned Location,
583 StringRef Name) {
584 auto *GV = new llvm::GlobalVariable(
585 M, Ty, /* isConstant= */ true, llvm::GlobalValue::ExternalLinkage,
586 /* Initializer= */ nullptr, /* Name= */ Name, /* insertBefore= */ nullptr,
587 llvm::GlobalVariable::GeneralDynamicTLSModel,
588 /* AddressSpace */ 7, /* isExternallyInitialized= */ true);
589 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
590 addLocationDecoration(GV, Location);
591 return B.CreateLoad(Ty, Ptr: GV);
592}
593
594llvm::Value *CGHLSLRuntime::emitSPIRVUserSemanticLoad(
595 llvm::IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
596 HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
597 Twine BaseName = Twine(Semantic->getAttrName()->getName());
598 Twine VariableName = BaseName.concat(Suffix: Twine(Index.value_or(u: 0)));
599
600 unsigned Location = SPIRVLastAssignedInputSemanticLocation;
601 if (auto *L = Decl->getAttr<HLSLVkLocationAttr>())
602 Location = L->getLocation();
603
604 // DXC completely ignores the semantic/index pair. Location are assigned from
605 // the first semantic to the last.
606 llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Val: Type);
607 unsigned ElementCount = AT ? AT->getNumElements() : 1;
608 SPIRVLastAssignedInputSemanticLocation += ElementCount;
609
610 return createSPIRVLocationLoad(B, M&: CGM.getModule(), Ty: Type, Location,
611 Name: VariableName.str());
612}
613
614static void createSPIRVLocationStore(IRBuilder<> &B, llvm::Module &M,
615 llvm::Value *Source, unsigned Location,
616 StringRef Name) {
617 auto *GV = new llvm::GlobalVariable(
618 M, Source->getType(), /* isConstant= */ false,
619 llvm::GlobalValue::ExternalLinkage,
620 /* Initializer= */ nullptr, /* Name= */ Name, /* insertBefore= */ nullptr,
621 llvm::GlobalVariable::GeneralDynamicTLSModel,
622 /* AddressSpace */ 8, /* isExternallyInitialized= */ false);
623 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
624 addLocationDecoration(GV, Location);
625 B.CreateStore(Val: Source, Ptr: GV);
626}
627
628void CGHLSLRuntime::emitSPIRVUserSemanticStore(
629 llvm::IRBuilder<> &B, llvm::Value *Source,
630 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic,
631 std::optional<unsigned> Index) {
632 Twine BaseName = Twine(Semantic->getAttrName()->getName());
633 Twine VariableName = BaseName.concat(Suffix: Twine(Index.value_or(u: 0)));
634
635 unsigned Location = SPIRVLastAssignedOutputSemanticLocation;
636 if (auto *L = Decl->getAttr<HLSLVkLocationAttr>())
637 Location = L->getLocation();
638
639 // DXC completely ignores the semantic/index pair. Location are assigned from
640 // the first semantic to the last.
641 llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Val: Source->getType());
642 unsigned ElementCount = AT ? AT->getNumElements() : 1;
643 SPIRVLastAssignedOutputSemanticLocation += ElementCount;
644 createSPIRVLocationStore(B, M&: CGM.getModule(), Source, Location,
645 Name: VariableName.str());
646}
647
648llvm::Value *
649CGHLSLRuntime::emitDXILUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
650 HLSLAppliedSemanticAttr *Semantic,
651 std::optional<unsigned> Index) {
652 Twine BaseName = Twine(Semantic->getAttrName()->getName());
653 Twine VariableName = BaseName.concat(Suffix: Twine(Index.value_or(u: 0)));
654
655 // DXIL packing rules etc shall be handled here.
656 // FIXME: generate proper sigpoint, index, col, row values.
657 // FIXME: also DXIL loads vectors element by element.
658 SmallVector<Value *> Args{B.getInt32(C: 4), B.getInt32(C: 0), B.getInt32(C: 0),
659 B.getInt8(C: 0),
660 llvm::PoisonValue::get(T: B.getInt32Ty())};
661
662 llvm::Intrinsic::ID IntrinsicID = llvm::Intrinsic::dx_load_input;
663 llvm::Value *Value = B.CreateIntrinsic(/*ReturnType=*/RetTy: Type, ID: IntrinsicID, Args,
664 FMFSource: nullptr, Name: VariableName);
665 return Value;
666}
667
668void CGHLSLRuntime::emitDXILUserSemanticStore(llvm::IRBuilder<> &B,
669 llvm::Value *Source,
670 HLSLAppliedSemanticAttr *Semantic,
671 std::optional<unsigned> Index) {
672 // DXIL packing rules etc shall be handled here.
673 // FIXME: generate proper sigpoint, index, col, row values.
674 SmallVector<Value *> Args{B.getInt32(C: 4),
675 B.getInt32(C: 0),
676 B.getInt32(C: 0),
677 B.getInt8(C: 0),
678 llvm::PoisonValue::get(T: B.getInt32Ty()),
679 Source};
680
681 llvm::Intrinsic::ID IntrinsicID = llvm::Intrinsic::dx_store_output;
682 B.CreateIntrinsic(/*ReturnType=*/RetTy: CGM.VoidTy, ID: IntrinsicID, Args, FMFSource: nullptr);
683}
684
685llvm::Value *CGHLSLRuntime::emitUserSemanticLoad(
686 IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
687 HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
688 if (CGM.getTarget().getTriple().isSPIRV())
689 return emitSPIRVUserSemanticLoad(B, Type, Decl, Semantic, Index);
690
691 if (CGM.getTarget().getTriple().isDXIL())
692 return emitDXILUserSemanticLoad(B, Type, Semantic, Index);
693
694 llvm_unreachable("Unsupported target for user-semantic load.");
695}
696
697void CGHLSLRuntime::emitUserSemanticStore(IRBuilder<> &B, llvm::Value *Source,
698 const clang::DeclaratorDecl *Decl,
699 HLSLAppliedSemanticAttr *Semantic,
700 std::optional<unsigned> Index) {
701 if (CGM.getTarget().getTriple().isSPIRV())
702 return emitSPIRVUserSemanticStore(B, Source, Decl, Semantic, Index);
703
704 if (CGM.getTarget().getTriple().isDXIL())
705 return emitDXILUserSemanticStore(B, Source, Semantic, Index);
706
707 llvm_unreachable("Unsupported target for user-semantic load.");
708}
709
710llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
711 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
712 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic,
713 std::optional<unsigned> Index) {
714
715 std::string SemanticName = Semantic->getAttrName()->getName().upper();
716 if (SemanticName == "SV_GROUPINDEX") {
717 llvm::Function *GroupIndex =
718 CGM.getIntrinsic(IID: getFlattenedThreadIdInGroupIntrinsic());
719 return B.CreateCall(Callee: FunctionCallee(GroupIndex));
720 }
721
722 if (SemanticName == "SV_DISPATCHTHREADID") {
723 llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic();
724 llvm::Function *ThreadIDIntrinsic =
725 llvm::Intrinsic::isOverloaded(id: IntrinID)
726 ? CGM.getIntrinsic(IID: IntrinID, Tys: {CGM.Int32Ty})
727 : CGM.getIntrinsic(IID: IntrinID);
728 return buildVectorInput(B, F: ThreadIDIntrinsic, Ty: Type);
729 }
730
731 if (SemanticName == "SV_GROUPTHREADID") {
732 llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic();
733 llvm::Function *GroupThreadIDIntrinsic =
734 llvm::Intrinsic::isOverloaded(id: IntrinID)
735 ? CGM.getIntrinsic(IID: IntrinID, Tys: {CGM.Int32Ty})
736 : CGM.getIntrinsic(IID: IntrinID);
737 return buildVectorInput(B, F: GroupThreadIDIntrinsic, Ty: Type);
738 }
739
740 if (SemanticName == "SV_GROUPID") {
741 llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic();
742 llvm::Function *GroupIDIntrinsic =
743 llvm::Intrinsic::isOverloaded(id: IntrinID)
744 ? CGM.getIntrinsic(IID: IntrinID, Tys: {CGM.Int32Ty})
745 : CGM.getIntrinsic(IID: IntrinID);
746 return buildVectorInput(B, F: GroupIDIntrinsic, Ty: Type);
747 }
748
749 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
750 assert(ShaderAttr && "Entry point has no shader attribute");
751 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
752
753 if (SemanticName == "SV_POSITION") {
754 if (ST == Triple::EnvironmentType::Pixel) {
755 if (CGM.getTarget().getTriple().isSPIRV())
756 return createSPIRVBuiltinLoad(B, M&: CGM.getModule(), Ty: Type,
757 Name: Semantic->getAttrName()->getName(),
758 /* BuiltIn::FragCoord */ BuiltInID: 15);
759 if (CGM.getTarget().getTriple().isDXIL())
760 return emitDXILUserSemanticLoad(B, Type, Semantic, Index);
761 }
762
763 if (ST == Triple::EnvironmentType::Vertex) {
764 return emitUserSemanticLoad(B, Type, Decl, Semantic, Index);
765 }
766 }
767
768 llvm_unreachable(
769 "Load hasn't been implemented yet for this system semantic. FIXME");
770}
771
772static void createSPIRVBuiltinStore(IRBuilder<> &B, llvm::Module &M,
773 llvm::Value *Source, const Twine &Name,
774 unsigned BuiltInID) {
775 auto *GV = new llvm::GlobalVariable(
776 M, Source->getType(), /* isConstant= */ false,
777 llvm::GlobalValue::ExternalLinkage,
778 /* Initializer= */ nullptr, Name, /* insertBefore= */ nullptr,
779 llvm::GlobalVariable::GeneralDynamicTLSModel,
780 /* AddressSpace */ 8, /* isExternallyInitialized= */ false);
781 addSPIRVBuiltinDecoration(GV, BuiltIn: BuiltInID);
782 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
783 B.CreateStore(Val: Source, Ptr: GV);
784}
785
786void CGHLSLRuntime::emitSystemSemanticStore(IRBuilder<> &B, llvm::Value *Source,
787 const clang::DeclaratorDecl *Decl,
788 HLSLAppliedSemanticAttr *Semantic,
789 std::optional<unsigned> Index) {
790
791 std::string SemanticName = Semantic->getAttrName()->getName().upper();
792 if (SemanticName == "SV_POSITION") {
793 if (CGM.getTarget().getTriple().isDXIL()) {
794 emitDXILUserSemanticStore(B, Source, Semantic, Index);
795 return;
796 }
797
798 if (CGM.getTarget().getTriple().isSPIRV()) {
799 createSPIRVBuiltinStore(B, M&: CGM.getModule(), Source,
800 Name: Semantic->getAttrName()->getName(),
801 /* BuiltIn::Position */ BuiltInID: 0);
802 return;
803 }
804 }
805
806 if (SemanticName == "SV_TARGET") {
807 emitUserSemanticStore(B, Source, Decl, Semantic, Index);
808 return;
809 }
810
811 llvm_unreachable(
812 "Store hasn't been implemented yet for this system semantic. FIXME");
813}
814
815llvm::Value *CGHLSLRuntime::handleScalarSemanticLoad(
816 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
817 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic) {
818
819 std::optional<unsigned> Index = Semantic->getSemanticIndex();
820 if (Semantic->getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
821 return emitSystemSemanticLoad(B, FD, Type, Decl, Semantic, Index);
822 return emitUserSemanticLoad(B, Type, Decl, Semantic, Index);
823}
824
825void CGHLSLRuntime::handleScalarSemanticStore(
826 IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
827 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic) {
828 std::optional<unsigned> Index = Semantic->getSemanticIndex();
829 if (Semantic->getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
830 emitSystemSemanticStore(B, Source, Decl, Semantic, Index);
831 else
832 emitUserSemanticStore(B, Source, Decl, Semantic, Index);
833}
834
835std::pair<llvm::Value *, specific_attr_iterator<HLSLAppliedSemanticAttr>>
836CGHLSLRuntime::handleStructSemanticLoad(
837 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
838 const clang::DeclaratorDecl *Decl,
839 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
840 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
841 const llvm::StructType *ST = cast<StructType>(Val: Type);
842 const clang::RecordDecl *RD = Decl->getType()->getAsRecordDecl();
843
844 assert(RD->getNumFields() == ST->getNumElements());
845
846 llvm::Value *Aggregate = llvm::PoisonValue::get(T: Type);
847 auto FieldDecl = RD->field_begin();
848 for (unsigned I = 0; I < ST->getNumElements(); ++I) {
849 auto [ChildValue, NextAttr] = handleSemanticLoad(
850 B, FD, Type: ST->getElementType(N: I), Decl: *FieldDecl, begin: AttrBegin, end: AttrEnd);
851 AttrBegin = NextAttr;
852 assert(ChildValue);
853 Aggregate = B.CreateInsertValue(Agg: Aggregate, Val: ChildValue, Idxs: I);
854 ++FieldDecl;
855 }
856
857 return std::make_pair(x&: Aggregate, y&: AttrBegin);
858}
859
860specific_attr_iterator<HLSLAppliedSemanticAttr>
861CGHLSLRuntime::handleStructSemanticStore(
862 IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
863 const clang::DeclaratorDecl *Decl,
864 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
865 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
866
867 const llvm::StructType *ST = cast<StructType>(Val: Source->getType());
868
869 const clang::RecordDecl *RD = nullptr;
870 if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(Val: Decl))
871 RD = FD->getDeclaredReturnType()->getAsRecordDecl();
872 else
873 RD = Decl->getType()->getAsRecordDecl();
874 assert(RD);
875
876 assert(RD->getNumFields() == ST->getNumElements());
877
878 auto FieldDecl = RD->field_begin();
879 for (unsigned I = 0; I < ST->getNumElements(); ++I) {
880 llvm::Value *Extract = B.CreateExtractValue(Agg: Source, Idxs: I);
881 AttrBegin =
882 handleSemanticStore(B, FD, Source: Extract, Decl: *FieldDecl, AttrBegin, AttrEnd);
883 }
884
885 return AttrBegin;
886}
887
888std::pair<llvm::Value *, specific_attr_iterator<HLSLAppliedSemanticAttr>>
889CGHLSLRuntime::handleSemanticLoad(
890 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
891 const clang::DeclaratorDecl *Decl,
892 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
893 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
894 assert(AttrBegin != AttrEnd);
895 if (Type->isStructTy())
896 return handleStructSemanticLoad(B, FD, Type, Decl, AttrBegin, AttrEnd);
897
898 HLSLAppliedSemanticAttr *Attr = *AttrBegin;
899 ++AttrBegin;
900 return std::make_pair(x: handleScalarSemanticLoad(B, FD, Type, Decl, Semantic: Attr),
901 y&: AttrBegin);
902}
903
904specific_attr_iterator<HLSLAppliedSemanticAttr>
905CGHLSLRuntime::handleSemanticStore(
906 IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
907 const clang::DeclaratorDecl *Decl,
908 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
909 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
910 assert(AttrBegin != AttrEnd);
911 if (Source->getType()->isStructTy())
912 return handleStructSemanticStore(B, FD, Source, Decl, AttrBegin, AttrEnd);
913
914 HLSLAppliedSemanticAttr *Attr = *AttrBegin;
915 ++AttrBegin;
916 handleScalarSemanticStore(B, FD, Source, Decl, Semantic: Attr);
917 return AttrBegin;
918}
919
920void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
921 llvm::Function *Fn) {
922 llvm::Module &M = CGM.getModule();
923 llvm::LLVMContext &Ctx = M.getContext();
924 auto *EntryTy = llvm::FunctionType::get(Result: llvm::Type::getVoidTy(C&: Ctx), isVarArg: false);
925 Function *EntryFn =
926 Function::Create(Ty: EntryTy, Linkage: Function::ExternalLinkage, N: FD->getName(), M: &M);
927
928 // Copy function attributes over, we have no argument or return attributes
929 // that can be valid on the real entry.
930 AttributeList NewAttrs = AttributeList::get(C&: Ctx, Index: AttributeList::FunctionIndex,
931 Attrs: Fn->getAttributes().getFnAttrs());
932 EntryFn->setAttributes(NewAttrs);
933 setHLSLEntryAttributes(FD, Fn: EntryFn);
934
935 // Set the called function as internal linkage.
936 Fn->setLinkage(GlobalValue::InternalLinkage);
937
938 BasicBlock *BB = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: EntryFn);
939 IRBuilder<> B(BB);
940 llvm::SmallVector<Value *> Args;
941
942 SmallVector<OperandBundleDef, 1> OB;
943 if (CGM.shouldEmitConvergenceTokens()) {
944 assert(EntryFn->isConvergent());
945 llvm::Value *I =
946 B.CreateIntrinsic(ID: llvm::Intrinsic::experimental_convergence_entry, Args: {});
947 llvm::Value *bundleArgs[] = {I};
948 OB.emplace_back(Args: "convergencectrl", Args&: bundleArgs);
949 }
950
951 llvm::DenseMap<const DeclaratorDecl *, llvm::Value *> OutputSemantic;
952
953 unsigned SRetOffset = 0;
954 for (const auto &Param : Fn->args()) {
955 if (Param.hasStructRetAttr()) {
956 SRetOffset = 1;
957 llvm::Type *VarType = Param.getParamStructRetType();
958 llvm::Value *Var = B.CreateAlloca(Ty: VarType);
959 OutputSemantic.try_emplace(Key: FD, Args&: Var);
960 Args.push_back(Elt: Var);
961 continue;
962 }
963
964 const ParmVarDecl *PD = FD->getParamDecl(i: Param.getArgNo() - SRetOffset);
965 llvm::Value *SemanticValue = nullptr;
966 // FIXME: support inout/out parameters for semantics.
967 if ([[maybe_unused]] HLSLParamModifierAttr *MA =
968 PD->getAttr<HLSLParamModifierAttr>()) {
969 llvm_unreachable("Not handled yet");
970 } else {
971 llvm::Type *ParamType =
972 Param.hasByValAttr() ? Param.getParamByValType() : Param.getType();
973 auto AttrBegin = PD->specific_attr_begin<HLSLAppliedSemanticAttr>();
974 auto AttrEnd = PD->specific_attr_end<HLSLAppliedSemanticAttr>();
975 auto Result =
976 handleSemanticLoad(B, FD, Type: ParamType, Decl: PD, AttrBegin, AttrEnd);
977 SemanticValue = Result.first;
978 if (!SemanticValue)
979 return;
980 if (Param.hasByValAttr()) {
981 llvm::Value *Var = B.CreateAlloca(Ty: Param.getParamByValType());
982 B.CreateStore(Val: SemanticValue, Ptr: Var);
983 SemanticValue = Var;
984 }
985 }
986
987 assert(SemanticValue);
988 Args.push_back(Elt: SemanticValue);
989 }
990
991 CallInst *CI = B.CreateCall(Callee: FunctionCallee(Fn), Args, OpBundles: OB);
992 CI->setCallingConv(Fn->getCallingConv());
993
994 if (Fn->getReturnType() != CGM.VoidTy)
995 OutputSemantic.try_emplace(Key: FD, Args&: CI);
996
997 for (auto &[Decl, Source] : OutputSemantic) {
998 AllocaInst *AI = dyn_cast<AllocaInst>(Val: Source);
999 llvm::Value *SourceValue =
1000 AI ? B.CreateLoad(Ty: AI->getAllocatedType(), Ptr: Source) : Source;
1001
1002 auto AttrBegin = Decl->specific_attr_begin<HLSLAppliedSemanticAttr>();
1003 auto AttrEnd = Decl->specific_attr_end<HLSLAppliedSemanticAttr>();
1004 handleSemanticStore(B, FD, Source: SourceValue, Decl, AttrBegin, AttrEnd);
1005 }
1006
1007 B.CreateRetVoid();
1008
1009 // Add and identify root signature to function, if applicable
1010 for (const Attr *Attr : FD->getAttrs()) {
1011 if (const auto *RSAttr = dyn_cast<RootSignatureAttr>(Val: Attr)) {
1012 auto *RSDecl = RSAttr->getSignatureDecl();
1013 addRootSignatureMD(RootSigVer: RSDecl->getVersion(), Elements: RSDecl->getRootElements(),
1014 Fn: EntryFn, M);
1015 }
1016 }
1017}
1018
1019static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
1020 bool CtorOrDtor) {
1021 const auto *GV =
1022 M.getNamedGlobal(Name: CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
1023 if (!GV)
1024 return;
1025 const auto *CA = dyn_cast<ConstantArray>(Val: GV->getInitializer());
1026 if (!CA)
1027 return;
1028 // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
1029 // HLSL neither supports priorities or COMDat values, so we will check those
1030 // in an assert but not handle them.
1031
1032 for (const auto &Ctor : CA->operands()) {
1033 if (isa<ConstantAggregateZero>(Val: Ctor))
1034 continue;
1035 ConstantStruct *CS = cast<ConstantStruct>(Val: Ctor);
1036
1037 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
1038 "HLSL doesn't support setting priority for global ctors.");
1039 assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
1040 "HLSL doesn't support COMDat for global ctors.");
1041 Fns.push_back(Elt: cast<Function>(Val: CS->getOperand(i_nocapture: 1)));
1042 }
1043}
1044
1045void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
1046 llvm::Module &M = CGM.getModule();
1047 SmallVector<Function *> CtorFns;
1048 SmallVector<Function *> DtorFns;
1049 gatherFunctions(Fns&: CtorFns, M, CtorOrDtor: true);
1050 gatherFunctions(Fns&: DtorFns, M, CtorOrDtor: false);
1051
1052 // Insert a call to the global constructor at the beginning of the entry block
1053 // to externally exported functions. This is a bit of a hack, but HLSL allows
1054 // global constructors, but doesn't support driver initialization of globals.
1055 for (auto &F : M.functions()) {
1056 if (!F.hasFnAttribute(Kind: "hlsl.shader"))
1057 continue;
1058 auto *Token = getConvergenceToken(BB&: F.getEntryBlock());
1059 Instruction *IP = &*F.getEntryBlock().begin();
1060 SmallVector<OperandBundleDef, 1> OB;
1061 if (Token) {
1062 llvm::Value *bundleArgs[] = {Token};
1063 OB.emplace_back(Args: "convergencectrl", Args&: bundleArgs);
1064 IP = Token->getNextNode();
1065 }
1066 IRBuilder<> B(IP);
1067 for (auto *Fn : CtorFns) {
1068 auto CI = B.CreateCall(Callee: FunctionCallee(Fn), Args: {}, OpBundles: OB);
1069 CI->setCallingConv(Fn->getCallingConv());
1070 }
1071
1072 // Insert global dtors before the terminator of the last instruction
1073 B.SetInsertPoint(F.back().getTerminator());
1074 for (auto *Fn : DtorFns) {
1075 auto CI = B.CreateCall(Callee: FunctionCallee(Fn), Args: {}, OpBundles: OB);
1076 CI->setCallingConv(Fn->getCallingConv());
1077 }
1078 }
1079
1080 // No need to keep global ctors/dtors for non-lib profile after call to
1081 // ctors/dtors added for entry.
1082 Triple T(M.getTargetTriple());
1083 if (T.getEnvironment() != Triple::EnvironmentType::Library) {
1084 if (auto *GV = M.getNamedGlobal(Name: "llvm.global_ctors"))
1085 GV->eraseFromParent();
1086 if (auto *GV = M.getNamedGlobal(Name: "llvm.global_dtors"))
1087 GV->eraseFromParent();
1088 }
1089}
1090
1091static void initializeBuffer(CodeGenModule &CGM, llvm::GlobalVariable *GV,
1092 Intrinsic::ID IntrID,
1093 ArrayRef<llvm::Value *> Args) {
1094
1095 LLVMContext &Ctx = CGM.getLLVMContext();
1096 llvm::Function *InitResFunc = llvm::Function::Create(
1097 Ty: llvm::FunctionType::get(Result: CGM.VoidTy, isVarArg: false),
1098 Linkage: llvm::GlobalValue::InternalLinkage,
1099 N: ("_init_buffer_" + GV->getName()).str(), M&: CGM.getModule());
1100 InitResFunc->addFnAttr(Kind: llvm::Attribute::AlwaysInline);
1101
1102 llvm::BasicBlock *EntryBB =
1103 llvm::BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: InitResFunc);
1104 CGBuilderTy Builder(CGM, Ctx);
1105 const DataLayout &DL = CGM.getModule().getDataLayout();
1106 Builder.SetInsertPoint(EntryBB);
1107
1108 // Make sure the global variable is buffer resource handle
1109 llvm::Type *HandleTy = GV->getValueType();
1110 assert(HandleTy->isTargetExtTy() && "unexpected type of the buffer global");
1111
1112 llvm::Value *CreateHandle = Builder.CreateIntrinsic(
1113 /*ReturnType=*/RetTy: HandleTy, ID: IntrID, Args, FMFSource: nullptr,
1114 Name: Twine(GV->getName()).concat(Suffix: "_h"));
1115
1116 Builder.CreateAlignedStore(Val: CreateHandle, Ptr: GV, Align: GV->getPointerAlignment(DL));
1117 Builder.CreateRetVoid();
1118
1119 CGM.AddCXXGlobalInit(F: InitResFunc);
1120}
1121
1122void CGHLSLRuntime::initializeBufferFromBinding(const HLSLBufferDecl *BufDecl,
1123 llvm::GlobalVariable *GV) {
1124 ResourceBindingAttrs Binding(BufDecl);
1125 assert(Binding.hasBinding() &&
1126 "cbuffer/tbuffer should always have resource binding attribute");
1127
1128 auto *Index = llvm::ConstantInt::get(Ty: CGM.IntTy, V: 0);
1129 auto *RangeSize = llvm::ConstantInt::get(Ty: CGM.IntTy, V: 1);
1130 auto *Space = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSpace());
1131 Value *Name = buildNameForResource(BaseName: BufDecl->getName(), CGM);
1132
1133 // buffer with explicit binding
1134 if (Binding.isExplicit()) {
1135 llvm::Intrinsic::ID IntrinsicID =
1136 CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic();
1137 auto *RegSlot = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSlot());
1138 SmallVector<Value *> Args{Space, RegSlot, RangeSize, Index, Name};
1139 initializeBuffer(CGM, GV, IntrID: IntrinsicID, Args);
1140 } else {
1141 // buffer with implicit binding
1142 llvm::Intrinsic::ID IntrinsicID =
1143 CGM.getHLSLRuntime().getCreateHandleFromImplicitBindingIntrinsic();
1144 auto *OrderID =
1145 llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getImplicitOrderID());
1146 SmallVector<Value *> Args{OrderID, Space, RangeSize, Index, Name};
1147 initializeBuffer(CGM, GV, IntrID: IntrinsicID, Args);
1148 }
1149}
1150
1151void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
1152 llvm::GlobalVariable *GV) {
1153 if (auto Attr = VD->getAttr<HLSLVkExtBuiltinInputAttr>())
1154 addSPIRVBuiltinDecoration(GV, BuiltIn: Attr->getBuiltIn());
1155}
1156
1157llvm::Instruction *CGHLSLRuntime::getConvergenceToken(BasicBlock &BB) {
1158 if (!CGM.shouldEmitConvergenceTokens())
1159 return nullptr;
1160
1161 auto E = BB.end();
1162 for (auto I = BB.begin(); I != E; ++I) {
1163 auto *II = dyn_cast<llvm::IntrinsicInst>(Val: &*I);
1164 if (II && llvm::isConvergenceControlIntrinsic(IntrinsicID: II->getIntrinsicID())) {
1165 return II;
1166 }
1167 }
1168 llvm_unreachable("Convergence token should have been emitted.");
1169 return nullptr;
1170}
1171
1172class OpaqueValueVisitor : public RecursiveASTVisitor<OpaqueValueVisitor> {
1173public:
1174 llvm::SmallVector<OpaqueValueExpr *, 8> OVEs;
1175 llvm::SmallPtrSet<OpaqueValueExpr *, 8> Visited;
1176 OpaqueValueVisitor() {}
1177
1178 bool VisitHLSLOutArgExpr(HLSLOutArgExpr *) {
1179 // These need to be bound in CodeGenFunction::EmitHLSLOutArgLValues
1180 // or CodeGenFunction::EmitHLSLOutArgExpr. If they are part of this
1181 // traversal, the temporary containing the copy out will not have
1182 // been created yet.
1183 return false;
1184 }
1185
1186 bool VisitOpaqueValueExpr(OpaqueValueExpr *E) {
1187 // Traverse the source expression first.
1188 if (E->getSourceExpr())
1189 TraverseStmt(S: E->getSourceExpr());
1190
1191 // Then add this OVE if we haven't seen it before.
1192 if (Visited.insert(Ptr: E).second)
1193 OVEs.push_back(Elt: E);
1194
1195 return true;
1196 }
1197};
1198
1199void CGHLSLRuntime::emitInitListOpaqueValues(CodeGenFunction &CGF,
1200 InitListExpr *E) {
1201
1202 typedef CodeGenFunction::OpaqueValueMappingData OpaqueValueMappingData;
1203 OpaqueValueVisitor Visitor;
1204 Visitor.TraverseStmt(S: E);
1205 for (auto *OVE : Visitor.OVEs) {
1206 if (CGF.isOpaqueValueEmitted(E: OVE))
1207 continue;
1208 if (OpaqueValueMappingData::shouldBindAsLValue(expr: OVE)) {
1209 LValue LV = CGF.EmitLValue(E: OVE->getSourceExpr());
1210 OpaqueValueMappingData::bind(CGF, ov: OVE, lv: LV);
1211 } else {
1212 RValue RV = CGF.EmitAnyExpr(E: OVE->getSourceExpr());
1213 OpaqueValueMappingData::bind(CGF, ov: OVE, rv: RV);
1214 }
1215 }
1216}
1217
1218std::optional<LValue> CGHLSLRuntime::emitResourceArraySubscriptExpr(
1219 const ArraySubscriptExpr *ArraySubsExpr, CodeGenFunction &CGF) {
1220 assert((ArraySubsExpr->getType()->isHLSLResourceRecord() ||
1221 ArraySubsExpr->getType()->isHLSLResourceRecordArray()) &&
1222 "expected resource array subscript expression");
1223
1224 // Let clang codegen handle local and static resource array subscripts,
1225 // or when the subscript references on opaque expression (as part of
1226 // ArrayInitLoopExpr AST node).
1227 const VarDecl *ArrayDecl =
1228 dyn_cast_or_null<VarDecl>(Val: getArrayDecl(ASE: ArraySubsExpr));
1229 if (!ArrayDecl || !ArrayDecl->hasGlobalStorage() ||
1230 ArrayDecl->getStorageClass() == SC_Static)
1231 return std::nullopt;
1232
1233 // get the resource array type
1234 ASTContext &AST = ArrayDecl->getASTContext();
1235 const Type *ResArrayTy = ArrayDecl->getType().getTypePtr();
1236 assert(ResArrayTy->isHLSLResourceRecordArray() &&
1237 "expected array of resource classes");
1238
1239 // Iterate through all nested array subscript expressions to calculate
1240 // the index in the flattened resource array (if this is a multi-
1241 // dimensional array). The index is calculated as a sum of all indices
1242 // multiplied by the total size of the array at that level.
1243 Value *Index = nullptr;
1244 const ArraySubscriptExpr *ASE = ArraySubsExpr;
1245 while (ASE != nullptr) {
1246 Value *SubIndex = CGF.EmitScalarExpr(E: ASE->getIdx());
1247 if (const auto *ArrayTy =
1248 dyn_cast<ConstantArrayType>(Val: ASE->getType().getTypePtr())) {
1249 Value *Multiplier = llvm::ConstantInt::get(
1250 Ty: CGM.IntTy, V: AST.getConstantArrayElementCount(CA: ArrayTy));
1251 SubIndex = CGF.Builder.CreateMul(LHS: SubIndex, RHS: Multiplier);
1252 }
1253 Index = Index ? CGF.Builder.CreateAdd(LHS: Index, RHS: SubIndex) : SubIndex;
1254 ASE = dyn_cast<ArraySubscriptExpr>(Val: ASE->getBase()->IgnoreParenImpCasts());
1255 }
1256
1257 // Find binding info for the resource array. For implicit binding
1258 // an HLSLResourceBindingAttr should have been added by SemaHLSL.
1259 ResourceBindingAttrs Binding(ArrayDecl);
1260 assert(Binding.hasBinding() &&
1261 "resource array must have a binding attribute");
1262
1263 // Find the individual resource type.
1264 QualType ResultTy = ArraySubsExpr->getType();
1265 QualType ResourceTy =
1266 ResultTy->isArrayType() ? AST.getBaseElementType(QT: ResultTy) : ResultTy;
1267
1268 // Create a temporary variable for the result, which is either going
1269 // to be a single resource instance or a local array of resources (we need to
1270 // return an LValue).
1271 RawAddress TmpVar = CGF.CreateMemTemp(T: ResultTy);
1272 if (CGF.EmitLifetimeStart(Addr: TmpVar.getPointer()))
1273 CGF.pushFullExprCleanup<CodeGenFunction::CallLifetimeEnd>(
1274 kind: NormalEHLifetimeMarker, A: TmpVar);
1275
1276 AggValueSlot ValueSlot = AggValueSlot::forAddr(
1277 addr: TmpVar, quals: Qualifiers(), isDestructed: AggValueSlot::IsDestructed_t(true),
1278 needsGC: AggValueSlot::DoesNotNeedGCBarriers, isAliased: AggValueSlot::IsAliased_t(false),
1279 mayOverlap: AggValueSlot::DoesNotOverlap);
1280
1281 // Calculate total array size (= range size).
1282 llvm::Value *Range = llvm::ConstantInt::getSigned(
1283 Ty: CGM.IntTy, V: getTotalArraySize(AST, Ty: ResArrayTy));
1284
1285 // If the result of the subscript operation is a single resource, call the
1286 // constructor.
1287 if (ResultTy == ResourceTy) {
1288 CallArgList Args;
1289 CXXMethodDecl *CreateMethod = lookupResourceInitMethodAndSetupArgs(
1290 CGM&: CGF.CGM, ResourceDecl: ResourceTy->getAsCXXRecordDecl(), Range, Index,
1291 Name: ArrayDecl->getName(), Binding, Args);
1292
1293 if (!CreateMethod)
1294 // This can happen if someone creates an array of structs that looks like
1295 // an HLSL resource record array but it does not have the required static
1296 // create method. No binding will be generated for it.
1297 return std::nullopt;
1298
1299 callResourceInitMethod(CGF, CreateMethod, Args, ReturnAddress: ValueSlot.getAddress());
1300
1301 } else {
1302 // The result of the subscript operation is a local resource array which
1303 // needs to be initialized.
1304 const ConstantArrayType *ArrayTy =
1305 cast<ConstantArrayType>(Val: ResultTy.getTypePtr());
1306 std::optional<llvm::Value *> EndIndex = initializeLocalResourceArray(
1307 CGF, ResourceDecl: ResourceTy->getAsCXXRecordDecl(), ArrayTy, ValueSlot, Range, StartIndex: Index,
1308 ResourceName: ArrayDecl->getName(), Binding, PrevGEPIndices: {llvm::ConstantInt::get(Ty: CGM.IntTy, V: 0)},
1309 ArraySubsExprLoc: ArraySubsExpr->getExprLoc());
1310 if (!EndIndex)
1311 return std::nullopt;
1312 }
1313 return CGF.MakeAddrLValue(Addr: TmpVar, T: ResultTy, Source: AlignmentSource::Decl);
1314}
1315
1316// If RHSExpr is a global resource array, initialize all of its resources and
1317// set them into LHS. Returns false if no copy has been performed and the
1318// array copy should be handled by Clang codegen.
1319bool CGHLSLRuntime::emitResourceArrayCopy(LValue &LHS, Expr *RHSExpr,
1320 CodeGenFunction &CGF) {
1321 QualType ResultTy = RHSExpr->getType();
1322 assert(ResultTy->isHLSLResourceRecordArray() && "expected resource array");
1323
1324 // Let Clang codegen handle local and static resource array copies.
1325 const VarDecl *ArrayDecl = dyn_cast_or_null<VarDecl>(Val: getArrayDecl(E: RHSExpr));
1326 if (!ArrayDecl || !ArrayDecl->hasGlobalStorage() ||
1327 ArrayDecl->getStorageClass() == SC_Static)
1328 return false;
1329
1330 // Find binding info for the resource array. For implicit binding
1331 // the HLSLResourceBindingAttr should have been added by SemaHLSL.
1332 ResourceBindingAttrs Binding(ArrayDecl);
1333 assert(Binding.hasBinding() &&
1334 "resource array must have a binding attribute");
1335
1336 // Find the individual resource type.
1337 ASTContext &AST = ArrayDecl->getASTContext();
1338 QualType ResTy = AST.getBaseElementType(QT: ResultTy);
1339 const auto *ResArrayTy = cast<ConstantArrayType>(Val: ResultTy.getTypePtr());
1340
1341 // Use the provided LHS for the result.
1342 AggValueSlot ValueSlot = AggValueSlot::forAddr(
1343 addr: LHS.getAddress(), quals: Qualifiers(), isDestructed: AggValueSlot::IsDestructed_t(true),
1344 needsGC: AggValueSlot::DoesNotNeedGCBarriers, isAliased: AggValueSlot::IsAliased_t(false),
1345 mayOverlap: AggValueSlot::DoesNotOverlap);
1346
1347 // Create Value for index and total array size (= range size).
1348 int Size = getTotalArraySize(AST, Ty: ResArrayTy);
1349 llvm::Value *Zero = llvm::ConstantInt::get(Ty: CGM.IntTy, V: 0);
1350 llvm::Value *Range = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Size);
1351
1352 // Initialize individual resources in the array into LHS.
1353 std::optional<llvm::Value *> EndIndex = initializeLocalResourceArray(
1354 CGF, ResourceDecl: ResTy->getAsCXXRecordDecl(), ArrayTy: ResArrayTy, ValueSlot, Range, StartIndex: Zero,
1355 ResourceName: ArrayDecl->getName(), Binding, PrevGEPIndices: {Zero}, ArraySubsExprLoc: RHSExpr->getExprLoc());
1356 return EndIndex.has_value();
1357}
1358
1359std::optional<LValue> CGHLSLRuntime::emitBufferArraySubscriptExpr(
1360 const ArraySubscriptExpr *E, CodeGenFunction &CGF,
1361 llvm::function_ref<llvm::Value *(bool Promote)> EmitIdxAfterBase) {
1362 // Find the element type to index by first padding the element type per HLSL
1363 // buffer rules, and then padding out to a 16-byte register boundary if
1364 // necessary.
1365 llvm::Type *LayoutTy =
1366 HLSLBufferLayoutBuilder(CGF.CGM).layOutType(Type: E->getType());
1367 uint64_t LayoutSizeInBits =
1368 CGM.getDataLayout().getTypeSizeInBits(Ty: LayoutTy).getFixedValue();
1369 CharUnits ElementSize = CharUnits::fromQuantity(Quantity: LayoutSizeInBits / 8);
1370 CharUnits RowAlignedSize = ElementSize.alignTo(Align: CharUnits::fromQuantity(Quantity: 16));
1371 if (RowAlignedSize > ElementSize) {
1372 llvm::Type *Padding = CGM.getTargetCodeGenInfo().getHLSLPadding(
1373 CGM, NumBytes: RowAlignedSize - ElementSize);
1374 assert(Padding && "No padding type for target?");
1375 LayoutTy = llvm::StructType::get(Context&: CGF.getLLVMContext(), Elements: {LayoutTy, Padding},
1376 /*isPacked=*/true);
1377 }
1378
1379 // If the layout type doesn't introduce any padding, we don't need to do
1380 // anything special.
1381 llvm::Type *OrigTy = CGF.CGM.getTypes().ConvertTypeForMem(T: E->getType());
1382 if (LayoutTy == OrigTy)
1383 return std::nullopt;
1384
1385 LValueBaseInfo EltBaseInfo;
1386 TBAAAccessInfo EltTBAAInfo;
1387 Address Addr =
1388 CGF.EmitPointerWithAlignment(Addr: E->getBase(), BaseInfo: &EltBaseInfo, TBAAInfo: &EltTBAAInfo);
1389 llvm::Value *Idx = EmitIdxAfterBase(/*Promote*/ true);
1390
1391 // Index into the object as-if we have an array of the padded element type,
1392 // and then dereference the element itself to avoid reading padding that may
1393 // be past the end of the in-memory object.
1394 SmallVector<llvm::Value *, 2> Indices;
1395 Indices.push_back(Elt: Idx);
1396 Indices.push_back(Elt: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 0));
1397
1398 llvm::Value *GEP = CGF.Builder.CreateGEP(Ty: LayoutTy, Ptr: Addr.emitRawPointer(CGF),
1399 IdxList: Indices, Name: "cbufferidx");
1400 Addr = Address(GEP, Addr.getElementType(), RowAlignedSize, KnownNonNull);
1401
1402 return CGF.MakeAddrLValue(Addr, T: E->getType(), BaseInfo: EltBaseInfo, TBAAInfo: EltTBAAInfo);
1403}
1404
1405namespace {
1406/// Utility for emitting copies following the HLSL buffer layout rules (ie,
1407/// copying out of a cbuffer).
1408class HLSLBufferCopyEmitter {
1409 CodeGenFunction &CGF;
1410 Address DestPtr;
1411 Address SrcPtr;
1412 llvm::Type *LayoutTy = nullptr;
1413
1414 SmallVector<llvm::Value *> CurStoreIndices;
1415 SmallVector<llvm::Value *> CurLoadIndices;
1416
1417 void emitCopyAtIndices(llvm::Type *FieldTy, llvm::ConstantInt *StoreIndex,
1418 llvm::ConstantInt *LoadIndex) {
1419 CurStoreIndices.push_back(Elt: StoreIndex);
1420 CurLoadIndices.push_back(Elt: LoadIndex);
1421 llvm::scope_exit RestoreIndices([&]() {
1422 CurStoreIndices.pop_back();
1423 CurLoadIndices.pop_back();
1424 });
1425
1426 // First, see if this is some kind of aggregate and recurse.
1427 if (processArray(FieldTy))
1428 return;
1429 if (processBufferLayoutArray(FieldTy))
1430 return;
1431 if (processStruct(FieldTy))
1432 return;
1433
1434 // When we have a scalar or vector element we can emit the copy.
1435 CharUnits Align = CharUnits::fromQuantity(
1436 Quantity: CGF.CGM.getDataLayout().getABITypeAlign(Ty: FieldTy));
1437 Address SrcGEP = RawAddress(
1438 CGF.Builder.CreateInBoundsGEP(Ty: LayoutTy, Ptr: SrcPtr.getBasePointer(),
1439 IdxList: CurLoadIndices, Name: "cbuf.src"),
1440 FieldTy, Align, SrcPtr.isKnownNonNull());
1441 Address DestGEP = CGF.Builder.CreateInBoundsGEP(
1442 Addr: DestPtr, IdxList: CurStoreIndices, ElementType: FieldTy, Align, Name: "cbuf.dest");
1443 llvm::Value *Load = CGF.Builder.CreateLoad(Addr: SrcGEP, Name: "cbuf.load");
1444 CGF.Builder.CreateStore(Val: Load, Addr: DestGEP);
1445 }
1446
1447 bool processArray(llvm::Type *FieldTy) {
1448 auto *AT = dyn_cast<llvm::ArrayType>(Val: FieldTy);
1449 if (!AT)
1450 return false;
1451
1452 // If we have an llvm::ArrayType this is just a regular array with no top
1453 // level padding, so all we need to do is copy each member.
1454 for (unsigned I = 0, E = AT->getNumElements(); I < E; ++I)
1455 emitCopyAtIndices(FieldTy: AT->getElementType(),
1456 StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I),
1457 LoadIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I));
1458 return true;
1459 }
1460
1461 bool processBufferLayoutArray(llvm::Type *FieldTy) {
1462 // A buffer layout array is a struct with two elements: the padded array,
1463 // and the last element. That is, is should look something like this:
1464 //
1465 // { [%n x { %type, %padding }], %type }
1466 //
1467 auto *ST = dyn_cast<llvm::StructType>(Val: FieldTy);
1468 if (!ST || ST->getNumElements() != 2)
1469 return false;
1470
1471 auto *PaddedEltsTy = dyn_cast<llvm::ArrayType>(Val: ST->getElementType(N: 0));
1472 if (!PaddedEltsTy)
1473 return false;
1474
1475 auto *PaddedTy = dyn_cast<llvm::StructType>(Val: PaddedEltsTy->getElementType());
1476 if (!PaddedTy || PaddedTy->getNumElements() != 2)
1477 return false;
1478
1479 if (!CGF.CGM.getTargetCodeGenInfo().isHLSLPadding(
1480 Ty: PaddedTy->getElementType(N: 1)))
1481 return false;
1482
1483 llvm::Type *ElementTy = ST->getElementType(N: 1);
1484 if (PaddedTy->getElementType(N: 0) != ElementTy)
1485 return false;
1486
1487 // All but the last of the logical array elements are in the padded array.
1488 unsigned NumElts = PaddedEltsTy->getNumElements() + 1;
1489
1490 // Add an extra indirection to the load for the struct and walk the
1491 // array prefix.
1492 CurLoadIndices.push_back(Elt: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 0));
1493 for (unsigned I = 0; I < NumElts - 1; ++I) {
1494 // We need to copy the element itself, without the padding.
1495 CurLoadIndices.push_back(Elt: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I));
1496 emitCopyAtIndices(FieldTy: ElementTy, StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I),
1497 LoadIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 0));
1498 CurLoadIndices.pop_back();
1499 }
1500 CurLoadIndices.pop_back();
1501
1502 // Now copy the last element.
1503 emitCopyAtIndices(FieldTy: ElementTy,
1504 StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: NumElts - 1),
1505 LoadIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 1));
1506
1507 return true;
1508 }
1509
1510 bool processStruct(llvm::Type *FieldTy) {
1511 auto *ST = dyn_cast<llvm::StructType>(Val: FieldTy);
1512 if (!ST)
1513 return false;
1514
1515 // Copy the struct field by field, but skip any explicit padding.
1516 unsigned Skipped = 0;
1517 for (unsigned I = 0, E = ST->getNumElements(); I < E; ++I) {
1518 llvm::Type *ElementTy = ST->getElementType(N: I);
1519 if (CGF.CGM.getTargetCodeGenInfo().isHLSLPadding(Ty: ElementTy))
1520 ++Skipped;
1521 else
1522 emitCopyAtIndices(FieldTy: ElementTy, StoreIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: I),
1523 LoadIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: I + Skipped));
1524 }
1525 return true;
1526 }
1527
1528public:
1529 HLSLBufferCopyEmitter(CodeGenFunction &CGF, Address DestPtr, Address SrcPtr)
1530 : CGF(CGF), DestPtr(DestPtr), SrcPtr(SrcPtr) {}
1531
1532 bool emitCopy(QualType CType) {
1533 LayoutTy = HLSLBufferLayoutBuilder(CGF.CGM).layOutType(Type: CType);
1534
1535 // TODO: We should be able to fall back to a regular memcpy if the layout
1536 // type doesn't have any padding, but that runs into issues in the backend
1537 // currently.
1538 //
1539 // See https://github.com/llvm/wg-hlsl/issues/351
1540 emitCopyAtIndices(FieldTy: LayoutTy, StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: 0),
1541 LoadIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: 0));
1542 return true;
1543 }
1544};
1545} // namespace
1546
1547bool CGHLSLRuntime::emitBufferCopy(CodeGenFunction &CGF, Address DestPtr,
1548 Address SrcPtr, QualType CType) {
1549 return HLSLBufferCopyEmitter(CGF, DestPtr, SrcPtr).emitCopy(CType);
1550}
1551
1552LValue CGHLSLRuntime::emitBufferMemberExpr(CodeGenFunction &CGF,
1553 const MemberExpr *E) {
1554 LValue Base =
1555 CGF.EmitCheckedLValue(E: E->getBase(), TCK: CodeGenFunction::TCK_MemberAccess);
1556 auto *Field = dyn_cast<FieldDecl>(Val: E->getMemberDecl());
1557 assert(Field && "Unexpected access into HLSL buffer");
1558
1559 // Get the field index for the struct.
1560 const RecordDecl *Rec = Field->getParent();
1561 unsigned FieldIdx =
1562 CGM.getTypes().getCGRecordLayout(Rec).getLLVMFieldNo(FD: Field);
1563
1564 // Work out the buffer layout type to index into.
1565 QualType RecType = CGM.getContext().getCanonicalTagType(TD: Rec);
1566 assert(RecType->isStructureOrClassType() && "Invalid type in HLSL buffer");
1567 // Since this is a member of an object in the buffer and not the buffer's
1568 // struct/class itself, we shouldn't have any offsets on the members we need
1569 // to contend with.
1570 CGHLSLOffsetInfo EmptyOffsets;
1571 llvm::StructType *LayoutTy = HLSLBufferLayoutBuilder(CGM).layOutStruct(
1572 StructType: RecType->getAsCanonical<RecordType>(), OffsetInfo: EmptyOffsets);
1573
1574 // Now index into the struct, making sure that the type we return is the
1575 // buffer layout type rather than the original type in the AST.
1576 QualType FieldType = Field->getType();
1577 llvm::Type *FieldLLVMTy = CGM.getTypes().ConvertTypeForMem(T: FieldType);
1578 CharUnits Align = CharUnits::fromQuantity(
1579 Quantity: CGF.CGM.getDataLayout().getABITypeAlign(Ty: FieldLLVMTy));
1580 Address Addr(CGF.Builder.CreateStructGEP(Ty: LayoutTy, Ptr: Base.getPointer(CGF),
1581 Idx: FieldIdx, Name: Field->getName()),
1582 FieldLLVMTy, Align, KnownNonNull);
1583
1584 LValue LV = LValue::MakeAddr(Addr, type: FieldType, Context&: CGM.getContext(),
1585 BaseInfo: LValueBaseInfo(AlignmentSource::Type),
1586 TBAAInfo: CGM.getTBAAAccessInfo(AccessType: FieldType));
1587 LV.getQuals().addCVRQualifiers(mask: Base.getVRQualifiers());
1588
1589 return LV;
1590}
1591