1//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This provides an abstract class for HLSL code generation. Concrete
10// subclasses of this implement code generation for specific HLSL
11// runtime libraries.
12//
13//===----------------------------------------------------------------------===//
14
15#include "CGHLSLRuntime.h"
16#include "CGDebugInfo.h"
17#include "CGRecordLayout.h"
18#include "CodeGenFunction.h"
19#include "CodeGenModule.h"
20#include "HLSLBufferLayoutBuilder.h"
21#include "TargetInfo.h"
22#include "clang/AST/ASTContext.h"
23#include "clang/AST/Attr.h"
24#include "clang/AST/Decl.h"
25#include "clang/AST/Expr.h"
26#include "clang/AST/HLSLResource.h"
27#include "clang/AST/RecursiveASTVisitor.h"
28#include "clang/AST/Type.h"
29#include "clang/Basic/DiagnosticFrontend.h"
30#include "clang/Basic/TargetOptions.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/ScopeExit.h"
33#include "llvm/ADT/SmallString.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
36#include "llvm/IR/Constants.h"
37#include "llvm/IR/DerivedTypes.h"
38#include "llvm/IR/GlobalVariable.h"
39#include "llvm/IR/IntrinsicInst.h"
40#include "llvm/IR/LLVMContext.h"
41#include "llvm/IR/Metadata.h"
42#include "llvm/IR/Module.h"
43#include "llvm/IR/Type.h"
44#include "llvm/IR/Value.h"
45#include "llvm/Support/Alignment.h"
46#include "llvm/Support/ErrorHandling.h"
47#include "llvm/Support/FormatVariadic.h"
48#include <cstdint>
49#include <optional>
50
51using namespace clang;
52using namespace CodeGen;
53using namespace clang::hlsl;
54using namespace llvm;
55
56using llvm::hlsl::CBufferRowSizeInBytes;
57
58namespace {
59
60void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
61 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
62 // Assume ValVersionStr is legal here.
63 VersionTuple Version;
64 if (Version.tryParse(string: ValVersionStr) || Version.getBuild() ||
65 Version.getSubminor() || !Version.getMinor()) {
66 return;
67 }
68
69 uint64_t Major = Version.getMajor();
70 uint64_t Minor = *Version.getMinor();
71
72 auto &Ctx = M.getContext();
73 IRBuilder<> B(M.getContext());
74 MDNode *Val = MDNode::get(Context&: Ctx, MDs: {ConstantAsMetadata::get(C: B.getInt32(C: Major)),
75 ConstantAsMetadata::get(C: B.getInt32(C: Minor))});
76 StringRef DXILValKey = "dx.valver";
77 auto *DXILValMD = M.getOrInsertNamedMetadata(Name: DXILValKey);
78 DXILValMD->addOperand(M: Val);
79}
80
81void addRootSignatureMD(llvm::dxbc::RootSignatureVersion RootSigVer,
82 ArrayRef<llvm::hlsl::rootsig::RootElement> Elements,
83 llvm::Function *Fn, llvm::Module &M) {
84 auto &Ctx = M.getContext();
85
86 llvm::hlsl::rootsig::MetadataBuilder RSBuilder(Ctx, Elements);
87 MDNode *RootSignature = RSBuilder.BuildRootSignature();
88
89 ConstantAsMetadata *Version = ConstantAsMetadata::get(C: ConstantInt::get(
90 Ty: llvm::Type::getInt32Ty(C&: Ctx), V: llvm::to_underlying(E: RootSigVer)));
91 ValueAsMetadata *EntryFunc = Fn ? ValueAsMetadata::get(V: Fn) : nullptr;
92 MDNode *MDVals = MDNode::get(Context&: Ctx, MDs: {EntryFunc, RootSignature, Version});
93
94 StringRef RootSignatureValKey = "dx.rootsignatures";
95 auto *RootSignatureValMD = M.getOrInsertNamedMetadata(Name: RootSignatureValKey);
96 RootSignatureValMD->addOperand(M: MDVals);
97}
98
99// Find array variable declaration from DeclRef expression
100static const ValueDecl *getArrayDecl(const Expr *E) {
101 if (const DeclRefExpr *DRE =
102 dyn_cast_or_null<DeclRefExpr>(Val: E->IgnoreImpCasts()))
103 return DRE->getDecl();
104 return nullptr;
105}
106
107// Find array variable declaration from nested array subscript AST nodes
108static const ValueDecl *getArrayDecl(const ArraySubscriptExpr *ASE) {
109 const Expr *E = nullptr;
110 while (ASE != nullptr) {
111 E = ASE->getBase()->IgnoreImpCasts();
112 if (!E)
113 return nullptr;
114 ASE = dyn_cast<ArraySubscriptExpr>(Val: E);
115 }
116 return getArrayDecl(E);
117}
118
119// Get the total size of the array, or 0 if the array is unbounded.
120static int getTotalArraySize(ASTContext &AST, const clang::Type *Ty) {
121 Ty = Ty->getUnqualifiedDesugaredType();
122 assert(Ty->isArrayType() && "expected array type");
123 if (Ty->isIncompleteArrayType())
124 return 0;
125 return AST.getConstantArrayElementCount(CA: cast<ConstantArrayType>(Val: Ty));
126}
127
128static Value *buildNameForResource(llvm::StringRef BaseName,
129 CodeGenModule &CGM) {
130 llvm::SmallString<64> GlobalName = {BaseName, ".str"};
131 return CGM.GetAddrOfConstantCString(Str: BaseName.str(), GlobalName: GlobalName.c_str())
132 .getPointer();
133}
134
135static CXXMethodDecl *lookupMethod(CXXRecordDecl *Record, StringRef Name,
136 StorageClass SC = SC_None) {
137 for (auto *Method : Record->methods()) {
138 if (Method->getStorageClass() == SC && Method->getName() == Name)
139 return Method;
140 }
141 return nullptr;
142}
143
144static CXXMethodDecl *lookupResourceInitMethodAndSetupArgs(
145 CodeGenModule &CGM, CXXRecordDecl *ResourceDecl, llvm::Value *Range,
146 llvm::Value *Index, StringRef Name, ResourceBindingAttrs &Binding,
147 CallArgList &Args) {
148 assert(Binding.hasBinding() && "at least one binding attribute expected");
149
150 ASTContext &AST = CGM.getContext();
151 CXXMethodDecl *CreateMethod = nullptr;
152 Value *NameStr = buildNameForResource(BaseName: Name, CGM);
153 Value *Space = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSpace());
154
155 if (Binding.isExplicit()) {
156 // explicit binding
157 auto *RegSlot = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSlot());
158 Args.add(rvalue: RValue::get(V: RegSlot), type: AST.UnsignedIntTy);
159 const char *Name = Binding.hasCounterImplicitOrderID()
160 ? "__createFromBindingWithImplicitCounter"
161 : "__createFromBinding";
162 CreateMethod = lookupMethod(Record: ResourceDecl, Name, SC: SC_Static);
163 } else {
164 // implicit binding
165 auto *OrderID =
166 llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getImplicitOrderID());
167 Args.add(rvalue: RValue::get(V: OrderID), type: AST.UnsignedIntTy);
168 const char *Name = Binding.hasCounterImplicitOrderID()
169 ? "__createFromImplicitBindingWithImplicitCounter"
170 : "__createFromImplicitBinding";
171 CreateMethod = lookupMethod(Record: ResourceDecl, Name, SC: SC_Static);
172 }
173 Args.add(rvalue: RValue::get(V: Space), type: AST.UnsignedIntTy);
174 Args.add(rvalue: RValue::get(V: Range), type: AST.IntTy);
175 Args.add(rvalue: RValue::get(V: Index), type: AST.UnsignedIntTy);
176 Args.add(rvalue: RValue::get(V: NameStr), type: AST.getPointerType(T: AST.CharTy.withConst()));
177 if (Binding.hasCounterImplicitOrderID()) {
178 uint32_t CounterBinding = Binding.getCounterImplicitOrderID();
179 auto *CounterOrderID = llvm::ConstantInt::get(Ty: CGM.IntTy, V: CounterBinding);
180 Args.add(rvalue: RValue::get(V: CounterOrderID), type: AST.UnsignedIntTy);
181 }
182
183 return CreateMethod;
184}
185
186static void callResourceInitMethod(CodeGenFunction &CGF,
187 CXXMethodDecl *CreateMethod,
188 CallArgList &Args, Address ReturnAddress) {
189 llvm::Constant *CalleeFn = CGF.CGM.GetAddrOfFunction(GD: CreateMethod);
190 const FunctionProtoType *Proto =
191 CreateMethod->getType()->getAs<FunctionProtoType>();
192 const CGFunctionInfo &FnInfo =
193 CGF.CGM.getTypes().arrangeFreeFunctionCall(Args, Ty: Proto, ChainCall: false);
194 ReturnValueSlot ReturnValue(ReturnAddress, false);
195 CGCallee Callee(CGCalleeInfo(Proto), CalleeFn);
196 CGF.EmitCall(CallInfo: FnInfo, Callee, ReturnValue, Args, CallOrInvoke: nullptr);
197}
198
199// Initializes local resource array variable. For multi-dimensional arrays it
200// calls itself recursively to initialize its sub-arrays. The Index used in the
201// resource constructor calls will begin at StartIndex and will be incremented
202// for each array element. The last used resource Index is returned to the
203// caller. If the function returns std::nullopt, it indicates an error.
204static std::optional<llvm::Value *> initializeLocalResourceArray(
205 CodeGenFunction &CGF, CXXRecordDecl *ResourceDecl,
206 const ConstantArrayType *ArrayTy, AggValueSlot &ValueSlot,
207 llvm::Value *Range, llvm::Value *StartIndex, StringRef ResourceName,
208 ResourceBindingAttrs &Binding, ArrayRef<llvm::Value *> PrevGEPIndices,
209 SourceLocation ArraySubsExprLoc) {
210
211 ASTContext &AST = CGF.getContext();
212 llvm::IntegerType *IntTy = CGF.CGM.IntTy;
213 llvm::Value *Index = StartIndex;
214 llvm::Value *One = llvm::ConstantInt::get(Ty: IntTy, V: 1);
215 const uint64_t ArraySize = ArrayTy->getSExtSize();
216 QualType ElemType = ArrayTy->getElementType();
217 Address TmpArrayAddr = ValueSlot.getAddress();
218
219 // Add additional index to the getelementptr call indices.
220 // This index will be updated for each array element in the loops below.
221 SmallVector<llvm::Value *> GEPIndices(PrevGEPIndices);
222 GEPIndices.push_back(Elt: llvm::ConstantInt::get(Ty: IntTy, V: 0));
223
224 // For array of arrays, recursively initialize the sub-arrays.
225 if (ElemType->isArrayType()) {
226 const ConstantArrayType *SubArrayTy = cast<ConstantArrayType>(Val&: ElemType);
227 for (uint64_t I = 0; I < ArraySize; I++) {
228 if (I > 0) {
229 Index = CGF.Builder.CreateAdd(LHS: Index, RHS: One);
230 GEPIndices.back() = llvm::ConstantInt::get(Ty: IntTy, V: I);
231 }
232 std::optional<llvm::Value *> MaybeIndex = initializeLocalResourceArray(
233 CGF, ResourceDecl, ArrayTy: SubArrayTy, ValueSlot, Range, StartIndex: Index, ResourceName,
234 Binding, PrevGEPIndices: GEPIndices, ArraySubsExprLoc);
235 if (!MaybeIndex)
236 return std::nullopt;
237 Index = *MaybeIndex;
238 }
239 return Index;
240 }
241
242 // For array of resources, initialize each resource in the array.
243 llvm::Type *Ty = CGF.ConvertTypeForMem(T: ElemType);
244 CharUnits ElemSize = AST.getTypeSizeInChars(T: ElemType);
245 CharUnits Align =
246 TmpArrayAddr.getAlignment().alignmentOfArrayElement(elementSize: ElemSize);
247
248 for (uint64_t I = 0; I < ArraySize; I++) {
249 if (I > 0) {
250 Index = CGF.Builder.CreateAdd(LHS: Index, RHS: One);
251 GEPIndices.back() = llvm::ConstantInt::get(Ty: IntTy, V: I);
252 }
253 Address ReturnAddress =
254 CGF.Builder.CreateGEP(Addr: TmpArrayAddr, IdxList: GEPIndices, ElementType: Ty, Align);
255
256 CallArgList Args;
257 CXXMethodDecl *CreateMethod = lookupResourceInitMethodAndSetupArgs(
258 CGM&: CGF.CGM, ResourceDecl, Range, Index, Name: ResourceName, Binding, Args);
259
260 if (!CreateMethod)
261 // This can happen if someone creates an array of structs that looks like
262 // an HLSL resource record array but it does not have the required static
263 // create method. No binding will be generated for it.
264 return std::nullopt;
265
266 callResourceInitMethod(CGF, CreateMethod, Args, ReturnAddress);
267 }
268 return Index;
269}
270
271} // namespace
272
273llvm::Type *
274CGHLSLRuntime::convertHLSLSpecificType(const Type *T,
275 const CGHLSLOffsetInfo &OffsetInfo) {
276 assert(T->isHLSLSpecificType() && "Not an HLSL specific type!");
277
278 // Check if the target has a specific translation for this type first.
279 if (llvm::Type *TargetTy =
280 CGM.getTargetCodeGenInfo().getHLSLType(CGM, T, OffsetInfo))
281 return TargetTy;
282
283 llvm_unreachable("Generic handling of HLSL types is not supported.");
284}
285
286llvm::Triple::ArchType CGHLSLRuntime::getArch() {
287 return CGM.getTarget().getTriple().getArch();
288}
289
290// Emits constant global variables for buffer constants declarations
291// and creates metadata linking the constant globals with the buffer global.
292void CGHLSLRuntime::emitBufferGlobalsAndMetadata(
293 const HLSLBufferDecl *BufDecl, llvm::GlobalVariable *BufGV,
294 const CGHLSLOffsetInfo &OffsetInfo) {
295 LLVMContext &Ctx = CGM.getLLVMContext();
296
297 // get the layout struct from constant buffer target type
298 llvm::Type *BufType = BufGV->getValueType();
299 llvm::StructType *LayoutStruct = cast<llvm::StructType>(
300 Val: cast<llvm::TargetExtType>(Val: BufType)->getTypeParameter(i: 0));
301
302 SmallVector<std::pair<VarDecl *, uint32_t>> DeclsWithOffset;
303 size_t OffsetIdx = 0;
304 for (Decl *D : BufDecl->buffer_decls()) {
305 if (isa<CXXRecordDecl, EmptyDecl>(Val: D))
306 // Nothing to do for this declaration.
307 continue;
308 if (isa<FunctionDecl>(Val: D)) {
309 // A function within an cbuffer is effectively a top-level function.
310 CGM.EmitTopLevelDecl(D);
311 continue;
312 }
313 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
314 if (!VD)
315 continue;
316
317 QualType VDTy = VD->getType();
318 if (VDTy.getAddressSpace() != LangAS::hlsl_constant) {
319 if (VD->getStorageClass() == SC_Static ||
320 VDTy.getAddressSpace() == LangAS::hlsl_groupshared ||
321 VDTy->isHLSLResourceRecord() || VDTy->isHLSLResourceRecordArray()) {
322 // Emit static and groupshared variables and resource classes inside
323 // cbuffer as regular globals
324 CGM.EmitGlobal(D: VD);
325 }
326 continue;
327 }
328
329 DeclsWithOffset.emplace_back(Args&: VD, Args: OffsetInfo[OffsetIdx++]);
330 }
331
332 if (!OffsetInfo.empty())
333 llvm::stable_sort(Range&: DeclsWithOffset, C: [](const auto &LHS, const auto &RHS) {
334 return CGHLSLOffsetInfo::compareOffsets(LHS: LHS.second, RHS: RHS.second);
335 });
336
337 // Associate the buffer global variable with its constants
338 SmallVector<llvm::Metadata *> BufGlobals;
339 BufGlobals.reserve(N: DeclsWithOffset.size() + 1);
340 BufGlobals.push_back(Elt: ValueAsMetadata::get(V: BufGV));
341
342 auto ElemIt = LayoutStruct->element_begin();
343 for (auto &[VD, _] : DeclsWithOffset) {
344 if (CGM.getTargetCodeGenInfo().isHLSLPadding(Ty: *ElemIt))
345 ++ElemIt;
346
347 assert(ElemIt != LayoutStruct->element_end() &&
348 "number of elements in layout struct does not match");
349 llvm::Type *LayoutType = *ElemIt++;
350
351 GlobalVariable *ElemGV =
352 cast<GlobalVariable>(Val: CGM.GetAddrOfGlobalVar(D: VD, Ty: LayoutType));
353 BufGlobals.push_back(Elt: ValueAsMetadata::get(V: ElemGV));
354 }
355 assert(ElemIt == LayoutStruct->element_end() &&
356 "number of elements in layout struct does not match");
357
358 // add buffer metadata to the module
359 CGM.getModule()
360 .getOrInsertNamedMetadata(Name: "hlsl.cbs")
361 ->addOperand(M: MDNode::get(Context&: Ctx, MDs: BufGlobals));
362}
363
364// Creates resource handle type for the HLSL buffer declaration
365static const clang::HLSLAttributedResourceType *
366createBufferHandleType(const HLSLBufferDecl *BufDecl) {
367 ASTContext &AST = BufDecl->getASTContext();
368 QualType QT = AST.getHLSLAttributedResourceType(
369 Wrapped: AST.HLSLResourceTy, Contained: AST.getCanonicalTagType(TD: BufDecl->getLayoutStruct()),
370 Attrs: HLSLAttributedResourceType::Attributes(ResourceClass::CBuffer));
371 return cast<HLSLAttributedResourceType>(Val: QT.getTypePtr());
372}
373
374CGHLSLOffsetInfo CGHLSLOffsetInfo::fromDecl(const HLSLBufferDecl &BufDecl) {
375 CGHLSLOffsetInfo Result;
376
377 // If we don't have packoffset info, just return an empty result.
378 if (!BufDecl.hasValidPackoffset())
379 return Result;
380
381 for (Decl *D : BufDecl.buffer_decls()) {
382 if (isa<CXXRecordDecl, EmptyDecl>(Val: D) || isa<FunctionDecl>(Val: D)) {
383 continue;
384 }
385 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
386 if (!VD || VD->getType().getAddressSpace() != LangAS::hlsl_constant)
387 continue;
388
389 if (!VD->hasAttrs()) {
390 Result.Offsets.push_back(Elt: Unspecified);
391 continue;
392 }
393
394 uint32_t Offset = Unspecified;
395 for (auto *Attr : VD->getAttrs()) {
396 if (auto *POA = dyn_cast<HLSLPackOffsetAttr>(Val: Attr)) {
397 Offset = POA->getOffsetInBytes();
398 break;
399 }
400 auto *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: Attr);
401 if (RBA &&
402 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
403 Offset = RBA->getSlotNumber() * CBufferRowSizeInBytes;
404 break;
405 }
406 }
407 Result.Offsets.push_back(Elt: Offset);
408 }
409 return Result;
410}
411
412// Codegen for HLSLBufferDecl
413void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *BufDecl) {
414
415 assert(BufDecl->isCBuffer() && "tbuffer codegen is not supported yet");
416
417 // create resource handle type for the buffer
418 const clang::HLSLAttributedResourceType *ResHandleTy =
419 createBufferHandleType(BufDecl);
420
421 // empty constant buffer is ignored
422 if (ResHandleTy->getContainedType()->getAsCXXRecordDecl()->isEmpty())
423 return;
424
425 // create global variable for the constant buffer
426 CGHLSLOffsetInfo OffsetInfo = CGHLSLOffsetInfo::fromDecl(BufDecl: *BufDecl);
427 llvm::Type *LayoutTy = convertHLSLSpecificType(T: ResHandleTy, OffsetInfo);
428 llvm::GlobalVariable *BufGV = new GlobalVariable(
429 LayoutTy, /*isConstant*/ false,
430 GlobalValue::LinkageTypes::ExternalLinkage, PoisonValue::get(T: LayoutTy),
431 llvm::formatv(Fmt: "{0}{1}", Vals: BufDecl->getName(),
432 Vals: BufDecl->isCBuffer() ? ".cb" : ".tb"),
433 GlobalValue::NotThreadLocal);
434 CGM.getModule().insertGlobalVariable(GV: BufGV);
435
436 // Add globals for constant buffer elements and create metadata nodes
437 emitBufferGlobalsAndMetadata(BufDecl, BufGV, OffsetInfo);
438
439 // Initialize cbuffer from binding (implicit or explicit)
440 initializeBufferFromBinding(BufDecl, GV: BufGV);
441}
442
443void CGHLSLRuntime::addRootSignature(
444 const HLSLRootSignatureDecl *SignatureDecl) {
445 llvm::Module &M = CGM.getModule();
446 Triple T(M.getTargetTriple());
447
448 // Generated later with the function decl if not targeting root signature
449 if (T.getEnvironment() != Triple::EnvironmentType::RootSignature)
450 return;
451
452 addRootSignatureMD(RootSigVer: SignatureDecl->getVersion(),
453 Elements: SignatureDecl->getRootElements(), Fn: nullptr, M);
454}
455
456llvm::StructType *
457CGHLSLRuntime::getHLSLBufferLayoutType(const RecordType *StructType) {
458 const auto Entry = LayoutTypes.find(Val: StructType);
459 if (Entry != LayoutTypes.end())
460 return Entry->getSecond();
461 return nullptr;
462}
463
464void CGHLSLRuntime::addHLSLBufferLayoutType(const RecordType *StructType,
465 llvm::StructType *LayoutTy) {
466 assert(getHLSLBufferLayoutType(StructType) == nullptr &&
467 "layout type for this struct already exist");
468 LayoutTypes[StructType] = LayoutTy;
469}
470
471void CGHLSLRuntime::finishCodeGen() {
472 auto &TargetOpts = CGM.getTarget().getTargetOpts();
473 auto &CodeGenOpts = CGM.getCodeGenOpts();
474 auto &LangOpts = CGM.getLangOpts();
475 llvm::Module &M = CGM.getModule();
476 Triple T(M.getTargetTriple());
477 if (T.getArch() == Triple::ArchType::dxil)
478 addDxilValVersion(ValVersionStr: TargetOpts.DxilValidatorVersion, M);
479 if (CodeGenOpts.ResMayAlias)
480 M.setModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Error, Key: "dx.resmayalias", Val: 1);
481 if (CodeGenOpts.AllResourcesBound)
482 M.setModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Error,
483 Key: "dx.allresourcesbound", Val: 1);
484 if (CodeGenOpts.OptimizationLevel == 0)
485 M.addModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Override,
486 Key: "dx.disable_optimizations", Val: 1);
487
488 // NativeHalfType corresponds to the -fnative-half-type clang option which is
489 // aliased by clang-dxc's -enable-16bit-types option. This option is used to
490 // set the UseNativeLowPrecision DXIL module flag in the DirectX backend
491 if (LangOpts.NativeHalfType)
492 M.setModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Error, Key: "dx.nativelowprec",
493 Val: 1);
494
495 generateGlobalCtorDtorCalls();
496}
497
498void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
499 const FunctionDecl *FD, llvm::Function *Fn) {
500 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
501 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
502 const StringRef ShaderAttrKindStr = "hlsl.shader";
503 Fn->addFnAttr(Kind: ShaderAttrKindStr,
504 Val: llvm::Triple::getEnvironmentTypeName(Kind: ShaderAttr->getType()));
505 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
506 const StringRef NumThreadsKindStr = "hlsl.numthreads";
507 std::string NumThreadsStr =
508 formatv(Fmt: "{0},{1},{2}", Vals: NumThreadsAttr->getX(), Vals: NumThreadsAttr->getY(),
509 Vals: NumThreadsAttr->getZ());
510 Fn->addFnAttr(Kind: NumThreadsKindStr, Val: NumThreadsStr);
511 }
512 if (HLSLWaveSizeAttr *WaveSizeAttr = FD->getAttr<HLSLWaveSizeAttr>()) {
513 const StringRef WaveSizeKindStr = "hlsl.wavesize";
514 std::string WaveSizeStr =
515 formatv(Fmt: "{0},{1},{2}", Vals: WaveSizeAttr->getMin(), Vals: WaveSizeAttr->getMax(),
516 Vals: WaveSizeAttr->getPreferred());
517 Fn->addFnAttr(Kind: WaveSizeKindStr, Val: WaveSizeStr);
518 }
519 // HLSL entry functions are materialized for module functions with
520 // HLSLShaderAttr attribute. SetLLVMFunctionAttributesForDefinition called
521 // later in the compiler-flow for such module functions is not aware of and
522 // hence not able to set attributes of the newly materialized entry functions.
523 // So, set attributes of entry function here, as appropriate.
524 Fn->addFnAttr(Kind: llvm::Attribute::NoInline);
525
526 if (CGM.getLangOpts().HLSLSpvEnableMaximalReconvergence) {
527 Fn->addFnAttr(Kind: "enable-maximal-reconvergence", Val: "true");
528 }
529}
530
531static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
532 if (const auto *VT = dyn_cast<FixedVectorType>(Val: Ty)) {
533 Value *Result = PoisonValue::get(T: Ty);
534 for (unsigned I = 0; I < VT->getNumElements(); ++I) {
535 Value *Elt = B.CreateCall(Callee: F, Args: {B.getInt32(C: I)});
536 Result = B.CreateInsertElement(Vec: Result, NewElt: Elt, Idx: I);
537 }
538 return Result;
539 }
540 return B.CreateCall(Callee: F, Args: {B.getInt32(C: 0)});
541}
542
543static void addSPIRVBuiltinDecoration(llvm::GlobalVariable *GV,
544 unsigned BuiltIn) {
545 LLVMContext &Ctx = GV->getContext();
546 IRBuilder<> B(GV->getContext());
547 MDNode *Operands = MDNode::get(
548 Context&: Ctx,
549 MDs: {ConstantAsMetadata::get(C: B.getInt32(/* Spirv::Decoration::BuiltIn */ C: 11)),
550 ConstantAsMetadata::get(C: B.getInt32(C: BuiltIn))});
551 MDNode *Decoration = MDNode::get(Context&: Ctx, MDs: {Operands});
552 GV->addMetadata(Kind: "spirv.Decorations", MD&: *Decoration);
553}
554
555static void addLocationDecoration(llvm::GlobalVariable *GV, unsigned Location) {
556 LLVMContext &Ctx = GV->getContext();
557 IRBuilder<> B(GV->getContext());
558 MDNode *Operands =
559 MDNode::get(Context&: Ctx, MDs: {ConstantAsMetadata::get(C: B.getInt32(/* Location */ C: 30)),
560 ConstantAsMetadata::get(C: B.getInt32(C: Location))});
561 MDNode *Decoration = MDNode::get(Context&: Ctx, MDs: {Operands});
562 GV->addMetadata(Kind: "spirv.Decorations", MD&: *Decoration);
563}
564
565static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
566 llvm::Type *Ty, const Twine &Name,
567 unsigned BuiltInID) {
568 auto *GV = new llvm::GlobalVariable(
569 M, Ty, /* isConstant= */ true, llvm::GlobalValue::ExternalLinkage,
570 /* Initializer= */ nullptr, Name, /* insertBefore= */ nullptr,
571 llvm::GlobalVariable::GeneralDynamicTLSModel,
572 /* AddressSpace */ 7, /* isExternallyInitialized= */ true);
573 addSPIRVBuiltinDecoration(GV, BuiltIn: BuiltInID);
574 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
575 return B.CreateLoad(Ty, Ptr: GV);
576}
577
578static llvm::Value *createSPIRVLocationLoad(IRBuilder<> &B, llvm::Module &M,
579 llvm::Type *Ty, unsigned Location,
580 StringRef Name) {
581 auto *GV = new llvm::GlobalVariable(
582 M, Ty, /* isConstant= */ true, llvm::GlobalValue::ExternalLinkage,
583 /* Initializer= */ nullptr, /* Name= */ Name, /* insertBefore= */ nullptr,
584 llvm::GlobalVariable::GeneralDynamicTLSModel,
585 /* AddressSpace */ 7, /* isExternallyInitialized= */ true);
586 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
587 addLocationDecoration(GV, Location);
588 return B.CreateLoad(Ty, Ptr: GV);
589}
590
591llvm::Value *CGHLSLRuntime::emitSPIRVUserSemanticLoad(
592 llvm::IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
593 HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
594 Twine BaseName = Twine(Semantic->getAttrName()->getName());
595 Twine VariableName = BaseName.concat(Suffix: Twine(Index.value_or(u: 0)));
596
597 unsigned Location = SPIRVLastAssignedInputSemanticLocation;
598 if (auto *L = Decl->getAttr<HLSLVkLocationAttr>())
599 Location = L->getLocation();
600
601 // DXC completely ignores the semantic/index pair. Location are assigned from
602 // the first semantic to the last.
603 llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Val: Type);
604 unsigned ElementCount = AT ? AT->getNumElements() : 1;
605 SPIRVLastAssignedInputSemanticLocation += ElementCount;
606
607 return createSPIRVLocationLoad(B, M&: CGM.getModule(), Ty: Type, Location,
608 Name: VariableName.str());
609}
610
611static void createSPIRVLocationStore(IRBuilder<> &B, llvm::Module &M,
612 llvm::Value *Source, unsigned Location,
613 StringRef Name) {
614 auto *GV = new llvm::GlobalVariable(
615 M, Source->getType(), /* isConstant= */ false,
616 llvm::GlobalValue::ExternalLinkage,
617 /* Initializer= */ nullptr, /* Name= */ Name, /* insertBefore= */ nullptr,
618 llvm::GlobalVariable::GeneralDynamicTLSModel,
619 /* AddressSpace */ 8, /* isExternallyInitialized= */ false);
620 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
621 addLocationDecoration(GV, Location);
622 B.CreateStore(Val: Source, Ptr: GV);
623}
624
625void CGHLSLRuntime::emitSPIRVUserSemanticStore(
626 llvm::IRBuilder<> &B, llvm::Value *Source,
627 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic,
628 std::optional<unsigned> Index) {
629 Twine BaseName = Twine(Semantic->getAttrName()->getName());
630 Twine VariableName = BaseName.concat(Suffix: Twine(Index.value_or(u: 0)));
631
632 unsigned Location = SPIRVLastAssignedOutputSemanticLocation;
633 if (auto *L = Decl->getAttr<HLSLVkLocationAttr>())
634 Location = L->getLocation();
635
636 // DXC completely ignores the semantic/index pair. Location are assigned from
637 // the first semantic to the last.
638 llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Val: Source->getType());
639 unsigned ElementCount = AT ? AT->getNumElements() : 1;
640 SPIRVLastAssignedOutputSemanticLocation += ElementCount;
641 createSPIRVLocationStore(B, M&: CGM.getModule(), Source, Location,
642 Name: VariableName.str());
643}
644
645llvm::Value *
646CGHLSLRuntime::emitDXILUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
647 HLSLAppliedSemanticAttr *Semantic,
648 std::optional<unsigned> Index) {
649 Twine BaseName = Twine(Semantic->getAttrName()->getName());
650 Twine VariableName = BaseName.concat(Suffix: Twine(Index.value_or(u: 0)));
651
652 // DXIL packing rules etc shall be handled here.
653 // FIXME: generate proper sigpoint, index, col, row values.
654 // FIXME: also DXIL loads vectors element by element.
655 SmallVector<Value *> Args{B.getInt32(C: 4), B.getInt32(C: 0), B.getInt32(C: 0),
656 B.getInt8(C: 0),
657 llvm::PoisonValue::get(T: B.getInt32Ty())};
658
659 llvm::Intrinsic::ID IntrinsicID = llvm::Intrinsic::dx_load_input;
660 llvm::Value *Value = B.CreateIntrinsic(/*ReturnType=*/RetTy: Type, ID: IntrinsicID, Args,
661 FMFSource: nullptr, Name: VariableName);
662 return Value;
663}
664
665void CGHLSLRuntime::emitDXILUserSemanticStore(llvm::IRBuilder<> &B,
666 llvm::Value *Source,
667 HLSLAppliedSemanticAttr *Semantic,
668 std::optional<unsigned> Index) {
669 // DXIL packing rules etc shall be handled here.
670 // FIXME: generate proper sigpoint, index, col, row values.
671 SmallVector<Value *> Args{B.getInt32(C: 4),
672 B.getInt32(C: 0),
673 B.getInt32(C: 0),
674 B.getInt8(C: 0),
675 llvm::PoisonValue::get(T: B.getInt32Ty()),
676 Source};
677
678 llvm::Intrinsic::ID IntrinsicID = llvm::Intrinsic::dx_store_output;
679 B.CreateIntrinsic(/*ReturnType=*/RetTy: CGM.VoidTy, ID: IntrinsicID, Args, FMFSource: nullptr);
680}
681
682llvm::Value *CGHLSLRuntime::emitUserSemanticLoad(
683 IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
684 HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
685 if (CGM.getTarget().getTriple().isSPIRV())
686 return emitSPIRVUserSemanticLoad(B, Type, Decl, Semantic, Index);
687
688 if (CGM.getTarget().getTriple().isDXIL())
689 return emitDXILUserSemanticLoad(B, Type, Semantic, Index);
690
691 llvm_unreachable("Unsupported target for user-semantic load.");
692}
693
694void CGHLSLRuntime::emitUserSemanticStore(IRBuilder<> &B, llvm::Value *Source,
695 const clang::DeclaratorDecl *Decl,
696 HLSLAppliedSemanticAttr *Semantic,
697 std::optional<unsigned> Index) {
698 if (CGM.getTarget().getTriple().isSPIRV())
699 return emitSPIRVUserSemanticStore(B, Source, Decl, Semantic, Index);
700
701 if (CGM.getTarget().getTriple().isDXIL())
702 return emitDXILUserSemanticStore(B, Source, Semantic, Index);
703
704 llvm_unreachable("Unsupported target for user-semantic load.");
705}
706
707llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
708 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
709 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic,
710 std::optional<unsigned> Index) {
711
712 std::string SemanticName = Semantic->getAttrName()->getName().upper();
713 if (SemanticName == "SV_GROUPINDEX") {
714 llvm::Function *GroupIndex =
715 CGM.getIntrinsic(IID: getFlattenedThreadIdInGroupIntrinsic());
716 return B.CreateCall(Callee: FunctionCallee(GroupIndex));
717 }
718
719 if (SemanticName == "SV_DISPATCHTHREADID") {
720 llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic();
721 llvm::Function *ThreadIDIntrinsic =
722 llvm::Intrinsic::isOverloaded(id: IntrinID)
723 ? CGM.getIntrinsic(IID: IntrinID, Tys: {CGM.Int32Ty})
724 : CGM.getIntrinsic(IID: IntrinID);
725 return buildVectorInput(B, F: ThreadIDIntrinsic, Ty: Type);
726 }
727
728 if (SemanticName == "SV_GROUPTHREADID") {
729 llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic();
730 llvm::Function *GroupThreadIDIntrinsic =
731 llvm::Intrinsic::isOverloaded(id: IntrinID)
732 ? CGM.getIntrinsic(IID: IntrinID, Tys: {CGM.Int32Ty})
733 : CGM.getIntrinsic(IID: IntrinID);
734 return buildVectorInput(B, F: GroupThreadIDIntrinsic, Ty: Type);
735 }
736
737 if (SemanticName == "SV_GROUPID") {
738 llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic();
739 llvm::Function *GroupIDIntrinsic =
740 llvm::Intrinsic::isOverloaded(id: IntrinID)
741 ? CGM.getIntrinsic(IID: IntrinID, Tys: {CGM.Int32Ty})
742 : CGM.getIntrinsic(IID: IntrinID);
743 return buildVectorInput(B, F: GroupIDIntrinsic, Ty: Type);
744 }
745
746 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
747 assert(ShaderAttr && "Entry point has no shader attribute");
748 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
749
750 if (SemanticName == "SV_POSITION") {
751 if (ST == Triple::EnvironmentType::Pixel) {
752 if (CGM.getTarget().getTriple().isSPIRV())
753 return createSPIRVBuiltinLoad(B, M&: CGM.getModule(), Ty: Type,
754 Name: Semantic->getAttrName()->getName(),
755 /* BuiltIn::FragCoord */ BuiltInID: 15);
756 if (CGM.getTarget().getTriple().isDXIL())
757 return emitDXILUserSemanticLoad(B, Type, Semantic, Index);
758 }
759
760 if (ST == Triple::EnvironmentType::Vertex) {
761 return emitUserSemanticLoad(B, Type, Decl, Semantic, Index);
762 }
763 }
764
765 if (SemanticName == "SV_VERTEXID") {
766 if (ST == Triple::EnvironmentType::Vertex) {
767 if (CGM.getTarget().getTriple().isSPIRV())
768 return createSPIRVBuiltinLoad(B, M&: CGM.getModule(), Ty: Type,
769 Name: Semantic->getAttrName()->getName(),
770 /* BuiltIn::VertexIndex */ BuiltInID: 42);
771 else
772 return emitDXILUserSemanticLoad(B, Type, Semantic, Index);
773 }
774 }
775
776 llvm_unreachable(
777 "Load hasn't been implemented yet for this system semantic. FIXME");
778}
779
780static void createSPIRVBuiltinStore(IRBuilder<> &B, llvm::Module &M,
781 llvm::Value *Source, const Twine &Name,
782 unsigned BuiltInID) {
783 auto *GV = new llvm::GlobalVariable(
784 M, Source->getType(), /* isConstant= */ false,
785 llvm::GlobalValue::ExternalLinkage,
786 /* Initializer= */ nullptr, Name, /* insertBefore= */ nullptr,
787 llvm::GlobalVariable::GeneralDynamicTLSModel,
788 /* AddressSpace */ 8, /* isExternallyInitialized= */ false);
789 addSPIRVBuiltinDecoration(GV, BuiltIn: BuiltInID);
790 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
791 B.CreateStore(Val: Source, Ptr: GV);
792}
793
794void CGHLSLRuntime::emitSystemSemanticStore(IRBuilder<> &B, llvm::Value *Source,
795 const clang::DeclaratorDecl *Decl,
796 HLSLAppliedSemanticAttr *Semantic,
797 std::optional<unsigned> Index) {
798
799 std::string SemanticName = Semantic->getAttrName()->getName().upper();
800 if (SemanticName == "SV_POSITION") {
801 if (CGM.getTarget().getTriple().isDXIL()) {
802 emitDXILUserSemanticStore(B, Source, Semantic, Index);
803 return;
804 }
805
806 if (CGM.getTarget().getTriple().isSPIRV()) {
807 createSPIRVBuiltinStore(B, M&: CGM.getModule(), Source,
808 Name: Semantic->getAttrName()->getName(),
809 /* BuiltIn::Position */ BuiltInID: 0);
810 return;
811 }
812 }
813
814 if (SemanticName == "SV_TARGET") {
815 emitUserSemanticStore(B, Source, Decl, Semantic, Index);
816 return;
817 }
818
819 llvm_unreachable(
820 "Store hasn't been implemented yet for this system semantic. FIXME");
821}
822
823llvm::Value *CGHLSLRuntime::handleScalarSemanticLoad(
824 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
825 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic) {
826
827 std::optional<unsigned> Index = Semantic->getSemanticIndex();
828 if (Semantic->getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
829 return emitSystemSemanticLoad(B, FD, Type, Decl, Semantic, Index);
830 return emitUserSemanticLoad(B, Type, Decl, Semantic, Index);
831}
832
833void CGHLSLRuntime::handleScalarSemanticStore(
834 IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
835 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic) {
836 std::optional<unsigned> Index = Semantic->getSemanticIndex();
837 if (Semantic->getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
838 emitSystemSemanticStore(B, Source, Decl, Semantic, Index);
839 else
840 emitUserSemanticStore(B, Source, Decl, Semantic, Index);
841}
842
843std::pair<llvm::Value *, specific_attr_iterator<HLSLAppliedSemanticAttr>>
844CGHLSLRuntime::handleStructSemanticLoad(
845 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
846 const clang::DeclaratorDecl *Decl,
847 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
848 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
849 const llvm::StructType *ST = cast<StructType>(Val: Type);
850 const clang::RecordDecl *RD = Decl->getType()->getAsRecordDecl();
851
852 assert(RD->getNumFields() == ST->getNumElements());
853
854 llvm::Value *Aggregate = llvm::PoisonValue::get(T: Type);
855 auto FieldDecl = RD->field_begin();
856 for (unsigned I = 0; I < ST->getNumElements(); ++I) {
857 auto [ChildValue, NextAttr] = handleSemanticLoad(
858 B, FD, Type: ST->getElementType(N: I), Decl: *FieldDecl, begin: AttrBegin, end: AttrEnd);
859 AttrBegin = NextAttr;
860 assert(ChildValue);
861 Aggregate = B.CreateInsertValue(Agg: Aggregate, Val: ChildValue, Idxs: I);
862 ++FieldDecl;
863 }
864
865 return std::make_pair(x&: Aggregate, y&: AttrBegin);
866}
867
868specific_attr_iterator<HLSLAppliedSemanticAttr>
869CGHLSLRuntime::handleStructSemanticStore(
870 IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
871 const clang::DeclaratorDecl *Decl,
872 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
873 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
874
875 const llvm::StructType *ST = cast<StructType>(Val: Source->getType());
876
877 const clang::RecordDecl *RD = nullptr;
878 if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(Val: Decl))
879 RD = FD->getDeclaredReturnType()->getAsRecordDecl();
880 else
881 RD = Decl->getType()->getAsRecordDecl();
882 assert(RD);
883
884 assert(RD->getNumFields() == ST->getNumElements());
885
886 auto FieldDecl = RD->field_begin();
887 for (unsigned I = 0; I < ST->getNumElements(); ++I, ++FieldDecl) {
888 llvm::Value *Extract = B.CreateExtractValue(Agg: Source, Idxs: I);
889 AttrBegin =
890 handleSemanticStore(B, FD, Source: Extract, Decl: *FieldDecl, AttrBegin, AttrEnd);
891 }
892
893 return AttrBegin;
894}
895
896std::pair<llvm::Value *, specific_attr_iterator<HLSLAppliedSemanticAttr>>
897CGHLSLRuntime::handleSemanticLoad(
898 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
899 const clang::DeclaratorDecl *Decl,
900 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
901 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
902 assert(AttrBegin != AttrEnd);
903 if (Type->isStructTy())
904 return handleStructSemanticLoad(B, FD, Type, Decl, AttrBegin, AttrEnd);
905
906 HLSLAppliedSemanticAttr *Attr = *AttrBegin;
907 ++AttrBegin;
908 return std::make_pair(x: handleScalarSemanticLoad(B, FD, Type, Decl, Semantic: Attr),
909 y&: AttrBegin);
910}
911
912specific_attr_iterator<HLSLAppliedSemanticAttr>
913CGHLSLRuntime::handleSemanticStore(
914 IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
915 const clang::DeclaratorDecl *Decl,
916 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
917 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
918 assert(AttrBegin != AttrEnd);
919 if (Source->getType()->isStructTy())
920 return handleStructSemanticStore(B, FD, Source, Decl, AttrBegin, AttrEnd);
921
922 HLSLAppliedSemanticAttr *Attr = *AttrBegin;
923 ++AttrBegin;
924 handleScalarSemanticStore(B, FD, Source, Decl, Semantic: Attr);
925 return AttrBegin;
926}
927
928void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
929 llvm::Function *Fn) {
930 llvm::Module &M = CGM.getModule();
931 llvm::LLVMContext &Ctx = M.getContext();
932 auto *EntryTy = llvm::FunctionType::get(Result: llvm::Type::getVoidTy(C&: Ctx), isVarArg: false);
933 Function *EntryFn =
934 Function::Create(Ty: EntryTy, Linkage: Function::ExternalLinkage, N: FD->getName(), M: &M);
935
936 // Copy function attributes over, we have no argument or return attributes
937 // that can be valid on the real entry.
938 AttributeList NewAttrs = AttributeList::get(C&: Ctx, Index: AttributeList::FunctionIndex,
939 Attrs: Fn->getAttributes().getFnAttrs());
940 EntryFn->setAttributes(NewAttrs);
941 setHLSLEntryAttributes(FD, Fn: EntryFn);
942
943 // Set the called function as internal linkage.
944 Fn->setLinkage(GlobalValue::InternalLinkage);
945
946 BasicBlock *BB = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: EntryFn);
947 IRBuilder<> B(BB);
948 llvm::SmallVector<Value *> Args;
949
950 SmallVector<OperandBundleDef, 1> OB;
951 if (CGM.shouldEmitConvergenceTokens()) {
952 assert(EntryFn->isConvergent());
953 llvm::Value *I =
954 B.CreateIntrinsic(ID: llvm::Intrinsic::experimental_convergence_entry, Args: {});
955 llvm::Value *bundleArgs[] = {I};
956 OB.emplace_back(Args: "convergencectrl", Args&: bundleArgs);
957 }
958
959 llvm::DenseMap<const DeclaratorDecl *, std::pair<llvm::Value *, llvm::Type *>>
960 OutputSemantic;
961
962 unsigned SRetOffset = 0;
963 for (const auto &Param : Fn->args()) {
964 if (Param.hasStructRetAttr()) {
965 SRetOffset = 1;
966 llvm::Type *VarType = Param.getParamStructRetType();
967 llvm::Value *Var = B.CreateAlloca(Ty: VarType);
968 OutputSemantic.try_emplace(Key: FD, Args: std::make_pair(x&: Var, y&: VarType));
969 Args.push_back(Elt: Var);
970 continue;
971 }
972
973 const ParmVarDecl *PD = FD->getParamDecl(i: Param.getArgNo() - SRetOffset);
974 llvm::Value *SemanticValue = nullptr;
975 // FIXME: support inout/out parameters for semantics.
976 if ([[maybe_unused]] HLSLParamModifierAttr *MA =
977 PD->getAttr<HLSLParamModifierAttr>()) {
978 llvm_unreachable("Not handled yet");
979 } else {
980 llvm::Type *ParamType =
981 Param.hasByValAttr() ? Param.getParamByValType() : Param.getType();
982 auto AttrBegin = PD->specific_attr_begin<HLSLAppliedSemanticAttr>();
983 auto AttrEnd = PD->specific_attr_end<HLSLAppliedSemanticAttr>();
984 auto Result =
985 handleSemanticLoad(B, FD, Type: ParamType, Decl: PD, AttrBegin, AttrEnd);
986 SemanticValue = Result.first;
987 if (!SemanticValue)
988 return;
989 if (Param.hasByValAttr()) {
990 llvm::Value *Var = B.CreateAlloca(Ty: Param.getParamByValType());
991 B.CreateStore(Val: SemanticValue, Ptr: Var);
992 SemanticValue = Var;
993 }
994 }
995
996 assert(SemanticValue);
997 Args.push_back(Elt: SemanticValue);
998 }
999
1000 CallInst *CI = B.CreateCall(Callee: FunctionCallee(Fn), Args, OpBundles: OB);
1001 CI->setCallingConv(Fn->getCallingConv());
1002
1003 if (Fn->getReturnType() != CGM.VoidTy)
1004 // Element type is unused, so set to dummy value (NULL).
1005 OutputSemantic.try_emplace(Key: FD, Args: std::make_pair(x&: CI, y: nullptr));
1006
1007 for (auto &[Decl, SourcePair] : OutputSemantic) {
1008 llvm::Value *Source = SourcePair.first;
1009 llvm::Type *ElementType = SourcePair.second;
1010 AllocaInst *AI = dyn_cast<AllocaInst>(Val: Source);
1011 llvm::Value *SourceValue = AI ? B.CreateLoad(Ty: ElementType, Ptr: Source) : Source;
1012
1013 auto AttrBegin = Decl->specific_attr_begin<HLSLAppliedSemanticAttr>();
1014 auto AttrEnd = Decl->specific_attr_end<HLSLAppliedSemanticAttr>();
1015 handleSemanticStore(B, FD, Source: SourceValue, Decl, AttrBegin, AttrEnd);
1016 }
1017
1018 B.CreateRetVoid();
1019
1020 // Add and identify root signature to function, if applicable
1021 for (const Attr *Attr : FD->getAttrs()) {
1022 if (const auto *RSAttr = dyn_cast<RootSignatureAttr>(Val: Attr)) {
1023 auto *RSDecl = RSAttr->getSignatureDecl();
1024 addRootSignatureMD(RootSigVer: RSDecl->getVersion(), Elements: RSDecl->getRootElements(),
1025 Fn: EntryFn, M);
1026 }
1027 }
1028}
1029
1030static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
1031 bool CtorOrDtor) {
1032 const auto *GV =
1033 M.getNamedGlobal(Name: CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
1034 if (!GV)
1035 return;
1036 const auto *CA = dyn_cast<ConstantArray>(Val: GV->getInitializer());
1037 if (!CA)
1038 return;
1039 // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
1040 // HLSL neither supports priorities or COMDat values, so we will check those
1041 // in an assert but not handle them.
1042
1043 for (const auto &Ctor : CA->operands()) {
1044 if (isa<ConstantAggregateZero>(Val: Ctor))
1045 continue;
1046 ConstantStruct *CS = cast<ConstantStruct>(Val: Ctor);
1047
1048 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
1049 "HLSL doesn't support setting priority for global ctors.");
1050 assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
1051 "HLSL doesn't support COMDat for global ctors.");
1052 Fns.push_back(Elt: cast<Function>(Val: CS->getOperand(i_nocapture: 1)));
1053 }
1054}
1055
1056void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
1057 llvm::Module &M = CGM.getModule();
1058 SmallVector<Function *> CtorFns;
1059 SmallVector<Function *> DtorFns;
1060 gatherFunctions(Fns&: CtorFns, M, CtorOrDtor: true);
1061 gatherFunctions(Fns&: DtorFns, M, CtorOrDtor: false);
1062
1063 // Insert a call to the global constructor at the beginning of the entry block
1064 // to externally exported functions. This is a bit of a hack, but HLSL allows
1065 // global constructors, but doesn't support driver initialization of globals.
1066 for (auto &F : M.functions()) {
1067 if (!F.hasFnAttribute(Kind: "hlsl.shader"))
1068 continue;
1069 auto *Token = getConvergenceToken(BB&: F.getEntryBlock());
1070 Instruction *IP = &*F.getEntryBlock().begin();
1071 SmallVector<OperandBundleDef, 1> OB;
1072 if (Token) {
1073 llvm::Value *bundleArgs[] = {Token};
1074 OB.emplace_back(Args: "convergencectrl", Args&: bundleArgs);
1075 IP = Token->getNextNode();
1076 }
1077 IRBuilder<> B(IP);
1078 for (auto *Fn : CtorFns) {
1079 auto CI = B.CreateCall(Callee: FunctionCallee(Fn), Args: {}, OpBundles: OB);
1080 CI->setCallingConv(Fn->getCallingConv());
1081 }
1082
1083 // Insert global dtors before the terminator of the last instruction
1084 B.SetInsertPoint(F.back().getTerminator());
1085 for (auto *Fn : DtorFns) {
1086 auto CI = B.CreateCall(Callee: FunctionCallee(Fn), Args: {}, OpBundles: OB);
1087 CI->setCallingConv(Fn->getCallingConv());
1088 }
1089 }
1090
1091 // No need to keep global ctors/dtors for non-lib profile after call to
1092 // ctors/dtors added for entry.
1093 Triple T(M.getTargetTriple());
1094 if (T.getEnvironment() != Triple::EnvironmentType::Library) {
1095 if (auto *GV = M.getNamedGlobal(Name: "llvm.global_ctors"))
1096 GV->eraseFromParent();
1097 if (auto *GV = M.getNamedGlobal(Name: "llvm.global_dtors"))
1098 GV->eraseFromParent();
1099 }
1100}
1101
1102static void initializeBuffer(CodeGenModule &CGM, llvm::GlobalVariable *GV,
1103 Intrinsic::ID IntrID,
1104 ArrayRef<llvm::Value *> Args) {
1105
1106 LLVMContext &Ctx = CGM.getLLVMContext();
1107 llvm::Function *InitResFunc = llvm::Function::Create(
1108 Ty: llvm::FunctionType::get(Result: CGM.VoidTy, isVarArg: false),
1109 Linkage: llvm::GlobalValue::InternalLinkage,
1110 N: ("_init_buffer_" + GV->getName()).str(), M&: CGM.getModule());
1111 InitResFunc->addFnAttr(Kind: llvm::Attribute::AlwaysInline);
1112
1113 llvm::BasicBlock *EntryBB =
1114 llvm::BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: InitResFunc);
1115 CGBuilderTy Builder(CGM, Ctx);
1116 const DataLayout &DL = CGM.getModule().getDataLayout();
1117 Builder.SetInsertPoint(EntryBB);
1118
1119 // Make sure the global variable is buffer resource handle
1120 llvm::Type *HandleTy = GV->getValueType();
1121 assert(HandleTy->isTargetExtTy() && "unexpected type of the buffer global");
1122
1123 llvm::Value *CreateHandle = Builder.CreateIntrinsic(
1124 /*ReturnType=*/RetTy: HandleTy, ID: IntrID, Args, FMFSource: nullptr,
1125 Name: Twine(GV->getName()).concat(Suffix: "_h"));
1126
1127 Builder.CreateAlignedStore(Val: CreateHandle, Ptr: GV, Align: GV->getPointerAlignment(DL));
1128 Builder.CreateRetVoid();
1129
1130 CGM.AddCXXGlobalInit(F: InitResFunc);
1131}
1132
1133void CGHLSLRuntime::initializeBufferFromBinding(const HLSLBufferDecl *BufDecl,
1134 llvm::GlobalVariable *GV) {
1135 ResourceBindingAttrs Binding(BufDecl);
1136 assert(Binding.hasBinding() &&
1137 "cbuffer/tbuffer should always have resource binding attribute");
1138
1139 auto *Index = llvm::ConstantInt::get(Ty: CGM.IntTy, V: 0);
1140 auto *RangeSize = llvm::ConstantInt::get(Ty: CGM.IntTy, V: 1);
1141 auto *Space = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSpace());
1142 Value *Name = buildNameForResource(BaseName: BufDecl->getName(), CGM);
1143
1144 // buffer with explicit binding
1145 if (Binding.isExplicit()) {
1146 llvm::Intrinsic::ID IntrinsicID =
1147 CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic();
1148 auto *RegSlot = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSlot());
1149 SmallVector<Value *> Args{Space, RegSlot, RangeSize, Index, Name};
1150 initializeBuffer(CGM, GV, IntrID: IntrinsicID, Args);
1151 } else {
1152 // buffer with implicit binding
1153 llvm::Intrinsic::ID IntrinsicID =
1154 CGM.getHLSLRuntime().getCreateHandleFromImplicitBindingIntrinsic();
1155 auto *OrderID =
1156 llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getImplicitOrderID());
1157 SmallVector<Value *> Args{OrderID, Space, RangeSize, Index, Name};
1158 initializeBuffer(CGM, GV, IntrID: IntrinsicID, Args);
1159 }
1160}
1161
1162void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
1163 llvm::GlobalVariable *GV) {
1164 if (auto Attr = VD->getAttr<HLSLVkExtBuiltinInputAttr>())
1165 addSPIRVBuiltinDecoration(GV, BuiltIn: Attr->getBuiltIn());
1166}
1167
1168llvm::Instruction *CGHLSLRuntime::getConvergenceToken(BasicBlock &BB) {
1169 if (!CGM.shouldEmitConvergenceTokens())
1170 return nullptr;
1171
1172 auto E = BB.end();
1173 for (auto I = BB.begin(); I != E; ++I) {
1174 auto *II = dyn_cast<llvm::IntrinsicInst>(Val: &*I);
1175 if (II && llvm::isConvergenceControlIntrinsic(IntrinsicID: II->getIntrinsicID())) {
1176 return II;
1177 }
1178 }
1179 llvm_unreachable("Convergence token should have been emitted.");
1180 return nullptr;
1181}
1182
1183class OpaqueValueVisitor : public RecursiveASTVisitor<OpaqueValueVisitor> {
1184public:
1185 llvm::SmallVector<OpaqueValueExpr *, 8> OVEs;
1186 llvm::SmallPtrSet<OpaqueValueExpr *, 8> Visited;
1187 OpaqueValueVisitor() {}
1188
1189 bool VisitHLSLOutArgExpr(HLSLOutArgExpr *) {
1190 // These need to be bound in CodeGenFunction::EmitHLSLOutArgLValues
1191 // or CodeGenFunction::EmitHLSLOutArgExpr. If they are part of this
1192 // traversal, the temporary containing the copy out will not have
1193 // been created yet.
1194 return false;
1195 }
1196
1197 bool VisitOpaqueValueExpr(OpaqueValueExpr *E) {
1198 // Traverse the source expression first.
1199 if (E->getSourceExpr())
1200 TraverseStmt(S: E->getSourceExpr());
1201
1202 // Then add this OVE if we haven't seen it before.
1203 if (Visited.insert(Ptr: E).second)
1204 OVEs.push_back(Elt: E);
1205
1206 return true;
1207 }
1208};
1209
1210void CGHLSLRuntime::emitInitListOpaqueValues(CodeGenFunction &CGF,
1211 InitListExpr *E) {
1212
1213 typedef CodeGenFunction::OpaqueValueMappingData OpaqueValueMappingData;
1214 OpaqueValueVisitor Visitor;
1215 Visitor.TraverseStmt(S: E);
1216 for (auto *OVE : Visitor.OVEs) {
1217 if (CGF.isOpaqueValueEmitted(E: OVE))
1218 continue;
1219 if (OpaqueValueMappingData::shouldBindAsLValue(expr: OVE)) {
1220 LValue LV = CGF.EmitLValue(E: OVE->getSourceExpr());
1221 OpaqueValueMappingData::bind(CGF, ov: OVE, lv: LV);
1222 } else {
1223 RValue RV = CGF.EmitAnyExpr(E: OVE->getSourceExpr());
1224 OpaqueValueMappingData::bind(CGF, ov: OVE, rv: RV);
1225 }
1226 }
1227}
1228
1229std::optional<LValue> CGHLSLRuntime::emitResourceArraySubscriptExpr(
1230 const ArraySubscriptExpr *ArraySubsExpr, CodeGenFunction &CGF) {
1231 assert((ArraySubsExpr->getType()->isHLSLResourceRecord() ||
1232 ArraySubsExpr->getType()->isHLSLResourceRecordArray()) &&
1233 "expected resource array subscript expression");
1234
1235 // Let clang codegen handle local and static resource array subscripts,
1236 // or when the subscript references on opaque expression (as part of
1237 // ArrayInitLoopExpr AST node).
1238 const VarDecl *ArrayDecl =
1239 dyn_cast_or_null<VarDecl>(Val: getArrayDecl(ASE: ArraySubsExpr));
1240 if (!ArrayDecl || !ArrayDecl->hasGlobalStorage() ||
1241 ArrayDecl->getStorageClass() == SC_Static)
1242 return std::nullopt;
1243
1244 // get the resource array type
1245 ASTContext &AST = ArrayDecl->getASTContext();
1246 const Type *ResArrayTy = ArrayDecl->getType().getTypePtr();
1247 assert(ResArrayTy->isHLSLResourceRecordArray() &&
1248 "expected array of resource classes");
1249
1250 // Iterate through all nested array subscript expressions to calculate
1251 // the index in the flattened resource array (if this is a multi-
1252 // dimensional array). The index is calculated as a sum of all indices
1253 // multiplied by the total size of the array at that level.
1254 Value *Index = nullptr;
1255 const ArraySubscriptExpr *ASE = ArraySubsExpr;
1256 while (ASE != nullptr) {
1257 Value *SubIndex = CGF.EmitScalarExpr(E: ASE->getIdx());
1258 if (const auto *ArrayTy =
1259 dyn_cast<ConstantArrayType>(Val: ASE->getType().getTypePtr())) {
1260 Value *Multiplier = llvm::ConstantInt::get(
1261 Ty: CGM.IntTy, V: AST.getConstantArrayElementCount(CA: ArrayTy));
1262 SubIndex = CGF.Builder.CreateMul(LHS: SubIndex, RHS: Multiplier);
1263 }
1264 Index = Index ? CGF.Builder.CreateAdd(LHS: Index, RHS: SubIndex) : SubIndex;
1265 ASE = dyn_cast<ArraySubscriptExpr>(Val: ASE->getBase()->IgnoreParenImpCasts());
1266 }
1267
1268 // Find binding info for the resource array. For implicit binding
1269 // an HLSLResourceBindingAttr should have been added by SemaHLSL.
1270 ResourceBindingAttrs Binding(ArrayDecl);
1271 assert(Binding.hasBinding() &&
1272 "resource array must have a binding attribute");
1273
1274 // Find the individual resource type.
1275 QualType ResultTy = ArraySubsExpr->getType();
1276 QualType ResourceTy =
1277 ResultTy->isArrayType() ? AST.getBaseElementType(QT: ResultTy) : ResultTy;
1278
1279 // Create a temporary variable for the result, which is either going
1280 // to be a single resource instance or a local array of resources (we need to
1281 // return an LValue).
1282 RawAddress TmpVar = CGF.CreateMemTemp(T: ResultTy);
1283 if (CGF.EmitLifetimeStart(Addr: TmpVar.getPointer()))
1284 CGF.pushFullExprCleanup<CodeGenFunction::CallLifetimeEnd>(
1285 kind: NormalEHLifetimeMarker, A: TmpVar);
1286
1287 AggValueSlot ValueSlot = AggValueSlot::forAddr(
1288 addr: TmpVar, quals: Qualifiers(), isDestructed: AggValueSlot::IsDestructed_t(true),
1289 needsGC: AggValueSlot::DoesNotNeedGCBarriers, isAliased: AggValueSlot::IsAliased_t(false),
1290 mayOverlap: AggValueSlot::DoesNotOverlap);
1291
1292 // Calculate total array size (= range size).
1293 llvm::Value *Range = llvm::ConstantInt::getSigned(
1294 Ty: CGM.IntTy, V: getTotalArraySize(AST, Ty: ResArrayTy));
1295
1296 // If the result of the subscript operation is a single resource, call the
1297 // constructor.
1298 if (ResultTy == ResourceTy) {
1299 CallArgList Args;
1300 CXXMethodDecl *CreateMethod = lookupResourceInitMethodAndSetupArgs(
1301 CGM&: CGF.CGM, ResourceDecl: ResourceTy->getAsCXXRecordDecl(), Range, Index,
1302 Name: ArrayDecl->getName(), Binding, Args);
1303
1304 if (!CreateMethod)
1305 // This can happen if someone creates an array of structs that looks like
1306 // an HLSL resource record array but it does not have the required static
1307 // create method. No binding will be generated for it.
1308 return std::nullopt;
1309
1310 callResourceInitMethod(CGF, CreateMethod, Args, ReturnAddress: ValueSlot.getAddress());
1311
1312 } else {
1313 // The result of the subscript operation is a local resource array which
1314 // needs to be initialized.
1315 const ConstantArrayType *ArrayTy =
1316 cast<ConstantArrayType>(Val: ResultTy.getTypePtr());
1317 std::optional<llvm::Value *> EndIndex = initializeLocalResourceArray(
1318 CGF, ResourceDecl: ResourceTy->getAsCXXRecordDecl(), ArrayTy, ValueSlot, Range, StartIndex: Index,
1319 ResourceName: ArrayDecl->getName(), Binding, PrevGEPIndices: {llvm::ConstantInt::get(Ty: CGM.IntTy, V: 0)},
1320 ArraySubsExprLoc: ArraySubsExpr->getExprLoc());
1321 if (!EndIndex)
1322 return std::nullopt;
1323 }
1324 return CGF.MakeAddrLValue(Addr: TmpVar, T: ResultTy, Source: AlignmentSource::Decl);
1325}
1326
1327// If RHSExpr is a global resource array, initialize all of its resources and
1328// set them into LHS. Returns false if no copy has been performed and the
1329// array copy should be handled by Clang codegen.
1330bool CGHLSLRuntime::emitResourceArrayCopy(LValue &LHS, Expr *RHSExpr,
1331 CodeGenFunction &CGF) {
1332 QualType ResultTy = RHSExpr->getType();
1333 assert(ResultTy->isHLSLResourceRecordArray() && "expected resource array");
1334
1335 // Let Clang codegen handle local and static resource array copies.
1336 const VarDecl *ArrayDecl = dyn_cast_or_null<VarDecl>(Val: getArrayDecl(E: RHSExpr));
1337 if (!ArrayDecl || !ArrayDecl->hasGlobalStorage() ||
1338 ArrayDecl->getStorageClass() == SC_Static)
1339 return false;
1340
1341 // Find binding info for the resource array. For implicit binding
1342 // the HLSLResourceBindingAttr should have been added by SemaHLSL.
1343 ResourceBindingAttrs Binding(ArrayDecl);
1344 assert(Binding.hasBinding() &&
1345 "resource array must have a binding attribute");
1346
1347 // Find the individual resource type.
1348 ASTContext &AST = ArrayDecl->getASTContext();
1349 QualType ResTy = AST.getBaseElementType(QT: ResultTy);
1350 const auto *ResArrayTy = cast<ConstantArrayType>(Val: ResultTy.getTypePtr());
1351
1352 // Use the provided LHS for the result.
1353 AggValueSlot ValueSlot = AggValueSlot::forAddr(
1354 addr: LHS.getAddress(), quals: Qualifiers(), isDestructed: AggValueSlot::IsDestructed_t(true),
1355 needsGC: AggValueSlot::DoesNotNeedGCBarriers, isAliased: AggValueSlot::IsAliased_t(false),
1356 mayOverlap: AggValueSlot::DoesNotOverlap);
1357
1358 // Create Value for index and total array size (= range size).
1359 int Size = getTotalArraySize(AST, Ty: ResArrayTy);
1360 llvm::Value *Zero = llvm::ConstantInt::get(Ty: CGM.IntTy, V: 0);
1361 llvm::Value *Range = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Size);
1362
1363 // Initialize individual resources in the array into LHS.
1364 std::optional<llvm::Value *> EndIndex = initializeLocalResourceArray(
1365 CGF, ResourceDecl: ResTy->getAsCXXRecordDecl(), ArrayTy: ResArrayTy, ValueSlot, Range, StartIndex: Zero,
1366 ResourceName: ArrayDecl->getName(), Binding, PrevGEPIndices: {Zero}, ArraySubsExprLoc: RHSExpr->getExprLoc());
1367 return EndIndex.has_value();
1368}
1369
1370RawAddress CGHLSLRuntime::createBufferMatrixTempAddress(const LValue &LV,
1371 SourceLocation Loc,
1372 CodeGenFunction &CGF) {
1373
1374 assert(LV.getType()->isConstantMatrixType() && "expected matrix type");
1375 assert(LV.getType().getAddressSpace() == LangAS::hlsl_constant &&
1376 "expected cbuffer matrix");
1377
1378 QualType MatQualTy = LV.getType();
1379 llvm::Type *MemTy = CGF.ConvertTypeForMem(T: MatQualTy);
1380 llvm::Type *LayoutTy = HLSLBufferLayoutBuilder(CGF.CGM).layOutType(Type: MatQualTy);
1381
1382 if (LayoutTy == MemTy)
1383 return LV.getAddress();
1384
1385 Address SrcAddr = LV.getAddress();
1386 // NOTE: B\C CreateMemTemp flattens MatrixTypes which causes
1387 // overlapping GEPs in emitBufferCopy. Use CreateTempAlloca with
1388 // the non-padded layout.
1389 CharUnits Align =
1390 CharUnits::fromQuantity(Quantity: CGF.CGM.getDataLayout().getABITypeAlign(Ty: MemTy));
1391 RawAddress DestAlloca = CGF.CreateTempAlloca(Ty: MemTy, align: Align, Name: "matrix.buf.copy");
1392 emitBufferCopy(CGF, DestPtr: DestAlloca, SrcPtr: SrcAddr, CType: MatQualTy);
1393 return DestAlloca;
1394}
1395
1396std::optional<LValue> CGHLSLRuntime::emitBufferArraySubscriptExpr(
1397 const ArraySubscriptExpr *E, CodeGenFunction &CGF,
1398 llvm::function_ref<llvm::Value *(bool Promote)> EmitIdxAfterBase) {
1399 // Find the element type to index by first padding the element type per HLSL
1400 // buffer rules, and then padding out to a 16-byte register boundary if
1401 // necessary.
1402 llvm::Type *LayoutTy =
1403 HLSLBufferLayoutBuilder(CGF.CGM).layOutType(Type: E->getType());
1404 uint64_t LayoutSizeInBits =
1405 CGM.getDataLayout().getTypeSizeInBits(Ty: LayoutTy).getFixedValue();
1406 CharUnits ElementSize = CharUnits::fromQuantity(Quantity: LayoutSizeInBits / 8);
1407 CharUnits RowAlignedSize = ElementSize.alignTo(Align: CharUnits::fromQuantity(Quantity: 16));
1408 if (RowAlignedSize > ElementSize) {
1409 llvm::Type *Padding = CGM.getTargetCodeGenInfo().getHLSLPadding(
1410 CGM, NumBytes: RowAlignedSize - ElementSize);
1411 assert(Padding && "No padding type for target?");
1412 LayoutTy = llvm::StructType::get(Context&: CGF.getLLVMContext(), Elements: {LayoutTy, Padding},
1413 /*isPacked=*/true);
1414 }
1415
1416 // If the layout type doesn't introduce any padding, we don't need to do
1417 // anything special.
1418 llvm::Type *OrigTy = CGF.CGM.getTypes().ConvertTypeForMem(T: E->getType());
1419 if (LayoutTy == OrigTy)
1420 return std::nullopt;
1421
1422 LValueBaseInfo EltBaseInfo;
1423 TBAAAccessInfo EltTBAAInfo;
1424
1425 // Index into the object as-if we have an array of the padded element type,
1426 // and then dereference the element itself to avoid reading padding that may
1427 // be past the end of the in-memory object.
1428 SmallVector<llvm::Value *, 2> Indices;
1429 llvm::Value *Idx = EmitIdxAfterBase(/*Promote*/ true);
1430 Indices.push_back(Elt: Idx);
1431 Indices.push_back(Elt: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 0));
1432
1433 if (CGF.getLangOpts().EmitStructuredGEP) {
1434 // The fact that we emit an array-to-pointer decay might be an oversight,
1435 // but for now, we simply ignore it (see #179951).
1436 const CastExpr *CE = cast<CastExpr>(Val: E->getBase());
1437 assert(CE->getCastKind() == CastKind::CK_ArrayToPointerDecay);
1438
1439 LValue LV = CGF.EmitLValue(E: CE->getSubExpr());
1440 Address Addr = LV.getAddress();
1441 LayoutTy = llvm::ArrayType::get(
1442 ElementType: LayoutTy,
1443 NumElements: cast<llvm::ArrayType>(Val: Addr.getElementType())->getNumElements());
1444 auto *GEP = cast<StructuredGEPInst>(Val: CGF.Builder.CreateStructuredGEP(
1445 BaseType: LayoutTy, PtrBase: Addr.emitRawPointer(CGF), Indices, Name: "cbufferidx"));
1446 Addr =
1447 Address(GEP, GEP->getResultElementType(), RowAlignedSize, KnownNonNull);
1448 return CGF.MakeAddrLValue(Addr, T: E->getType(), BaseInfo: EltBaseInfo, TBAAInfo: EltTBAAInfo);
1449 }
1450
1451 Address Addr =
1452 CGF.EmitPointerWithAlignment(Addr: E->getBase(), BaseInfo: &EltBaseInfo, TBAAInfo: &EltTBAAInfo);
1453 llvm::Value *GEP = CGF.Builder.CreateGEP(Ty: LayoutTy, Ptr: Addr.emitRawPointer(CGF),
1454 IdxList: Indices, Name: "cbufferidx");
1455 Addr = Address(GEP, Addr.getElementType(), RowAlignedSize, KnownNonNull);
1456 return CGF.MakeAddrLValue(Addr, T: E->getType(), BaseInfo: EltBaseInfo, TBAAInfo: EltTBAAInfo);
1457}
1458
1459namespace {
1460/// Utility for emitting copies following the HLSL buffer layout rules (ie,
1461/// copying out of a cbuffer).
1462class HLSLBufferCopyEmitter {
1463 CodeGenFunction &CGF;
1464 Address DestPtr;
1465 Address SrcPtr;
1466 llvm::Type *LayoutTy = nullptr;
1467
1468 SmallVector<llvm::Value *> CurStoreIndices;
1469 SmallVector<llvm::Value *> CurLoadIndices;
1470
1471 void emitCopyAtIndices(llvm::Type *FieldTy, llvm::ConstantInt *StoreIndex,
1472 llvm::ConstantInt *LoadIndex) {
1473 CurStoreIndices.push_back(Elt: StoreIndex);
1474 CurLoadIndices.push_back(Elt: LoadIndex);
1475 llvm::scope_exit RestoreIndices([&]() {
1476 CurStoreIndices.pop_back();
1477 CurLoadIndices.pop_back();
1478 });
1479
1480 // First, see if this is some kind of aggregate and recurse.
1481 if (processArray(FieldTy))
1482 return;
1483 if (processBufferLayoutArray(FieldTy))
1484 return;
1485 if (processStruct(FieldTy))
1486 return;
1487
1488 // When we have a scalar or vector element we can emit the copy.
1489 CharUnits Align = CharUnits::fromQuantity(
1490 Quantity: CGF.CGM.getDataLayout().getABITypeAlign(Ty: FieldTy));
1491 Address SrcGEP = RawAddress(
1492 CGF.Builder.CreateInBoundsGEP(Ty: LayoutTy, Ptr: SrcPtr.getBasePointer(),
1493 IdxList: CurLoadIndices, Name: "cbuf.src"),
1494 FieldTy, Align, SrcPtr.isKnownNonNull());
1495 Address DestGEP = CGF.Builder.CreateInBoundsGEP(
1496 Addr: DestPtr, IdxList: CurStoreIndices, ElementType: FieldTy, Align, Name: "cbuf.dest");
1497 llvm::Value *Load = CGF.Builder.CreateLoad(Addr: SrcGEP, Name: "cbuf.load");
1498 CGF.Builder.CreateStore(Val: Load, Addr: DestGEP);
1499 }
1500
1501 bool processArray(llvm::Type *FieldTy) {
1502 auto *AT = dyn_cast<llvm::ArrayType>(Val: FieldTy);
1503 if (!AT)
1504 return false;
1505
1506 // If we have an llvm::ArrayType this is just a regular array with no top
1507 // level padding, so all we need to do is copy each member.
1508 for (unsigned I = 0, E = AT->getNumElements(); I < E; ++I)
1509 emitCopyAtIndices(FieldTy: AT->getElementType(),
1510 StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I),
1511 LoadIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I));
1512 return true;
1513 }
1514
1515 bool processBufferLayoutArray(llvm::Type *FieldTy) {
1516 // A buffer layout array is a struct with two elements: the padded array,
1517 // and the last element. That is, is should look something like this:
1518 //
1519 // { [%n x { %type, %padding }], %type }
1520 //
1521 auto *ST = dyn_cast<llvm::StructType>(Val: FieldTy);
1522 if (!ST || ST->getNumElements() != 2)
1523 return false;
1524
1525 auto *PaddedEltsTy = dyn_cast<llvm::ArrayType>(Val: ST->getElementType(N: 0));
1526 if (!PaddedEltsTy)
1527 return false;
1528
1529 auto *PaddedTy = dyn_cast<llvm::StructType>(Val: PaddedEltsTy->getElementType());
1530 if (!PaddedTy || PaddedTy->getNumElements() != 2)
1531 return false;
1532
1533 if (!CGF.CGM.getTargetCodeGenInfo().isHLSLPadding(
1534 Ty: PaddedTy->getElementType(N: 1)))
1535 return false;
1536
1537 llvm::Type *ElementTy = ST->getElementType(N: 1);
1538 if (PaddedTy->getElementType(N: 0) != ElementTy)
1539 return false;
1540
1541 // All but the last of the logical array elements are in the padded array.
1542 unsigned NumElts = PaddedEltsTy->getNumElements() + 1;
1543
1544 // Add an extra indirection to the load for the struct and walk the
1545 // array prefix.
1546 CurLoadIndices.push_back(Elt: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 0));
1547 for (unsigned I = 0; I < NumElts - 1; ++I) {
1548 // We need to copy the element itself, without the padding.
1549 CurLoadIndices.push_back(Elt: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I));
1550 emitCopyAtIndices(FieldTy: ElementTy, StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I),
1551 LoadIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 0));
1552 CurLoadIndices.pop_back();
1553 }
1554 CurLoadIndices.pop_back();
1555
1556 // Now copy the last element.
1557 emitCopyAtIndices(FieldTy: ElementTy,
1558 StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: NumElts - 1),
1559 LoadIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 1));
1560
1561 return true;
1562 }
1563
1564 bool processStruct(llvm::Type *FieldTy) {
1565 auto *ST = dyn_cast<llvm::StructType>(Val: FieldTy);
1566 if (!ST)
1567 return false;
1568
1569 // Copy the struct field by field, but skip any explicit padding.
1570 unsigned Skipped = 0;
1571 for (unsigned I = 0, E = ST->getNumElements(); I < E; ++I) {
1572 llvm::Type *ElementTy = ST->getElementType(N: I);
1573 if (CGF.CGM.getTargetCodeGenInfo().isHLSLPadding(Ty: ElementTy))
1574 ++Skipped;
1575 else
1576 emitCopyAtIndices(FieldTy: ElementTy, StoreIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: I),
1577 LoadIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: I + Skipped));
1578 }
1579 return true;
1580 }
1581
1582public:
1583 HLSLBufferCopyEmitter(CodeGenFunction &CGF, Address DestPtr, Address SrcPtr)
1584 : CGF(CGF), DestPtr(DestPtr), SrcPtr(SrcPtr) {}
1585
1586 bool emitCopy(QualType CType) {
1587 LayoutTy = HLSLBufferLayoutBuilder(CGF.CGM).layOutType(Type: CType);
1588
1589 // TODO: We should be able to fall back to a regular memcpy if the layout
1590 // type doesn't have any padding, but that runs into issues in the backend
1591 // currently.
1592 //
1593 // See https://github.com/llvm/wg-hlsl/issues/351
1594 emitCopyAtIndices(FieldTy: LayoutTy, StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: 0),
1595 LoadIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: 0));
1596 return true;
1597 }
1598};
1599} // namespace
1600
1601bool CGHLSLRuntime::emitBufferCopy(CodeGenFunction &CGF, Address DestPtr,
1602 Address SrcPtr, QualType CType) {
1603 return HLSLBufferCopyEmitter(CGF, DestPtr, SrcPtr).emitCopy(CType);
1604}
1605
1606LValue CGHLSLRuntime::emitBufferMemberExpr(CodeGenFunction &CGF,
1607 const MemberExpr *E) {
1608 LValue Base =
1609 CGF.EmitCheckedLValue(E: E->getBase(), TCK: CodeGenFunction::TCK_MemberAccess);
1610 auto *Field = dyn_cast<FieldDecl>(Val: E->getMemberDecl());
1611 assert(Field && "Unexpected access into HLSL buffer");
1612
1613 const RecordDecl *Rec = Field->getParent();
1614
1615 // Work out the buffer layout type to index into.
1616 QualType RecType = CGM.getContext().getCanonicalTagType(TD: Rec);
1617 assert(RecType->isStructureOrClassType() && "Invalid type in HLSL buffer");
1618 // Since this is a member of an object in the buffer and not the buffer's
1619 // struct/class itself, we shouldn't have any offsets on the members we need
1620 // to contend with.
1621 CGHLSLOffsetInfo EmptyOffsets;
1622 llvm::StructType *LayoutTy = HLSLBufferLayoutBuilder(CGM).layOutStruct(
1623 StructType: RecType->getAsCanonical<RecordType>(), OffsetInfo: EmptyOffsets);
1624
1625 // Get the field index for the layout struct, accounting for padding.
1626 unsigned FieldIdx =
1627 CGM.getTypes().getCGRecordLayout(Rec).getLLVMFieldNo(FD: Field);
1628 assert(FieldIdx < LayoutTy->getNumElements() &&
1629 "Layout struct is smaller than member struct");
1630 unsigned Skipped = 0;
1631 for (unsigned I = 0; I <= FieldIdx;) {
1632 llvm::Type *ElementTy = LayoutTy->getElementType(N: I + Skipped);
1633 if (CGF.CGM.getTargetCodeGenInfo().isHLSLPadding(Ty: ElementTy))
1634 ++Skipped;
1635 else
1636 ++I;
1637 }
1638 FieldIdx += Skipped;
1639 assert(FieldIdx < LayoutTy->getNumElements() && "Access out of bounds");
1640
1641 // Now index into the struct, making sure that the type we return is the
1642 // buffer layout type rather than the original type in the AST.
1643 QualType FieldType = Field->getType();
1644 llvm::Type *FieldLLVMTy = CGM.getTypes().ConvertTypeForMem(T: FieldType);
1645 CharUnits Align = CharUnits::fromQuantity(
1646 Quantity: CGF.CGM.getDataLayout().getABITypeAlign(Ty: FieldLLVMTy));
1647
1648 Value *Ptr = CGF.getLangOpts().EmitStructuredGEP
1649 ? CGF.Builder.CreateStructuredGEP(
1650 BaseType: LayoutTy, PtrBase: Base.getPointer(CGF),
1651 Indices: llvm::ConstantInt::get(Ty: CGM.IntTy, V: FieldIdx))
1652 : CGF.Builder.CreateStructGEP(Ty: LayoutTy, Ptr: Base.getPointer(CGF),
1653 Idx: FieldIdx, Name: Field->getName());
1654 Address Addr(Ptr, FieldLLVMTy, Align, KnownNonNull);
1655
1656 LValue LV = LValue::MakeAddr(Addr, type: FieldType, Context&: CGM.getContext(),
1657 BaseInfo: LValueBaseInfo(AlignmentSource::Type),
1658 TBAAInfo: CGM.getTBAAAccessInfo(AccessType: FieldType));
1659 LV.getQuals().addCVRQualifiers(mask: Base.getVRQualifiers());
1660
1661 return LV;
1662}
1663