1//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This provides an abstract class for HLSL code generation. Concrete
10// subclasses of this implement code generation for specific HLSL
11// runtime libraries.
12//
13//===----------------------------------------------------------------------===//
14
15#include "CGHLSLRuntime.h"
16#include "CGDebugInfo.h"
17#include "CGRecordLayout.h"
18#include "CodeGenFunction.h"
19#include "CodeGenModule.h"
20#include "HLSLBufferLayoutBuilder.h"
21#include "TargetInfo.h"
22#include "clang/AST/ASTContext.h"
23#include "clang/AST/Attr.h"
24#include "clang/AST/Decl.h"
25#include "clang/AST/Expr.h"
26#include "clang/AST/HLSLResource.h"
27#include "clang/AST/RecursiveASTVisitor.h"
28#include "clang/AST/Type.h"
29#include "clang/Basic/DiagnosticFrontend.h"
30#include "clang/Basic/TargetOptions.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/ScopeExit.h"
33#include "llvm/ADT/SmallString.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
36#include "llvm/IR/Constants.h"
37#include "llvm/IR/DerivedTypes.h"
38#include "llvm/IR/GlobalVariable.h"
39#include "llvm/IR/IntrinsicInst.h"
40#include "llvm/IR/LLVMContext.h"
41#include "llvm/IR/Metadata.h"
42#include "llvm/IR/Module.h"
43#include "llvm/IR/Type.h"
44#include "llvm/IR/Value.h"
45#include "llvm/Support/Alignment.h"
46#include "llvm/Support/ErrorHandling.h"
47#include "llvm/Support/FormatVariadic.h"
48#include <cstdint>
49#include <optional>
50
51using namespace clang;
52using namespace CodeGen;
53using namespace clang::hlsl;
54using namespace llvm;
55
56using llvm::hlsl::CBufferRowSizeInBytes;
57
58namespace {
59
60void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
61 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
62 // Assume ValVersionStr is legal here.
63 VersionTuple Version;
64 if (Version.tryParse(string: ValVersionStr) || Version.getBuild() ||
65 Version.getSubminor() || !Version.getMinor()) {
66 return;
67 }
68
69 uint64_t Major = Version.getMajor();
70 uint64_t Minor = *Version.getMinor();
71
72 auto &Ctx = M.getContext();
73 IRBuilder<> B(M.getContext());
74 MDNode *Val = MDNode::get(Context&: Ctx, MDs: {ConstantAsMetadata::get(C: B.getInt32(C: Major)),
75 ConstantAsMetadata::get(C: B.getInt32(C: Minor))});
76 StringRef DXILValKey = "dx.valver";
77 auto *DXILValMD = M.getOrInsertNamedMetadata(Name: DXILValKey);
78 DXILValMD->addOperand(M: Val);
79}
80
81void addRootSignatureMD(llvm::dxbc::RootSignatureVersion RootSigVer,
82 ArrayRef<llvm::hlsl::rootsig::RootElement> Elements,
83 llvm::Function *Fn, llvm::Module &M) {
84 auto &Ctx = M.getContext();
85
86 llvm::hlsl::rootsig::MetadataBuilder RSBuilder(Ctx, Elements);
87 MDNode *RootSignature = RSBuilder.BuildRootSignature();
88
89 ConstantAsMetadata *Version = ConstantAsMetadata::get(C: ConstantInt::get(
90 Ty: llvm::Type::getInt32Ty(C&: Ctx), V: llvm::to_underlying(E: RootSigVer)));
91 ValueAsMetadata *EntryFunc = Fn ? ValueAsMetadata::get(V: Fn) : nullptr;
92 MDNode *MDVals = MDNode::get(Context&: Ctx, MDs: {EntryFunc, RootSignature, Version});
93
94 StringRef RootSignatureValKey = "dx.rootsignatures";
95 auto *RootSignatureValMD = M.getOrInsertNamedMetadata(Name: RootSignatureValKey);
96 RootSignatureValMD->addOperand(M: MDVals);
97}
98
99// Find array variable declaration from DeclRef expression
100static const ValueDecl *getArrayDecl(const Expr *E) {
101 if (const DeclRefExpr *DRE =
102 dyn_cast_or_null<DeclRefExpr>(Val: E->IgnoreImpCasts()))
103 return DRE->getDecl();
104 return nullptr;
105}
106
107// Find array variable declaration from nested array subscript AST nodes
108static const ValueDecl *getArrayDecl(const ArraySubscriptExpr *ASE) {
109 const Expr *E = nullptr;
110 while (ASE != nullptr) {
111 E = ASE->getBase()->IgnoreImpCasts();
112 if (!E)
113 return nullptr;
114 ASE = dyn_cast<ArraySubscriptExpr>(Val: E);
115 }
116 return getArrayDecl(E);
117}
118
119// Get the total size of the array, or -1 if the array is unbounded.
120static int getTotalArraySize(ASTContext &AST, const clang::Type *Ty) {
121 Ty = Ty->getUnqualifiedDesugaredType();
122 assert(Ty->isArrayType() && "expected array type");
123 if (Ty->isIncompleteArrayType())
124 return -1;
125 return AST.getConstantArrayElementCount(CA: cast<ConstantArrayType>(Val: Ty));
126}
127
128static Value *buildNameForResource(llvm::StringRef BaseName,
129 CodeGenModule &CGM) {
130 llvm::SmallString<64> GlobalName = {BaseName, ".str"};
131 return CGM.GetAddrOfConstantCString(Str: BaseName.str(), GlobalName: GlobalName.c_str())
132 .getPointer();
133}
134
135static CXXMethodDecl *lookupMethod(CXXRecordDecl *Record, StringRef Name,
136 StorageClass SC = SC_None) {
137 for (auto *Method : Record->methods()) {
138 if (Method->getStorageClass() == SC && Method->getName() == Name)
139 return Method;
140 }
141 return nullptr;
142}
143
144static CXXMethodDecl *lookupResourceInitMethodAndSetupArgs(
145 CodeGenModule &CGM, CXXRecordDecl *ResourceDecl, llvm::Value *Range,
146 llvm::Value *Index, StringRef Name, ResourceBindingAttrs &Binding,
147 CallArgList &Args) {
148 assert(Binding.hasBinding() && "at least one binding attribute expected");
149
150 ASTContext &AST = CGM.getContext();
151 CXXMethodDecl *CreateMethod = nullptr;
152 Value *NameStr = buildNameForResource(BaseName: Name, CGM);
153 Value *Space = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSpace());
154
155 if (Binding.isExplicit()) {
156 // explicit binding
157 auto *RegSlot = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSlot());
158 Args.add(rvalue: RValue::get(V: RegSlot), type: AST.UnsignedIntTy);
159 const char *Name = Binding.hasCounterImplicitOrderID()
160 ? "__createFromBindingWithImplicitCounter"
161 : "__createFromBinding";
162 CreateMethod = lookupMethod(Record: ResourceDecl, Name, SC: SC_Static);
163 } else {
164 // implicit binding
165 auto *OrderID =
166 llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getImplicitOrderID());
167 Args.add(rvalue: RValue::get(V: OrderID), type: AST.UnsignedIntTy);
168 const char *Name = Binding.hasCounterImplicitOrderID()
169 ? "__createFromImplicitBindingWithImplicitCounter"
170 : "__createFromImplicitBinding";
171 CreateMethod = lookupMethod(Record: ResourceDecl, Name, SC: SC_Static);
172 }
173 Args.add(rvalue: RValue::get(V: Space), type: AST.UnsignedIntTy);
174 Args.add(rvalue: RValue::get(V: Range), type: AST.IntTy);
175 Args.add(rvalue: RValue::get(V: Index), type: AST.UnsignedIntTy);
176 Args.add(rvalue: RValue::get(V: NameStr), type: AST.getPointerType(T: AST.CharTy.withConst()));
177 if (Binding.hasCounterImplicitOrderID()) {
178 uint32_t CounterBinding = Binding.getCounterImplicitOrderID();
179 auto *CounterOrderID = llvm::ConstantInt::get(Ty: CGM.IntTy, V: CounterBinding);
180 Args.add(rvalue: RValue::get(V: CounterOrderID), type: AST.UnsignedIntTy);
181 }
182
183 return CreateMethod;
184}
185
186static void callResourceInitMethod(CodeGenFunction &CGF,
187 CXXMethodDecl *CreateMethod,
188 CallArgList &Args, Address ReturnAddress) {
189 llvm::Constant *CalleeFn = CGF.CGM.GetAddrOfFunction(GD: CreateMethod);
190 const FunctionProtoType *Proto =
191 CreateMethod->getType()->getAs<FunctionProtoType>();
192 const CGFunctionInfo &FnInfo =
193 CGF.CGM.getTypes().arrangeFreeFunctionCall(Args, Ty: Proto, ChainCall: false);
194 ReturnValueSlot ReturnValue(ReturnAddress, false);
195 CGCallee Callee(CGCalleeInfo(Proto), CalleeFn);
196 CGF.EmitCall(CallInfo: FnInfo, Callee, ReturnValue, Args, CallOrInvoke: nullptr);
197}
198
199// Initializes local resource array variable. For multi-dimensional arrays it
200// calls itself recursively to initialize its sub-arrays. The Index used in the
201// resource constructor calls will begin at StartIndex and will be incremented
202// for each array element. The last used resource Index is returned to the
203// caller. If the function returns std::nullopt, it indicates an error.
204static std::optional<llvm::Value *> initializeLocalResourceArray(
205 CodeGenFunction &CGF, CXXRecordDecl *ResourceDecl,
206 const ConstantArrayType *ArrayTy, AggValueSlot &ValueSlot,
207 llvm::Value *Range, llvm::Value *StartIndex, StringRef ResourceName,
208 ResourceBindingAttrs &Binding, ArrayRef<llvm::Value *> PrevGEPIndices,
209 SourceLocation ArraySubsExprLoc) {
210
211 ASTContext &AST = CGF.getContext();
212 llvm::IntegerType *IntTy = CGF.CGM.IntTy;
213 llvm::Value *Index = StartIndex;
214 llvm::Value *One = llvm::ConstantInt::get(Ty: IntTy, V: 1);
215 const uint64_t ArraySize = ArrayTy->getSExtSize();
216 QualType ElemType = ArrayTy->getElementType();
217 Address TmpArrayAddr = ValueSlot.getAddress();
218
219 // Add additional index to the getelementptr call indices.
220 // This index will be updated for each array element in the loops below.
221 SmallVector<llvm::Value *> GEPIndices(PrevGEPIndices);
222 GEPIndices.push_back(Elt: llvm::ConstantInt::get(Ty: IntTy, V: 0));
223
224 // For array of arrays, recursively initialize the sub-arrays.
225 if (ElemType->isArrayType()) {
226 const ConstantArrayType *SubArrayTy = cast<ConstantArrayType>(Val&: ElemType);
227 for (uint64_t I = 0; I < ArraySize; I++) {
228 if (I > 0) {
229 Index = CGF.Builder.CreateAdd(LHS: Index, RHS: One);
230 GEPIndices.back() = llvm::ConstantInt::get(Ty: IntTy, V: I);
231 }
232 std::optional<llvm::Value *> MaybeIndex = initializeLocalResourceArray(
233 CGF, ResourceDecl, ArrayTy: SubArrayTy, ValueSlot, Range, StartIndex: Index, ResourceName,
234 Binding, PrevGEPIndices: GEPIndices, ArraySubsExprLoc);
235 if (!MaybeIndex)
236 return std::nullopt;
237 Index = *MaybeIndex;
238 }
239 return Index;
240 }
241
242 // For array of resources, initialize each resource in the array.
243 llvm::Type *Ty = CGF.ConvertTypeForMem(T: ElemType);
244 CharUnits ElemSize = AST.getTypeSizeInChars(T: ElemType);
245 CharUnits Align =
246 TmpArrayAddr.getAlignment().alignmentOfArrayElement(elementSize: ElemSize);
247
248 for (uint64_t I = 0; I < ArraySize; I++) {
249 if (I > 0) {
250 Index = CGF.Builder.CreateAdd(LHS: Index, RHS: One);
251 GEPIndices.back() = llvm::ConstantInt::get(Ty: IntTy, V: I);
252 }
253 Address ReturnAddress =
254 CGF.Builder.CreateGEP(Addr: TmpArrayAddr, IdxList: GEPIndices, ElementType: Ty, Align);
255
256 CallArgList Args;
257 CXXMethodDecl *CreateMethod = lookupResourceInitMethodAndSetupArgs(
258 CGM&: CGF.CGM, ResourceDecl, Range, Index, Name: ResourceName, Binding, Args);
259
260 if (!CreateMethod)
261 // This can happen if someone creates an array of structs that looks like
262 // an HLSL resource record array but it does not have the required static
263 // create method. No binding will be generated for it.
264 return std::nullopt;
265
266 callResourceInitMethod(CGF, CreateMethod, Args, ReturnAddress);
267 }
268 return Index;
269}
270
271} // namespace
272
273llvm::Type *
274CGHLSLRuntime::convertHLSLSpecificType(const Type *T,
275 const CGHLSLOffsetInfo &OffsetInfo) {
276 assert(T->isHLSLSpecificType() && "Not an HLSL specific type!");
277
278 // Check if the target has a specific translation for this type first.
279 if (llvm::Type *TargetTy =
280 CGM.getTargetCodeGenInfo().getHLSLType(CGM, T, OffsetInfo))
281 return TargetTy;
282
283 llvm_unreachable("Generic handling of HLSL types is not supported.");
284}
285
286llvm::Triple::ArchType CGHLSLRuntime::getArch() {
287 return CGM.getTarget().getTriple().getArch();
288}
289
290// Emits constant global variables for buffer constants declarations
291// and creates metadata linking the constant globals with the buffer global.
292void CGHLSLRuntime::emitBufferGlobalsAndMetadata(
293 const HLSLBufferDecl *BufDecl, llvm::GlobalVariable *BufGV,
294 const CGHLSLOffsetInfo &OffsetInfo) {
295 LLVMContext &Ctx = CGM.getLLVMContext();
296
297 // get the layout struct from constant buffer target type
298 llvm::Type *BufType = BufGV->getValueType();
299 llvm::StructType *LayoutStruct = cast<llvm::StructType>(
300 Val: cast<llvm::TargetExtType>(Val: BufType)->getTypeParameter(i: 0));
301
302 SmallVector<std::pair<VarDecl *, uint32_t>> DeclsWithOffset;
303 size_t OffsetIdx = 0;
304 for (Decl *D : BufDecl->buffer_decls()) {
305 if (isa<CXXRecordDecl, EmptyDecl>(Val: D))
306 // Nothing to do for this declaration.
307 continue;
308 if (isa<FunctionDecl>(Val: D)) {
309 // A function within an cbuffer is effectively a top-level function.
310 CGM.EmitTopLevelDecl(D);
311 continue;
312 }
313 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
314 if (!VD)
315 continue;
316
317 QualType VDTy = VD->getType();
318 if (VDTy.getAddressSpace() != LangAS::hlsl_constant) {
319 if (VD->getStorageClass() == SC_Static ||
320 VDTy.getAddressSpace() == LangAS::hlsl_groupshared ||
321 VDTy->isHLSLResourceRecord() || VDTy->isHLSLResourceRecordArray()) {
322 // Emit static and groupshared variables and resource classes inside
323 // cbuffer as regular globals
324 CGM.EmitGlobal(D: VD);
325 } else {
326 // Anything else that is not in the hlsl_constant address space must be
327 // an empty struct or a zero-sized array and can be ignored
328 assert(BufDecl->getASTContext().getTypeSize(VDTy) == 0 &&
329 "constant buffer decl with non-zero sized type outside of "
330 "hlsl_constant address space");
331 }
332 continue;
333 }
334
335 DeclsWithOffset.emplace_back(Args&: VD, Args: OffsetInfo[OffsetIdx++]);
336 }
337
338 if (!OffsetInfo.empty())
339 llvm::stable_sort(Range&: DeclsWithOffset, C: [](const auto &LHS, const auto &RHS) {
340 return CGHLSLOffsetInfo::compareOffsets(LHS: LHS.second, RHS: RHS.second);
341 });
342
343 // Associate the buffer global variable with its constants
344 SmallVector<llvm::Metadata *> BufGlobals;
345 BufGlobals.reserve(N: DeclsWithOffset.size() + 1);
346 BufGlobals.push_back(Elt: ValueAsMetadata::get(V: BufGV));
347
348 auto ElemIt = LayoutStruct->element_begin();
349 for (auto &[VD, _] : DeclsWithOffset) {
350 if (CGM.getTargetCodeGenInfo().isHLSLPadding(Ty: *ElemIt))
351 ++ElemIt;
352
353 assert(ElemIt != LayoutStruct->element_end() &&
354 "number of elements in layout struct does not match");
355 llvm::Type *LayoutType = *ElemIt++;
356
357 GlobalVariable *ElemGV =
358 cast<GlobalVariable>(Val: CGM.GetAddrOfGlobalVar(D: VD, Ty: LayoutType));
359 BufGlobals.push_back(Elt: ValueAsMetadata::get(V: ElemGV));
360 }
361 assert(ElemIt == LayoutStruct->element_end() &&
362 "number of elements in layout struct does not match");
363
364 // add buffer metadata to the module
365 CGM.getModule()
366 .getOrInsertNamedMetadata(Name: "hlsl.cbs")
367 ->addOperand(M: MDNode::get(Context&: Ctx, MDs: BufGlobals));
368}
369
370// Creates resource handle type for the HLSL buffer declaration
371static const clang::HLSLAttributedResourceType *
372createBufferHandleType(const HLSLBufferDecl *BufDecl) {
373 ASTContext &AST = BufDecl->getASTContext();
374 QualType QT = AST.getHLSLAttributedResourceType(
375 Wrapped: AST.HLSLResourceTy, Contained: AST.getCanonicalTagType(TD: BufDecl->getLayoutStruct()),
376 Attrs: HLSLAttributedResourceType::Attributes(ResourceClass::CBuffer));
377 return cast<HLSLAttributedResourceType>(Val: QT.getTypePtr());
378}
379
380CGHLSLOffsetInfo CGHLSLOffsetInfo::fromDecl(const HLSLBufferDecl &BufDecl) {
381 CGHLSLOffsetInfo Result;
382
383 // If we don't have packoffset info, just return an empty result.
384 if (!BufDecl.hasValidPackoffset())
385 return Result;
386
387 for (Decl *D : BufDecl.buffer_decls()) {
388 if (isa<CXXRecordDecl, EmptyDecl>(Val: D) || isa<FunctionDecl>(Val: D)) {
389 continue;
390 }
391 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
392 if (!VD || VD->getType().getAddressSpace() != LangAS::hlsl_constant)
393 continue;
394
395 if (!VD->hasAttrs()) {
396 Result.Offsets.push_back(Elt: Unspecified);
397 continue;
398 }
399
400 uint32_t Offset = Unspecified;
401 for (auto *Attr : VD->getAttrs()) {
402 if (auto *POA = dyn_cast<HLSLPackOffsetAttr>(Val: Attr)) {
403 Offset = POA->getOffsetInBytes();
404 break;
405 }
406 auto *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: Attr);
407 if (RBA &&
408 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
409 Offset = RBA->getSlotNumber() * CBufferRowSizeInBytes;
410 break;
411 }
412 }
413 Result.Offsets.push_back(Elt: Offset);
414 }
415 return Result;
416}
417
418// Codegen for HLSLBufferDecl
419void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *BufDecl) {
420
421 assert(BufDecl->isCBuffer() && "tbuffer codegen is not supported yet");
422
423 // create resource handle type for the buffer
424 const clang::HLSLAttributedResourceType *ResHandleTy =
425 createBufferHandleType(BufDecl);
426
427 // empty constant buffer is ignored
428 if (ResHandleTy->getContainedType()->getAsCXXRecordDecl()->isEmpty())
429 return;
430
431 // create global variable for the constant buffer
432 CGHLSLOffsetInfo OffsetInfo = CGHLSLOffsetInfo::fromDecl(BufDecl: *BufDecl);
433 llvm::Type *LayoutTy = convertHLSLSpecificType(T: ResHandleTy, OffsetInfo);
434 llvm::GlobalVariable *BufGV = new GlobalVariable(
435 LayoutTy, /*isConstant*/ false,
436 GlobalValue::LinkageTypes::ExternalLinkage, PoisonValue::get(T: LayoutTy),
437 llvm::formatv(Fmt: "{0}{1}", Vals: BufDecl->getName(),
438 Vals: BufDecl->isCBuffer() ? ".cb" : ".tb"),
439 GlobalValue::NotThreadLocal);
440 CGM.getModule().insertGlobalVariable(GV: BufGV);
441
442 // Add globals for constant buffer elements and create metadata nodes
443 emitBufferGlobalsAndMetadata(BufDecl, BufGV, OffsetInfo);
444
445 // Initialize cbuffer from binding (implicit or explicit)
446 initializeBufferFromBinding(BufDecl, GV: BufGV);
447}
448
449void CGHLSLRuntime::addRootSignature(
450 const HLSLRootSignatureDecl *SignatureDecl) {
451 llvm::Module &M = CGM.getModule();
452 Triple T(M.getTargetTriple());
453
454 // Generated later with the function decl if not targeting root signature
455 if (T.getEnvironment() != Triple::EnvironmentType::RootSignature)
456 return;
457
458 addRootSignatureMD(RootSigVer: SignatureDecl->getVersion(),
459 Elements: SignatureDecl->getRootElements(), Fn: nullptr, M);
460}
461
462llvm::StructType *
463CGHLSLRuntime::getHLSLBufferLayoutType(const RecordType *StructType) {
464 const auto Entry = LayoutTypes.find(Val: StructType);
465 if (Entry != LayoutTypes.end())
466 return Entry->getSecond();
467 return nullptr;
468}
469
470void CGHLSLRuntime::addHLSLBufferLayoutType(const RecordType *StructType,
471 llvm::StructType *LayoutTy) {
472 assert(getHLSLBufferLayoutType(StructType) == nullptr &&
473 "layout type for this struct already exist");
474 LayoutTypes[StructType] = LayoutTy;
475}
476
477void CGHLSLRuntime::finishCodeGen() {
478 auto &TargetOpts = CGM.getTarget().getTargetOpts();
479 auto &CodeGenOpts = CGM.getCodeGenOpts();
480 auto &LangOpts = CGM.getLangOpts();
481 llvm::Module &M = CGM.getModule();
482 Triple T(M.getTargetTriple());
483 if (T.getArch() == Triple::ArchType::dxil)
484 addDxilValVersion(ValVersionStr: TargetOpts.DxilValidatorVersion, M);
485 if (CodeGenOpts.ResMayAlias)
486 M.setModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Error, Key: "dx.resmayalias", Val: 1);
487 if (CodeGenOpts.AllResourcesBound)
488 M.setModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Error,
489 Key: "dx.allresourcesbound", Val: 1);
490 // NativeHalfType corresponds to the -fnative-half-type clang option which is
491 // aliased by clang-dxc's -enable-16bit-types option. This option is used to
492 // set the UseNativeLowPrecision DXIL module flag in the DirectX backend
493 if (LangOpts.NativeHalfType)
494 M.setModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Error, Key: "dx.nativelowprec",
495 Val: 1);
496
497 generateGlobalCtorDtorCalls();
498}
499
500void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
501 const FunctionDecl *FD, llvm::Function *Fn) {
502 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
503 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
504 const StringRef ShaderAttrKindStr = "hlsl.shader";
505 Fn->addFnAttr(Kind: ShaderAttrKindStr,
506 Val: llvm::Triple::getEnvironmentTypeName(Kind: ShaderAttr->getType()));
507 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
508 const StringRef NumThreadsKindStr = "hlsl.numthreads";
509 std::string NumThreadsStr =
510 formatv(Fmt: "{0},{1},{2}", Vals: NumThreadsAttr->getX(), Vals: NumThreadsAttr->getY(),
511 Vals: NumThreadsAttr->getZ());
512 Fn->addFnAttr(Kind: NumThreadsKindStr, Val: NumThreadsStr);
513 }
514 if (HLSLWaveSizeAttr *WaveSizeAttr = FD->getAttr<HLSLWaveSizeAttr>()) {
515 const StringRef WaveSizeKindStr = "hlsl.wavesize";
516 std::string WaveSizeStr =
517 formatv(Fmt: "{0},{1},{2}", Vals: WaveSizeAttr->getMin(), Vals: WaveSizeAttr->getMax(),
518 Vals: WaveSizeAttr->getPreferred());
519 Fn->addFnAttr(Kind: WaveSizeKindStr, Val: WaveSizeStr);
520 }
521 // HLSL entry functions are materialized for module functions with
522 // HLSLShaderAttr attribute. SetLLVMFunctionAttributesForDefinition called
523 // later in the compiler-flow for such module functions is not aware of and
524 // hence not able to set attributes of the newly materialized entry functions.
525 // So, set attributes of entry function here, as appropriate.
526 if (CGM.getCodeGenOpts().OptimizationLevel == 0)
527 Fn->addFnAttr(Kind: llvm::Attribute::OptimizeNone);
528 Fn->addFnAttr(Kind: llvm::Attribute::NoInline);
529
530 if (CGM.getLangOpts().HLSLSpvEnableMaximalReconvergence) {
531 Fn->addFnAttr(Kind: "enable-maximal-reconvergence", Val: "true");
532 }
533}
534
535static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
536 if (const auto *VT = dyn_cast<FixedVectorType>(Val: Ty)) {
537 Value *Result = PoisonValue::get(T: Ty);
538 for (unsigned I = 0; I < VT->getNumElements(); ++I) {
539 Value *Elt = B.CreateCall(Callee: F, Args: {B.getInt32(C: I)});
540 Result = B.CreateInsertElement(Vec: Result, NewElt: Elt, Idx: I);
541 }
542 return Result;
543 }
544 return B.CreateCall(Callee: F, Args: {B.getInt32(C: 0)});
545}
546
547static void addSPIRVBuiltinDecoration(llvm::GlobalVariable *GV,
548 unsigned BuiltIn) {
549 LLVMContext &Ctx = GV->getContext();
550 IRBuilder<> B(GV->getContext());
551 MDNode *Operands = MDNode::get(
552 Context&: Ctx,
553 MDs: {ConstantAsMetadata::get(C: B.getInt32(/* Spirv::Decoration::BuiltIn */ C: 11)),
554 ConstantAsMetadata::get(C: B.getInt32(C: BuiltIn))});
555 MDNode *Decoration = MDNode::get(Context&: Ctx, MDs: {Operands});
556 GV->addMetadata(Kind: "spirv.Decorations", MD&: *Decoration);
557}
558
559static void addLocationDecoration(llvm::GlobalVariable *GV, unsigned Location) {
560 LLVMContext &Ctx = GV->getContext();
561 IRBuilder<> B(GV->getContext());
562 MDNode *Operands =
563 MDNode::get(Context&: Ctx, MDs: {ConstantAsMetadata::get(C: B.getInt32(/* Location */ C: 30)),
564 ConstantAsMetadata::get(C: B.getInt32(C: Location))});
565 MDNode *Decoration = MDNode::get(Context&: Ctx, MDs: {Operands});
566 GV->addMetadata(Kind: "spirv.Decorations", MD&: *Decoration);
567}
568
569static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
570 llvm::Type *Ty, const Twine &Name,
571 unsigned BuiltInID) {
572 auto *GV = new llvm::GlobalVariable(
573 M, Ty, /* isConstant= */ true, llvm::GlobalValue::ExternalLinkage,
574 /* Initializer= */ nullptr, Name, /* insertBefore= */ nullptr,
575 llvm::GlobalVariable::GeneralDynamicTLSModel,
576 /* AddressSpace */ 7, /* isExternallyInitialized= */ true);
577 addSPIRVBuiltinDecoration(GV, BuiltIn: BuiltInID);
578 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
579 return B.CreateLoad(Ty, Ptr: GV);
580}
581
582static llvm::Value *createSPIRVLocationLoad(IRBuilder<> &B, llvm::Module &M,
583 llvm::Type *Ty, unsigned Location,
584 StringRef Name) {
585 auto *GV = new llvm::GlobalVariable(
586 M, Ty, /* isConstant= */ true, llvm::GlobalValue::ExternalLinkage,
587 /* Initializer= */ nullptr, /* Name= */ Name, /* insertBefore= */ nullptr,
588 llvm::GlobalVariable::GeneralDynamicTLSModel,
589 /* AddressSpace */ 7, /* isExternallyInitialized= */ true);
590 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
591 addLocationDecoration(GV, Location);
592 return B.CreateLoad(Ty, Ptr: GV);
593}
594
595llvm::Value *CGHLSLRuntime::emitSPIRVUserSemanticLoad(
596 llvm::IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
597 HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
598 Twine BaseName = Twine(Semantic->getAttrName()->getName());
599 Twine VariableName = BaseName.concat(Suffix: Twine(Index.value_or(u: 0)));
600
601 unsigned Location = SPIRVLastAssignedInputSemanticLocation;
602 if (auto *L = Decl->getAttr<HLSLVkLocationAttr>())
603 Location = L->getLocation();
604
605 // DXC completely ignores the semantic/index pair. Location are assigned from
606 // the first semantic to the last.
607 llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Val: Type);
608 unsigned ElementCount = AT ? AT->getNumElements() : 1;
609 SPIRVLastAssignedInputSemanticLocation += ElementCount;
610
611 return createSPIRVLocationLoad(B, M&: CGM.getModule(), Ty: Type, Location,
612 Name: VariableName.str());
613}
614
615static void createSPIRVLocationStore(IRBuilder<> &B, llvm::Module &M,
616 llvm::Value *Source, unsigned Location,
617 StringRef Name) {
618 auto *GV = new llvm::GlobalVariable(
619 M, Source->getType(), /* isConstant= */ false,
620 llvm::GlobalValue::ExternalLinkage,
621 /* Initializer= */ nullptr, /* Name= */ Name, /* insertBefore= */ nullptr,
622 llvm::GlobalVariable::GeneralDynamicTLSModel,
623 /* AddressSpace */ 8, /* isExternallyInitialized= */ false);
624 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
625 addLocationDecoration(GV, Location);
626 B.CreateStore(Val: Source, Ptr: GV);
627}
628
629void CGHLSLRuntime::emitSPIRVUserSemanticStore(
630 llvm::IRBuilder<> &B, llvm::Value *Source,
631 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic,
632 std::optional<unsigned> Index) {
633 Twine BaseName = Twine(Semantic->getAttrName()->getName());
634 Twine VariableName = BaseName.concat(Suffix: Twine(Index.value_or(u: 0)));
635
636 unsigned Location = SPIRVLastAssignedOutputSemanticLocation;
637 if (auto *L = Decl->getAttr<HLSLVkLocationAttr>())
638 Location = L->getLocation();
639
640 // DXC completely ignores the semantic/index pair. Location are assigned from
641 // the first semantic to the last.
642 llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Val: Source->getType());
643 unsigned ElementCount = AT ? AT->getNumElements() : 1;
644 SPIRVLastAssignedOutputSemanticLocation += ElementCount;
645 createSPIRVLocationStore(B, M&: CGM.getModule(), Source, Location,
646 Name: VariableName.str());
647}
648
649llvm::Value *
650CGHLSLRuntime::emitDXILUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
651 HLSLAppliedSemanticAttr *Semantic,
652 std::optional<unsigned> Index) {
653 Twine BaseName = Twine(Semantic->getAttrName()->getName());
654 Twine VariableName = BaseName.concat(Suffix: Twine(Index.value_or(u: 0)));
655
656 // DXIL packing rules etc shall be handled here.
657 // FIXME: generate proper sigpoint, index, col, row values.
658 // FIXME: also DXIL loads vectors element by element.
659 SmallVector<Value *> Args{B.getInt32(C: 4), B.getInt32(C: 0), B.getInt32(C: 0),
660 B.getInt8(C: 0),
661 llvm::PoisonValue::get(T: B.getInt32Ty())};
662
663 llvm::Intrinsic::ID IntrinsicID = llvm::Intrinsic::dx_load_input;
664 llvm::Value *Value = B.CreateIntrinsic(/*ReturnType=*/RetTy: Type, ID: IntrinsicID, Args,
665 FMFSource: nullptr, Name: VariableName);
666 return Value;
667}
668
669void CGHLSLRuntime::emitDXILUserSemanticStore(llvm::IRBuilder<> &B,
670 llvm::Value *Source,
671 HLSLAppliedSemanticAttr *Semantic,
672 std::optional<unsigned> Index) {
673 // DXIL packing rules etc shall be handled here.
674 // FIXME: generate proper sigpoint, index, col, row values.
675 SmallVector<Value *> Args{B.getInt32(C: 4),
676 B.getInt32(C: 0),
677 B.getInt32(C: 0),
678 B.getInt8(C: 0),
679 llvm::PoisonValue::get(T: B.getInt32Ty()),
680 Source};
681
682 llvm::Intrinsic::ID IntrinsicID = llvm::Intrinsic::dx_store_output;
683 B.CreateIntrinsic(/*ReturnType=*/RetTy: CGM.VoidTy, ID: IntrinsicID, Args, FMFSource: nullptr);
684}
685
686llvm::Value *CGHLSLRuntime::emitUserSemanticLoad(
687 IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
688 HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
689 if (CGM.getTarget().getTriple().isSPIRV())
690 return emitSPIRVUserSemanticLoad(B, Type, Decl, Semantic, Index);
691
692 if (CGM.getTarget().getTriple().isDXIL())
693 return emitDXILUserSemanticLoad(B, Type, Semantic, Index);
694
695 llvm_unreachable("Unsupported target for user-semantic load.");
696}
697
698void CGHLSLRuntime::emitUserSemanticStore(IRBuilder<> &B, llvm::Value *Source,
699 const clang::DeclaratorDecl *Decl,
700 HLSLAppliedSemanticAttr *Semantic,
701 std::optional<unsigned> Index) {
702 if (CGM.getTarget().getTriple().isSPIRV())
703 return emitSPIRVUserSemanticStore(B, Source, Decl, Semantic, Index);
704
705 if (CGM.getTarget().getTriple().isDXIL())
706 return emitDXILUserSemanticStore(B, Source, Semantic, Index);
707
708 llvm_unreachable("Unsupported target for user-semantic load.");
709}
710
711llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
712 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
713 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic,
714 std::optional<unsigned> Index) {
715
716 std::string SemanticName = Semantic->getAttrName()->getName().upper();
717 if (SemanticName == "SV_GROUPINDEX") {
718 llvm::Function *GroupIndex =
719 CGM.getIntrinsic(IID: getFlattenedThreadIdInGroupIntrinsic());
720 return B.CreateCall(Callee: FunctionCallee(GroupIndex));
721 }
722
723 if (SemanticName == "SV_DISPATCHTHREADID") {
724 llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic();
725 llvm::Function *ThreadIDIntrinsic =
726 llvm::Intrinsic::isOverloaded(id: IntrinID)
727 ? CGM.getIntrinsic(IID: IntrinID, Tys: {CGM.Int32Ty})
728 : CGM.getIntrinsic(IID: IntrinID);
729 return buildVectorInput(B, F: ThreadIDIntrinsic, Ty: Type);
730 }
731
732 if (SemanticName == "SV_GROUPTHREADID") {
733 llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic();
734 llvm::Function *GroupThreadIDIntrinsic =
735 llvm::Intrinsic::isOverloaded(id: IntrinID)
736 ? CGM.getIntrinsic(IID: IntrinID, Tys: {CGM.Int32Ty})
737 : CGM.getIntrinsic(IID: IntrinID);
738 return buildVectorInput(B, F: GroupThreadIDIntrinsic, Ty: Type);
739 }
740
741 if (SemanticName == "SV_GROUPID") {
742 llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic();
743 llvm::Function *GroupIDIntrinsic =
744 llvm::Intrinsic::isOverloaded(id: IntrinID)
745 ? CGM.getIntrinsic(IID: IntrinID, Tys: {CGM.Int32Ty})
746 : CGM.getIntrinsic(IID: IntrinID);
747 return buildVectorInput(B, F: GroupIDIntrinsic, Ty: Type);
748 }
749
750 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
751 assert(ShaderAttr && "Entry point has no shader attribute");
752 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
753
754 if (SemanticName == "SV_POSITION") {
755 if (ST == Triple::EnvironmentType::Pixel) {
756 if (CGM.getTarget().getTriple().isSPIRV())
757 return createSPIRVBuiltinLoad(B, M&: CGM.getModule(), Ty: Type,
758 Name: Semantic->getAttrName()->getName(),
759 /* BuiltIn::FragCoord */ BuiltInID: 15);
760 if (CGM.getTarget().getTriple().isDXIL())
761 return emitDXILUserSemanticLoad(B, Type, Semantic, Index);
762 }
763
764 if (ST == Triple::EnvironmentType::Vertex) {
765 return emitUserSemanticLoad(B, Type, Decl, Semantic, Index);
766 }
767 }
768
769 llvm_unreachable(
770 "Load hasn't been implemented yet for this system semantic. FIXME");
771}
772
773static void createSPIRVBuiltinStore(IRBuilder<> &B, llvm::Module &M,
774 llvm::Value *Source, const Twine &Name,
775 unsigned BuiltInID) {
776 auto *GV = new llvm::GlobalVariable(
777 M, Source->getType(), /* isConstant= */ false,
778 llvm::GlobalValue::ExternalLinkage,
779 /* Initializer= */ nullptr, Name, /* insertBefore= */ nullptr,
780 llvm::GlobalVariable::GeneralDynamicTLSModel,
781 /* AddressSpace */ 8, /* isExternallyInitialized= */ false);
782 addSPIRVBuiltinDecoration(GV, BuiltIn: BuiltInID);
783 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
784 B.CreateStore(Val: Source, Ptr: GV);
785}
786
787void CGHLSLRuntime::emitSystemSemanticStore(IRBuilder<> &B, llvm::Value *Source,
788 const clang::DeclaratorDecl *Decl,
789 HLSLAppliedSemanticAttr *Semantic,
790 std::optional<unsigned> Index) {
791
792 std::string SemanticName = Semantic->getAttrName()->getName().upper();
793 if (SemanticName == "SV_POSITION") {
794 if (CGM.getTarget().getTriple().isDXIL()) {
795 emitDXILUserSemanticStore(B, Source, Semantic, Index);
796 return;
797 }
798
799 if (CGM.getTarget().getTriple().isSPIRV()) {
800 createSPIRVBuiltinStore(B, M&: CGM.getModule(), Source,
801 Name: Semantic->getAttrName()->getName(),
802 /* BuiltIn::Position */ BuiltInID: 0);
803 return;
804 }
805 }
806
807 if (SemanticName == "SV_TARGET") {
808 emitUserSemanticStore(B, Source, Decl, Semantic, Index);
809 return;
810 }
811
812 llvm_unreachable(
813 "Store hasn't been implemented yet for this system semantic. FIXME");
814}
815
816llvm::Value *CGHLSLRuntime::handleScalarSemanticLoad(
817 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
818 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic) {
819
820 std::optional<unsigned> Index = Semantic->getSemanticIndex();
821 if (Semantic->getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
822 return emitSystemSemanticLoad(B, FD, Type, Decl, Semantic, Index);
823 return emitUserSemanticLoad(B, Type, Decl, Semantic, Index);
824}
825
826void CGHLSLRuntime::handleScalarSemanticStore(
827 IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
828 const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic) {
829 std::optional<unsigned> Index = Semantic->getSemanticIndex();
830 if (Semantic->getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
831 emitSystemSemanticStore(B, Source, Decl, Semantic, Index);
832 else
833 emitUserSemanticStore(B, Source, Decl, Semantic, Index);
834}
835
836std::pair<llvm::Value *, specific_attr_iterator<HLSLAppliedSemanticAttr>>
837CGHLSLRuntime::handleStructSemanticLoad(
838 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
839 const clang::DeclaratorDecl *Decl,
840 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
841 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
842 const llvm::StructType *ST = cast<StructType>(Val: Type);
843 const clang::RecordDecl *RD = Decl->getType()->getAsRecordDecl();
844
845 assert(RD->getNumFields() == ST->getNumElements());
846
847 llvm::Value *Aggregate = llvm::PoisonValue::get(T: Type);
848 auto FieldDecl = RD->field_begin();
849 for (unsigned I = 0; I < ST->getNumElements(); ++I) {
850 auto [ChildValue, NextAttr] = handleSemanticLoad(
851 B, FD, Type: ST->getElementType(N: I), Decl: *FieldDecl, begin: AttrBegin, end: AttrEnd);
852 AttrBegin = NextAttr;
853 assert(ChildValue);
854 Aggregate = B.CreateInsertValue(Agg: Aggregate, Val: ChildValue, Idxs: I);
855 ++FieldDecl;
856 }
857
858 return std::make_pair(x&: Aggregate, y&: AttrBegin);
859}
860
861specific_attr_iterator<HLSLAppliedSemanticAttr>
862CGHLSLRuntime::handleStructSemanticStore(
863 IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
864 const clang::DeclaratorDecl *Decl,
865 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
866 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
867
868 const llvm::StructType *ST = cast<StructType>(Val: Source->getType());
869
870 const clang::RecordDecl *RD = nullptr;
871 if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(Val: Decl))
872 RD = FD->getDeclaredReturnType()->getAsRecordDecl();
873 else
874 RD = Decl->getType()->getAsRecordDecl();
875 assert(RD);
876
877 assert(RD->getNumFields() == ST->getNumElements());
878
879 auto FieldDecl = RD->field_begin();
880 for (unsigned I = 0; I < ST->getNumElements(); ++I, ++FieldDecl) {
881 llvm::Value *Extract = B.CreateExtractValue(Agg: Source, Idxs: I);
882 AttrBegin =
883 handleSemanticStore(B, FD, Source: Extract, Decl: *FieldDecl, AttrBegin, AttrEnd);
884 }
885
886 return AttrBegin;
887}
888
889std::pair<llvm::Value *, specific_attr_iterator<HLSLAppliedSemanticAttr>>
890CGHLSLRuntime::handleSemanticLoad(
891 IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
892 const clang::DeclaratorDecl *Decl,
893 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
894 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
895 assert(AttrBegin != AttrEnd);
896 if (Type->isStructTy())
897 return handleStructSemanticLoad(B, FD, Type, Decl, AttrBegin, AttrEnd);
898
899 HLSLAppliedSemanticAttr *Attr = *AttrBegin;
900 ++AttrBegin;
901 return std::make_pair(x: handleScalarSemanticLoad(B, FD, Type, Decl, Semantic: Attr),
902 y&: AttrBegin);
903}
904
905specific_attr_iterator<HLSLAppliedSemanticAttr>
906CGHLSLRuntime::handleSemanticStore(
907 IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
908 const clang::DeclaratorDecl *Decl,
909 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
910 specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
911 assert(AttrBegin != AttrEnd);
912 if (Source->getType()->isStructTy())
913 return handleStructSemanticStore(B, FD, Source, Decl, AttrBegin, AttrEnd);
914
915 HLSLAppliedSemanticAttr *Attr = *AttrBegin;
916 ++AttrBegin;
917 handleScalarSemanticStore(B, FD, Source, Decl, Semantic: Attr);
918 return AttrBegin;
919}
920
921void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
922 llvm::Function *Fn) {
923 llvm::Module &M = CGM.getModule();
924 llvm::LLVMContext &Ctx = M.getContext();
925 auto *EntryTy = llvm::FunctionType::get(Result: llvm::Type::getVoidTy(C&: Ctx), isVarArg: false);
926 Function *EntryFn =
927 Function::Create(Ty: EntryTy, Linkage: Function::ExternalLinkage, N: FD->getName(), M: &M);
928
929 // Copy function attributes over, we have no argument or return attributes
930 // that can be valid on the real entry.
931 AttributeList NewAttrs = AttributeList::get(C&: Ctx, Index: AttributeList::FunctionIndex,
932 Attrs: Fn->getAttributes().getFnAttrs());
933 EntryFn->setAttributes(NewAttrs);
934 setHLSLEntryAttributes(FD, Fn: EntryFn);
935
936 // Set the called function as internal linkage.
937 Fn->setLinkage(GlobalValue::InternalLinkage);
938
939 BasicBlock *BB = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: EntryFn);
940 IRBuilder<> B(BB);
941 llvm::SmallVector<Value *> Args;
942
943 SmallVector<OperandBundleDef, 1> OB;
944 if (CGM.shouldEmitConvergenceTokens()) {
945 assert(EntryFn->isConvergent());
946 llvm::Value *I =
947 B.CreateIntrinsic(ID: llvm::Intrinsic::experimental_convergence_entry, Args: {});
948 llvm::Value *bundleArgs[] = {I};
949 OB.emplace_back(Args: "convergencectrl", Args&: bundleArgs);
950 }
951
952 llvm::DenseMap<const DeclaratorDecl *, llvm::Value *> OutputSemantic;
953
954 unsigned SRetOffset = 0;
955 for (const auto &Param : Fn->args()) {
956 if (Param.hasStructRetAttr()) {
957 SRetOffset = 1;
958 llvm::Type *VarType = Param.getParamStructRetType();
959 llvm::Value *Var = B.CreateAlloca(Ty: VarType);
960 OutputSemantic.try_emplace(Key: FD, Args&: Var);
961 Args.push_back(Elt: Var);
962 continue;
963 }
964
965 const ParmVarDecl *PD = FD->getParamDecl(i: Param.getArgNo() - SRetOffset);
966 llvm::Value *SemanticValue = nullptr;
967 // FIXME: support inout/out parameters for semantics.
968 if ([[maybe_unused]] HLSLParamModifierAttr *MA =
969 PD->getAttr<HLSLParamModifierAttr>()) {
970 llvm_unreachable("Not handled yet");
971 } else {
972 llvm::Type *ParamType =
973 Param.hasByValAttr() ? Param.getParamByValType() : Param.getType();
974 auto AttrBegin = PD->specific_attr_begin<HLSLAppliedSemanticAttr>();
975 auto AttrEnd = PD->specific_attr_end<HLSLAppliedSemanticAttr>();
976 auto Result =
977 handleSemanticLoad(B, FD, Type: ParamType, Decl: PD, AttrBegin, AttrEnd);
978 SemanticValue = Result.first;
979 if (!SemanticValue)
980 return;
981 if (Param.hasByValAttr()) {
982 llvm::Value *Var = B.CreateAlloca(Ty: Param.getParamByValType());
983 B.CreateStore(Val: SemanticValue, Ptr: Var);
984 SemanticValue = Var;
985 }
986 }
987
988 assert(SemanticValue);
989 Args.push_back(Elt: SemanticValue);
990 }
991
992 CallInst *CI = B.CreateCall(Callee: FunctionCallee(Fn), Args, OpBundles: OB);
993 CI->setCallingConv(Fn->getCallingConv());
994
995 if (Fn->getReturnType() != CGM.VoidTy)
996 OutputSemantic.try_emplace(Key: FD, Args&: CI);
997
998 for (auto &[Decl, Source] : OutputSemantic) {
999 AllocaInst *AI = dyn_cast<AllocaInst>(Val: Source);
1000 llvm::Value *SourceValue =
1001 AI ? B.CreateLoad(Ty: AI->getAllocatedType(), Ptr: Source) : Source;
1002
1003 auto AttrBegin = Decl->specific_attr_begin<HLSLAppliedSemanticAttr>();
1004 auto AttrEnd = Decl->specific_attr_end<HLSLAppliedSemanticAttr>();
1005 handleSemanticStore(B, FD, Source: SourceValue, Decl, AttrBegin, AttrEnd);
1006 }
1007
1008 B.CreateRetVoid();
1009
1010 // Add and identify root signature to function, if applicable
1011 for (const Attr *Attr : FD->getAttrs()) {
1012 if (const auto *RSAttr = dyn_cast<RootSignatureAttr>(Val: Attr)) {
1013 auto *RSDecl = RSAttr->getSignatureDecl();
1014 addRootSignatureMD(RootSigVer: RSDecl->getVersion(), Elements: RSDecl->getRootElements(),
1015 Fn: EntryFn, M);
1016 }
1017 }
1018}
1019
1020static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
1021 bool CtorOrDtor) {
1022 const auto *GV =
1023 M.getNamedGlobal(Name: CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
1024 if (!GV)
1025 return;
1026 const auto *CA = dyn_cast<ConstantArray>(Val: GV->getInitializer());
1027 if (!CA)
1028 return;
1029 // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
1030 // HLSL neither supports priorities or COMDat values, so we will check those
1031 // in an assert but not handle them.
1032
1033 for (const auto &Ctor : CA->operands()) {
1034 if (isa<ConstantAggregateZero>(Val: Ctor))
1035 continue;
1036 ConstantStruct *CS = cast<ConstantStruct>(Val: Ctor);
1037
1038 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
1039 "HLSL doesn't support setting priority for global ctors.");
1040 assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
1041 "HLSL doesn't support COMDat for global ctors.");
1042 Fns.push_back(Elt: cast<Function>(Val: CS->getOperand(i_nocapture: 1)));
1043 }
1044}
1045
1046void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
1047 llvm::Module &M = CGM.getModule();
1048 SmallVector<Function *> CtorFns;
1049 SmallVector<Function *> DtorFns;
1050 gatherFunctions(Fns&: CtorFns, M, CtorOrDtor: true);
1051 gatherFunctions(Fns&: DtorFns, M, CtorOrDtor: false);
1052
1053 // Insert a call to the global constructor at the beginning of the entry block
1054 // to externally exported functions. This is a bit of a hack, but HLSL allows
1055 // global constructors, but doesn't support driver initialization of globals.
1056 for (auto &F : M.functions()) {
1057 if (!F.hasFnAttribute(Kind: "hlsl.shader"))
1058 continue;
1059 auto *Token = getConvergenceToken(BB&: F.getEntryBlock());
1060 Instruction *IP = &*F.getEntryBlock().begin();
1061 SmallVector<OperandBundleDef, 1> OB;
1062 if (Token) {
1063 llvm::Value *bundleArgs[] = {Token};
1064 OB.emplace_back(Args: "convergencectrl", Args&: bundleArgs);
1065 IP = Token->getNextNode();
1066 }
1067 IRBuilder<> B(IP);
1068 for (auto *Fn : CtorFns) {
1069 auto CI = B.CreateCall(Callee: FunctionCallee(Fn), Args: {}, OpBundles: OB);
1070 CI->setCallingConv(Fn->getCallingConv());
1071 }
1072
1073 // Insert global dtors before the terminator of the last instruction
1074 B.SetInsertPoint(F.back().getTerminator());
1075 for (auto *Fn : DtorFns) {
1076 auto CI = B.CreateCall(Callee: FunctionCallee(Fn), Args: {}, OpBundles: OB);
1077 CI->setCallingConv(Fn->getCallingConv());
1078 }
1079 }
1080
1081 // No need to keep global ctors/dtors for non-lib profile after call to
1082 // ctors/dtors added for entry.
1083 Triple T(M.getTargetTriple());
1084 if (T.getEnvironment() != Triple::EnvironmentType::Library) {
1085 if (auto *GV = M.getNamedGlobal(Name: "llvm.global_ctors"))
1086 GV->eraseFromParent();
1087 if (auto *GV = M.getNamedGlobal(Name: "llvm.global_dtors"))
1088 GV->eraseFromParent();
1089 }
1090}
1091
1092static void initializeBuffer(CodeGenModule &CGM, llvm::GlobalVariable *GV,
1093 Intrinsic::ID IntrID,
1094 ArrayRef<llvm::Value *> Args) {
1095
1096 LLVMContext &Ctx = CGM.getLLVMContext();
1097 llvm::Function *InitResFunc = llvm::Function::Create(
1098 Ty: llvm::FunctionType::get(Result: CGM.VoidTy, isVarArg: false),
1099 Linkage: llvm::GlobalValue::InternalLinkage,
1100 N: ("_init_buffer_" + GV->getName()).str(), M&: CGM.getModule());
1101 InitResFunc->addFnAttr(Kind: llvm::Attribute::AlwaysInline);
1102
1103 llvm::BasicBlock *EntryBB =
1104 llvm::BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: InitResFunc);
1105 CGBuilderTy Builder(CGM, Ctx);
1106 const DataLayout &DL = CGM.getModule().getDataLayout();
1107 Builder.SetInsertPoint(EntryBB);
1108
1109 // Make sure the global variable is buffer resource handle
1110 llvm::Type *HandleTy = GV->getValueType();
1111 assert(HandleTy->isTargetExtTy() && "unexpected type of the buffer global");
1112
1113 llvm::Value *CreateHandle = Builder.CreateIntrinsic(
1114 /*ReturnType=*/RetTy: HandleTy, ID: IntrID, Args, FMFSource: nullptr,
1115 Name: Twine(GV->getName()).concat(Suffix: "_h"));
1116
1117 Builder.CreateAlignedStore(Val: CreateHandle, Ptr: GV, Align: GV->getPointerAlignment(DL));
1118 Builder.CreateRetVoid();
1119
1120 CGM.AddCXXGlobalInit(F: InitResFunc);
1121}
1122
1123void CGHLSLRuntime::initializeBufferFromBinding(const HLSLBufferDecl *BufDecl,
1124 llvm::GlobalVariable *GV) {
1125 ResourceBindingAttrs Binding(BufDecl);
1126 assert(Binding.hasBinding() &&
1127 "cbuffer/tbuffer should always have resource binding attribute");
1128
1129 auto *Index = llvm::ConstantInt::get(Ty: CGM.IntTy, V: 0);
1130 auto *RangeSize = llvm::ConstantInt::get(Ty: CGM.IntTy, V: 1);
1131 auto *Space = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSpace());
1132 Value *Name = buildNameForResource(BaseName: BufDecl->getName(), CGM);
1133
1134 // buffer with explicit binding
1135 if (Binding.isExplicit()) {
1136 llvm::Intrinsic::ID IntrinsicID =
1137 CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic();
1138 auto *RegSlot = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getSlot());
1139 SmallVector<Value *> Args{Space, RegSlot, RangeSize, Index, Name};
1140 initializeBuffer(CGM, GV, IntrID: IntrinsicID, Args);
1141 } else {
1142 // buffer with implicit binding
1143 llvm::Intrinsic::ID IntrinsicID =
1144 CGM.getHLSLRuntime().getCreateHandleFromImplicitBindingIntrinsic();
1145 auto *OrderID =
1146 llvm::ConstantInt::get(Ty: CGM.IntTy, V: Binding.getImplicitOrderID());
1147 SmallVector<Value *> Args{OrderID, Space, RangeSize, Index, Name};
1148 initializeBuffer(CGM, GV, IntrID: IntrinsicID, Args);
1149 }
1150}
1151
1152void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
1153 llvm::GlobalVariable *GV) {
1154 if (auto Attr = VD->getAttr<HLSLVkExtBuiltinInputAttr>())
1155 addSPIRVBuiltinDecoration(GV, BuiltIn: Attr->getBuiltIn());
1156}
1157
1158llvm::Instruction *CGHLSLRuntime::getConvergenceToken(BasicBlock &BB) {
1159 if (!CGM.shouldEmitConvergenceTokens())
1160 return nullptr;
1161
1162 auto E = BB.end();
1163 for (auto I = BB.begin(); I != E; ++I) {
1164 auto *II = dyn_cast<llvm::IntrinsicInst>(Val: &*I);
1165 if (II && llvm::isConvergenceControlIntrinsic(IntrinsicID: II->getIntrinsicID())) {
1166 return II;
1167 }
1168 }
1169 llvm_unreachable("Convergence token should have been emitted.");
1170 return nullptr;
1171}
1172
1173class OpaqueValueVisitor : public RecursiveASTVisitor<OpaqueValueVisitor> {
1174public:
1175 llvm::SmallVector<OpaqueValueExpr *, 8> OVEs;
1176 llvm::SmallPtrSet<OpaqueValueExpr *, 8> Visited;
1177 OpaqueValueVisitor() {}
1178
1179 bool VisitHLSLOutArgExpr(HLSLOutArgExpr *) {
1180 // These need to be bound in CodeGenFunction::EmitHLSLOutArgLValues
1181 // or CodeGenFunction::EmitHLSLOutArgExpr. If they are part of this
1182 // traversal, the temporary containing the copy out will not have
1183 // been created yet.
1184 return false;
1185 }
1186
1187 bool VisitOpaqueValueExpr(OpaqueValueExpr *E) {
1188 // Traverse the source expression first.
1189 if (E->getSourceExpr())
1190 TraverseStmt(S: E->getSourceExpr());
1191
1192 // Then add this OVE if we haven't seen it before.
1193 if (Visited.insert(Ptr: E).second)
1194 OVEs.push_back(Elt: E);
1195
1196 return true;
1197 }
1198};
1199
1200void CGHLSLRuntime::emitInitListOpaqueValues(CodeGenFunction &CGF,
1201 InitListExpr *E) {
1202
1203 typedef CodeGenFunction::OpaqueValueMappingData OpaqueValueMappingData;
1204 OpaqueValueVisitor Visitor;
1205 Visitor.TraverseStmt(S: E);
1206 for (auto *OVE : Visitor.OVEs) {
1207 if (CGF.isOpaqueValueEmitted(E: OVE))
1208 continue;
1209 if (OpaqueValueMappingData::shouldBindAsLValue(expr: OVE)) {
1210 LValue LV = CGF.EmitLValue(E: OVE->getSourceExpr());
1211 OpaqueValueMappingData::bind(CGF, ov: OVE, lv: LV);
1212 } else {
1213 RValue RV = CGF.EmitAnyExpr(E: OVE->getSourceExpr());
1214 OpaqueValueMappingData::bind(CGF, ov: OVE, rv: RV);
1215 }
1216 }
1217}
1218
1219std::optional<LValue> CGHLSLRuntime::emitResourceArraySubscriptExpr(
1220 const ArraySubscriptExpr *ArraySubsExpr, CodeGenFunction &CGF) {
1221 assert((ArraySubsExpr->getType()->isHLSLResourceRecord() ||
1222 ArraySubsExpr->getType()->isHLSLResourceRecordArray()) &&
1223 "expected resource array subscript expression");
1224
1225 // Let clang codegen handle local and static resource array subscripts,
1226 // or when the subscript references on opaque expression (as part of
1227 // ArrayInitLoopExpr AST node).
1228 const VarDecl *ArrayDecl =
1229 dyn_cast_or_null<VarDecl>(Val: getArrayDecl(ASE: ArraySubsExpr));
1230 if (!ArrayDecl || !ArrayDecl->hasGlobalStorage() ||
1231 ArrayDecl->getStorageClass() == SC_Static)
1232 return std::nullopt;
1233
1234 // get the resource array type
1235 ASTContext &AST = ArrayDecl->getASTContext();
1236 const Type *ResArrayTy = ArrayDecl->getType().getTypePtr();
1237 assert(ResArrayTy->isHLSLResourceRecordArray() &&
1238 "expected array of resource classes");
1239
1240 // Iterate through all nested array subscript expressions to calculate
1241 // the index in the flattened resource array (if this is a multi-
1242 // dimensional array). The index is calculated as a sum of all indices
1243 // multiplied by the total size of the array at that level.
1244 Value *Index = nullptr;
1245 const ArraySubscriptExpr *ASE = ArraySubsExpr;
1246 while (ASE != nullptr) {
1247 Value *SubIndex = CGF.EmitScalarExpr(E: ASE->getIdx());
1248 if (const auto *ArrayTy =
1249 dyn_cast<ConstantArrayType>(Val: ASE->getType().getTypePtr())) {
1250 Value *Multiplier = llvm::ConstantInt::get(
1251 Ty: CGM.IntTy, V: AST.getConstantArrayElementCount(CA: ArrayTy));
1252 SubIndex = CGF.Builder.CreateMul(LHS: SubIndex, RHS: Multiplier);
1253 }
1254 Index = Index ? CGF.Builder.CreateAdd(LHS: Index, RHS: SubIndex) : SubIndex;
1255 ASE = dyn_cast<ArraySubscriptExpr>(Val: ASE->getBase()->IgnoreParenImpCasts());
1256 }
1257
1258 // Find binding info for the resource array. For implicit binding
1259 // an HLSLResourceBindingAttr should have been added by SemaHLSL.
1260 ResourceBindingAttrs Binding(ArrayDecl);
1261 assert(Binding.hasBinding() &&
1262 "resource array must have a binding attribute");
1263
1264 // Find the individual resource type.
1265 QualType ResultTy = ArraySubsExpr->getType();
1266 QualType ResourceTy =
1267 ResultTy->isArrayType() ? AST.getBaseElementType(QT: ResultTy) : ResultTy;
1268
1269 // Create a temporary variable for the result, which is either going
1270 // to be a single resource instance or a local array of resources (we need to
1271 // return an LValue).
1272 RawAddress TmpVar = CGF.CreateMemTemp(T: ResultTy);
1273 if (CGF.EmitLifetimeStart(Addr: TmpVar.getPointer()))
1274 CGF.pushFullExprCleanup<CodeGenFunction::CallLifetimeEnd>(
1275 kind: NormalEHLifetimeMarker, A: TmpVar);
1276
1277 AggValueSlot ValueSlot = AggValueSlot::forAddr(
1278 addr: TmpVar, quals: Qualifiers(), isDestructed: AggValueSlot::IsDestructed_t(true),
1279 needsGC: AggValueSlot::DoesNotNeedGCBarriers, isAliased: AggValueSlot::IsAliased_t(false),
1280 mayOverlap: AggValueSlot::DoesNotOverlap);
1281
1282 // Calculate total array size (= range size).
1283 llvm::Value *Range = llvm::ConstantInt::getSigned(
1284 Ty: CGM.IntTy, V: getTotalArraySize(AST, Ty: ResArrayTy));
1285
1286 // If the result of the subscript operation is a single resource, call the
1287 // constructor.
1288 if (ResultTy == ResourceTy) {
1289 CallArgList Args;
1290 CXXMethodDecl *CreateMethod = lookupResourceInitMethodAndSetupArgs(
1291 CGM&: CGF.CGM, ResourceDecl: ResourceTy->getAsCXXRecordDecl(), Range, Index,
1292 Name: ArrayDecl->getName(), Binding, Args);
1293
1294 if (!CreateMethod)
1295 // This can happen if someone creates an array of structs that looks like
1296 // an HLSL resource record array but it does not have the required static
1297 // create method. No binding will be generated for it.
1298 return std::nullopt;
1299
1300 callResourceInitMethod(CGF, CreateMethod, Args, ReturnAddress: ValueSlot.getAddress());
1301
1302 } else {
1303 // The result of the subscript operation is a local resource array which
1304 // needs to be initialized.
1305 const ConstantArrayType *ArrayTy =
1306 cast<ConstantArrayType>(Val: ResultTy.getTypePtr());
1307 std::optional<llvm::Value *> EndIndex = initializeLocalResourceArray(
1308 CGF, ResourceDecl: ResourceTy->getAsCXXRecordDecl(), ArrayTy, ValueSlot, Range, StartIndex: Index,
1309 ResourceName: ArrayDecl->getName(), Binding, PrevGEPIndices: {llvm::ConstantInt::get(Ty: CGM.IntTy, V: 0)},
1310 ArraySubsExprLoc: ArraySubsExpr->getExprLoc());
1311 if (!EndIndex)
1312 return std::nullopt;
1313 }
1314 return CGF.MakeAddrLValue(Addr: TmpVar, T: ResultTy, Source: AlignmentSource::Decl);
1315}
1316
1317// If RHSExpr is a global resource array, initialize all of its resources and
1318// set them into LHS. Returns false if no copy has been performed and the
1319// array copy should be handled by Clang codegen.
1320bool CGHLSLRuntime::emitResourceArrayCopy(LValue &LHS, Expr *RHSExpr,
1321 CodeGenFunction &CGF) {
1322 QualType ResultTy = RHSExpr->getType();
1323 assert(ResultTy->isHLSLResourceRecordArray() && "expected resource array");
1324
1325 // Let Clang codegen handle local and static resource array copies.
1326 const VarDecl *ArrayDecl = dyn_cast_or_null<VarDecl>(Val: getArrayDecl(E: RHSExpr));
1327 if (!ArrayDecl || !ArrayDecl->hasGlobalStorage() ||
1328 ArrayDecl->getStorageClass() == SC_Static)
1329 return false;
1330
1331 // Find binding info for the resource array. For implicit binding
1332 // the HLSLResourceBindingAttr should have been added by SemaHLSL.
1333 ResourceBindingAttrs Binding(ArrayDecl);
1334 assert(Binding.hasBinding() &&
1335 "resource array must have a binding attribute");
1336
1337 // Find the individual resource type.
1338 ASTContext &AST = ArrayDecl->getASTContext();
1339 QualType ResTy = AST.getBaseElementType(QT: ResultTy);
1340 const auto *ResArrayTy = cast<ConstantArrayType>(Val: ResultTy.getTypePtr());
1341
1342 // Use the provided LHS for the result.
1343 AggValueSlot ValueSlot = AggValueSlot::forAddr(
1344 addr: LHS.getAddress(), quals: Qualifiers(), isDestructed: AggValueSlot::IsDestructed_t(true),
1345 needsGC: AggValueSlot::DoesNotNeedGCBarriers, isAliased: AggValueSlot::IsAliased_t(false),
1346 mayOverlap: AggValueSlot::DoesNotOverlap);
1347
1348 // Create Value for index and total array size (= range size).
1349 int Size = getTotalArraySize(AST, Ty: ResArrayTy);
1350 llvm::Value *Zero = llvm::ConstantInt::get(Ty: CGM.IntTy, V: 0);
1351 llvm::Value *Range = llvm::ConstantInt::get(Ty: CGM.IntTy, V: Size);
1352
1353 // Initialize individual resources in the array into LHS.
1354 std::optional<llvm::Value *> EndIndex = initializeLocalResourceArray(
1355 CGF, ResourceDecl: ResTy->getAsCXXRecordDecl(), ArrayTy: ResArrayTy, ValueSlot, Range, StartIndex: Zero,
1356 ResourceName: ArrayDecl->getName(), Binding, PrevGEPIndices: {Zero}, ArraySubsExprLoc: RHSExpr->getExprLoc());
1357 return EndIndex.has_value();
1358}
1359
1360RawAddress CGHLSLRuntime::createBufferMatrixTempAddress(const LValue &LV,
1361 SourceLocation Loc,
1362 CodeGenFunction &CGF) {
1363
1364 assert(LV.getType()->isConstantMatrixType() && "expected matrix type");
1365 assert(LV.getType().getAddressSpace() == LangAS::hlsl_constant &&
1366 "expected cbuffer matrix");
1367
1368 QualType MatQualTy = LV.getType();
1369 llvm::Type *MemTy = CGF.ConvertTypeForMem(T: MatQualTy);
1370 llvm::Type *LayoutTy = HLSLBufferLayoutBuilder(CGF.CGM).layOutType(Type: MatQualTy);
1371
1372 if (LayoutTy == MemTy)
1373 return LV.getAddress();
1374
1375 Address SrcAddr = LV.getAddress();
1376 // NOTE: B\C CreateMemTemp flattens MatrixTypes which causes
1377 // overlapping GEPs in emitBufferCopy. Use CreateTempAlloca with
1378 // the non-padded layout.
1379 CharUnits Align =
1380 CharUnits::fromQuantity(Quantity: CGF.CGM.getDataLayout().getABITypeAlign(Ty: MemTy));
1381 RawAddress DestAlloca = CGF.CreateTempAlloca(Ty: MemTy, align: Align, Name: "matrix.buf.copy");
1382 emitBufferCopy(CGF, DestPtr: DestAlloca, SrcPtr: SrcAddr, CType: MatQualTy);
1383 return DestAlloca;
1384}
1385
1386std::optional<LValue> CGHLSLRuntime::emitBufferArraySubscriptExpr(
1387 const ArraySubscriptExpr *E, CodeGenFunction &CGF,
1388 llvm::function_ref<llvm::Value *(bool Promote)> EmitIdxAfterBase) {
1389 // Find the element type to index by first padding the element type per HLSL
1390 // buffer rules, and then padding out to a 16-byte register boundary if
1391 // necessary.
1392 llvm::Type *LayoutTy =
1393 HLSLBufferLayoutBuilder(CGF.CGM).layOutType(Type: E->getType());
1394 uint64_t LayoutSizeInBits =
1395 CGM.getDataLayout().getTypeSizeInBits(Ty: LayoutTy).getFixedValue();
1396 CharUnits ElementSize = CharUnits::fromQuantity(Quantity: LayoutSizeInBits / 8);
1397 CharUnits RowAlignedSize = ElementSize.alignTo(Align: CharUnits::fromQuantity(Quantity: 16));
1398 if (RowAlignedSize > ElementSize) {
1399 llvm::Type *Padding = CGM.getTargetCodeGenInfo().getHLSLPadding(
1400 CGM, NumBytes: RowAlignedSize - ElementSize);
1401 assert(Padding && "No padding type for target?");
1402 LayoutTy = llvm::StructType::get(Context&: CGF.getLLVMContext(), Elements: {LayoutTy, Padding},
1403 /*isPacked=*/true);
1404 }
1405
1406 // If the layout type doesn't introduce any padding, we don't need to do
1407 // anything special.
1408 llvm::Type *OrigTy = CGF.CGM.getTypes().ConvertTypeForMem(T: E->getType());
1409 if (LayoutTy == OrigTy)
1410 return std::nullopt;
1411
1412 LValueBaseInfo EltBaseInfo;
1413 TBAAAccessInfo EltTBAAInfo;
1414
1415 // Index into the object as-if we have an array of the padded element type,
1416 // and then dereference the element itself to avoid reading padding that may
1417 // be past the end of the in-memory object.
1418 SmallVector<llvm::Value *, 2> Indices;
1419 llvm::Value *Idx = EmitIdxAfterBase(/*Promote*/ true);
1420 Indices.push_back(Elt: Idx);
1421 Indices.push_back(Elt: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 0));
1422
1423 if (CGF.getLangOpts().EmitStructuredGEP) {
1424 // The fact that we emit an array-to-pointer decay might be an oversight,
1425 // but for now, we simply ignore it (see #179951).
1426 const CastExpr *CE = cast<CastExpr>(Val: E->getBase());
1427 assert(CE->getCastKind() == CastKind::CK_ArrayToPointerDecay);
1428
1429 LValue LV = CGF.EmitLValue(E: CE->getSubExpr());
1430 Address Addr = LV.getAddress();
1431 LayoutTy = llvm::ArrayType::get(
1432 ElementType: LayoutTy,
1433 NumElements: cast<llvm::ArrayType>(Val: Addr.getElementType())->getNumElements());
1434 auto *GEP = cast<StructuredGEPInst>(Val: CGF.Builder.CreateStructuredGEP(
1435 BaseType: LayoutTy, PtrBase: Addr.emitRawPointer(CGF), Indices, Name: "cbufferidx"));
1436 Addr =
1437 Address(GEP, GEP->getResultElementType(), RowAlignedSize, KnownNonNull);
1438 return CGF.MakeAddrLValue(Addr, T: E->getType(), BaseInfo: EltBaseInfo, TBAAInfo: EltTBAAInfo);
1439 }
1440
1441 Address Addr =
1442 CGF.EmitPointerWithAlignment(Addr: E->getBase(), BaseInfo: &EltBaseInfo, TBAAInfo: &EltTBAAInfo);
1443 llvm::Value *GEP = CGF.Builder.CreateGEP(Ty: LayoutTy, Ptr: Addr.emitRawPointer(CGF),
1444 IdxList: Indices, Name: "cbufferidx");
1445 Addr = Address(GEP, Addr.getElementType(), RowAlignedSize, KnownNonNull);
1446 return CGF.MakeAddrLValue(Addr, T: E->getType(), BaseInfo: EltBaseInfo, TBAAInfo: EltTBAAInfo);
1447}
1448
1449namespace {
1450/// Utility for emitting copies following the HLSL buffer layout rules (ie,
1451/// copying out of a cbuffer).
1452class HLSLBufferCopyEmitter {
1453 CodeGenFunction &CGF;
1454 Address DestPtr;
1455 Address SrcPtr;
1456 llvm::Type *LayoutTy = nullptr;
1457
1458 SmallVector<llvm::Value *> CurStoreIndices;
1459 SmallVector<llvm::Value *> CurLoadIndices;
1460
1461 void emitCopyAtIndices(llvm::Type *FieldTy, llvm::ConstantInt *StoreIndex,
1462 llvm::ConstantInt *LoadIndex) {
1463 CurStoreIndices.push_back(Elt: StoreIndex);
1464 CurLoadIndices.push_back(Elt: LoadIndex);
1465 llvm::scope_exit RestoreIndices([&]() {
1466 CurStoreIndices.pop_back();
1467 CurLoadIndices.pop_back();
1468 });
1469
1470 // First, see if this is some kind of aggregate and recurse.
1471 if (processArray(FieldTy))
1472 return;
1473 if (processBufferLayoutArray(FieldTy))
1474 return;
1475 if (processStruct(FieldTy))
1476 return;
1477
1478 // When we have a scalar or vector element we can emit the copy.
1479 CharUnits Align = CharUnits::fromQuantity(
1480 Quantity: CGF.CGM.getDataLayout().getABITypeAlign(Ty: FieldTy));
1481 Address SrcGEP = RawAddress(
1482 CGF.Builder.CreateInBoundsGEP(Ty: LayoutTy, Ptr: SrcPtr.getBasePointer(),
1483 IdxList: CurLoadIndices, Name: "cbuf.src"),
1484 FieldTy, Align, SrcPtr.isKnownNonNull());
1485 Address DestGEP = CGF.Builder.CreateInBoundsGEP(
1486 Addr: DestPtr, IdxList: CurStoreIndices, ElementType: FieldTy, Align, Name: "cbuf.dest");
1487 llvm::Value *Load = CGF.Builder.CreateLoad(Addr: SrcGEP, Name: "cbuf.load");
1488 CGF.Builder.CreateStore(Val: Load, Addr: DestGEP);
1489 }
1490
1491 bool processArray(llvm::Type *FieldTy) {
1492 auto *AT = dyn_cast<llvm::ArrayType>(Val: FieldTy);
1493 if (!AT)
1494 return false;
1495
1496 // If we have an llvm::ArrayType this is just a regular array with no top
1497 // level padding, so all we need to do is copy each member.
1498 for (unsigned I = 0, E = AT->getNumElements(); I < E; ++I)
1499 emitCopyAtIndices(FieldTy: AT->getElementType(),
1500 StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I),
1501 LoadIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I));
1502 return true;
1503 }
1504
1505 bool processBufferLayoutArray(llvm::Type *FieldTy) {
1506 // A buffer layout array is a struct with two elements: the padded array,
1507 // and the last element. That is, is should look something like this:
1508 //
1509 // { [%n x { %type, %padding }], %type }
1510 //
1511 auto *ST = dyn_cast<llvm::StructType>(Val: FieldTy);
1512 if (!ST || ST->getNumElements() != 2)
1513 return false;
1514
1515 auto *PaddedEltsTy = dyn_cast<llvm::ArrayType>(Val: ST->getElementType(N: 0));
1516 if (!PaddedEltsTy)
1517 return false;
1518
1519 auto *PaddedTy = dyn_cast<llvm::StructType>(Val: PaddedEltsTy->getElementType());
1520 if (!PaddedTy || PaddedTy->getNumElements() != 2)
1521 return false;
1522
1523 if (!CGF.CGM.getTargetCodeGenInfo().isHLSLPadding(
1524 Ty: PaddedTy->getElementType(N: 1)))
1525 return false;
1526
1527 llvm::Type *ElementTy = ST->getElementType(N: 1);
1528 if (PaddedTy->getElementType(N: 0) != ElementTy)
1529 return false;
1530
1531 // All but the last of the logical array elements are in the padded array.
1532 unsigned NumElts = PaddedEltsTy->getNumElements() + 1;
1533
1534 // Add an extra indirection to the load for the struct and walk the
1535 // array prefix.
1536 CurLoadIndices.push_back(Elt: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 0));
1537 for (unsigned I = 0; I < NumElts - 1; ++I) {
1538 // We need to copy the element itself, without the padding.
1539 CurLoadIndices.push_back(Elt: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I));
1540 emitCopyAtIndices(FieldTy: ElementTy, StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: I),
1541 LoadIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 0));
1542 CurLoadIndices.pop_back();
1543 }
1544 CurLoadIndices.pop_back();
1545
1546 // Now copy the last element.
1547 emitCopyAtIndices(FieldTy: ElementTy,
1548 StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: NumElts - 1),
1549 LoadIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: 1));
1550
1551 return true;
1552 }
1553
1554 bool processStruct(llvm::Type *FieldTy) {
1555 auto *ST = dyn_cast<llvm::StructType>(Val: FieldTy);
1556 if (!ST)
1557 return false;
1558
1559 // Copy the struct field by field, but skip any explicit padding.
1560 unsigned Skipped = 0;
1561 for (unsigned I = 0, E = ST->getNumElements(); I < E; ++I) {
1562 llvm::Type *ElementTy = ST->getElementType(N: I);
1563 if (CGF.CGM.getTargetCodeGenInfo().isHLSLPadding(Ty: ElementTy))
1564 ++Skipped;
1565 else
1566 emitCopyAtIndices(FieldTy: ElementTy, StoreIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: I),
1567 LoadIndex: llvm::ConstantInt::get(Ty: CGF.Int32Ty, V: I + Skipped));
1568 }
1569 return true;
1570 }
1571
1572public:
1573 HLSLBufferCopyEmitter(CodeGenFunction &CGF, Address DestPtr, Address SrcPtr)
1574 : CGF(CGF), DestPtr(DestPtr), SrcPtr(SrcPtr) {}
1575
1576 bool emitCopy(QualType CType) {
1577 LayoutTy = HLSLBufferLayoutBuilder(CGF.CGM).layOutType(Type: CType);
1578
1579 // TODO: We should be able to fall back to a regular memcpy if the layout
1580 // type doesn't have any padding, but that runs into issues in the backend
1581 // currently.
1582 //
1583 // See https://github.com/llvm/wg-hlsl/issues/351
1584 emitCopyAtIndices(FieldTy: LayoutTy, StoreIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: 0),
1585 LoadIndex: llvm::ConstantInt::get(Ty: CGF.SizeTy, V: 0));
1586 return true;
1587 }
1588};
1589} // namespace
1590
1591bool CGHLSLRuntime::emitBufferCopy(CodeGenFunction &CGF, Address DestPtr,
1592 Address SrcPtr, QualType CType) {
1593 return HLSLBufferCopyEmitter(CGF, DestPtr, SrcPtr).emitCopy(CType);
1594}
1595
1596LValue CGHLSLRuntime::emitBufferMemberExpr(CodeGenFunction &CGF,
1597 const MemberExpr *E) {
1598 LValue Base =
1599 CGF.EmitCheckedLValue(E: E->getBase(), TCK: CodeGenFunction::TCK_MemberAccess);
1600 auto *Field = dyn_cast<FieldDecl>(Val: E->getMemberDecl());
1601 assert(Field && "Unexpected access into HLSL buffer");
1602
1603 const RecordDecl *Rec = Field->getParent();
1604
1605 // Work out the buffer layout type to index into.
1606 QualType RecType = CGM.getContext().getCanonicalTagType(TD: Rec);
1607 assert(RecType->isStructureOrClassType() && "Invalid type in HLSL buffer");
1608 // Since this is a member of an object in the buffer and not the buffer's
1609 // struct/class itself, we shouldn't have any offsets on the members we need
1610 // to contend with.
1611 CGHLSLOffsetInfo EmptyOffsets;
1612 llvm::StructType *LayoutTy = HLSLBufferLayoutBuilder(CGM).layOutStruct(
1613 StructType: RecType->getAsCanonical<RecordType>(), OffsetInfo: EmptyOffsets);
1614
1615 // Get the field index for the layout struct, accounting for padding.
1616 unsigned FieldIdx =
1617 CGM.getTypes().getCGRecordLayout(Rec).getLLVMFieldNo(FD: Field);
1618 assert(FieldIdx < LayoutTy->getNumElements() &&
1619 "Layout struct is smaller than member struct");
1620 unsigned Skipped = 0;
1621 for (unsigned I = 0; I <= FieldIdx;) {
1622 llvm::Type *ElementTy = LayoutTy->getElementType(N: I + Skipped);
1623 if (CGF.CGM.getTargetCodeGenInfo().isHLSLPadding(Ty: ElementTy))
1624 ++Skipped;
1625 else
1626 ++I;
1627 }
1628 FieldIdx += Skipped;
1629 assert(FieldIdx < LayoutTy->getNumElements() && "Access out of bounds");
1630
1631 // Now index into the struct, making sure that the type we return is the
1632 // buffer layout type rather than the original type in the AST.
1633 QualType FieldType = Field->getType();
1634 llvm::Type *FieldLLVMTy = CGM.getTypes().ConvertTypeForMem(T: FieldType);
1635 CharUnits Align = CharUnits::fromQuantity(
1636 Quantity: CGF.CGM.getDataLayout().getABITypeAlign(Ty: FieldLLVMTy));
1637
1638 Value *Ptr = CGF.getLangOpts().EmitStructuredGEP
1639 ? CGF.Builder.CreateStructuredGEP(
1640 BaseType: LayoutTy, PtrBase: Base.getPointer(CGF),
1641 Indices: llvm::ConstantInt::get(Ty: CGM.IntTy, V: FieldIdx))
1642 : CGF.Builder.CreateStructGEP(Ty: LayoutTy, Ptr: Base.getPointer(CGF),
1643 Idx: FieldIdx, Name: Field->getName());
1644 Address Addr(Ptr, FieldLLVMTy, Align, KnownNonNull);
1645
1646 LValue LV = LValue::MakeAddr(Addr, type: FieldType, Context&: CGM.getContext(),
1647 BaseInfo: LValueBaseInfo(AlignmentSource::Type),
1648 TBAAInfo: CGM.getTBAAAccessInfo(AccessType: FieldType));
1649 LV.getQuals().addCVRQualifiers(mask: Base.getVRQualifiers());
1650
1651 return LV;
1652}
1653