1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
12#include "clang/AST/ASTConsumer.h"
13#include "clang/AST/ASTContext.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Decl.h"
16#include "clang/AST/DeclBase.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/DeclarationName.h"
19#include "clang/AST/DynamicRecursiveASTVisitor.h"
20#include "clang/AST/Expr.h"
21#include "clang/AST/HLSLResource.h"
22#include "clang/AST/Type.h"
23#include "clang/AST/TypeBase.h"
24#include "clang/AST/TypeLoc.h"
25#include "clang/Basic/Builtins.h"
26#include "clang/Basic/DiagnosticSema.h"
27#include "clang/Basic/IdentifierTable.h"
28#include "clang/Basic/LLVM.h"
29#include "clang/Basic/SourceLocation.h"
30#include "clang/Basic/Specifiers.h"
31#include "clang/Basic/TargetInfo.h"
32#include "clang/Sema/Initialization.h"
33#include "clang/Sema/Lookup.h"
34#include "clang/Sema/ParsedAttr.h"
35#include "clang/Sema/Sema.h"
36#include "clang/Sema/Template.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/StringExtras.h"
41#include "llvm/ADT/StringRef.h"
42#include "llvm/ADT/Twine.h"
43#include "llvm/Frontend/HLSL/HLSLBinding.h"
44#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
45#include "llvm/Support/Casting.h"
46#include "llvm/Support/DXILABI.h"
47#include "llvm/Support/ErrorHandling.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/TargetParser/Triple.h"
50#include <cmath>
51#include <cstddef>
52#include <iterator>
53#include <utility>
54
55using namespace clang;
56using namespace clang::hlsl;
57using RegisterType = HLSLResourceBindingAttr::RegisterType;
58
59static CXXRecordDecl *createHostLayoutStruct(Sema &S,
60 CXXRecordDecl *StructDecl);
61
62static RegisterType getRegisterType(ResourceClass RC) {
63 switch (RC) {
64 case ResourceClass::SRV:
65 return RegisterType::SRV;
66 case ResourceClass::UAV:
67 return RegisterType::UAV;
68 case ResourceClass::CBuffer:
69 return RegisterType::CBuffer;
70 case ResourceClass::Sampler:
71 return RegisterType::Sampler;
72 }
73 llvm_unreachable("unexpected ResourceClass value");
74}
75
76static RegisterType getRegisterType(const HLSLAttributedResourceType *ResTy) {
77 return getRegisterType(RC: ResTy->getAttrs().ResourceClass);
78}
79
80static LangAS getLangASFromResourceClass(ResourceClass RC) {
81 switch (RC) {
82 case ResourceClass::SRV:
83 case ResourceClass::UAV:
84 return LangAS::hlsl_device;
85 case ResourceClass::CBuffer:
86 return LangAS::hlsl_constant;
87 case ResourceClass::Sampler:
88 return LangAS::hlsl_device;
89 }
90 llvm_unreachable("unexpected ResourceClass value");
91}
92
93// Converts the first letter of string Slot to RegisterType.
94// Returns false if the letter does not correspond to a valid register type.
95static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
96 assert(RT != nullptr);
97 switch (Slot[0]) {
98 case 't':
99 case 'T':
100 *RT = RegisterType::SRV;
101 return true;
102 case 'u':
103 case 'U':
104 *RT = RegisterType::UAV;
105 return true;
106 case 'b':
107 case 'B':
108 *RT = RegisterType::CBuffer;
109 return true;
110 case 's':
111 case 'S':
112 *RT = RegisterType::Sampler;
113 return true;
114 case 'c':
115 case 'C':
116 *RT = RegisterType::C;
117 return true;
118 case 'i':
119 case 'I':
120 *RT = RegisterType::I;
121 return true;
122 default:
123 return false;
124 }
125}
126
127static char getRegisterTypeChar(RegisterType RT) {
128 switch (RT) {
129 case RegisterType::SRV:
130 return 't';
131 case RegisterType::UAV:
132 return 'u';
133 case RegisterType::CBuffer:
134 return 'b';
135 case RegisterType::Sampler:
136 return 's';
137 case RegisterType::C:
138 return 'c';
139 case RegisterType::I:
140 return 'i';
141 }
142 llvm_unreachable("unexpected RegisterType value");
143}
144
145static ResourceClass getResourceClass(RegisterType RT) {
146 switch (RT) {
147 case RegisterType::SRV:
148 return ResourceClass::SRV;
149 case RegisterType::UAV:
150 return ResourceClass::UAV;
151 case RegisterType::CBuffer:
152 return ResourceClass::CBuffer;
153 case RegisterType::Sampler:
154 return ResourceClass::Sampler;
155 case RegisterType::C:
156 case RegisterType::I:
157 // Deliberately falling through to the unreachable below.
158 break;
159 }
160 llvm_unreachable("unexpected RegisterType value");
161}
162
163static Builtin::ID getSpecConstBuiltinId(const Type *Type) {
164 const auto *BT = dyn_cast<BuiltinType>(Val: Type);
165 if (!BT) {
166 if (!Type->isEnumeralType())
167 return Builtin::NotBuiltin;
168 return Builtin::BI__builtin_get_spirv_spec_constant_int;
169 }
170
171 switch (BT->getKind()) {
172 case BuiltinType::Bool:
173 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
174 case BuiltinType::Short:
175 return Builtin::BI__builtin_get_spirv_spec_constant_short;
176 case BuiltinType::Int:
177 return Builtin::BI__builtin_get_spirv_spec_constant_int;
178 case BuiltinType::LongLong:
179 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
180 case BuiltinType::UShort:
181 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
182 case BuiltinType::UInt:
183 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
184 case BuiltinType::ULongLong:
185 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
186 case BuiltinType::Half:
187 return Builtin::BI__builtin_get_spirv_spec_constant_half;
188 case BuiltinType::Float:
189 return Builtin::BI__builtin_get_spirv_spec_constant_float;
190 case BuiltinType::Double:
191 return Builtin::BI__builtin_get_spirv_spec_constant_double;
192 default:
193 return Builtin::NotBuiltin;
194 }
195}
196
197static StringRef createRegisterString(ASTContext &AST, RegisterType RegType,
198 unsigned N) {
199 llvm::SmallString<16> Buffer;
200 llvm::raw_svector_ostream OS(Buffer);
201 OS << getRegisterTypeChar(RT: RegType);
202 OS << N;
203 return AST.backupStr(S: OS.str());
204}
205
206DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
207 ResourceClass ResClass) {
208 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
209 "DeclBindingInfo already added");
210 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
211 // VarDecl may have multiple entries for different resource classes.
212 // DeclToBindingListIndex stores the index of the first binding we saw
213 // for this decl. If there are any additional ones then that index
214 // shouldn't be updated.
215 DeclToBindingListIndex.try_emplace(Key: VD, Args: BindingsList.size());
216 return &BindingsList.emplace_back(Args&: VD, Args&: ResClass);
217}
218
219DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
220 ResourceClass ResClass) {
221 auto Entry = DeclToBindingListIndex.find(Val: VD);
222 if (Entry != DeclToBindingListIndex.end()) {
223 for (unsigned Index = Entry->getSecond();
224 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
225 ++Index) {
226 if (BindingsList[Index].ResClass == ResClass)
227 return &BindingsList[Index];
228 }
229 }
230 return nullptr;
231}
232
233bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
234 return DeclToBindingListIndex.contains(Val: VD);
235}
236
237SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
238
239Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
240 SourceLocation KwLoc, IdentifierInfo *Ident,
241 SourceLocation IdentLoc,
242 SourceLocation LBrace) {
243 // For anonymous namespace, take the location of the left brace.
244 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
245 HLSLBufferDecl *Result = HLSLBufferDecl::Create(
246 C&: getASTContext(), LexicalParent, CBuffer, KwLoc, ID: Ident, IDLoc: IdentLoc, LBrace);
247
248 // if CBuffer is false, then it's a TBuffer
249 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
250 : llvm::hlsl::ResourceClass::SRV;
251 Result->addAttr(A: HLSLResourceClassAttr::CreateImplicit(Ctx&: getASTContext(), ResourceClass: RC));
252
253 SemaRef.PushOnScopeChains(D: Result, S: BufferScope);
254 SemaRef.PushDeclContext(S: BufferScope, DC: Result);
255
256 return Result;
257}
258
259static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
260 QualType T) {
261 // Arrays, Matrices, and Structs are always aligned to new buffer rows
262 if (T->isArrayType() || T->isStructureType() || T->isConstantMatrixType())
263 return 16;
264
265 // Vectors are aligned to the type they contain
266 if (const VectorType *VT = T->getAs<VectorType>())
267 return calculateLegacyCbufferFieldAlign(Context, T: VT->getElementType());
268
269 assert(Context.getTypeSize(T) <= 64 &&
270 "Scalar bit widths larger than 64 not supported");
271
272 // Scalar types are aligned to their byte width
273 return Context.getTypeSize(T) / 8;
274}
275
276// Calculate the size of a legacy cbuffer type in bytes based on
277// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
278static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
279 QualType T) {
280 constexpr unsigned CBufferAlign = 16;
281 if (const auto *RD = T->getAsRecordDecl()) {
282 unsigned Size = 0;
283 for (const FieldDecl *Field : RD->fields()) {
284 QualType Ty = Field->getType();
285 unsigned FieldSize = calculateLegacyCbufferSize(Context, T: Ty);
286 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, T: Ty);
287
288 // If the field crosses the row boundary after alignment it drops to the
289 // next row
290 unsigned AlignSize = llvm::alignTo(Value: Size, Align: FieldAlign);
291 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
292 FieldAlign = CBufferAlign;
293 }
294
295 Size = llvm::alignTo(Value: Size, Align: FieldAlign);
296 Size += FieldSize;
297 }
298 return Size;
299 }
300
301 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
302 unsigned ElementCount = AT->getSize().getZExtValue();
303 if (ElementCount == 0)
304 return 0;
305
306 unsigned ElementSize =
307 calculateLegacyCbufferSize(Context, T: AT->getElementType());
308 unsigned AlignedElementSize = llvm::alignTo(Value: ElementSize, Align: CBufferAlign);
309 return AlignedElementSize * (ElementCount - 1) + ElementSize;
310 }
311
312 if (const VectorType *VT = T->getAs<VectorType>()) {
313 unsigned ElementCount = VT->getNumElements();
314 unsigned ElementSize =
315 calculateLegacyCbufferSize(Context, T: VT->getElementType());
316 return ElementSize * ElementCount;
317 }
318
319 return Context.getTypeSize(T) / 8;
320}
321
322// Validate packoffset:
323// - if packoffset it used it must be set on all declarations inside the buffer
324// - packoffset ranges must not overlap
325static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
326 llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
327
328 // Make sure the packoffset annotations are either on all declarations
329 // or on none.
330 bool HasPackOffset = false;
331 bool HasNonPackOffset = false;
332 for (auto *Field : BufDecl->buffer_decls()) {
333 VarDecl *Var = dyn_cast<VarDecl>(Val: Field);
334 if (!Var)
335 continue;
336 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
337 PackOffsetVec.emplace_back(Args&: Var, Args: Field->getAttr<HLSLPackOffsetAttr>());
338 HasPackOffset = true;
339 } else {
340 HasNonPackOffset = true;
341 }
342 }
343
344 if (!HasPackOffset)
345 return;
346
347 if (HasNonPackOffset)
348 S.Diag(Loc: BufDecl->getLocation(), DiagID: diag::warn_hlsl_packoffset_mix);
349
350 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
351 // and compare adjacent values.
352 bool IsValid = true;
353 ASTContext &Context = S.getASTContext();
354 std::sort(first: PackOffsetVec.begin(), last: PackOffsetVec.end(),
355 comp: [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
356 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
357 return LHS.second->getOffsetInBytes() <
358 RHS.second->getOffsetInBytes();
359 });
360 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
361 VarDecl *Var = PackOffsetVec[i].first;
362 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
363 unsigned Size = calculateLegacyCbufferSize(Context, T: Var->getType());
364 unsigned Begin = Attr->getOffsetInBytes();
365 unsigned End = Begin + Size;
366 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
367 if (End > NextBegin) {
368 VarDecl *NextVar = PackOffsetVec[i + 1].first;
369 S.Diag(Loc: NextVar->getLocation(), DiagID: diag::err_hlsl_packoffset_overlap)
370 << NextVar << Var;
371 IsValid = false;
372 }
373 }
374 BufDecl->setHasValidPackoffset(IsValid);
375}
376
377// Returns true if the array has a zero size = if any of the dimensions is 0
378static bool isZeroSizedArray(const ConstantArrayType *CAT) {
379 while (CAT && !CAT->isZeroSize())
380 CAT = dyn_cast<ConstantArrayType>(
381 Val: CAT->getElementType()->getUnqualifiedDesugaredType());
382 return CAT != nullptr;
383}
384
385static bool isResourceRecordTypeOrArrayOf(QualType Ty) {
386 return Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray();
387}
388
389static bool isResourceRecordTypeOrArrayOf(VarDecl *VD) {
390 return isResourceRecordTypeOrArrayOf(Ty: VD->getType());
391}
392
393static const HLSLAttributedResourceType *
394getResourceArrayHandleType(QualType QT) {
395 assert(QT->isHLSLResourceRecordArray() &&
396 "expected array of resource records");
397 const Type *Ty = QT->getUnqualifiedDesugaredType();
398 while (const ArrayType *AT = dyn_cast<ArrayType>(Val: Ty))
399 Ty = AT->getArrayElementTypeNoTypeQual()->getUnqualifiedDesugaredType();
400 return HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty);
401}
402
403static const HLSLAttributedResourceType *
404getResourceArrayHandleType(VarDecl *VD) {
405 return getResourceArrayHandleType(QT: VD->getType());
406}
407
408// Returns true if the type is a leaf element type that is not valid to be
409// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
410// array, or a builtin intangible type. Returns false it is a valid leaf element
411// type or if it is a record type that needs to be inspected further.
412static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
413 Ty = Ty->getUnqualifiedDesugaredType();
414 if (Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray())
415 return true;
416 if (const auto *RD = Ty->getAsCXXRecordDecl())
417 return RD->isEmpty();
418 if (Ty->isConstantArrayType() &&
419 isZeroSizedArray(CAT: cast<ConstantArrayType>(Val: Ty)))
420 return true;
421 if (Ty->isHLSLBuiltinIntangibleType() || Ty->isHLSLAttributedResourceType())
422 return true;
423 return false;
424}
425
426// Returns true if the struct contains at least one element that prevents it
427// from being included inside HLSL Buffer as is, such as an intangible type,
428// empty struct, or zero-sized array. If it does, a new implicit layout struct
429// needs to be created for HLSL Buffer use that will exclude these unwanted
430// declarations (see createHostLayoutStruct function).
431static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
432 if (RD->isHLSLIntangible() || RD->isEmpty())
433 return true;
434 // check fields
435 for (const FieldDecl *Field : RD->fields()) {
436 QualType Ty = Field->getType();
437 if (isInvalidConstantBufferLeafElementType(Ty: Ty.getTypePtr()))
438 return true;
439 if (const auto *RD = Ty->getAsCXXRecordDecl();
440 RD && requiresImplicitBufferLayoutStructure(RD))
441 return true;
442 }
443 // check bases
444 for (const CXXBaseSpecifier &Base : RD->bases())
445 if (requiresImplicitBufferLayoutStructure(
446 RD: Base.getType()->castAsCXXRecordDecl()))
447 return true;
448 return false;
449}
450
451static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II,
452 DeclContext *DC) {
453 CXXRecordDecl *RD = nullptr;
454 for (NamedDecl *Decl :
455 DC->getNonTransparentContext()->lookup(Name: DeclarationName(II))) {
456 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Val: Decl)) {
457 assert(RD == nullptr &&
458 "there should be at most 1 record by a given name in a scope");
459 RD = FoundRD;
460 }
461 }
462 return RD;
463}
464
465// Creates a name for buffer layout struct using the provide name base.
466// If the name must be unique (not previously defined), a suffix is added
467// until a unique name is found.
468static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl,
469 bool MustBeUnique) {
470 ASTContext &AST = S.getASTContext();
471
472 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
473 llvm::SmallString<64> Name("__cblayout_");
474 if (NameBaseII) {
475 Name.append(RHS: NameBaseII->getName());
476 } else {
477 // anonymous struct
478 Name.append(RHS: "anon");
479 MustBeUnique = true;
480 }
481
482 size_t NameLength = Name.size();
483 IdentifierInfo *II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
484 if (!MustBeUnique)
485 return II;
486
487 unsigned suffix = 0;
488 while (true) {
489 if (suffix != 0) {
490 Name.append(RHS: "_");
491 Name.append(RHS: llvm::Twine(suffix).str());
492 II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
493 }
494 if (!findRecordDeclInContext(II, DC: BaseDecl->getDeclContext()))
495 return II;
496 // declaration with that name already exists - increment suffix and try
497 // again until unique name is found
498 suffix++;
499 Name.truncate(N: NameLength);
500 };
501}
502
503static const Type *createHostLayoutType(Sema &S, const Type *Ty) {
504 ASTContext &AST = S.getASTContext();
505 if (auto *RD = Ty->getAsCXXRecordDecl()) {
506 if (!requiresImplicitBufferLayoutStructure(RD))
507 return Ty;
508 RD = createHostLayoutStruct(S, StructDecl: RD);
509 if (!RD)
510 return nullptr;
511 return AST.getCanonicalTagType(TD: RD)->getTypePtr();
512 }
513
514 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: Ty)) {
515 const Type *ElementTy = createHostLayoutType(
516 S, Ty: CAT->getElementType()->getUnqualifiedDesugaredType());
517 if (!ElementTy)
518 return nullptr;
519 return AST
520 .getConstantArrayType(EltTy: QualType(ElementTy, 0), ArySize: CAT->getSize(), SizeExpr: nullptr,
521 ASM: CAT->getSizeModifier(),
522 IndexTypeQuals: CAT->getIndexTypeCVRQualifiers())
523 .getTypePtr();
524 }
525 return Ty;
526}
527
528// Returns the type to use for a host layout struct field. For most types this
529// is the unqualified desugared type. Matrix types, however, retain their sugar
530// so that the row_major/column_major orientation (carried as an AttributedType)
531// is preserved; the orientation determines the in-memory cbuffer layout.
532static const Type *getHostLayoutFieldType(QualType QT) {
533 const Type *Desugared = QT->getUnqualifiedDesugaredType();
534 if (Desugared->isConstantMatrixType())
535 return QT.getTypePtr();
536 return Desugared;
537}
538
539// Creates a field declaration of given name and type for HLSL buffer layout
540// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
541static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
542 IdentifierInfo *II,
543 CXXRecordDecl *LayoutStruct) {
544 if (isInvalidConstantBufferLeafElementType(Ty))
545 return nullptr;
546
547 Ty = createHostLayoutType(S, Ty);
548 if (!Ty)
549 return nullptr;
550
551 QualType QT = QualType(Ty, 0);
552 ASTContext &AST = S.getASTContext();
553 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(T: QT, Loc: SourceLocation());
554 auto *Field = FieldDecl::Create(C: AST, DC: LayoutStruct, StartLoc: SourceLocation(),
555 IdLoc: SourceLocation(), Id: II, T: QT, TInfo: TSI, BW: nullptr, Mutable: false,
556 InitStyle: InClassInitStyle::ICIS_NoInit);
557 Field->setAccess(AccessSpecifier::AS_public);
558 return Field;
559}
560
561// Creates host layout struct for a struct included in HLSL Buffer.
562// The layout struct will include only fields that are allowed in HLSL buffer.
563// These fields will be filtered out:
564// - resource classes
565// - empty structs
566// - zero-sized arrays
567// Returns nullptr if the resulting layout struct would be empty.
568static CXXRecordDecl *createHostLayoutStruct(Sema &S,
569 CXXRecordDecl *StructDecl) {
570 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
571 "struct is already HLSL buffer compatible");
572
573 ASTContext &AST = S.getASTContext();
574 DeclContext *DC = StructDecl->getDeclContext();
575 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: StructDecl, MustBeUnique: false);
576
577 // reuse existing if the layout struct if it already exists
578 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
579 return RD;
580
581 CXXRecordDecl *LS =
582 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC, StartLoc: SourceLocation(),
583 IdLoc: SourceLocation(), Id: II);
584 LS->setImplicit(true);
585 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
586 LS->startDefinition();
587
588 // copy base struct, create HLSL Buffer compatible version if needed
589 if (unsigned NumBases = StructDecl->getNumBases()) {
590 assert(NumBases == 1 && "HLSL supports only one base type");
591 (void)NumBases;
592 CXXBaseSpecifier Base = *StructDecl->bases_begin();
593 CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
594 if (requiresImplicitBufferLayoutStructure(RD: BaseDecl)) {
595 BaseDecl = createHostLayoutStruct(S, StructDecl: BaseDecl);
596 if (BaseDecl) {
597 TypeSourceInfo *TSI =
598 AST.getTrivialTypeSourceInfo(T: AST.getCanonicalTagType(TD: BaseDecl));
599 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
600 AS_none, TSI, SourceLocation());
601 }
602 }
603 if (BaseDecl) {
604 const CXXBaseSpecifier *BasesArray[1] = {&Base};
605 LS->setBases(Bases: BasesArray, NumBases: 1);
606 }
607 }
608
609 // filter struct fields
610 for (const FieldDecl *FD : StructDecl->fields()) {
611 const Type *Ty = getHostLayoutFieldType(QT: FD->getType());
612 if (FieldDecl *NewFD =
613 createFieldForHostLayoutStruct(S, Ty, II: FD->getIdentifier(), LayoutStruct: LS))
614 LS->addDecl(D: NewFD);
615 }
616 LS->completeDefinition();
617
618 if (LS->field_empty() && LS->getNumBases() == 0)
619 return nullptr;
620
621 DC->addDecl(D: LS);
622 return LS;
623}
624
625// Creates host layout struct for HLSL Buffer. The struct will include only
626// fields of types that are allowed in HLSL buffer and it will filter out:
627// - static or groupshared variable declarations
628// - resource classes
629// - empty structs
630// - zero-sized arrays
631// - non-variable declarations
632// The layout struct will be added to the HLSLBufferDecl declarations.
633static void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
634 ASTContext &AST = S.getASTContext();
635 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: BufDecl, MustBeUnique: true);
636
637 CXXRecordDecl *LS =
638 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC: BufDecl,
639 StartLoc: SourceLocation(), IdLoc: SourceLocation(), Id: II);
640 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
641 LS->setImplicit(true);
642 LS->startDefinition();
643
644 for (Decl *D : BufDecl->buffer_decls()) {
645 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
646 if (!VD || VD->getStorageClass() == SC_Static ||
647 VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
648 continue;
649 const Type *Ty = getHostLayoutFieldType(QT: VD->getType());
650
651 FieldDecl *FD =
652 createFieldForHostLayoutStruct(S, Ty, II: VD->getIdentifier(), LayoutStruct: LS);
653 // Declarations collected for the default $Globals constant buffer have
654 // already been checked to have non-empty cbuffer layout, so
655 // createFieldForHostLayoutStruct should always succeed. These declarations
656 // already have their address space set to hlsl_constant.
657 // For declarations in a named cbuffer block
658 // createFieldForHostLayoutStruct can still return nullptr if the type
659 // is empty (does not have a cbuffer layout).
660 assert((FD || VD->getType().getAddressSpace() != LangAS::hlsl_constant) &&
661 "host layout field for $Globals decl failed to be created");
662 if (FD) {
663 // Add the field decl to the layout struct.
664 LS->addDecl(D: FD);
665 if (VD->getType().getAddressSpace() != LangAS::hlsl_constant) {
666 // Update address space of the original decl to hlsl_constant.
667 QualType NewTy =
668 AST.getAddrSpaceQualType(T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
669 VD->setType(NewTy);
670 }
671 }
672 }
673 LS->completeDefinition();
674 BufDecl->addLayoutStruct(LS);
675}
676
677static void addImplicitBindingAttrToDecl(Sema &S, Decl *D, RegisterType RT,
678 uint32_t ImplicitBindingOrderID) {
679 auto *Attr =
680 HLSLResourceBindingAttr::CreateImplicit(Ctx&: S.getASTContext(), Slot: "", Space: "0", Range: {});
681 Attr->setBinding(RT, SlotNum: std::nullopt, SpaceNum: 0);
682 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
683 D->addAttr(A: Attr);
684}
685
686// Handle end of cbuffer/tbuffer declaration
687void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
688 auto *BufDecl = cast<HLSLBufferDecl>(Val: Dcl);
689 BufDecl->setRBraceLoc(RBrace);
690
691 validatePackoffset(S&: SemaRef, BufDecl);
692
693 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl);
694
695 // Handle implicit binding if needed.
696 ResourceBindingAttrs ResourceAttrs(Dcl);
697 if (!ResourceAttrs.isExplicit()) {
698 SemaRef.Diag(Loc: Dcl->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
699 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
700 // to codegen. If it does not exist, create an implicit attribute.
701 uint32_t OrderID = getNextImplicitBindingOrderID();
702 if (ResourceAttrs.hasBinding())
703 ResourceAttrs.setImplicitOrderID(OrderID);
704 else
705 addImplicitBindingAttrToDecl(S&: SemaRef, D: BufDecl,
706 RT: BufDecl->isCBuffer() ? RegisterType::CBuffer
707 : RegisterType::SRV,
708 ImplicitBindingOrderID: OrderID);
709 }
710
711 SemaRef.PopDeclContext();
712}
713
714HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
715 const AttributeCommonInfo &AL,
716 int X, int Y, int Z) {
717 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
718 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
719 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
720 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
721 }
722 return nullptr;
723 }
724 return ::new (getASTContext())
725 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
726}
727
728HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
729 const AttributeCommonInfo &AL,
730 int Min, int Max, int Preferred,
731 int SpelledArgsCount) {
732 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
733 if (WS->getMin() != Min || WS->getMax() != Max ||
734 WS->getPreferred() != Preferred ||
735 WS->getSpelledArgsCount() != SpelledArgsCount) {
736 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
737 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
738 }
739 return nullptr;
740 }
741 HLSLWaveSizeAttr *Result = ::new (getASTContext())
742 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
743 Result->setSpelledArgsCount(SpelledArgsCount);
744 return Result;
745}
746
747HLSLVkConstantIdAttr *
748SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
749 int Id) {
750
751 auto &TargetInfo = getASTContext().getTargetInfo();
752 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
753 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_ignored) << AL;
754 return nullptr;
755 }
756
757 auto *VD = cast<VarDecl>(Val: D);
758
759 if (getSpecConstBuiltinId(Type: VD->getType()->getUnqualifiedDesugaredType()) ==
760 Builtin::NotBuiltin) {
761 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
762 return nullptr;
763 }
764
765 if (!VD->getType().isConstQualified()) {
766 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
767 return nullptr;
768 }
769
770 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
771 if (CI->getId() != Id) {
772 Diag(Loc: CI->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
773 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
774 }
775 return nullptr;
776 }
777
778 HLSLVkConstantIdAttr *Result =
779 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
780 return Result;
781}
782
783HLSLShaderAttr *
784SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
785 llvm::Triple::EnvironmentType ShaderType) {
786 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
787 if (NT->getType() != ShaderType) {
788 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
789 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
790 }
791 return nullptr;
792 }
793 return HLSLShaderAttr::Create(Ctx&: getASTContext(), Type: ShaderType, CommonInfo: AL);
794}
795
796HLSLParamModifierAttr *
797SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
798 HLSLParamModifierAttr::Spelling Spelling) {
799 // We can only merge an `in` attribute with an `out` attribute. All other
800 // combinations of duplicated attributes are ill-formed.
801 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
802 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
803 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
804 D->dropAttr<HLSLParamModifierAttr>();
805 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
806 return HLSLParamModifierAttr::Create(
807 Ctx&: getASTContext(), /*MergedSpelling=*/true, Range: AdjustedRange,
808 S: HLSLParamModifierAttr::Keyword_inout);
809 }
810 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_duplicate_parameter_modifier) << AL;
811 Diag(Loc: PA->getLocation(), DiagID: diag::note_conflicting_attribute);
812 return nullptr;
813 }
814 return HLSLParamModifierAttr::Create(Ctx&: getASTContext(), CommonInfo: AL);
815}
816
817void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
818 auto &TargetInfo = getASTContext().getTargetInfo();
819
820 if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
821 return;
822
823 // If we have specified a root signature to override the entry function then
824 // attach it now
825 HLSLRootSignatureDecl *SignatureDecl =
826 lookupRootSignatureOverrideDecl(DC: FD->getDeclContext());
827 if (SignatureDecl) {
828 FD->dropAttr<RootSignatureAttr>();
829 // We could look up the SourceRange of the macro here as well
830 AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(),
831 SourceRange(), ParsedAttr::Form::Microsoft());
832 FD->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
833 getASTContext(), AL, RootSigOverrideIdent, SignatureDecl));
834 }
835
836 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
837 if (HLSLShaderAttr::isValidShaderType(ShaderType: Env) && Env != llvm::Triple::Library) {
838 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
839 // The entry point is already annotated - check that it matches the
840 // triple.
841 if (Shader->getType() != Env) {
842 Diag(Loc: Shader->getLocation(), DiagID: diag::err_hlsl_entry_shader_attr_mismatch)
843 << Shader;
844 FD->setInvalidDecl();
845 }
846 } else {
847 // Implicitly add the shader attribute if the entry function isn't
848 // explicitly annotated.
849 FD->addAttr(A: HLSLShaderAttr::CreateImplicit(Ctx&: getASTContext(), Type: Env,
850 Range: FD->getBeginLoc()));
851 }
852 } else {
853 switch (Env) {
854 case llvm::Triple::UnknownEnvironment:
855 case llvm::Triple::Library:
856 break;
857 case llvm::Triple::RootSignature:
858 llvm_unreachable("rootsig environment has no functions");
859 default:
860 llvm_unreachable("Unhandled environment in triple");
861 }
862 }
863}
864
865static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD,
866 HLSLAppliedSemanticAttr *Semantic,
867 bool IsInput) {
868 if (AstContext.getTargetInfo().getTriple().getOS() != llvm::Triple::Vulkan)
869 return false;
870
871 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
872 assert(ShaderAttr && "Entry point has no shader attribute");
873 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
874 auto SemanticName = Semantic->getSemanticName().upper();
875
876 // The SV_Position semantic is lowered to:
877 // - Position built-in for vertex output.
878 // - FragCoord built-in for fragment input.
879 if (SemanticName == "SV_POSITION") {
880 return (ST == llvm::Triple::Vertex && !IsInput) ||
881 (ST == llvm::Triple::Pixel && IsInput);
882 }
883 if (SemanticName == "SV_VERTEXID")
884 return true;
885
886 return false;
887}
888
889bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
890 DeclaratorDecl *OutputDecl,
891 DeclaratorDecl *D,
892 SemanticInfo &ActiveSemantic,
893 SemaHLSL::SemanticContext &SC) {
894 if (ActiveSemantic.Semantic == nullptr) {
895 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
896 if (ActiveSemantic.Semantic)
897 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
898 }
899
900 if (!ActiveSemantic.Semantic) {
901 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_missing_semantic_annotation);
902 return false;
903 }
904
905 auto *A = ::new (getASTContext())
906 HLSLAppliedSemanticAttr(getASTContext(), *ActiveSemantic.Semantic,
907 ActiveSemantic.Semantic->getAttrName()->getName(),
908 ActiveSemantic.Index.value_or(u: 0));
909 if (!A)
910 return false;
911
912 checkSemanticAnnotation(EntryPoint: FD, Param: D, SemanticAttr: A, SC);
913 OutputDecl->addAttr(A);
914
915 unsigned Location = ActiveSemantic.Index.value_or(u: 0);
916
917 if (!isVkPipelineBuiltin(AstContext: getASTContext(), FD, Semantic: A,
918 IsInput: SC.CurrentIOType & IOType::In)) {
919 bool HasVkLocation = false;
920 if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
921 HasVkLocation = true;
922 Location = A->getLocation();
923 }
924
925 if (SC.UsesExplicitVkLocations.value_or(u&: HasVkLocation) != HasVkLocation) {
926 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_semantic_partial_explicit_indexing);
927 return false;
928 }
929 SC.UsesExplicitVkLocations = HasVkLocation;
930 }
931
932 const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(Val: D->getType());
933 unsigned ElementCount = AT ? AT->getZExtSize() : 1;
934 ActiveSemantic.Index = Location + ElementCount;
935
936 Twine BaseName = Twine(ActiveSemantic.Semantic->getAttrName()->getName());
937 for (unsigned I = 0; I < ElementCount; ++I) {
938 Twine VariableName = BaseName.concat(Suffix: Twine(Location + I));
939
940 auto [_, Inserted] = SC.ActiveSemantics.insert(key: VariableName.str());
941 if (!Inserted) {
942 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_semantic_index_overlap)
943 << VariableName.str();
944 return false;
945 }
946 }
947
948 return true;
949}
950
951bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
952 DeclaratorDecl *OutputDecl,
953 DeclaratorDecl *D,
954 SemanticInfo &ActiveSemantic,
955 SemaHLSL::SemanticContext &SC) {
956 if (ActiveSemantic.Semantic == nullptr) {
957 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
958 if (ActiveSemantic.Semantic)
959 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
960 }
961
962 const Type *T = D == FD ? &*FD->getReturnType() : &*D->getType();
963 T = T->getUnqualifiedDesugaredType();
964
965 const RecordType *RT = dyn_cast<RecordType>(Val: T);
966 if (!RT)
967 return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
968 SC);
969
970 const RecordDecl *RD = RT->getDecl();
971 for (FieldDecl *Field : RD->fields()) {
972 SemanticInfo Info = ActiveSemantic;
973 if (!determineActiveSemantic(FD, OutputDecl, D: Field, ActiveSemantic&: Info, SC)) {
974 Diag(Loc: Field->getLocation(), DiagID: diag::note_hlsl_semantic_used_here) << Field;
975 return false;
976 }
977 if (ActiveSemantic.Semantic)
978 ActiveSemantic = Info;
979 }
980
981 return true;
982}
983
984void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
985 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
986 assert(ShaderAttr && "Entry point has no shader attribute");
987 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
988 auto &TargetInfo = getASTContext().getTargetInfo();
989 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
990 switch (ST) {
991 case llvm::Triple::Pixel:
992 case llvm::Triple::Vertex:
993 case llvm::Triple::Geometry:
994 case llvm::Triple::Hull:
995 case llvm::Triple::Domain:
996 case llvm::Triple::RayGeneration:
997 case llvm::Triple::Intersection:
998 case llvm::Triple::AnyHit:
999 case llvm::Triple::ClosestHit:
1000 case llvm::Triple::Miss:
1001 case llvm::Triple::Callable:
1002 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
1003 diagnoseAttrStageMismatch(A: NT, Stage: ST,
1004 AllowedStages: {llvm::Triple::Compute,
1005 llvm::Triple::Amplification,
1006 llvm::Triple::Mesh});
1007 FD->setInvalidDecl();
1008 }
1009 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
1010 diagnoseAttrStageMismatch(A: WS, Stage: ST,
1011 AllowedStages: {llvm::Triple::Compute,
1012 llvm::Triple::Amplification,
1013 llvm::Triple::Mesh});
1014 FD->setInvalidDecl();
1015 }
1016 break;
1017
1018 case llvm::Triple::Compute:
1019 case llvm::Triple::Amplification:
1020 case llvm::Triple::Mesh:
1021 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
1022 Diag(Loc: FD->getLocation(), DiagID: diag::err_hlsl_missing_numthreads)
1023 << llvm::Triple::getEnvironmentTypeName(Kind: ST);
1024 FD->setInvalidDecl();
1025 }
1026 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
1027 if (TargetInfo.getTriple().isSPIRV()) {
1028 Diag(Loc: WS->getLocation(), DiagID: diag::warn_hlsl_wavesize_unsupported_spirv);
1029 } else if (Ver < VersionTuple(6, 6)) {
1030 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_in_wrong_shader_model)
1031 << WS << "6.6";
1032 FD->setInvalidDecl();
1033 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
1034 Diag(
1035 Loc: WS->getLocation(),
1036 DiagID: diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
1037 << WS << WS->getSpelledArgsCount() << "6.8";
1038 FD->setInvalidDecl();
1039 }
1040 }
1041 break;
1042 case llvm::Triple::RootSignature:
1043 llvm_unreachable("rootsig environment has no function entry point");
1044 default:
1045 llvm_unreachable("Unhandled environment in triple");
1046 }
1047
1048 SemaHLSL::SemanticContext InputSC = {};
1049 InputSC.CurrentIOType = IOType::In;
1050
1051 for (ParmVarDecl *Param : FD->parameters()) {
1052 SemanticInfo ActiveSemantic;
1053 ActiveSemantic.Semantic = Param->getAttr<HLSLParsedSemanticAttr>();
1054 if (ActiveSemantic.Semantic)
1055 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
1056
1057 // FIXME: Verify output semantics in parameters.
1058 if (!determineActiveSemantic(FD, OutputDecl: Param, D: Param, ActiveSemantic, SC&: InputSC)) {
1059 Diag(Loc: Param->getLocation(), DiagID: diag::note_previous_decl) << Param;
1060 FD->setInvalidDecl();
1061 }
1062 }
1063
1064 SemanticInfo ActiveSemantic;
1065 SemaHLSL::SemanticContext OutputSC = {};
1066 OutputSC.CurrentIOType = IOType::Out;
1067 ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
1068 if (ActiveSemantic.Semantic)
1069 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
1070 if (!FD->getReturnType()->isVoidType())
1071 determineActiveSemantic(FD, OutputDecl: FD, D: FD, ActiveSemantic, SC&: OutputSC);
1072}
1073
1074void SemaHLSL::checkSemanticAnnotation(
1075 FunctionDecl *EntryPoint, const Decl *Param,
1076 const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC) {
1077 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
1078 assert(ShaderAttr && "Entry point has no shader attribute");
1079 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
1080
1081 auto SemanticName = SemanticAttr->getSemanticName().upper();
1082 if (SemanticName == "SV_DISPATCHTHREADID" ||
1083 SemanticName == "SV_GROUPINDEX" || SemanticName == "SV_GROUPTHREADID" ||
1084 SemanticName == "SV_GROUPID") {
1085
1086 if (ST != llvm::Triple::Compute)
1087 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1088 AllowedStages: {{.Stage: llvm::Triple::Compute, .AllowedIOTypesMask: IOType::In}});
1089
1090 if (SemanticAttr->getSemanticIndex() != 0) {
1091 std::string PrettyName =
1092 "'" + SemanticAttr->getSemanticName().str() + "'";
1093 Diag(Loc: SemanticAttr->getLoc(),
1094 DiagID: diag::err_hlsl_semantic_indexing_not_supported)
1095 << PrettyName;
1096 }
1097 return;
1098 }
1099
1100 if (SemanticName == "SV_POSITION") {
1101 // SV_Position can be an input or output in vertex shaders,
1102 // but only an input in pixel shaders.
1103 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1104 AllowedStages: {{.Stage: llvm::Triple::Vertex, .AllowedIOTypesMask: IOType::InOut},
1105 {.Stage: llvm::Triple::Pixel, .AllowedIOTypesMask: IOType::In}});
1106 return;
1107 }
1108 if (SemanticName == "SV_VERTEXID") {
1109 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1110 AllowedStages: {{.Stage: llvm::Triple::Vertex, .AllowedIOTypesMask: IOType::In}});
1111 return;
1112 }
1113
1114 if (SemanticName == "SV_TARGET") {
1115 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1116 AllowedStages: {{.Stage: llvm::Triple::Pixel, .AllowedIOTypesMask: IOType::Out}});
1117 return;
1118 }
1119
1120 // FIXME: catch-all for non-implemented system semantics reaching this
1121 // location.
1122 if (SemanticAttr->getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
1123 llvm_unreachable("Unknown SemanticAttr");
1124}
1125
1126void SemaHLSL::diagnoseAttrStageMismatch(
1127 const Attr *A, llvm::Triple::EnvironmentType Stage,
1128 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
1129 SmallVector<StringRef, 8> StageStrings;
1130 llvm::transform(Range&: AllowedStages, d_first: std::back_inserter(x&: StageStrings),
1131 F: [](llvm::Triple::EnvironmentType ST) {
1132 return StringRef(
1133 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: ST));
1134 });
1135 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
1136 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
1137 << (AllowedStages.size() != 1) << join(R&: StageStrings, Separator: ", ");
1138}
1139
1140void SemaHLSL::diagnoseSemanticStageMismatch(
1141 const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
1142 std::initializer_list<SemanticStageInfo> Allowed) {
1143
1144 for (auto &Case : Allowed) {
1145 if (Case.Stage != Stage)
1146 continue;
1147
1148 if (CurrentIOType & Case.AllowedIOTypesMask)
1149 return;
1150
1151 SmallVector<std::string, 8> ValidCases;
1152 llvm::transform(
1153 Range&: Allowed, d_first: std::back_inserter(x&: ValidCases), F: [](SemanticStageInfo Case) {
1154 SmallVector<std::string, 2> ValidType;
1155 if (Case.AllowedIOTypesMask & IOType::In)
1156 ValidType.push_back(Elt: "input");
1157 if (Case.AllowedIOTypesMask & IOType::Out)
1158 ValidType.push_back(Elt: "output");
1159 return std::string(
1160 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: Case.Stage)) +
1161 " " + join(R&: ValidType, Separator: "/");
1162 });
1163 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1164 << A->getAttrName() << (CurrentIOType & IOType::In ? "input" : "output")
1165 << llvm::Triple::getEnvironmentTypeName(Kind: Case.Stage)
1166 << join(R&: ValidCases, Separator: ", ");
1167 return;
1168 }
1169
1170 SmallVector<StringRef, 8> StageStrings;
1171 llvm::transform(
1172 Range&: Allowed, d_first: std::back_inserter(x&: StageStrings), F: [](SemanticStageInfo Case) {
1173 return StringRef(
1174 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: Case.Stage));
1175 });
1176
1177 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
1178 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
1179 << (Allowed.size() != 1) << join(R&: StageStrings, Separator: ", ");
1180}
1181
1182template <CastKind Kind>
1183static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
1184 if (const auto *VTy = Ty->getAs<VectorType>())
1185 Ty = VTy->getElementType();
1186 Ty = S.getASTContext().getExtVectorType(VectorType: Ty, NumElts: Sz);
1187 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
1188}
1189
1190template <CastKind Kind>
1191static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
1192 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
1193 return Ty;
1194}
1195
1196static QualType handleFloatVectorBinOpConversion(
1197 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1198 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1199 bool LHSFloat = LElTy->isRealFloatingType();
1200 bool RHSFloat = RElTy->isRealFloatingType();
1201
1202 if (LHSFloat && RHSFloat) {
1203 if (IsCompAssign ||
1204 SemaRef.getASTContext().getFloatingTypeOrder(LHS: LElTy, RHS: RElTy) > 0)
1205 return castElement<CK_FloatingCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1206
1207 return castElement<CK_FloatingCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1208 }
1209
1210 if (LHSFloat)
1211 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: RHS, Ty: LHSType);
1212
1213 assert(RHSFloat);
1214 if (IsCompAssign)
1215 return castElement<clang::CK_FloatingToIntegral>(S&: SemaRef, E&: RHS, Ty: LHSType);
1216
1217 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: LHS, Ty: RHSType);
1218}
1219
1220static QualType handleIntegerVectorBinOpConversion(
1221 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1222 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1223
1224 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LHS: LElTy, RHS: RElTy);
1225 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
1226 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
1227 auto &Ctx = SemaRef.getASTContext();
1228
1229 // If both types have the same signedness, use the higher ranked type.
1230 if (LHSSigned == RHSSigned) {
1231 if (IsCompAssign || IntOrder >= 0)
1232 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1233
1234 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1235 }
1236
1237 // If the unsigned type has greater than or equal rank of the signed type, use
1238 // the unsigned type.
1239 if (IntOrder != (LHSSigned ? 1 : -1)) {
1240 if (IsCompAssign || RHSSigned)
1241 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1242 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1243 }
1244
1245 // At this point the signed type has higher rank than the unsigned type, which
1246 // means it will be the same size or bigger. If the signed type is bigger, it
1247 // can represent all the values of the unsigned type, so select it.
1248 if (Ctx.getIntWidth(T: LElTy) != Ctx.getIntWidth(T: RElTy)) {
1249 if (IsCompAssign || LHSSigned)
1250 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1251 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1252 }
1253
1254 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
1255 // to C/C++ leaking through. The place this happens today is long vs long
1256 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
1257 // the long long has higher rank than long even though they are the same size.
1258
1259 // If this is a compound assignment cast the right hand side to the left hand
1260 // side's type.
1261 if (IsCompAssign)
1262 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1263
1264 // If this isn't a compound assignment we convert to unsigned long long.
1265 QualType ElTy = Ctx.getCorrespondingUnsignedType(T: LHSSigned ? LElTy : RElTy);
1266 QualType NewTy = Ctx.getExtVectorType(
1267 VectorType: ElTy, NumElts: RHSType->castAs<VectorType>()->getNumElements());
1268 (void)castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: NewTy);
1269
1270 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: NewTy);
1271}
1272
1273static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
1274 QualType SrcTy) {
1275 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
1276 return CK_FloatingCast;
1277 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
1278 return CK_IntegralCast;
1279 if (DestTy->isRealFloatingType())
1280 return CK_IntegralToFloating;
1281 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
1282 return CK_FloatingToIntegral;
1283}
1284
1285QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
1286 QualType LHSType,
1287 QualType RHSType,
1288 bool IsCompAssign) {
1289 const auto *LVecTy = LHSType->getAs<VectorType>();
1290 const auto *RVecTy = RHSType->getAs<VectorType>();
1291 auto &Ctx = getASTContext();
1292
1293 // If the LHS is not a vector and this is a compound assignment, we truncate
1294 // the argument to a scalar then convert it to the LHS's type.
1295 if (!LVecTy && IsCompAssign) {
1296 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1297 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: RElTy, CK: CK_HLSLVectorTruncation);
1298 RHSType = RHS.get()->getType();
1299 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1300 return LHSType;
1301 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: LHSType,
1302 CK: getScalarCastKind(Ctx, DestTy: LHSType, SrcTy: RHSType));
1303 return LHSType;
1304 }
1305
1306 unsigned EndSz = std::numeric_limits<unsigned>::max();
1307 unsigned LSz = 0;
1308 if (LVecTy)
1309 LSz = EndSz = LVecTy->getNumElements();
1310 if (RVecTy)
1311 EndSz = std::min(a: RVecTy->getNumElements(), b: EndSz);
1312 assert(EndSz != std::numeric_limits<unsigned>::max() &&
1313 "one of the above should have had a value");
1314
1315 // In a compound assignment, the left operand does not change type, the right
1316 // operand is converted to the type of the left operand.
1317 if (IsCompAssign && LSz != EndSz) {
1318 Diag(Loc: LHS.get()->getBeginLoc(),
1319 DiagID: diag::err_hlsl_vector_compound_assignment_truncation)
1320 << LHSType << RHSType;
1321 return QualType();
1322 }
1323
1324 if (RVecTy && RVecTy->getNumElements() > EndSz)
1325 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1326 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1327 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1328
1329 if (!RVecTy)
1330 castVector<CK_VectorSplat>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1331 if (!IsCompAssign && !LVecTy)
1332 castVector<CK_VectorSplat>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1333
1334 // If we're at the same type after resizing we can stop here.
1335 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1336 return Ctx.getCommonSugaredType(X: LHSType, Y: RHSType);
1337
1338 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1339 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1340
1341 // Handle conversion for floating point vectors.
1342 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1343 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1344 LElTy, RElTy, IsCompAssign);
1345
1346 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1347 "HLSL Vectors can only contain integer or floating point types");
1348 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1349 LElTy, RElTy, IsCompAssign);
1350}
1351
1352void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
1353 BinaryOperatorKind Opc) {
1354 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1355 "Called with non-logical operator");
1356 llvm::SmallVector<char, 256> Buff;
1357 llvm::raw_svector_ostream OS(Buff);
1358 PrintingPolicy PP(SemaRef.getLangOpts());
1359 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1360 OS << NewFnName << "(";
1361 LHS->printPretty(OS, Helper: nullptr, Policy: PP);
1362 OS << ", ";
1363 RHS->printPretty(OS, Helper: nullptr, Policy: PP);
1364 OS << ")";
1365 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1366 SemaRef.Diag(Loc: LHS->getBeginLoc(), DiagID: diag::note_function_suggestion)
1367 << NewFnName << FixItHint::CreateReplacement(RemoveRange: FullRange, Code: OS.str());
1368}
1369
1370std::pair<IdentifierInfo *, bool>
1371SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) {
1372 llvm::hash_code Hash = llvm::hash_value(S: Signature);
1373 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(val: Hash);
1374 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(Name: IdStr));
1375
1376 // Check if we have already found a decl of the same name.
1377 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1378 Sema::LookupOrdinaryName);
1379 bool Found = SemaRef.LookupQualifiedName(R, LookupCtx: SemaRef.CurContext);
1380 return {DeclIdent, Found};
1381}
1382
1383void SemaHLSL::ActOnFinishRootSignatureDecl(
1384 SourceLocation Loc, IdentifierInfo *DeclIdent,
1385 ArrayRef<hlsl::RootSignatureElement> RootElements) {
1386
1387 if (handleRootSignatureElements(Elements: RootElements))
1388 return;
1389
1390 SmallVector<llvm::hlsl::rootsig::RootElement> Elements;
1391 for (auto &RootSigElement : RootElements)
1392 Elements.push_back(Elt: RootSigElement.getElement());
1393
1394 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1395 C&: SemaRef.getASTContext(), /*DeclContext=*/DC: SemaRef.CurContext, Loc,
1396 ID: DeclIdent, Version: SemaRef.getLangOpts().HLSLRootSigVer, RootElements: Elements);
1397
1398 SignatureDecl->setImplicit();
1399 SemaRef.PushOnScopeChains(D: SignatureDecl, S: SemaRef.getCurScope());
1400}
1401
1402HLSLRootSignatureDecl *
1403SemaHLSL::lookupRootSignatureOverrideDecl(DeclContext *DC) const {
1404 if (RootSigOverrideIdent) {
1405 LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(),
1406 Sema::LookupOrdinaryName);
1407 if (SemaRef.LookupQualifiedName(R, LookupCtx: DC))
1408 return dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl());
1409 }
1410
1411 return nullptr;
1412}
1413
1414namespace {
1415
1416struct PerVisibilityBindingChecker {
1417 SemaHLSL *S;
1418 // We need one builder per `llvm::dxbc::ShaderVisibility` value.
1419 std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
1420
1421 struct ElemInfo {
1422 const hlsl::RootSignatureElement *Elem;
1423 llvm::dxbc::ShaderVisibility Vis;
1424 bool Diagnosed;
1425 };
1426 llvm::SmallVector<ElemInfo> ElemInfoMap;
1427
1428 PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
1429
1430 void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
1431 llvm::dxil::ResourceClass RC, uint32_t Space,
1432 uint32_t LowerBound, uint32_t UpperBound,
1433 const hlsl::RootSignatureElement *Elem) {
1434 uint32_t BuilderIndex = llvm::to_underlying(E: Visibility);
1435 assert(BuilderIndex < Builders.size() &&
1436 "Not enough builders for visibility type");
1437 Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
1438 Cookie: static_cast<const void *>(Elem));
1439
1440 static_assert(llvm::to_underlying(E: llvm::dxbc::ShaderVisibility::All) == 0,
1441 "'All' visibility must come first");
1442 if (Visibility == llvm::dxbc::ShaderVisibility::All)
1443 for (size_t I = 1, E = Builders.size(); I < E; ++I)
1444 Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
1445 Cookie: static_cast<const void *>(Elem));
1446
1447 ElemInfoMap.push_back(Elt: {.Elem: Elem, .Vis: Visibility, .Diagnosed: false});
1448 }
1449
1450 ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
1451 auto It = llvm::lower_bound(
1452 Range&: ElemInfoMap, Value&: Elem,
1453 C: [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
1454 assert(It->Elem == Elem && "Element not in map");
1455 return *It;
1456 }
1457
1458 bool checkOverlap() {
1459 llvm::sort(C&: ElemInfoMap, Comp: [](const auto &LHS, const auto &RHS) {
1460 return LHS.Elem < RHS.Elem;
1461 });
1462
1463 bool HadOverlap = false;
1464
1465 using llvm::hlsl::BindingInfoBuilder;
1466 auto ReportOverlap = [this,
1467 &HadOverlap](const BindingInfoBuilder &Builder,
1468 const llvm::hlsl::Binding &Reported) {
1469 HadOverlap = true;
1470
1471 const auto *Elem =
1472 static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
1473 const llvm::hlsl::Binding &Previous = Builder.findOverlapping(ReportedBinding: Reported);
1474 const auto *PrevElem =
1475 static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
1476
1477 ElemInfo &Info = getInfo(Elem);
1478 // We will have already diagnosed this binding if there's overlap in the
1479 // "All" visibility as well as any particular visibility.
1480 if (Info.Diagnosed)
1481 return;
1482 Info.Diagnosed = true;
1483
1484 ElemInfo &PrevInfo = getInfo(Elem: PrevElem);
1485 llvm::dxbc::ShaderVisibility CommonVis =
1486 Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
1487 : Info.Vis;
1488
1489 this->S->Diag(Loc: Elem->getLocation(), DiagID: diag::err_hlsl_resource_range_overlap)
1490 << llvm::to_underlying(E: Reported.RC) << Reported.LowerBound
1491 << Reported.isUnbounded() << Reported.UpperBound
1492 << llvm::to_underlying(E: Previous.RC) << Previous.LowerBound
1493 << Previous.isUnbounded() << Previous.UpperBound << Reported.Space
1494 << CommonVis;
1495
1496 this->S->Diag(Loc: PrevElem->getLocation(),
1497 DiagID: diag::note_hlsl_resource_range_here);
1498 };
1499
1500 for (BindingInfoBuilder &Builder : Builders)
1501 Builder.calculateBindingInfo(ReportOverlap);
1502
1503 return HadOverlap;
1504 }
1505};
1506
1507static CXXMethodDecl *lookupMethod(Sema &S, CXXRecordDecl *RecordDecl,
1508 StringRef Name, SourceLocation Loc) {
1509 DeclarationName DeclName(&S.getASTContext().Idents.get(Name));
1510 LookupResult Result(S, DeclName, Loc, Sema::LookupMemberName);
1511 if (!S.LookupQualifiedName(R&: Result, LookupCtx: static_cast<DeclContext *>(RecordDecl)))
1512 return nullptr;
1513 return cast<CXXMethodDecl>(Val: Result.getFoundDecl());
1514}
1515
1516} // end anonymous namespace
1517
1518bool SemaHLSL::handleRootSignatureElements(
1519 ArrayRef<hlsl::RootSignatureElement> Elements) {
1520 // Define some common error handling functions
1521 bool HadError = false;
1522 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1523 uint32_t UpperBound) {
1524 HadError = true;
1525 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1526 << LowerBound << UpperBound;
1527 };
1528
1529 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1530 float LowerBound,
1531 float UpperBound) {
1532 HadError = true;
1533 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1534 << llvm::formatv(Fmt: "{0:f}", Vals&: LowerBound).sstr<6>()
1535 << llvm::formatv(Fmt: "{0:f}", Vals&: UpperBound).sstr<6>();
1536 };
1537
1538 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1539 if (!llvm::hlsl::rootsig::verifyRegisterValue(RegisterValue: Register))
1540 ReportError(Loc, 0, 0xfffffffe);
1541 };
1542
1543 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1544 if (!llvm::hlsl::rootsig::verifyRegisterSpace(RegisterSpace: Space))
1545 ReportError(Loc, 0, 0xffffffef);
1546 };
1547
1548 const uint32_t Version =
1549 llvm::to_underlying(E: SemaRef.getLangOpts().HLSLRootSigVer);
1550 const uint32_t VersionEnum = Version - 1;
1551 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1552 HadError = true;
1553 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_flag)
1554 << /*version minor*/ VersionEnum;
1555 };
1556
1557 // Iterate through the elements and do basic validations
1558 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1559 SourceLocation Loc = RootSigElem.getLocation();
1560 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1561 if (const auto *Descriptor =
1562 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1563 VerifyRegister(Loc, Descriptor->Reg.Number);
1564 VerifySpace(Loc, Descriptor->Space);
1565
1566 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
1567 Flags: Descriptor->Flags))
1568 ReportFlagError(Loc);
1569 } else if (const auto *Constants =
1570 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1571 VerifyRegister(Loc, Constants->Reg.Number);
1572 VerifySpace(Loc, Constants->Space);
1573 } else if (const auto *Sampler =
1574 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1575 VerifyRegister(Loc, Sampler->Reg.Number);
1576 VerifySpace(Loc, Sampler->Space);
1577
1578 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1579 "By construction, parseFloatParam can't produce a NaN from a "
1580 "float_literal token");
1581
1582 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(MaxAnisotropy: Sampler->MaxAnisotropy))
1583 ReportError(Loc, 0, 16);
1584 if (!llvm::hlsl::rootsig::verifyMipLODBias(MipLODBias: Sampler->MipLODBias))
1585 ReportFloatError(Loc, -16.f, 15.99f);
1586 } else if (const auto *Clause =
1587 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1588 ptr: &Elem)) {
1589 VerifyRegister(Loc, Clause->Reg.Number);
1590 VerifySpace(Loc, Clause->Space);
1591
1592 if (!llvm::hlsl::rootsig::verifyNumDescriptors(NumDescriptors: Clause->NumDescriptors)) {
1593 // NumDescriptor could techincally be ~0u but that is reserved for
1594 // unbounded, so the diagnostic will not report that as a valid int
1595 // value
1596 ReportError(Loc, 1, 0xfffffffe);
1597 }
1598
1599 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Type: Clause->Type,
1600 Flags: Clause->Flags))
1601 ReportFlagError(Loc);
1602 }
1603 }
1604
1605 PerVisibilityBindingChecker BindingChecker(this);
1606 SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
1607 const hlsl::RootSignatureElement *>>
1608 UnboundClauses;
1609
1610 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1611 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1612 if (const auto *Descriptor =
1613 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1614 uint32_t LowerBound(Descriptor->Reg.Number);
1615 uint32_t UpperBound(LowerBound); // inclusive range
1616
1617 BindingChecker.trackBinding(
1618 Visibility: Descriptor->Visibility,
1619 RC: static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
1620 Space: Descriptor->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1621 } else if (const auto *Constants =
1622 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1623 uint32_t LowerBound(Constants->Reg.Number);
1624 uint32_t UpperBound(LowerBound); // inclusive range
1625
1626 BindingChecker.trackBinding(
1627 Visibility: Constants->Visibility, RC: llvm::dxil::ResourceClass::CBuffer,
1628 Space: Constants->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1629 } else if (const auto *Sampler =
1630 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1631 uint32_t LowerBound(Sampler->Reg.Number);
1632 uint32_t UpperBound(LowerBound); // inclusive range
1633
1634 BindingChecker.trackBinding(
1635 Visibility: Sampler->Visibility, RC: llvm::dxil::ResourceClass::Sampler,
1636 Space: Sampler->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1637 } else if (const auto *Clause =
1638 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1639 ptr: &Elem)) {
1640 // We'll process these once we see the table element.
1641 UnboundClauses.emplace_back(Args&: Clause, Args: &RootSigElem);
1642 } else if (const auto *Table =
1643 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(ptr: &Elem)) {
1644 assert(UnboundClauses.size() == Table->NumClauses &&
1645 "Number of unbound elements must match the number of clauses");
1646 bool HasAnySampler = false;
1647 bool HasAnyNonSampler = false;
1648 uint64_t Offset = 0;
1649 bool IsPrevUnbound = false;
1650 for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1651 SourceLocation Loc = ClauseElem->getLocation();
1652 if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
1653 HasAnySampler = true;
1654 else
1655 HasAnyNonSampler = true;
1656
1657 if (HasAnySampler && HasAnyNonSampler)
1658 Diag(Loc, DiagID: diag::err_hlsl_invalid_mixed_resources);
1659
1660 // Relevant error will have already been reported above and needs to be
1661 // fixed before we can conduct further analysis, so shortcut error
1662 // return
1663 if (Clause->NumDescriptors == 0)
1664 return true;
1665
1666 bool IsAppending =
1667 Clause->Offset == llvm::hlsl::rootsig::DescriptorTableOffsetAppend;
1668 if (!IsAppending)
1669 Offset = Clause->Offset;
1670
1671 uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
1672 Offset, Size: Clause->NumDescriptors);
1673
1674 if (IsPrevUnbound && IsAppending)
1675 Diag(Loc, DiagID: diag::err_hlsl_appending_onto_unbound);
1676 else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(Offset: RangeBound))
1677 Diag(Loc, DiagID: diag::err_hlsl_offset_overflow) << Offset << RangeBound;
1678
1679 // Update offset to be 1 past this range's bound
1680 Offset = RangeBound + 1;
1681 IsPrevUnbound = Clause->NumDescriptors ==
1682 llvm::hlsl::rootsig::NumDescriptorsUnbounded;
1683
1684 // Compute the register bounds and track resource binding
1685 uint32_t LowerBound(Clause->Reg.Number);
1686 uint32_t UpperBound = llvm::hlsl::rootsig::computeRangeBound(
1687 Offset: LowerBound, Size: Clause->NumDescriptors);
1688
1689 BindingChecker.trackBinding(
1690 Visibility: Table->Visibility,
1691 RC: static_cast<llvm::dxil::ResourceClass>(Clause->Type), Space: Clause->Space,
1692 LowerBound, UpperBound, Elem: ClauseElem);
1693 }
1694 UnboundClauses.clear();
1695 }
1696 }
1697
1698 return BindingChecker.checkOverlap();
1699}
1700
1701void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
1702 if (AL.getNumArgs() != 1) {
1703 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
1704 return;
1705 }
1706
1707 IdentifierInfo *Ident = AL.getArgAsIdent(Arg: 0)->getIdentifierInfo();
1708 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1709 if (RS->getSignatureIdent() != Ident) {
1710 Diag(Loc: AL.getLoc(), DiagID: diag::err_disallowed_duplicate_attribute) << RS;
1711 return;
1712 }
1713
1714 Diag(Loc: AL.getLoc(), DiagID: diag::warn_duplicate_attribute_exact) << RS;
1715 return;
1716 }
1717
1718 LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName);
1719 if (SemaRef.LookupQualifiedName(R, LookupCtx: D->getDeclContext()))
1720 if (auto *SignatureDecl =
1721 dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl())) {
1722 D->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
1723 getASTContext(), AL, Ident, SignatureDecl));
1724 }
1725}
1726
1727void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
1728 llvm::VersionTuple SMVersion =
1729 getASTContext().getTargetInfo().getTriple().getOSVersion();
1730 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1731 llvm::Triple::dxil;
1732
1733 uint32_t ZMax = 1024;
1734 uint32_t ThreadMax = 1024;
1735 if (IsDXIL && SMVersion.getMajor() <= 4) {
1736 ZMax = 1;
1737 ThreadMax = 768;
1738 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1739 ZMax = 64;
1740 ThreadMax = 1024;
1741 }
1742
1743 uint32_t X;
1744 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: X))
1745 return;
1746 if (X > 1024) {
1747 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1748 DiagID: diag::err_hlsl_numthreads_argument_oor)
1749 << 0 << 1024;
1750 return;
1751 }
1752 uint32_t Y;
1753 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Y))
1754 return;
1755 if (Y > 1024) {
1756 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1757 DiagID: diag::err_hlsl_numthreads_argument_oor)
1758 << 1 << 1024;
1759 return;
1760 }
1761 uint32_t Z;
1762 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Z))
1763 return;
1764 if (Z > ZMax) {
1765 SemaRef.Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1766 DiagID: diag::err_hlsl_numthreads_argument_oor)
1767 << 2 << ZMax;
1768 return;
1769 }
1770
1771 if (X * Y * Z > ThreadMax) {
1772 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_numthreads_invalid) << ThreadMax;
1773 return;
1774 }
1775
1776 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1777 if (NewAttr)
1778 D->addAttr(A: NewAttr);
1779}
1780
1781static bool isValidWaveSizeValue(unsigned Value) {
1782 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1783}
1784
1785void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
1786 // validate that the wavesize argument is a power of 2 between 4 and 128
1787 // inclusive
1788 unsigned SpelledArgsCount = AL.getNumArgs();
1789 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1790 return;
1791
1792 uint32_t Min;
1793 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Min))
1794 return;
1795
1796 uint32_t Max = 0;
1797 if (SpelledArgsCount > 1 &&
1798 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Max))
1799 return;
1800
1801 uint32_t Preferred = 0;
1802 if (SpelledArgsCount > 2 &&
1803 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Preferred))
1804 return;
1805
1806 if (SpelledArgsCount > 2) {
1807 if (!isValidWaveSizeValue(Value: Preferred)) {
1808 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1809 DiagID: diag::err_attribute_power_of_two_in_range)
1810 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1811 << Preferred;
1812 return;
1813 }
1814 // Preferred not in range.
1815 if (Preferred < Min || Preferred > Max) {
1816 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1817 DiagID: diag::err_attribute_power_of_two_in_range)
1818 << AL << Min << Max << Preferred;
1819 return;
1820 }
1821 } else if (SpelledArgsCount > 1) {
1822 if (!isValidWaveSizeValue(Value: Max)) {
1823 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1824 DiagID: diag::err_attribute_power_of_two_in_range)
1825 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1826 return;
1827 }
1828 if (Max < Min) {
1829 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_invalid) << AL << 1;
1830 return;
1831 } else if (Max == Min) {
1832 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attr_min_eq_max) << AL;
1833 }
1834 } else {
1835 if (!isValidWaveSizeValue(Value: Min)) {
1836 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1837 DiagID: diag::err_attribute_power_of_two_in_range)
1838 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1839 return;
1840 }
1841 }
1842
1843 HLSLWaveSizeAttr *NewAttr =
1844 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1845 if (NewAttr)
1846 D->addAttr(A: NewAttr);
1847}
1848
1849void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
1850 uint32_t ID;
1851 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: ID))
1852 return;
1853 D->addAttr(A: ::new (getASTContext())
1854 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1855}
1856
1857void SemaHLSL::handleVkExtBuiltinOutputAttr(Decl *D, const ParsedAttr &AL) {
1858 uint32_t ID;
1859 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: ID))
1860 return;
1861 D->addAttr(A: ::new (getASTContext())
1862 HLSLVkExtBuiltinOutputAttr(getASTContext(), AL, ID));
1863}
1864
1865void SemaHLSL::handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL) {
1866 D->addAttr(A: ::new (getASTContext())
1867 HLSLVkPushConstantAttr(getASTContext(), AL));
1868}
1869
1870void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1871 uint32_t Id;
1872 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Id))
1873 return;
1874 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1875 if (NewAttr)
1876 D->addAttr(A: NewAttr);
1877}
1878
1879void SemaHLSL::handleVkBindingAttr(Decl *D, const ParsedAttr &AL) {
1880 uint32_t Binding = 0;
1881 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Binding))
1882 return;
1883 uint32_t Set = 0;
1884 if (AL.getNumArgs() > 1 &&
1885 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Set))
1886 return;
1887
1888 D->addAttr(A: ::new (getASTContext())
1889 HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
1890}
1891
1892void SemaHLSL::handleVkLocationAttr(Decl *D, const ParsedAttr &AL) {
1893 uint32_t Location;
1894 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Location))
1895 return;
1896
1897 D->addAttr(A: ::new (getASTContext())
1898 HLSLVkLocationAttr(getASTContext(), AL, Location));
1899}
1900
1901bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
1902 const auto *VT = T->getAs<VectorType>();
1903
1904 if (!T->hasUnsignedIntegerRepresentation() ||
1905 (VT && VT->getNumElements() > 3)) {
1906 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1907 << AL << "uint/uint2/uint3";
1908 return false;
1909 }
1910
1911 return true;
1912}
1913
1914bool SemaHLSL::diagnosePositionType(QualType T, const ParsedAttr &AL) {
1915 const auto *VT = T->getAs<VectorType>();
1916 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1917 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1918 << AL << "float/float1/float2/float3/float4";
1919 return false;
1920 }
1921
1922 return true;
1923}
1924
1925void SemaHLSL::diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL,
1926 std::optional<unsigned> Index) {
1927 std::string SemanticName = AL.getAttrName()->getName().upper();
1928
1929 auto *VD = cast<ValueDecl>(Val: D);
1930 QualType ValueType = VD->getType();
1931 if (auto *FD = dyn_cast<FunctionDecl>(Val: D))
1932 ValueType = FD->getReturnType();
1933
1934 bool IsOutput = false;
1935 if (HLSLParamModifierAttr *MA = D->getAttr<HLSLParamModifierAttr>()) {
1936 if (MA->isOut()) {
1937 IsOutput = true;
1938 ValueType = cast<ReferenceType>(Val&: ValueType)->getPointeeType();
1939 }
1940 }
1941
1942 if (SemanticName == "SV_DISPATCHTHREADID") {
1943 diagnoseInputIDType(T: ValueType, AL);
1944 if (IsOutput)
1945 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1946 if (Index.has_value())
1947 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1948 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1949 return;
1950 }
1951
1952 if (SemanticName == "SV_GROUPINDEX") {
1953 if (IsOutput)
1954 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1955 if (Index.has_value())
1956 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1957 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1958 return;
1959 }
1960
1961 if (SemanticName == "SV_GROUPTHREADID") {
1962 diagnoseInputIDType(T: ValueType, AL);
1963 if (IsOutput)
1964 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1965 if (Index.has_value())
1966 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1967 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1968 return;
1969 }
1970
1971 if (SemanticName == "SV_GROUPID") {
1972 diagnoseInputIDType(T: ValueType, AL);
1973 if (IsOutput)
1974 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1975 if (Index.has_value())
1976 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1977 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1978 return;
1979 }
1980
1981 if (SemanticName == "SV_POSITION") {
1982 const auto *VT = ValueType->getAs<VectorType>();
1983 if (!ValueType->hasFloatingRepresentation() ||
1984 (VT && VT->getNumElements() > 4))
1985 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1986 << AL << "float/float1/float2/float3/float4";
1987 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1988 return;
1989 }
1990
1991 if (SemanticName == "SV_VERTEXID") {
1992 uint64_t SizeInBits = SemaRef.Context.getTypeSize(T: ValueType);
1993 if (!ValueType->isUnsignedIntegerType() || SizeInBits != 32)
1994 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type) << AL << "uint";
1995 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1996 return;
1997 }
1998
1999 if (SemanticName == "SV_TARGET") {
2000 const auto *VT = ValueType->getAs<VectorType>();
2001 if (!ValueType->hasFloatingRepresentation() ||
2002 (VT && VT->getNumElements() > 4))
2003 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
2004 << AL << "float/float1/float2/float3/float4";
2005 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
2006 return;
2007 }
2008
2009 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_unknown_semantic) << AL;
2010}
2011
2012void SemaHLSL::handleSemanticAttr(Decl *D, const ParsedAttr &AL) {
2013 uint32_t IndexValue(0), ExplicitIndex(0);
2014 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: IndexValue) ||
2015 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: ExplicitIndex)) {
2016 assert(0 && "HLSLUnparsedSemantic is expected to have 2 int arguments.");
2017 }
2018 assert(IndexValue > 0 ? ExplicitIndex : true);
2019 std::optional<unsigned> Index =
2020 ExplicitIndex ? std::optional<unsigned>(IndexValue) : std::nullopt;
2021
2022 if (AL.getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
2023 diagnoseSystemSemanticAttr(D, AL, Index);
2024 else
2025 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
2026}
2027
2028void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
2029 if (!isa<VarDecl>(Val: D) || !isa<HLSLBufferDecl>(Val: D->getDeclContext())) {
2030 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_ast_node)
2031 << AL << "shader constant in a constant buffer";
2032 return;
2033 }
2034
2035 uint32_t SubComponent;
2036 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: SubComponent))
2037 return;
2038 uint32_t Component;
2039 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Component))
2040 return;
2041
2042 QualType T = cast<VarDecl>(Val: D)->getType().getCanonicalType();
2043 // Check if T is an array or struct type.
2044 // TODO: mark matrix type as aggregate type.
2045 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
2046
2047 // Check Component is valid for T.
2048 if (Component) {
2049 unsigned Size = getASTContext().getTypeSize(T);
2050 if (IsAggregateTy) {
2051 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_invalid_register_or_packoffset);
2052 return;
2053 } else {
2054 // Make sure Component + sizeof(T) <= 4.
2055 if ((Component * 32 + Size) > 128) {
2056 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_cross_reg_boundary);
2057 return;
2058 }
2059 QualType EltTy = T;
2060 if (const auto *VT = T->getAs<VectorType>())
2061 EltTy = VT->getElementType();
2062 unsigned Align = getASTContext().getTypeAlign(T: EltTy);
2063 if (Align > 32 && Component == 1) {
2064 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
2065 // So we only need to check Component 1 here.
2066 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_alignment_mismatch)
2067 << Align << EltTy;
2068 return;
2069 }
2070 }
2071 }
2072
2073 D->addAttr(A: ::new (getASTContext()) HLSLPackOffsetAttr(
2074 getASTContext(), AL, SubComponent, Component));
2075}
2076
2077void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
2078 StringRef Str;
2079 SourceLocation ArgLoc;
2080 if (!SemaRef.checkStringLiteralArgumentAttr(Attr: AL, ArgNum: 0, Str, ArgLocation: &ArgLoc))
2081 return;
2082
2083 llvm::Triple::EnvironmentType ShaderType;
2084 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Val: Str, Out&: ShaderType)) {
2085 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_type_not_supported)
2086 << AL << Str << ArgLoc;
2087 return;
2088 }
2089
2090 // FIXME: check function match the shader stage.
2091
2092 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
2093 if (NewAttr)
2094 D->addAttr(A: NewAttr);
2095}
2096
2097bool clang::CreateHLSLAttributedResourceType(
2098 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
2099 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
2100 assert(AttrList.size() && "expected list of resource attributes");
2101
2102 QualType ContainedTy = QualType();
2103 TypeSourceInfo *ContainedTyInfo = nullptr;
2104 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
2105 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
2106
2107 HLSLAttributedResourceType::Attributes ResAttrs;
2108
2109 bool HasResourceClass = false;
2110 bool HasResourceDimension = false;
2111 for (const Attr *A : AttrList) {
2112 if (!A)
2113 continue;
2114 LocEnd = A->getRange().getEnd();
2115 switch (A->getKind()) {
2116 case attr::HLSLResourceClass: {
2117 ResourceClass RC = cast<HLSLResourceClassAttr>(Val: A)->getResourceClass();
2118 if (HasResourceClass) {
2119 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceClass == RC
2120 ? diag::warn_duplicate_attribute_exact
2121 : diag::warn_duplicate_attribute)
2122 << A;
2123 return false;
2124 }
2125 ResAttrs.ResourceClass = RC;
2126 HasResourceClass = true;
2127 break;
2128 }
2129 case attr::HLSLResourceDimension: {
2130 llvm::dxil::ResourceDimension RD =
2131 cast<HLSLResourceDimensionAttr>(Val: A)->getDimension();
2132 if (HasResourceDimension) {
2133 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceDimension == RD
2134 ? diag::warn_duplicate_attribute_exact
2135 : diag::warn_duplicate_attribute)
2136 << A;
2137 return false;
2138 }
2139 ResAttrs.ResourceDimension = RD;
2140 HasResourceDimension = true;
2141 break;
2142 }
2143 case attr::HLSLROV:
2144 if (ResAttrs.IsROV) {
2145 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2146 return false;
2147 }
2148 ResAttrs.IsROV = true;
2149 break;
2150 case attr::HLSLRawBuffer:
2151 if (ResAttrs.RawBuffer) {
2152 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2153 return false;
2154 }
2155 ResAttrs.RawBuffer = true;
2156 break;
2157 case attr::HLSLIsArray:
2158 if (ResAttrs.IsArray) {
2159 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2160 return false;
2161 }
2162 ResAttrs.IsArray = true;
2163 break;
2164 case attr::HLSLIsCounter:
2165 if (ResAttrs.IsCounter) {
2166 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2167 return false;
2168 }
2169 ResAttrs.IsCounter = true;
2170 break;
2171 case attr::HLSLContainedType: {
2172 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(Val: A);
2173 QualType Ty = CTAttr->getType();
2174 if (!ContainedTy.isNull()) {
2175 S.Diag(Loc: A->getLocation(), DiagID: ContainedTy == Ty
2176 ? diag::warn_duplicate_attribute_exact
2177 : diag::warn_duplicate_attribute)
2178 << A;
2179 return false;
2180 }
2181 ContainedTy = Ty;
2182 ContainedTyInfo = CTAttr->getTypeLoc();
2183 break;
2184 }
2185 default:
2186 llvm_unreachable("unhandled resource attribute type");
2187 }
2188 }
2189
2190 if (!HasResourceClass) {
2191 S.Diag(Loc: AttrList.back()->getRange().getEnd(),
2192 DiagID: diag::err_hlsl_missing_resource_class);
2193 return false;
2194 }
2195
2196 ResType = S.getASTContext().getHLSLAttributedResourceType(
2197 Wrapped, Contained: ContainedTy, Attrs: ResAttrs);
2198
2199 if (LocInfo && ContainedTyInfo) {
2200 LocInfo->Range = SourceRange(LocBegin, LocEnd);
2201 LocInfo->ContainedTyInfo = ContainedTyInfo;
2202 }
2203 return true;
2204}
2205
2206// Validates and creates an HLSL attribute that is applied as type attribute on
2207// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
2208// the end of the declaration they are applied to the declaration type by
2209// wrapping it in HLSLAttributedResourceType.
2210bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) {
2211 // only allow resource type attributes on intangible types
2212 if (!T->isHLSLResourceType()) {
2213 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attribute_needs_intangible_type)
2214 << AL << getASTContext().HLSLResourceTy;
2215 return false;
2216 }
2217
2218 // validate number of arguments
2219 if (!AL.checkExactlyNumArgs(S&: SemaRef, Num: AL.getMinArgs()))
2220 return false;
2221
2222 Attr *A = nullptr;
2223
2224 AttributeCommonInfo ACI(
2225 AL.getLoc(), AttributeScopeInfo(AL.getScopeName(), AL.getScopeLoc()),
2226 AttributeCommonInfo::NoSemaHandlerAttribute,
2227 {
2228 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
2229 false /*IsRegularKeywordAttribute*/
2230 });
2231
2232 switch (AL.getKind()) {
2233 case ParsedAttr::AT_HLSLResourceClass: {
2234 if (!AL.isArgIdent(Arg: 0)) {
2235 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2236 << AL << AANT_ArgumentIdentifier;
2237 return false;
2238 }
2239
2240 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
2241 StringRef Identifier = Loc->getIdentifierInfo()->getName();
2242 SourceLocation ArgLoc = Loc->getLoc();
2243
2244 // Validate resource class value
2245 ResourceClass RC;
2246 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Val: Identifier, Out&: RC)) {
2247 Diag(Loc: ArgLoc, DiagID: diag::warn_attribute_type_not_supported)
2248 << "ResourceClass" << Identifier;
2249 return false;
2250 }
2251 A = HLSLResourceClassAttr::Create(Ctx&: getASTContext(), ResourceClass: RC, CommonInfo: ACI);
2252 break;
2253 }
2254
2255 case ParsedAttr::AT_HLSLResourceDimension: {
2256 StringRef Identifier;
2257 SourceLocation ArgLoc;
2258 if (!SemaRef.checkStringLiteralArgumentAttr(Attr: AL, ArgNum: 0, Str&: Identifier, ArgLocation: &ArgLoc))
2259 return false;
2260
2261 // Validate resource dimension value
2262 llvm::dxil::ResourceDimension RD;
2263 if (!HLSLResourceDimensionAttr::ConvertStrToResourceDimension(Val: Identifier,
2264 Out&: RD)) {
2265 Diag(Loc: ArgLoc, DiagID: diag::warn_attribute_type_not_supported)
2266 << "ResourceDimension" << Identifier;
2267 return false;
2268 }
2269 A = HLSLResourceDimensionAttr::Create(Ctx&: getASTContext(), Dimension: RD, CommonInfo: ACI);
2270 break;
2271 }
2272
2273 case ParsedAttr::AT_HLSLROV:
2274 A = HLSLROVAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2275 break;
2276
2277 case ParsedAttr::AT_HLSLRawBuffer:
2278 A = HLSLRawBufferAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2279 break;
2280
2281 case ParsedAttr::AT_HLSLIsCounter:
2282 A = HLSLIsCounterAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2283 break;
2284
2285 case ParsedAttr::AT_HLSLIsArray:
2286 A = HLSLIsArrayAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2287 break;
2288
2289 case ParsedAttr::AT_HLSLContainedType: {
2290 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
2291 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
2292 return false;
2293 }
2294
2295 TypeSourceInfo *TSI = nullptr;
2296 QualType QT = SemaRef.GetTypeFromParser(Ty: AL.getTypeArg(), TInfo: &TSI);
2297 assert(TSI && "no type source info for attribute argument");
2298 if (SemaRef.RequireCompleteType(Loc: TSI->getTypeLoc().getBeginLoc(), T: QT,
2299 DiagID: diag::err_incomplete_type))
2300 return false;
2301 A = HLSLContainedTypeAttr::Create(Ctx&: getASTContext(), Type: TSI, CommonInfo: ACI);
2302 break;
2303 }
2304
2305 default:
2306 llvm_unreachable("unhandled HLSL attribute");
2307 }
2308
2309 HLSLResourcesTypeAttrs.emplace_back(Args&: A);
2310 return true;
2311}
2312
2313// Combines all resource type attributes and creates HLSLAttributedResourceType.
2314QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) {
2315 if (!HLSLResourcesTypeAttrs.size())
2316 return CurrentType;
2317
2318 QualType QT = CurrentType;
2319 HLSLAttributedResourceLocInfo LocInfo;
2320 if (CreateHLSLAttributedResourceType(S&: SemaRef, Wrapped: CurrentType,
2321 AttrList: HLSLResourcesTypeAttrs, ResType&: QT, LocInfo: &LocInfo)) {
2322 const HLSLAttributedResourceType *RT =
2323 cast<HLSLAttributedResourceType>(Val: QT.getTypePtr());
2324
2325 // Temporarily store TypeLoc information for the new type.
2326 // It will be transferred to HLSLAttributesResourceTypeLoc
2327 // shortly after the type is created by TypeSpecLocFiller which
2328 // will call the TakeLocForHLSLAttribute method below.
2329 LocsForHLSLAttributedResources.insert(KV: std::pair(RT, LocInfo));
2330 }
2331 HLSLResourcesTypeAttrs.clear();
2332 return QT;
2333}
2334
2335// Returns source location for the HLSLAttributedResourceType
2336HLSLAttributedResourceLocInfo
2337SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
2338 HLSLAttributedResourceLocInfo LocInfo = {};
2339 auto I = LocsForHLSLAttributedResources.find(Val: RT);
2340 if (I != LocsForHLSLAttributedResources.end()) {
2341 LocInfo = I->second;
2342 LocsForHLSLAttributedResources.erase(I);
2343 return LocInfo;
2344 }
2345 LocInfo.Range = SourceRange();
2346 return LocInfo;
2347}
2348
2349// Walks though the global variable declaration, collects all resource binding
2350// requirements and adds them to Bindings
2351void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
2352 const RecordType *RT) {
2353 const RecordDecl *RD = RT->getDecl()->getDefinitionOrSelf();
2354 for (FieldDecl *FD : RD->fields()) {
2355 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
2356
2357 // Unwrap arrays
2358 // FIXME: Calculate array size while unwrapping
2359 assert(!Ty->isIncompleteArrayType() &&
2360 "incomplete arrays inside user defined types are not supported");
2361 while (Ty->isConstantArrayType()) {
2362 const ConstantArrayType *CAT = cast<ConstantArrayType>(Val: Ty);
2363 Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
2364 }
2365
2366 if (!Ty->isRecordType())
2367 continue;
2368
2369 if (const HLSLAttributedResourceType *AttrResType =
2370 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
2371 // Add a new DeclBindingInfo to Bindings if it does not already exist
2372 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
2373 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, ResClass: RC);
2374 if (!DBI)
2375 Bindings.addDeclBindingInfo(VD, ResClass: RC);
2376 } else if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty)) {
2377 // Recursively scan embedded struct or class; it would be nice to do this
2378 // without recursion, but tricky to correctly calculate the size of the
2379 // binding, which is something we are probably going to need to do later
2380 // on. Hopefully nesting of structs in structs too many levels is
2381 // unlikely.
2382 collectResourceBindingsOnUserRecordDecl(VD, RT);
2383 }
2384 }
2385}
2386
2387// Diagnose localized register binding errors for a single binding; does not
2388// diagnose resource binding on user record types, that will be done later
2389// in processResourceBindingOnDecl based on the information collected in
2390// collectResourceBindingsOnVarDecl.
2391// Returns false if the register binding is not valid.
2392static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
2393 Decl *D, RegisterType RegType,
2394 bool SpecifiedSpace) {
2395 int RegTypeNum = static_cast<int>(RegType);
2396
2397 // check if the decl type is groupshared
2398 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
2399 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2400 return false;
2401 }
2402
2403 // Cbuffers and Tbuffers are HLSLBufferDecl types
2404 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: D)) {
2405 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
2406 : ResourceClass::SRV;
2407 if (RegType == getRegisterType(RC))
2408 return true;
2409
2410 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
2411 << RegTypeNum;
2412 return false;
2413 }
2414
2415 // Samplers, UAVs, and SRVs are VarDecl types
2416 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
2417 VarDecl *VD = cast<VarDecl>(Val: D);
2418
2419 // Resource
2420 if (const HLSLAttributedResourceType *AttrResType =
2421 HLSLAttributedResourceType::findHandleTypeOnResource(
2422 RT: VD->getType().getTypePtr())) {
2423 if (RegType == getRegisterType(ResTy: AttrResType))
2424 return true;
2425
2426 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
2427 << RegTypeNum;
2428 return false;
2429 }
2430
2431 const clang::Type *Ty = VD->getType().getTypePtr();
2432 while (Ty->isArrayType())
2433 Ty = Ty->getArrayElementTypeNoTypeQual();
2434
2435 // Basic types
2436 if (Ty->isArithmeticType() || Ty->isVectorType()) {
2437 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(Val: D->getDeclContext());
2438 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
2439 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_space_on_global_constant);
2440
2441 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(Ctx: S.getASTContext()) ||
2442 Ty->isFloatingType() || Ty->isVectorType())) {
2443 // Register annotation on default constant buffer declaration ($Globals)
2444 if (RegType == RegisterType::CBuffer)
2445 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_deprecated_register_type_b);
2446 else if (RegType != RegisterType::C)
2447 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2448 else
2449 return true;
2450 } else {
2451 if (RegType == RegisterType::C)
2452 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_register_type_c_packoffset);
2453 else
2454 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2455 }
2456 return false;
2457 }
2458 if (Ty->isRecordType())
2459 // RecordTypes will be diagnosed in processResourceBindingOnDecl
2460 // that is called from ActOnVariableDeclarator
2461 return true;
2462
2463 // Anything else is an error
2464 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2465 return false;
2466}
2467
2468static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
2469 RegisterType regType) {
2470 // make sure that there are no two register annotations
2471 // applied to the decl with the same register type
2472 bool RegisterTypesDetected[5] = {false};
2473 RegisterTypesDetected[static_cast<int>(regType)] = true;
2474
2475 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
2476 if (HLSLResourceBindingAttr *attr =
2477 dyn_cast<HLSLResourceBindingAttr>(Val: *it)) {
2478
2479 RegisterType otherRegType = attr->getRegisterType();
2480 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
2481 int otherRegTypeNum = static_cast<int>(otherRegType);
2482 S.Diag(Loc: TheDecl->getLocation(),
2483 DiagID: diag::err_hlsl_duplicate_register_annotation)
2484 << otherRegTypeNum;
2485 return false;
2486 }
2487 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
2488 }
2489 }
2490 return true;
2491}
2492
2493static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
2494 Decl *D, RegisterType RegType,
2495 bool SpecifiedSpace) {
2496
2497 // exactly one of these two types should be set
2498 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
2499 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
2500 "expecting VarDecl or HLSLBufferDecl");
2501
2502 // check if the declaration contains resource matching the register type
2503 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
2504 return false;
2505
2506 // next, if multiple register annotations exist, check that none conflict.
2507 return ValidateMultipleRegisterAnnotations(S, TheDecl: D, regType: RegType);
2508}
2509
2510// return false if the slot count exceeds the limit, true otherwise
2511static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot,
2512 const uint64_t &Limit,
2513 const ResourceClass ResClass,
2514 ASTContext &Ctx,
2515 uint64_t ArrayCount = 1) {
2516 Ty = Ty.getCanonicalType();
2517 const Type *T = Ty.getTypePtr();
2518
2519 // Early exit if already overflowed
2520 if (StartSlot > Limit)
2521 return false;
2522
2523 // Case 1: array type
2524 if (const auto *AT = dyn_cast<ArrayType>(Val: T)) {
2525 uint64_t Count = 1;
2526
2527 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: AT))
2528 Count = CAT->getSize().getZExtValue();
2529
2530 QualType ElemTy = AT->getElementType();
2531 return AccumulateHLSLResourceSlots(Ty: ElemTy, StartSlot, Limit, ResClass, Ctx,
2532 ArrayCount: ArrayCount * Count);
2533 }
2534
2535 // Case 2: resource leaf
2536 if (auto ResTy = dyn_cast<HLSLAttributedResourceType>(Val: T)) {
2537 // First ensure this resource counts towards the corresponding
2538 // register type limit.
2539 if (ResTy->getAttrs().ResourceClass != ResClass)
2540 return true;
2541
2542 // Validate highest slot used
2543 uint64_t EndSlot = StartSlot + ArrayCount - 1;
2544 if (EndSlot > Limit)
2545 return false;
2546
2547 // Advance SlotCount past the consumed range
2548 StartSlot = EndSlot + 1;
2549 return true;
2550 }
2551
2552 // Case 3: struct / record
2553 if (const auto *RT = dyn_cast<RecordType>(Val: T)) {
2554 const RecordDecl *RD = RT->getDecl();
2555
2556 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
2557 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
2558 if (!AccumulateHLSLResourceSlots(Ty: Base.getType(), StartSlot, Limit,
2559 ResClass, Ctx, ArrayCount))
2560 return false;
2561 }
2562 }
2563
2564 for (const FieldDecl *Field : RD->fields()) {
2565 if (!AccumulateHLSLResourceSlots(Ty: Field->getType(), StartSlot, Limit,
2566 ResClass, Ctx, ArrayCount))
2567 return false;
2568 }
2569
2570 return true;
2571 }
2572
2573 // Case 4: everything else
2574 return true;
2575}
2576
2577// return true if there is something invalid, false otherwise
2578static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl,
2579 ASTContext &Ctx, RegisterType RegTy) {
2580 const uint64_t Limit = UINT32_MAX;
2581 if (SlotNum > Limit)
2582 return true;
2583
2584 // after verifying the number doesn't exceed uint32max, we don't need
2585 // to look further into c or i register types
2586 if (RegTy == RegisterType::C || RegTy == RegisterType::I)
2587 return false;
2588
2589 if (VarDecl *VD = dyn_cast<VarDecl>(Val: TheDecl)) {
2590 uint64_t BaseSlot = SlotNum;
2591
2592 if (!AccumulateHLSLResourceSlots(Ty: VD->getType(), StartSlot&: SlotNum, Limit,
2593 ResClass: getResourceClass(RT: RegTy), Ctx))
2594 return true;
2595
2596 // After AccumulateHLSLResourceSlots runs, SlotNum is now
2597 // the first free slot; last used was SlotNum - 1
2598 return (BaseSlot > Limit);
2599 }
2600 // handle the cbuffer/tbuffer case
2601 if (isa<HLSLBufferDecl>(Val: TheDecl))
2602 // resources cannot be put within a cbuffer, so no need
2603 // to analyze the structure since the register number
2604 // won't be pushed any higher.
2605 return (SlotNum > Limit);
2606
2607 // we don't expect any other decl type, so fail
2608 llvm_unreachable("unexpected decl type");
2609}
2610
2611void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
2612 if (VarDecl *VD = dyn_cast<VarDecl>(Val: TheDecl)) {
2613 QualType Ty = VD->getType();
2614 if (const auto *IAT = dyn_cast<IncompleteArrayType>(Val&: Ty))
2615 Ty = IAT->getElementType();
2616 if (SemaRef.RequireCompleteType(Loc: TheDecl->getBeginLoc(), T: Ty,
2617 DiagID: diag::err_incomplete_type))
2618 return;
2619 }
2620
2621 StringRef Slot = "";
2622 StringRef Space = "";
2623 SourceLocation SlotLoc, SpaceLoc;
2624
2625 if (!AL.isArgIdent(Arg: 0)) {
2626 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2627 << AL << AANT_ArgumentIdentifier;
2628 return;
2629 }
2630 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
2631
2632 if (AL.getNumArgs() == 2) {
2633 Slot = Loc->getIdentifierInfo()->getName();
2634 SlotLoc = Loc->getLoc();
2635 if (!AL.isArgIdent(Arg: 1)) {
2636 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2637 << AL << AANT_ArgumentIdentifier;
2638 return;
2639 }
2640 Loc = AL.getArgAsIdent(Arg: 1);
2641 Space = Loc->getIdentifierInfo()->getName();
2642 SpaceLoc = Loc->getLoc();
2643 } else {
2644 StringRef Str = Loc->getIdentifierInfo()->getName();
2645 if (Str.starts_with(Prefix: "space")) {
2646 Space = Str;
2647 SpaceLoc = Loc->getLoc();
2648 } else {
2649 Slot = Str;
2650 SlotLoc = Loc->getLoc();
2651 Space = "space0";
2652 }
2653 }
2654
2655 RegisterType RegType = RegisterType::SRV;
2656 std::optional<unsigned> SlotNum;
2657 unsigned SpaceNum = 0;
2658
2659 // Validate slot
2660 if (!Slot.empty()) {
2661 if (!convertToRegisterType(Slot, RT: &RegType)) {
2662 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_binding_type_invalid) << Slot.substr(Start: 0, N: 1);
2663 return;
2664 }
2665 if (RegType == RegisterType::I) {
2666 Diag(Loc: SlotLoc, DiagID: diag::warn_hlsl_deprecated_register_type_i);
2667 return;
2668 }
2669 const StringRef SlotNumStr = Slot.substr(Start: 1);
2670
2671 uint64_t N;
2672
2673 // validate that the slot number is a non-empty number
2674 if (SlotNumStr.getAsInteger(Radix: 10, Result&: N)) {
2675 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_unsupported_register_number);
2676 return;
2677 }
2678
2679 // Validate register number. It should not exceed UINT32_MAX,
2680 // including if the resource type is an array that starts
2681 // before UINT32_MAX, but ends afterwards.
2682 if (ValidateRegisterNumber(SlotNum: N, TheDecl, Ctx&: getASTContext(), RegTy: RegType)) {
2683 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_register_number_too_large);
2684 return;
2685 }
2686
2687 // the slot number has been validated and does not exceed UINT32_MAX
2688 SlotNum = (unsigned)N;
2689 }
2690
2691 // Validate space
2692 if (!Space.starts_with(Prefix: "space")) {
2693 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2694 return;
2695 }
2696 StringRef SpaceNumStr = Space.substr(Start: 5);
2697 if (SpaceNumStr.getAsInteger(Radix: 10, Result&: SpaceNum)) {
2698 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2699 return;
2700 }
2701
2702 // If we have slot, diagnose it is the right register type for the decl
2703 if (SlotNum.has_value())
2704 if (!DiagnoseHLSLRegisterAttribute(S&: SemaRef, ArgLoc&: SlotLoc, D: TheDecl, RegType,
2705 SpecifiedSpace: !SpaceLoc.isInvalid()))
2706 return;
2707
2708 HLSLResourceBindingAttr *NewAttr =
2709 HLSLResourceBindingAttr::Create(Ctx&: getASTContext(), Slot, Space, CommonInfo: AL);
2710 if (NewAttr) {
2711 NewAttr->setBinding(RT: RegType, SlotNum, SpaceNum);
2712 TheDecl->addAttr(A: NewAttr);
2713 }
2714}
2715
2716void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
2717 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2718 D, AL,
2719 Spelling: static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2720 if (NewAttr)
2721 D->addAttr(A: NewAttr);
2722}
2723
2724static bool isMatrixOrArrayOfMatrix(const ASTContext &Ctx, QualType QT) {
2725 const Type *Ty = QT->getUnqualifiedDesugaredType();
2726 while (isa<ArrayType>(Val: Ty))
2727 Ty = Ty->getArrayElementTypeNoTypeQual();
2728 return Ty->isDependentType() || Ty->isConstantMatrixType();
2729}
2730
2731/// Walks the existing AttributedType sugar of \p T looking for a previously
2732/// applied HLSLRowMajor/HLSLColumnMajor marker. If one is found, populates
2733/// \p ExistingKind with its attr::Kind and returns true.
2734static bool findExistingMatrixLayoutMarker(QualType T,
2735 attr::Kind &ExistingKind) {
2736 QualType Cur = T;
2737 while (const auto *AT = Cur->getAs<AttributedType>()) {
2738 attr::Kind K = AT->getAttrKind();
2739 if (K == attr::HLSLRowMajor || K == attr::HLSLColumnMajor) {
2740 ExistingKind = K;
2741 return true;
2742 }
2743 Cur = AT->getModifiedType();
2744 }
2745 return false;
2746}
2747
2748Attr *SemaHLSL::buildMatrixLayoutTypeAttr(QualType T, const ParsedAttr &AL) {
2749 if (T.isNull())
2750 return nullptr;
2751
2752 ASTContext &Ctx = getASTContext();
2753 attr::Kind AttrK = AL.getKind() == ParsedAttr::AT_HLSLRowMajor
2754 ? attr::HLSLRowMajor
2755 : attr::HLSLColumnMajor;
2756
2757 // For non-dependent types, the operand must be a matrix (or array of
2758 // matrices).
2759 if (!T->isDependentType() && !isMatrixOrArrayOfMatrix(Ctx, QT: T)) {
2760 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_matrix_layout_non_matrix)
2761 << AL.getAttrName();
2762 AL.setInvalid();
2763 return nullptr;
2764 }
2765
2766 // Conflict / duplicate detection by walking existing sugar.
2767 attr::Kind ExistingKind;
2768 if (findExistingMatrixLayoutMarker(T, ExistingKind)) {
2769 if (ExistingKind == AttrK) {
2770 Diag(Loc: AL.getLoc(), DiagID: diag::warn_duplicate_attribute_exact)
2771 << AL.getAttrName();
2772 Diag(Loc: AL.getLoc(), DiagID: diag::note_previous_attribute);
2773 return nullptr;
2774 }
2775 IdentifierInfo *ExistingII = &Ctx.Idents.get(
2776 Name: ExistingKind == attr::HLSLRowMajor ? "row_major" : "column_major");
2777 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_matrix_layout_conflict)
2778 << AL.getAttrName() << ExistingII;
2779 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
2780 AL.setInvalid();
2781 return nullptr;
2782 }
2783
2784 if (AttrK == attr::HLSLRowMajor)
2785 return ::new (Ctx) HLSLRowMajorAttr(Ctx, AL);
2786 return ::new (Ctx) HLSLColumnMajorAttr(Ctx, AL);
2787}
2788
2789// Re-validates an HLSL `row_major` / `column_major` attribute after template
2790// substitution. The parse-time check in `buildMatrixLayoutTypeAttr` is skipped
2791// for dependent types; `TransformAttributedType` calls this once the type is
2792// concrete. Returns `true` (and emits a diagnostic) if the substituted type is
2793// not a matrix or array of matrices, signaling the caller to abort the
2794// transform.
2795bool SemaHLSL::diagnoseMatrixLayoutInstantiation(attr::Kind K, QualType T,
2796 SourceLocation Loc) {
2797 if (K != attr::HLSLRowMajor && K != attr::HLSLColumnMajor)
2798 return false;
2799 if (T.isNull() || T->isDependentType())
2800 return false;
2801 if (isMatrixOrArrayOfMatrix(Ctx: getASTContext(), QT: T))
2802 return false;
2803 IdentifierInfo *II = &getASTContext().Idents.get(
2804 Name: K == attr::HLSLRowMajor ? "row_major" : "column_major");
2805 Diag(Loc, DiagID: diag::err_hlsl_matrix_layout_non_matrix) << II;
2806 return true;
2807}
2808
2809// Transpose and matrix mul need to read the destination layout.
2810// Elementwise builtins reuse the operand layout instead.
2811static bool isLayoutAdaptingMatrixBuiltin(unsigned BuiltinID) {
2812 switch (BuiltinID) {
2813 case Builtin::BI__builtin_hlsl_mul:
2814 case Builtin::BI__builtin_hlsl_transpose:
2815 return true;
2816 default:
2817 return false;
2818 }
2819}
2820
2821void SemaHLSL::propagateContextualMatrixLayout(Expr *E, QualType DestType) {
2822 if (!E || DestType.isNull())
2823 return;
2824 const auto *DestMat = DestType->getAs<ConstantMatrixType>();
2825 if (!DestMat)
2826 return;
2827 auto *Call = dyn_cast<CallExpr>(Val: E->IgnoreParenImpCasts());
2828 if (!Call)
2829 return;
2830 const FunctionDecl *Callee = Call->getDirectCallee();
2831 if (!Callee || !isLayoutAdaptingMatrixBuiltin(BuiltinID: Callee->getBuiltinID()))
2832 return;
2833 const auto *CallMat = Call->getType()->getAs<ConstantMatrixType>();
2834 if (!CallMat || CallMat->getNumRows() != DestMat->getNumRows() ||
2835 CallMat->getNumColumns() != DestMat->getNumColumns())
2836 return;
2837 // Re-type the call with the destination sugar so CodeGen lowers into that
2838 // layout, not the TU default.
2839 Call->setType(DestType.getUnqualifiedType());
2840}
2841
2842namespace {
2843
2844/// This class implements HLSL availability diagnostics for default
2845/// and relaxed mode
2846///
2847/// The goal of this diagnostic is to emit an error or warning when an
2848/// unavailable API is found in code that is reachable from the shader
2849/// entry function or from an exported function (when compiling a shader
2850/// library).
2851///
2852/// This is done by traversing the AST of all shader entry point functions
2853/// and of all exported functions, and any functions that are referenced
2854/// from this AST. In other words, any functions that are reachable from
2855/// the entry points.
2856class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2857 Sema &SemaRef;
2858
2859 // Stack of functions to be scaned
2860 llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;
2861
2862 // Tracks which environments functions have been scanned in.
2863 //
2864 // Maps FunctionDecl to an unsigned number that represents the set of shader
2865 // environments the function has been scanned for.
2866 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2867 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2868 // (verified by static_asserts in Triple.cpp), we can use it to index
2869 // individual bits in the set, as long as we shift the values to start with 0
2870 // by subtracting the value of llvm::Triple::Pixel first.
2871 //
2872 // The N'th bit in the set will be set if the function has been scanned
2873 // in shader environment whose llvm::Triple::EnvironmentType integer value
2874 // equals (llvm::Triple::Pixel + N).
2875 //
2876 // For example, if a function has been scanned in compute and pixel stage
2877 // environment, the value will be 0x21 (100001 binary) because:
2878 //
2879 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2880 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2881 //
2882 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2883 // been scanned in any environment.
2884 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2885
2886 // Do not access these directly, use the get/set methods below to make
2887 // sure the values are in sync
2888 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2889 unsigned CurrentShaderStageBit;
2890
2891 // True if scanning a function that was already scanned in a different
2892 // shader stage context, and therefore we should not report issues that
2893 // depend only on shader model version because they would be duplicate.
2894 bool ReportOnlyShaderStageIssues;
2895
2896 // Helper methods for dealing with current stage context / environment
2897 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2898 static_assert(sizeof(unsigned) >= 4);
2899 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2900 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2901 "ShaderType is too big for this bitmap"); // 31 is reserved for
2902 // "unknown"
2903
2904 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2905 CurrentShaderEnvironment = ShaderType;
2906 CurrentShaderStageBit = (1 << bitmapIndex);
2907 }
2908
2909 void SetUnknownShaderStageContext() {
2910 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2911 CurrentShaderStageBit = (1 << 31);
2912 }
2913
2914 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2915 return CurrentShaderEnvironment;
2916 }
2917
2918 bool InUnknownShaderStageContext() const {
2919 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2920 }
2921
2922 // Helper methods for dealing with shader stage bitmap
2923 void AddToScannedFunctions(const FunctionDecl *FD) {
2924 unsigned &ScannedStages = ScannedDecls[FD];
2925 ScannedStages |= CurrentShaderStageBit;
2926 }
2927
2928 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2929
2930 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2931 return WasAlreadyScannedInCurrentStage(ScannerStages: GetScannedStages(FD));
2932 }
2933
2934 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2935 return ScannerStages & CurrentShaderStageBit;
2936 }
2937
2938 static bool NeverBeenScanned(unsigned ScannedStages) {
2939 return ScannedStages == 0;
2940 }
2941
2942 // Scanning methods
2943 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2944 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2945 SourceRange Range);
2946 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2947 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2948
2949public:
2950 DiagnoseHLSLAvailability(Sema &SemaRef)
2951 : SemaRef(SemaRef),
2952 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2953 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2954
2955 // AST traversal methods
2956 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2957 void RunOnFunction(const FunctionDecl *FD);
2958
2959 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2960 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: DRE->getDecl());
2961 if (FD)
2962 HandleFunctionOrMethodRef(FD, RefExpr: DRE);
2963 return true;
2964 }
2965
2966 bool VisitMemberExpr(MemberExpr *ME) override {
2967 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: ME->getMemberDecl());
2968 if (FD)
2969 HandleFunctionOrMethodRef(FD, RefExpr: ME);
2970 return true;
2971 }
2972};
2973
2974void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2975 Expr *RefExpr) {
2976 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2977 "expected DeclRefExpr or MemberExpr");
2978
2979 // has a definition -> add to stack to be scanned
2980 const FunctionDecl *FDWithBody = nullptr;
2981 if (FD->hasBody(Definition&: FDWithBody)) {
2982 if (!WasAlreadyScannedInCurrentStage(FD: FDWithBody))
2983 DeclsToScan.push_back(Elt: FDWithBody);
2984 return;
2985 }
2986
2987 // no body -> diagnose availability
2988 const AvailabilityAttr *AA = FindAvailabilityAttr(D: FD);
2989 if (AA)
2990 CheckDeclAvailability(
2991 D: FD, AA, Range: SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2992}
2993
2994void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2995 const TranslationUnitDecl *TU) {
2996
2997 // Iterate over all shader entry functions and library exports, and for those
2998 // that have a body (definiton), run diag scan on each, setting appropriate
2999 // shader environment context based on whether it is a shader entry function
3000 // or an exported function. Exported functions can be in namespaces and in
3001 // export declarations so we need to scan those declaration contexts as well.
3002 llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;
3003 DeclContextsToScan.push_back(Elt: TU);
3004
3005 while (!DeclContextsToScan.empty()) {
3006 const DeclContext *DC = DeclContextsToScan.pop_back_val();
3007 for (auto &D : DC->decls()) {
3008 // do not scan implicit declaration generated by the implementation
3009 if (D->isImplicit())
3010 continue;
3011
3012 // for namespace or export declaration add the context to the list to be
3013 // scanned later
3014 if (llvm::dyn_cast<NamespaceDecl>(Val: D) || llvm::dyn_cast<ExportDecl>(Val: D)) {
3015 DeclContextsToScan.push_back(Elt: llvm::dyn_cast<DeclContext>(Val: D));
3016 continue;
3017 }
3018
3019 // skip over other decls or function decls without body
3020 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: D);
3021 if (!FD || !FD->isThisDeclarationADefinition())
3022 continue;
3023
3024 // shader entry point
3025 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
3026 SetShaderStageContext(ShaderAttr->getType());
3027 RunOnFunction(FD);
3028 continue;
3029 }
3030 // exported library function
3031 // FIXME: replace this loop with external linkage check once issue #92071
3032 // is resolved
3033 bool isExport = FD->isInExportDeclContext();
3034 if (!isExport) {
3035 for (const auto *Redecl : FD->redecls()) {
3036 if (Redecl->isInExportDeclContext()) {
3037 isExport = true;
3038 break;
3039 }
3040 }
3041 }
3042 if (isExport) {
3043 SetUnknownShaderStageContext();
3044 RunOnFunction(FD);
3045 continue;
3046 }
3047 }
3048 }
3049}
3050
3051void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
3052 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
3053 DeclsToScan.push_back(Elt: FD);
3054
3055 while (!DeclsToScan.empty()) {
3056 // Take one decl from the stack and check it by traversing its AST.
3057 // For any CallExpr found during the traversal add it's callee to the top of
3058 // the stack to be processed next. Functions already processed are stored in
3059 // ScannedDecls.
3060 const FunctionDecl *FD = DeclsToScan.pop_back_val();
3061
3062 // Decl was already scanned
3063 const unsigned ScannedStages = GetScannedStages(FD);
3064 if (WasAlreadyScannedInCurrentStage(ScannerStages: ScannedStages))
3065 continue;
3066
3067 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
3068
3069 AddToScannedFunctions(FD);
3070 TraverseStmt(S: FD->getBody());
3071 }
3072}
3073
3074bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
3075 const AvailabilityAttr *AA) {
3076 const IdentifierInfo *IIEnvironment = AA->getEnvironment();
3077 if (!IIEnvironment)
3078 return true;
3079
3080 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
3081 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
3082 return false;
3083
3084 llvm::Triple::EnvironmentType AttrEnv =
3085 AvailabilityAttr::getEnvironmentType(Environment: IIEnvironment->getName());
3086
3087 return CurrentEnv == AttrEnv;
3088}
3089
3090const AvailabilityAttr *
3091DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
3092 AvailabilityAttr const *PartialMatch = nullptr;
3093 // Check each AvailabilityAttr to find the one for this platform.
3094 // For multiple attributes with the same platform try to find one for this
3095 // environment.
3096 for (const auto *A : D->attrs()) {
3097 if (const auto *Avail = dyn_cast<AvailabilityAttr>(Val: A)) {
3098 const AvailabilityAttr *EffectiveAvail = Avail->getEffectiveAttr();
3099 StringRef AttrPlatform = EffectiveAvail->getPlatform()->getName();
3100 StringRef TargetPlatform =
3101 SemaRef.getASTContext().getTargetInfo().getPlatformName();
3102
3103 // Match the platform name.
3104 if (AttrPlatform == TargetPlatform) {
3105 // Find the best matching attribute for this environment
3106 if (HasMatchingEnvironmentOrNone(AA: EffectiveAvail))
3107 return Avail;
3108 PartialMatch = Avail;
3109 }
3110 }
3111 }
3112 return PartialMatch;
3113}
3114
3115// Check availability against target shader model version and current shader
3116// stage and emit diagnostic
3117void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
3118 const AvailabilityAttr *AA,
3119 SourceRange Range) {
3120
3121 const IdentifierInfo *IIEnv = AA->getEnvironment();
3122
3123 if (!IIEnv) {
3124 // The availability attribute does not have environment -> it depends only
3125 // on shader model version and not on specific the shader stage.
3126
3127 // Skip emitting the diagnostics if the diagnostic mode is set to
3128 // strict (-fhlsl-strict-availability) because all relevant diagnostics
3129 // were already emitted in the DiagnoseUnguardedAvailability scan
3130 // (SemaAvailability.cpp).
3131 if (SemaRef.getLangOpts().HLSLStrictAvailability)
3132 return;
3133
3134 // Do not report shader-stage-independent issues if scanning a function
3135 // that was already scanned in a different shader stage context (they would
3136 // be duplicate)
3137 if (ReportOnlyShaderStageIssues)
3138 return;
3139
3140 } else {
3141 // The availability attribute has environment -> we need to know
3142 // the current stage context to property diagnose it.
3143 if (InUnknownShaderStageContext())
3144 return;
3145 }
3146
3147 // Check introduced version and if environment matches
3148 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
3149 VersionTuple Introduced = AA->getIntroduced();
3150 VersionTuple TargetVersion =
3151 SemaRef.Context.getTargetInfo().getPlatformMinVersion();
3152
3153 if (TargetVersion >= Introduced && EnvironmentMatches)
3154 return;
3155
3156 // Emit diagnostic message
3157 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
3158 llvm::StringRef PlatformName(
3159 AvailabilityAttr::getPrettyPlatformName(Platform: TI.getPlatformName()));
3160
3161 llvm::StringRef CurrentEnvStr =
3162 llvm::Triple::getEnvironmentTypeName(Kind: GetCurrentShaderEnvironment());
3163
3164 llvm::StringRef AttrEnvStr =
3165 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
3166 bool UseEnvironment = !AttrEnvStr.empty();
3167
3168 if (EnvironmentMatches) {
3169 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability)
3170 << Range << D << PlatformName << Introduced.getAsString()
3171 << UseEnvironment << CurrentEnvStr;
3172 } else {
3173 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability_unavailable)
3174 << Range << D;
3175 }
3176
3177 SemaRef.Diag(Loc: D->getLocation(), DiagID: diag::note_partial_availability_specified_here)
3178 << D << PlatformName << Introduced.getAsString()
3179 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
3180 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
3181}
3182
3183} // namespace
3184
3185void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) {
3186 // process default CBuffer - create buffer layout struct and invoke codegenCGH
3187 if (!DefaultCBufferDecls.empty()) {
3188 HLSLBufferDecl *DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer(
3189 C&: SemaRef.getASTContext(), LexicalParent: SemaRef.getCurLexicalContext(),
3190 DefaultCBufferDecls);
3191 addImplicitBindingAttrToDecl(S&: SemaRef, D: DefaultCBuffer, RT: RegisterType::CBuffer,
3192 ImplicitBindingOrderID: getNextImplicitBindingOrderID());
3193 SemaRef.getCurLexicalContext()->addDecl(D: DefaultCBuffer);
3194 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl: DefaultCBuffer);
3195
3196 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
3197 for (const Decl *VD : DefaultCBufferDecls) {
3198 const HLSLResourceBindingAttr *RBA =
3199 VD->getAttr<HLSLResourceBindingAttr>();
3200 if (RBA && RBA->hasRegisterSlot() &&
3201 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
3202 DefaultCBuffer->setHasValidPackoffset(true);
3203 break;
3204 }
3205 }
3206
3207 DeclGroupRef DG(DefaultCBuffer);
3208 SemaRef.Consumer.HandleTopLevelDecl(D: DG);
3209 }
3210 diagnoseAvailabilityViolations(TU);
3211}
3212
3213// For resource member access through a global struct array, verify that the
3214// array index selecting the struct element is a constant integer expression.
3215// Returns false if the member expression is invalid.
3216bool SemaHLSL::ActOnResourceMemberAccessExpr(MemberExpr *ME) {
3217 assert((ME->getType()->isHLSLResourceRecord() ||
3218 ME->getType()->isHLSLResourceRecordArray()) &&
3219 "expected member expr to have resource record type or array of them");
3220
3221 // Walk the AST from MemberExpr to the VarDecl of the parent struct instance
3222 // and take note of any non-constant array indexing along the way. If the
3223 // VarDecl we find is a global variable, report error if there was any
3224 // non-constant array index in the resource member access along the way.
3225 const Expr *NonConstIndexExpr = nullptr;
3226 const Expr *E = ME->getBase();
3227 while (E) {
3228 if (const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Val: E)) {
3229 if (!NonConstIndexExpr)
3230 return true;
3231
3232 const VarDecl *VD = cast<VarDecl>(Val: DRE->getDecl());
3233 if (!VD->hasGlobalStorage())
3234 return true;
3235
3236 SemaRef.Diag(Loc: NonConstIndexExpr->getExprLoc(),
3237 DiagID: diag::err_hlsl_resource_member_array_access_not_constant);
3238 return false;
3239 }
3240
3241 if (const auto *ASE = dyn_cast<ArraySubscriptExpr>(Val: E)) {
3242 const Expr *IdxExpr = ASE->getIdx();
3243 if (!IdxExpr->isIntegerConstantExpr(Ctx: SemaRef.getASTContext()))
3244 NonConstIndexExpr = IdxExpr;
3245 E = ASE->getBase();
3246 } else if (const auto *SubME = dyn_cast<MemberExpr>(Val: E)) {
3247 E = SubME->getBase();
3248 } else if (const auto *ICE = dyn_cast<ImplicitCastExpr>(Val: E)) {
3249 E = ICE->getSubExpr();
3250 } else {
3251 llvm_unreachable("unexpected expr type in resource member access");
3252 }
3253 }
3254 return true;
3255}
3256
3257NamedDecl *SemaHLSL::getConstantBufferConversionFunction(QualType Type,
3258 CXXRecordDecl *RD) {
3259 QualType AddrSpaceType =
3260 SemaRef.Context.getCanonicalType(T: SemaRef.Context.getAddrSpaceQualType(
3261 T: Type.withConst(), AddressSpace: LangAS::hlsl_constant));
3262 QualType ReturnTy = SemaRef.Context.getCanonicalType(
3263 T: SemaRef.Context.getLValueReferenceType(T: AddrSpaceType));
3264
3265 DeclarationName ConvName =
3266 SemaRef.Context.DeclarationNames.getCXXConversionFunctionName(
3267 Ty: CanQualType::CreateUnsafe(Other: ReturnTy));
3268 LookupResult ConvR(SemaRef, ConvName, SourceLocation(),
3269 Sema::LookupOrdinaryName);
3270 [[maybe_unused]] bool LookupSucceeded =
3271 SemaRef.LookupQualifiedName(R&: ConvR, LookupCtx: RD);
3272 assert(LookupSucceeded);
3273
3274 for (NamedDecl *D : ConvR) {
3275 if (isa<CXXConversionDecl>(Val: D->getUnderlyingDecl()))
3276 return D;
3277 }
3278 return nullptr;
3279}
3280
3281std::optional<ExprResult>
3282SemaHLSL::tryPerformConstantBufferConversion(ExprResult &BaseExpr) {
3283 QualType BaseType = BaseExpr.get()->getType();
3284 const HLSLAttributedResourceType *ResTy =
3285 HLSLAttributedResourceType::findHandleTypeOnResource(
3286 RT: BaseType.getTypePtr());
3287 if (!ResTy ||
3288 ResTy->getAttrs().ResourceClass != llvm::dxil::ResourceClass::CBuffer)
3289 return std::nullopt;
3290
3291 QualType TemplateType = ResTy->getContainedType();
3292
3293 NamedDecl *NamedConversionDecl = getConstantBufferConversionFunction(
3294 Type: TemplateType, RD: BaseType->getAsCXXRecordDecl());
3295 assert(NamedConversionDecl &&
3296 "Could not find conversion function for ConstantBuffer.");
3297 auto *ConversionDecl =
3298 cast<CXXConversionDecl>(Val: NamedConversionDecl->getUnderlyingDecl());
3299
3300 return SemaRef.BuildCXXMemberCallExpr(Exp: BaseExpr.get(), FoundDecl: NamedConversionDecl,
3301 Method: ConversionDecl,
3302 /*HadMultipleCandidates=*/false);
3303}
3304
3305void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
3306 // Skip running the diagnostics scan if the diagnostic mode is
3307 // strict (-fhlsl-strict-availability) and the target shader stage is known
3308 // because all relevant diagnostics were already emitted in the
3309 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
3310 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
3311 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
3312 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
3313 return;
3314
3315 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
3316}
3317
3318static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
3319 assert(TheCall->getNumArgs() > 1);
3320 QualType ArgTy0 = TheCall->getArg(Arg: 0)->getType();
3321
3322 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
3323 if (!S->getASTContext().hasSameUnqualifiedType(
3324 T1: ArgTy0, T2: TheCall->getArg(Arg: I)->getType())) {
3325 S->Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vec_builtin_incompatible_vector)
3326 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
3327 << SourceRange(TheCall->getArg(Arg: 0)->getBeginLoc(),
3328 TheCall->getArg(Arg: N - 1)->getEndLoc());
3329 return true;
3330 }
3331 }
3332 return false;
3333}
3334
3335static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
3336 QualType ArgType = Arg->getType();
3337 if (!S->getASTContext().hasSameUnqualifiedType(T1: ArgType, T2: ExpectedType)) {
3338 S->Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_typecheck_convert_incompatible)
3339 << ArgType << ExpectedType << 1 << 0 << 0;
3340 return true;
3341 }
3342 return false;
3343}
3344
3345static bool CheckAllArgTypesAreCorrect(
3346 Sema *S, CallExpr *TheCall,
3347 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
3348 clang::QualType PassedType)>
3349 Check) {
3350 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
3351 Expr *Arg = TheCall->getArg(Arg: I);
3352 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
3353 return true;
3354 }
3355 return false;
3356}
3357
3358static bool CheckFloatRepresentation(Sema *S, SourceLocation Loc,
3359 int ArgOrdinal,
3360 clang::QualType PassedType) {
3361 clang::QualType BaseType =
3362 PassedType->isVectorType()
3363 ? PassedType->castAs<clang::VectorType>()->getElementType()
3364 : PassedType;
3365 if (!BaseType->isFloat32Type())
3366 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3367 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3368 << /* float */ 1 << PassedType;
3369 return false;
3370}
3371
3372static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
3373 int ArgOrdinal,
3374 clang::QualType PassedType) {
3375 clang::QualType BaseType = PassedType;
3376 if (const auto *VT = PassedType->getAs<clang::VectorType>())
3377 BaseType = VT->getElementType();
3378 else if (const auto *MT = PassedType->getAs<clang::MatrixType>())
3379 BaseType = MT->getElementType();
3380
3381 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
3382 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3383 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3384 << /* half or float */ 2 << PassedType;
3385 return false;
3386}
3387
3388static bool CheckAnyDoubleRepresentation(Sema *S, SourceLocation Loc,
3389 int ArgOrdinal,
3390 clang::QualType PassedType) {
3391 clang::QualType BaseType =
3392 PassedType->isVectorType()
3393 ? PassedType->castAs<clang::VectorType>()->getElementType()
3394 : PassedType->isMatrixType()
3395 ? PassedType->castAs<clang::MatrixType>()->getElementType()
3396 : PassedType;
3397 if (!BaseType->isDoubleType()) {
3398 // FIXME: adopt standard `err_builtin_invalid_arg_type` instead of using
3399 // this custom error.
3400 return S->Diag(Loc, DiagID: diag::err_builtin_requires_double_type)
3401 << ArgOrdinal << PassedType;
3402 }
3403
3404 return false;
3405}
3406
3407static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
3408 unsigned ArgIndex) {
3409 auto *Arg = TheCall->getArg(Arg: ArgIndex);
3410 SourceLocation OrigLoc = Arg->getExprLoc();
3411 if (Arg->IgnoreCasts()->isModifiableLvalue(Ctx&: S->Context, Loc: &OrigLoc) ==
3412 Expr::MLV_Valid)
3413 return false;
3414 S->Diag(Loc: OrigLoc, DiagID: diag::error_hlsl_inout_lvalue) << Arg << 0;
3415 return true;
3416}
3417
3418// Verifies that the argument at `ArgIndex` of `TheCall` refers to memory in
3419// one of `AllowedSpaces`. Intended for HLSL builtins (e.g. atomics).
3420static bool CheckArgAddrSpaceOneOf(Sema *S, CallExpr *TheCall,
3421 unsigned ArgIndex,
3422 ArrayRef<LangAS> AllowedSpaces) {
3423 Expr *Arg = TheCall->getArg(Arg: ArgIndex);
3424 QualType LValueTy = Arg->IgnoreCasts()->getType();
3425 if (llvm::is_contained(Range&: AllowedSpaces, Element: LValueTy.getAddressSpace()))
3426 return false;
3427 S->Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_hlsl_atomic_arg_addr_space)
3428 << (ArgIndex + 1) << LValueTy;
3429 return true;
3430}
3431
3432static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
3433 clang::QualType PassedType) {
3434 const auto *VecTy = PassedType->getAs<VectorType>();
3435 if (!VecTy)
3436 return false;
3437
3438 if (VecTy->getElementType()->isDoubleType())
3439 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3440 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
3441 << PassedType;
3442 return false;
3443}
3444
3445static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
3446 int ArgOrdinal,
3447 clang::QualType PassedType) {
3448 if (!PassedType->hasIntegerRepresentation() &&
3449 !PassedType->hasFloatingRepresentation())
3450 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3451 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
3452 << /* fp */ 1 << PassedType;
3453 return false;
3454}
3455
3456static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
3457 int ArgOrdinal,
3458 clang::QualType PassedType) {
3459 if (auto *VecTy = PassedType->getAs<VectorType>())
3460 if (VecTy->getElementType()->isUnsignedIntegerType())
3461 return false;
3462
3463 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3464 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
3465 << PassedType;
3466}
3467
3468// checks for unsigned ints of all sizes
3469static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
3470 int ArgOrdinal,
3471 clang::QualType PassedType) {
3472 if (!PassedType->hasUnsignedIntegerRepresentation())
3473 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3474 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
3475 << /* no fp */ 0 << PassedType;
3476 return false;
3477}
3478
3479static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall,
3480 unsigned ArgOrdinal, unsigned Width) {
3481 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
3482 if (auto *VTy = ArgTy->getAs<VectorType>())
3483 ArgTy = VTy->getElementType();
3484 // ensure arg type has expected bit width
3485 uint64_t ElementBitCount =
3486 S->getASTContext().getTypeSizeInChars(T: ArgTy).getQuantity() * 8;
3487 if (ElementBitCount != Width) {
3488 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3489 DiagID: diag::err_integer_incorrect_bit_count)
3490 << Width << ElementBitCount;
3491 return true;
3492 }
3493 return false;
3494}
3495
3496static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
3497 QualType ReturnType) {
3498 auto *VecTyA = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
3499 if (VecTyA)
3500 ReturnType =
3501 S->Context.getExtVectorType(VectorType: ReturnType, NumElts: VecTyA->getNumElements());
3502
3503 TheCall->setType(ReturnType);
3504}
3505
3506static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
3507 unsigned ArgIndex) {
3508 assert(TheCall->getNumArgs() >= ArgIndex);
3509 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3510 auto *VTy = ArgType->getAs<VectorType>();
3511 // not the scalar or vector<scalar>
3512 if (!(S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar) ||
3513 (VTy &&
3514 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar)))) {
3515 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3516 DiagID: diag::err_typecheck_expect_scalar_or_vector)
3517 << ArgType << Scalar;
3518 return true;
3519 }
3520 return false;
3521}
3522
3523static bool CheckScalarOrVectorOrMatrix(Sema *S, CallExpr *TheCall,
3524 QualType Scalar, unsigned ArgIndex) {
3525 assert(TheCall->getNumArgs() > ArgIndex);
3526
3527 Expr *Arg = TheCall->getArg(Arg: ArgIndex);
3528 QualType ArgType = Arg->getType();
3529
3530 // Scalar: T
3531 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar))
3532 return false;
3533
3534 // Vector: vector<T>
3535 if (const auto *VTy = ArgType->getAs<VectorType>()) {
3536 if (S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar))
3537 return false;
3538 }
3539
3540 // Matrix: ConstantMatrixType with element type T
3541 if (const auto *MTy = ArgType->getAs<ConstantMatrixType>()) {
3542 if (S->Context.hasSameUnqualifiedType(T1: MTy->getElementType(), T2: Scalar))
3543 return false;
3544 }
3545
3546 // Not a scalar/vector/matrix-of-scalar
3547 S->Diag(Loc: Arg->getBeginLoc(),
3548 DiagID: diag::err_typecheck_expect_scalar_or_vector_or_matrix)
3549 << ArgType << Scalar;
3550 return true;
3551}
3552
3553static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
3554 unsigned ArgIndex) {
3555 assert(TheCall->getNumArgs() >= ArgIndex);
3556 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3557 auto *VTy = ArgType->getAs<VectorType>();
3558 // not the scalar or vector<scalar>
3559 if (!(ArgType->isScalarType() ||
3560 (VTy && VTy->getElementType()->isScalarType()))) {
3561 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3562 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3563 << ArgType << 1;
3564 return true;
3565 }
3566 return false;
3567}
3568
3569// Check that the argument is not a bool or vector<bool>
3570// Returns true on error
3571static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall,
3572 unsigned ArgIndex) {
3573 QualType BoolType = S->getASTContext().BoolTy;
3574 assert(ArgIndex < TheCall->getNumArgs());
3575 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3576 auto *VTy = ArgType->getAs<VectorType>();
3577 // is the bool or vector<bool>
3578 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: BoolType) ||
3579 (VTy &&
3580 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: BoolType))) {
3581 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3582 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3583 << ArgType << 0;
3584 return true;
3585 }
3586 return false;
3587}
3588
3589static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
3590 if (CheckNotBoolScalarOrVector(S, TheCall, ArgIndex: 0))
3591 return true;
3592 return false;
3593}
3594
3595static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
3596 if (CheckNotBoolScalarOrVector(S, TheCall, ArgIndex: 0))
3597 return true;
3598 return false;
3599}
3600
3601static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
3602 assert(TheCall->getNumArgs() == 3);
3603 Expr *Arg1 = TheCall->getArg(Arg: 1);
3604 Expr *Arg2 = TheCall->getArg(Arg: 2);
3605 if (!S->Context.hasSameUnqualifiedType(T1: Arg1->getType(), T2: Arg2->getType())) {
3606 S->Diag(Loc: TheCall->getBeginLoc(),
3607 DiagID: diag::err_typecheck_call_different_arg_types)
3608 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
3609 << Arg2->getSourceRange();
3610 return true;
3611 }
3612
3613 TheCall->setType(Arg1->getType());
3614 return false;
3615}
3616
3617static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
3618 assert(TheCall->getNumArgs() == 3);
3619 Expr *Arg1 = TheCall->getArg(Arg: 1);
3620 QualType Arg1Ty = Arg1->getType();
3621 Expr *Arg2 = TheCall->getArg(Arg: 2);
3622 QualType Arg2Ty = Arg2->getType();
3623
3624 QualType Arg1ScalarTy = Arg1Ty;
3625 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
3626 Arg1ScalarTy = VTy->getElementType();
3627
3628 QualType Arg2ScalarTy = Arg2Ty;
3629 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
3630 Arg2ScalarTy = VTy->getElementType();
3631
3632 if (!S->Context.hasSameUnqualifiedType(T1: Arg1ScalarTy, T2: Arg2ScalarTy))
3633 S->Diag(Loc: Arg1->getBeginLoc(), DiagID: diag::err_hlsl_builtin_scalar_vector_mismatch)
3634 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
3635
3636 QualType Arg0Ty = TheCall->getArg(Arg: 0)->getType();
3637 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
3638 unsigned Arg1Length = Arg1Ty->isVectorType()
3639 ? Arg1Ty->getAs<VectorType>()->getNumElements()
3640 : 0;
3641 unsigned Arg2Length = Arg2Ty->isVectorType()
3642 ? Arg2Ty->getAs<VectorType>()->getNumElements()
3643 : 0;
3644 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
3645 S->Diag(Loc: TheCall->getBeginLoc(),
3646 DiagID: diag::err_typecheck_vector_lengths_not_equal)
3647 << Arg0Ty << Arg1Ty << TheCall->getArg(Arg: 0)->getSourceRange()
3648 << Arg1->getSourceRange();
3649 return true;
3650 }
3651
3652 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
3653 S->Diag(Loc: TheCall->getBeginLoc(),
3654 DiagID: diag::err_typecheck_vector_lengths_not_equal)
3655 << Arg0Ty << Arg2Ty << TheCall->getArg(Arg: 0)->getSourceRange()
3656 << Arg2->getSourceRange();
3657 return true;
3658 }
3659
3660 TheCall->setType(
3661 S->getASTContext().getExtVectorType(VectorType: Arg1ScalarTy, NumElts: Arg0Length));
3662 return false;
3663}
3664
3665static bool CheckIndexType(Sema *S, CallExpr *TheCall, unsigned IndexArgIndex) {
3666 assert(TheCall->getNumArgs() > IndexArgIndex && "Index argument missing");
3667 QualType ArgType = TheCall->getArg(Arg: IndexArgIndex)->getType();
3668 QualType IndexTy = ArgType;
3669 unsigned int ActualDim = 1;
3670 if (const auto *VTy = IndexTy->getAs<VectorType>()) {
3671 ActualDim = VTy->getNumElements();
3672 IndexTy = VTy->getElementType();
3673 }
3674 if (!IndexTy->isIntegerType()) {
3675 S->Diag(Loc: TheCall->getArg(Arg: IndexArgIndex)->getBeginLoc(),
3676 DiagID: diag::err_typecheck_expect_int)
3677 << ArgType;
3678 return true;
3679 }
3680
3681 QualType ResourceArgTy = TheCall->getArg(Arg: 0)->getType();
3682 const HLSLAttributedResourceType *ResTy =
3683 ResourceArgTy.getTypePtr()->getAs<HLSLAttributedResourceType>();
3684 assert(ResTy && "Resource argument must be a resource");
3685 HLSLAttributedResourceType::Attributes ResAttrs = ResTy->getAttrs();
3686
3687 unsigned int ExpectedDim = 1;
3688 if (ResAttrs.ResourceDimension != llvm::dxil::ResourceDimension::Unknown)
3689 ExpectedDim = getResourceDimensions(Dim: ResAttrs.ResourceDimension) +
3690 (ResAttrs.IsArray ? 1 : 0);
3691
3692 if (ActualDim != ExpectedDim) {
3693 S->Diag(Loc: TheCall->getArg(Arg: IndexArgIndex)->getBeginLoc(),
3694 DiagID: diag::err_hlsl_builtin_resource_coordinate_dimension_mismatch)
3695 << cast<NamedDecl>(Val: TheCall->getCalleeDecl()) << ExpectedDim
3696 << ActualDim;
3697 return true;
3698 }
3699
3700 return false;
3701}
3702
3703static bool CheckResourceHandle(
3704 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
3705 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
3706 nullptr) {
3707 assert(TheCall->getNumArgs() >= ArgIndex);
3708 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3709 const HLSLAttributedResourceType *ResTy =
3710 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
3711 if (!ResTy) {
3712 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getBeginLoc(),
3713 DiagID: diag::err_typecheck_expect_hlsl_resource)
3714 << ArgType;
3715 return true;
3716 }
3717 if (Check && Check(ResTy)) {
3718 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getExprLoc(),
3719 DiagID: diag::err_invalid_hlsl_resource_type)
3720 << ArgType;
3721 return true;
3722 }
3723 return false;
3724}
3725
3726static bool CheckVectorElementCount(Sema *S, QualType PassedType,
3727 QualType BaseType, unsigned ExpectedCount,
3728 SourceLocation Loc) {
3729 unsigned PassedCount = 1;
3730 if (const auto *VecTy = PassedType->getAs<VectorType>())
3731 PassedCount = VecTy->getNumElements();
3732
3733 if (PassedCount != ExpectedCount) {
3734 QualType ExpectedType =
3735 S->Context.getExtVectorType(VectorType: BaseType, NumElts: ExpectedCount);
3736 S->Diag(Loc, DiagID: diag::err_typecheck_convert_incompatible)
3737 << PassedType << ExpectedType << 1 << 0 << 0;
3738 return true;
3739 }
3740 return false;
3741}
3742
3743enum class SampleKind { Sample, Bias, Grad, Level, Cmp, CmpLevelZero };
3744
3745static bool CheckTextureSamplerAndLocation(Sema &S, CallExpr *TheCall,
3746 bool IncludeArraySlice = true) {
3747 // Check the texture handle.
3748 if (CheckResourceHandle(S: &S, TheCall, ArgIndex: 0,
3749 Check: [](const HLSLAttributedResourceType *ResType) {
3750 return ResType->getAttrs().ResourceDimension ==
3751 llvm::dxil::ResourceDimension::Unknown;
3752 }))
3753 return true;
3754
3755 // Check the sampler handle.
3756 if (CheckResourceHandle(S: &S, TheCall, ArgIndex: 1,
3757 Check: [](const HLSLAttributedResourceType *ResType) {
3758 return ResType->getAttrs().ResourceClass !=
3759 llvm::hlsl::ResourceClass::Sampler;
3760 }))
3761 return true;
3762
3763 auto *ResourceTy =
3764 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3765
3766 // Check the location.
3767 unsigned ExpectedDim =
3768 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension) +
3769 (IncludeArraySlice && ResourceTy->getAttrs().IsArray ? 1 : 0);
3770 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: 2)->getType(),
3771 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3772 Loc: TheCall->getBeginLoc()))
3773 return true;
3774
3775 return false;
3776}
3777
3778static bool CheckCalculateLodBuiltin(Sema &S, CallExpr *TheCall) {
3779 if (S.checkArgCount(Call: TheCall, DesiredArgCount: 3))
3780 return true;
3781
3782 // CalculateLevelOfDetail location uses resource dimension only (e.g. float2
3783 // for 2D), not an extra array slice component like Sample/Gather.
3784 if (CheckTextureSamplerAndLocation(S, TheCall, /*IncludeArraySlice=*/false))
3785 return true;
3786
3787 TheCall->setType(S.Context.FloatTy);
3788 return false;
3789}
3790
3791static bool CheckGatherBuiltin(Sema &S, CallExpr *TheCall, bool IsCmp) {
3792 if (S.checkArgCountRange(Call: TheCall, MinArgCount: IsCmp ? 5 : 4, MaxArgCount: IsCmp ? 6 : 5))
3793 return true;
3794
3795 if (CheckTextureSamplerAndLocation(S, TheCall))
3796 return true;
3797
3798 unsigned NextIdx = 3;
3799 if (IsCmp) {
3800 // Check the compare value.
3801 QualType CmpTy = TheCall->getArg(Arg: NextIdx)->getType();
3802 if (!CmpTy->isFloatingType() || CmpTy->isVectorType()) {
3803 S.Diag(Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc(),
3804 DiagID: diag::err_typecheck_convert_incompatible)
3805 << CmpTy << S.Context.FloatTy << 1 << 0 << 0;
3806 return true;
3807 }
3808 NextIdx++;
3809 }
3810
3811 // Check the component operand.
3812 Expr *ComponentArg = TheCall->getArg(Arg: NextIdx);
3813 QualType ComponentTy = ComponentArg->getType();
3814 if (!ComponentTy->isIntegerType() || ComponentTy->isVectorType()) {
3815 S.Diag(Loc: ComponentArg->getBeginLoc(),
3816 DiagID: diag::err_typecheck_convert_incompatible)
3817 << ComponentTy << S.Context.UnsignedIntTy << 1 << 0 << 0;
3818 return true;
3819 }
3820
3821 // GatherCmp operations on Vulkan target must use component 0 (Red).
3822 if (IsCmp && S.getASTContext().getTargetInfo().getTriple().isSPIRV()) {
3823 std::optional<llvm::APSInt> ComponentOpt =
3824 ComponentArg->getIntegerConstantExpr(Ctx: S.getASTContext());
3825 if (ComponentOpt) {
3826 int64_t ComponentVal = ComponentOpt->getSExtValue();
3827 if (ComponentVal != 0) {
3828 // Issue an error if the component is not 0 (Red).
3829 // 0 -> Red, 1 -> Green, 2 -> Blue, 3 -> Alpha
3830 assert(ComponentVal >= 0 && ComponentVal <= 3 &&
3831 "The component is not in the expected range.");
3832 S.Diag(Loc: ComponentArg->getBeginLoc(),
3833 DiagID: diag::err_hlsl_gathercmp_invalid_component)
3834 << ComponentVal;
3835 return true;
3836 }
3837 }
3838 }
3839
3840 NextIdx++;
3841
3842 // Check the offset operand.
3843 const HLSLAttributedResourceType *ResourceTy =
3844 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3845 if (TheCall->getNumArgs() > NextIdx) {
3846 unsigned ExpectedDim =
3847 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3848 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx)->getType(),
3849 BaseType: S.Context.IntTy, ExpectedCount: ExpectedDim,
3850 Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc()))
3851 return true;
3852 NextIdx++;
3853 }
3854
3855 assert(ResourceTy->hasContainedType() &&
3856 "Expecting a contained type for resource with a dimension "
3857 "attribute.");
3858 QualType ReturnType = ResourceTy->getContainedType();
3859
3860 if (IsCmp) {
3861 if (!ReturnType->hasFloatingRepresentation()) {
3862 S.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_hlsl_samplecmp_requires_float);
3863 return true;
3864 }
3865 }
3866
3867 if (const auto *VecTy = ReturnType->getAs<VectorType>())
3868 ReturnType = VecTy->getElementType();
3869 ReturnType = S.Context.getExtVectorType(VectorType: ReturnType, NumElts: 4);
3870
3871 TheCall->setType(ReturnType);
3872
3873 return false;
3874}
3875static bool CheckLoadLevelBuiltin(Sema &S, CallExpr *TheCall) {
3876 if (S.checkArgCountRange(Call: TheCall, MinArgCount: 2, MaxArgCount: 3))
3877 return true;
3878
3879 // Check the texture handle.
3880 if (CheckResourceHandle(S: &S, TheCall, ArgIndex: 0,
3881 Check: [](const HLSLAttributedResourceType *ResType) {
3882 return ResType->getAttrs().ResourceDimension ==
3883 llvm::dxil::ResourceDimension::Unknown;
3884 }))
3885 return true;
3886
3887 auto *ResourceTy =
3888 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3889
3890 // Check the location + lod (int3 for Texture2D, int4 for Texture2DArray).
3891 unsigned ResourceDim =
3892 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3893 unsigned LocationDim = ResourceDim + (ResourceTy->getAttrs().IsArray ? 1 : 0);
3894 QualType CoordLODTy = TheCall->getArg(Arg: 1)->getType();
3895 if (CheckVectorElementCount(S: &S, PassedType: CoordLODTy, BaseType: S.Context.IntTy, ExpectedCount: LocationDim + 1,
3896 Loc: TheCall->getArg(Arg: 1)->getBeginLoc()))
3897 return true;
3898
3899 QualType EltTy = CoordLODTy;
3900 if (const auto *VTy = EltTy->getAs<VectorType>())
3901 EltTy = VTy->getElementType();
3902 if (!EltTy->isIntegerType()) {
3903 S.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(), DiagID: diag::err_typecheck_expect_int)
3904 << CoordLODTy;
3905 return true;
3906 }
3907
3908 // Check the offset operand (int2 for 2D textures; no array slice).
3909 if (TheCall->getNumArgs() > 2) {
3910 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: 2)->getType(),
3911 BaseType: S.Context.IntTy, ExpectedCount: ResourceDim,
3912 Loc: TheCall->getArg(Arg: 2)->getBeginLoc()))
3913 return true;
3914 }
3915
3916 TheCall->setType(ResourceTy->getContainedType());
3917 return false;
3918}
3919
3920static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind) {
3921 unsigned MinArgs, MaxArgs;
3922 if (Kind == SampleKind::Sample) {
3923 MinArgs = 3;
3924 MaxArgs = 5;
3925 } else if (Kind == SampleKind::Bias) {
3926 MinArgs = 4;
3927 MaxArgs = 6;
3928 } else if (Kind == SampleKind::Grad) {
3929 MinArgs = 5;
3930 MaxArgs = 7;
3931 } else if (Kind == SampleKind::Level) {
3932 MinArgs = 4;
3933 MaxArgs = 5;
3934 } else if (Kind == SampleKind::Cmp) {
3935 MinArgs = 4;
3936 MaxArgs = 6;
3937 } else {
3938 assert(Kind == SampleKind::CmpLevelZero);
3939 MinArgs = 4;
3940 MaxArgs = 5;
3941 }
3942
3943 if (S.checkArgCountRange(Call: TheCall, MinArgCount: MinArgs, MaxArgCount: MaxArgs))
3944 return true;
3945
3946 if (CheckTextureSamplerAndLocation(S, TheCall))
3947 return true;
3948
3949 const HLSLAttributedResourceType *ResourceTy =
3950 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3951 unsigned ExpectedDim =
3952 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3953
3954 unsigned NextIdx = 3;
3955 if (Kind == SampleKind::Bias || Kind == SampleKind::Level ||
3956 Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3957 // Check the bias, lod level, or compare value, depending on the kind.
3958 // All of them must be a scalar float value.
3959 QualType BiasOrLODOrCmpTy = TheCall->getArg(Arg: NextIdx)->getType();
3960 if (!BiasOrLODOrCmpTy->isFloatingType() ||
3961 BiasOrLODOrCmpTy->isVectorType()) {
3962 S.Diag(Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc(),
3963 DiagID: diag::err_typecheck_convert_incompatible)
3964 << BiasOrLODOrCmpTy << S.Context.FloatTy << 1 << 0 << 0;
3965 return true;
3966 }
3967 NextIdx++;
3968 } else if (Kind == SampleKind::Grad) {
3969 // Check the DDX operand.
3970 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx)->getType(),
3971 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3972 Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc()))
3973 return true;
3974
3975 // Check the DDY operand.
3976 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx + 1)->getType(),
3977 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3978 Loc: TheCall->getArg(Arg: NextIdx + 1)->getBeginLoc()))
3979 return true;
3980 NextIdx += 2;
3981 }
3982
3983 // Check the offset operand.
3984 if (TheCall->getNumArgs() > NextIdx) {
3985 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx)->getType(),
3986 BaseType: S.Context.IntTy, ExpectedCount: ExpectedDim,
3987 Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc()))
3988 return true;
3989 NextIdx++;
3990 }
3991
3992 // Check the clamp operand.
3993 if (Kind != SampleKind::Level && Kind != SampleKind::CmpLevelZero &&
3994 TheCall->getNumArgs() > NextIdx) {
3995 QualType ClampTy = TheCall->getArg(Arg: NextIdx)->getType();
3996 if (!ClampTy->isFloatingType() || ClampTy->isVectorType()) {
3997 S.Diag(Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc(),
3998 DiagID: diag::err_typecheck_convert_incompatible)
3999 << ClampTy << S.Context.FloatTy << 1 << 0 << 0;
4000 return true;
4001 }
4002 }
4003
4004 assert(ResourceTy->hasContainedType() &&
4005 "Expecting a contained type for resource with a dimension "
4006 "attribute.");
4007 QualType ReturnType = ResourceTy->getContainedType();
4008 if (Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
4009 if (!ReturnType->hasFloatingRepresentation()) {
4010 S.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_hlsl_samplecmp_requires_float);
4011 return true;
4012 }
4013 ReturnType = S.Context.FloatTy;
4014 }
4015 TheCall->setType(ReturnType);
4016
4017 return false;
4018}
4019
4020// Note: returning true in this case results in CheckBuiltinFunctionCall
4021// returning an ExprError
4022bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
4023 switch (BuiltinID) {
4024 case Builtin::BI__builtin_hlsl_adduint64: {
4025 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
4026 return true;
4027
4028 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4029 Check: CheckUnsignedIntVecRepresentation))
4030 return true;
4031
4032 // ensure arg integers are 32-bits
4033 if (CheckExpectedBitWidth(S: &SemaRef, TheCall, ArgOrdinal: 0, Width: 32))
4034 return true;
4035
4036 // ensure both args are vectors of total bit size of a multiple of 64
4037 auto *VTy = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
4038 int NumElementsArg = VTy->getNumElements();
4039 if (NumElementsArg != 2 && NumElementsArg != 4) {
4040 SemaRef.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vector_incorrect_bit_count)
4041 << 1 /*a multiple of*/ << 64 << NumElementsArg * 32;
4042 return true;
4043 }
4044
4045 // ensure first arg and second arg have the same type
4046 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
4047 return true;
4048
4049 ExprResult A = TheCall->getArg(Arg: 0);
4050 QualType ArgTyA = A.get()->getType();
4051 // return type is the same as the input type
4052 TheCall->setType(ArgTyA);
4053 break;
4054 }
4055 case Builtin::BI__builtin_hlsl_resource_getpointer: {
4056 if (SemaRef.checkArgCountRange(Call: TheCall, MinArgCount: 1, MaxArgCount: 2) ||
4057 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
4058 (TheCall->getNumArgs() == 2 && CheckIndexType(S: &SemaRef, TheCall, IndexArgIndex: 1)))
4059 return true;
4060
4061 auto *ResourceTy =
4062 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
4063 QualType ContainedTy = ResourceTy->getContainedType();
4064 auto ReturnType = SemaRef.Context.getAddrSpaceQualType(
4065 T: ContainedTy,
4066 AddressSpace: getLangASFromResourceClass(RC: ResourceTy->getAttrs().ResourceClass));
4067 ReturnType = SemaRef.Context.getPointerType(T: ReturnType);
4068 TheCall->setType(ReturnType);
4069
4070 break;
4071 }
4072 case Builtin::BI__builtin_hlsl_resource_getpointer_typed: {
4073 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3) ||
4074 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
4075 CheckIndexType(S: &SemaRef, TheCall, IndexArgIndex: 1))
4076 return true;
4077
4078 QualType ElementTy = TheCall->getArg(Arg: 2)->getType();
4079 assert(ElementTy->isPointerType() &&
4080 "expected pointer type for second argument");
4081 ElementTy = ElementTy->getPointeeType();
4082
4083 // Reject array types
4084 if (ElementTy->isArrayType())
4085 return SemaRef.Diag(
4086 Loc: cast<FunctionDecl>(Val: SemaRef.CurContext)->getPointOfInstantiation(),
4087 DiagID: diag::err_invalid_use_of_array_type);
4088
4089 auto *ResourceTy =
4090 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
4091 auto ReturnType = SemaRef.Context.getAddrSpaceQualType(
4092 T: ElementTy,
4093 AddressSpace: getLangASFromResourceClass(RC: ResourceTy->getAttrs().ResourceClass));
4094 ReturnType = SemaRef.Context.getPointerType(T: ReturnType);
4095 TheCall->setType(ReturnType);
4096
4097 break;
4098 }
4099 case Builtin::BI__builtin_hlsl_resource_load_with_status: {
4100 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3) ||
4101 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
4102 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
4103 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
4104 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2),
4105 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
4106 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
4107 return true;
4108
4109 auto *ResourceTy =
4110 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
4111 QualType ReturnType = ResourceTy->getContainedType();
4112 TheCall->setType(ReturnType);
4113
4114 break;
4115 }
4116 case Builtin::BI__builtin_hlsl_resource_load_with_status_typed: {
4117 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 4) ||
4118 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
4119 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
4120 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
4121 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2),
4122 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
4123 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
4124 return true;
4125
4126 QualType ReturnType = TheCall->getArg(Arg: 3)->getType();
4127 assert(ReturnType->isPointerType() &&
4128 "expected pointer type for second argument");
4129 ReturnType = ReturnType->getPointeeType();
4130
4131 // Reject array types
4132 if (ReturnType->isArrayType())
4133 return SemaRef.Diag(
4134 Loc: cast<FunctionDecl>(Val: SemaRef.CurContext)->getPointOfInstantiation(),
4135 DiagID: diag::err_invalid_use_of_array_type);
4136
4137 TheCall->setType(ReturnType);
4138
4139 break;
4140 }
4141 case Builtin::BI__builtin_hlsl_resource_load_level:
4142 return CheckLoadLevelBuiltin(S&: SemaRef, TheCall);
4143 case Builtin::BI__builtin_hlsl_resource_sample:
4144 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Sample);
4145 case Builtin::BI__builtin_hlsl_resource_sample_bias:
4146 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Bias);
4147 case Builtin::BI__builtin_hlsl_resource_sample_grad:
4148 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Grad);
4149 case Builtin::BI__builtin_hlsl_resource_sample_level:
4150 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Level);
4151 case Builtin::BI__builtin_hlsl_resource_sample_cmp:
4152 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Cmp);
4153 case Builtin::BI__builtin_hlsl_resource_sample_cmp_level_zero:
4154 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::CmpLevelZero);
4155 case Builtin::BI__builtin_hlsl_resource_calculate_lod:
4156 case Builtin::BI__builtin_hlsl_resource_calculate_lod_unclamped:
4157 return CheckCalculateLodBuiltin(S&: SemaRef, TheCall);
4158 case Builtin::BI__builtin_hlsl_resource_gather:
4159 return CheckGatherBuiltin(S&: SemaRef, TheCall, /*IsCmp=*/false);
4160 case Builtin::BI__builtin_hlsl_resource_gather_cmp:
4161 return CheckGatherBuiltin(S&: SemaRef, TheCall, /*IsCmp=*/true);
4162 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
4163 assert(TheCall->getNumArgs() == 1 && "expected 1 arg");
4164 // Update return type to be the attributed resource type from arg0.
4165 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
4166 TheCall->setType(ResourceTy);
4167 break;
4168 }
4169 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
4170 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
4171 // Update return type to be the attributed resource type from arg0.
4172 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
4173 TheCall->setType(ResourceTy);
4174 break;
4175 }
4176 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
4177 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
4178 // Update return type to be the attributed resource type from arg0.
4179 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
4180 TheCall->setType(ResourceTy);
4181 break;
4182 }
4183 case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: {
4184 assert(TheCall->getNumArgs() == 3 && "expected 3 args");
4185 ASTContext &AST = SemaRef.getASTContext();
4186 QualType MainHandleTy = TheCall->getArg(Arg: 0)->getType();
4187 auto *MainResType = MainHandleTy->getAs<HLSLAttributedResourceType>();
4188 auto MainAttrs = MainResType->getAttrs();
4189 assert(!MainAttrs.IsCounter && "cannot create a counter from a counter");
4190 MainAttrs.IsCounter = true;
4191 QualType CounterHandleTy = AST.getHLSLAttributedResourceType(
4192 Wrapped: MainResType->getWrappedType(), Contained: MainResType->getContainedType(),
4193 Attrs: MainAttrs);
4194 // Update return type to be the attributed resource type from arg0
4195 // with added IsCounter flag.
4196 TheCall->setType(CounterHandleTy);
4197 break;
4198 }
4199 case Builtin::BI__builtin_hlsl_and:
4200 case Builtin::BI__builtin_hlsl_or: {
4201 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
4202 return true;
4203 if (CheckScalarOrVectorOrMatrix(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy,
4204 ArgIndex: 0))
4205 return true;
4206 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
4207 return true;
4208
4209 ExprResult A = TheCall->getArg(Arg: 0);
4210 QualType ArgTyA = A.get()->getType();
4211 // return type is the same as the input type
4212 TheCall->setType(ArgTyA);
4213 break;
4214 }
4215 case Builtin::BI__builtin_hlsl_all:
4216 case Builtin::BI__builtin_hlsl_any: {
4217 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4218 return true;
4219 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4220 return true;
4221 break;
4222 }
4223 case Builtin::BI__builtin_hlsl_asdouble: {
4224 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
4225 return true;
4226 if (CheckScalarOrVector(
4227 S: &SemaRef, TheCall,
4228 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
4229 /* arg index */ ArgIndex: 0))
4230 return true;
4231 if (CheckScalarOrVector(
4232 S: &SemaRef, TheCall,
4233 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
4234 /* arg index */ ArgIndex: 1))
4235 return true;
4236 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
4237 return true;
4238
4239 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().DoubleTy);
4240 break;
4241 }
4242 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
4243 if (SemaRef.BuiltinElementwiseTernaryMath(
4244 TheCall, /*ArgTyRestr=*/
4245 Sema::EltwiseBuiltinArgTyRestriction::None))
4246 return true;
4247 break;
4248 }
4249 case Builtin::BI__builtin_hlsl_dot: {
4250 // arg count is checked by BuiltinVectorToScalarMath
4251 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
4252 return true;
4253 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckNoDoubleVectors))
4254 return true;
4255 break;
4256 }
4257 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
4258 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
4259 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4260 return true;
4261
4262 const Expr *Arg = TheCall->getArg(Arg: 0);
4263 QualType ArgTy = Arg->getType();
4264 QualType EltTy = ArgTy;
4265
4266 QualType ResTy = SemaRef.Context.UnsignedIntTy;
4267
4268 if (auto *VecTy = EltTy->getAs<VectorType>()) {
4269 EltTy = VecTy->getElementType();
4270 ResTy = SemaRef.Context.getExtVectorType(VectorType: ResTy, NumElts: VecTy->getNumElements());
4271 }
4272
4273 if (!EltTy->isIntegerType()) {
4274 Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
4275 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
4276 << /* no fp */ 0 << ArgTy;
4277 return true;
4278 }
4279
4280 TheCall->setType(ResTy);
4281 break;
4282 }
4283 case Builtin::BI__builtin_hlsl_select: {
4284 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
4285 return true;
4286 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy, ArgIndex: 0))
4287 return true;
4288 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
4289 if (ArgTy->isBooleanType() && CheckBoolSelect(S: &SemaRef, TheCall))
4290 return true;
4291 auto *VTy = ArgTy->getAs<VectorType>();
4292 if (VTy && VTy->getElementType()->isBooleanType() &&
4293 CheckVectorSelect(S: &SemaRef, TheCall))
4294 return true;
4295 break;
4296 }
4297 case Builtin::BI__builtin_hlsl_elementwise_saturate:
4298 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
4299 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4300 return true;
4301 if (!TheCall->getArg(Arg: 0)
4302 ->getType()
4303 ->hasFloatingRepresentation()) // half or float or double
4304 return SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4305 DiagID: diag::err_builtin_invalid_arg_type)
4306 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
4307 << /* fp */ 1 << TheCall->getArg(Arg: 0)->getType();
4308 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4309 return true;
4310 break;
4311 }
4312 case Builtin::BI__builtin_hlsl_elementwise_degrees:
4313 case Builtin::BI__builtin_hlsl_elementwise_radians:
4314 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
4315 case Builtin::BI__builtin_hlsl_elementwise_frac:
4316 case Builtin::BI__builtin_hlsl_elementwise_ddx_coarse:
4317 case Builtin::BI__builtin_hlsl_elementwise_ddy_coarse:
4318 case Builtin::BI__builtin_hlsl_elementwise_ddx_fine:
4319 case Builtin::BI__builtin_hlsl_elementwise_ddy_fine: {
4320 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4321 return true;
4322 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4323 Check: CheckFloatOrHalfRepresentation))
4324 return true;
4325 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4326 return true;
4327 break;
4328 }
4329 case Builtin::BI__builtin_hlsl_elementwise_isinf:
4330 case Builtin::BI__builtin_hlsl_elementwise_isnan: {
4331 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4332 return true;
4333 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4334 Check: CheckFloatOrHalfRepresentation))
4335 return true;
4336 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4337 return true;
4338 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().BoolTy);
4339 break;
4340 }
4341 case Builtin::BI__builtin_hlsl_lerp: {
4342 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
4343 return true;
4344 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4345 Check: CheckFloatOrHalfRepresentation))
4346 return true;
4347 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
4348 return true;
4349 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
4350 return true;
4351 break;
4352 }
4353 case Builtin::BI__builtin_hlsl_mad: {
4354 if (SemaRef.BuiltinElementwiseTernaryMath(
4355 TheCall, /*ArgTyRestr=*/
4356 Sema::EltwiseBuiltinArgTyRestriction::None))
4357 return true;
4358 break;
4359 }
4360 case Builtin::BI__builtin_hlsl_mul: {
4361 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
4362 return true;
4363
4364 Expr *Arg0 = TheCall->getArg(Arg: 0);
4365 Expr *Arg1 = TheCall->getArg(Arg: 1);
4366 QualType Ty0 = Arg0->getType();
4367 QualType Ty1 = Arg1->getType();
4368
4369 auto getElemType = [](QualType T) -> QualType {
4370 if (const auto *VTy = T->getAs<VectorType>())
4371 return VTy->getElementType();
4372 if (const auto *MTy = T->getAs<ConstantMatrixType>())
4373 return MTy->getElementType();
4374 return T;
4375 };
4376
4377 QualType EltTy0 = getElemType(Ty0);
4378
4379 bool IsVec0 = Ty0->isVectorType();
4380 bool IsMat0 = Ty0->isConstantMatrixType();
4381 bool IsVec1 = Ty1->isVectorType();
4382 bool IsMat1 = Ty1->isConstantMatrixType();
4383
4384 QualType RetTy;
4385
4386 if (IsVec0 && IsMat1) {
4387 auto *MatTy = Ty1->castAs<ConstantMatrixType>();
4388 RetTy = getASTContext().getExtVectorType(VectorType: EltTy0, NumElts: MatTy->getNumColumns());
4389 } else if (IsMat0 && IsVec1) {
4390 auto *MatTy = Ty0->castAs<ConstantMatrixType>();
4391 RetTy = getASTContext().getExtVectorType(VectorType: EltTy0, NumElts: MatTy->getNumRows());
4392 } else {
4393 assert(IsMat0 && IsMat1);
4394 auto *MatTy0 = Ty0->castAs<ConstantMatrixType>();
4395 auto *MatTy1 = Ty1->castAs<ConstantMatrixType>();
4396 RetTy = getASTContext().getConstantMatrixType(
4397 ElementType: EltTy0, NumRows: MatTy0->getNumRows(), NumColumns: MatTy1->getNumColumns());
4398 }
4399
4400 TheCall->setType(RetTy);
4401 break;
4402 }
4403 case Builtin::BI__builtin_hlsl_normalize: {
4404 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4405 return true;
4406 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4407 Check: CheckFloatOrHalfRepresentation))
4408 return true;
4409 ExprResult A = TheCall->getArg(Arg: 0);
4410 QualType ArgTyA = A.get()->getType();
4411 // return type is the same as the input type
4412 TheCall->setType(ArgTyA);
4413 break;
4414 }
4415 case Builtin::BI__builtin_elementwise_fma: {
4416 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3) ||
4417 CheckAllArgsHaveSameType(S: &SemaRef, TheCall)) {
4418 return true;
4419 }
4420
4421 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4422 Check: CheckAnyDoubleRepresentation))
4423 return true;
4424
4425 ExprResult A = TheCall->getArg(Arg: 0);
4426 QualType ArgTyA = A.get()->getType();
4427 // return type is the same as input type
4428 TheCall->setType(ArgTyA);
4429 break;
4430 }
4431 case Builtin::BI__builtin_hlsl_transpose: {
4432 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4433 return true;
4434
4435 Expr *Arg = TheCall->getArg(Arg: 0);
4436 QualType ArgTy = Arg->getType();
4437
4438 const auto *MatTy = ArgTy->getAs<ConstantMatrixType>();
4439 if (!MatTy) {
4440 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
4441 << 1 << /* matrix */ 3 << /* no int */ 0 << /* no fp */ 0 << ArgTy;
4442 return true;
4443 }
4444
4445 QualType RetTy = getASTContext().getConstantMatrixType(
4446 ElementType: MatTy->getElementType(), NumRows: MatTy->getNumColumns(), NumColumns: MatTy->getNumRows());
4447 TheCall->setType(RetTy);
4448 break;
4449 }
4450 case Builtin::BI__builtin_hlsl_elementwise_sign: {
4451 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4452 return true;
4453 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4454 Check: CheckFloatingOrIntRepresentation))
4455 return true;
4456 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().IntTy);
4457 break;
4458 }
4459 case Builtin::BI__builtin_hlsl_step: {
4460 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
4461 return true;
4462 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4463 Check: CheckFloatOrHalfRepresentation))
4464 return true;
4465
4466 ExprResult A = TheCall->getArg(Arg: 0);
4467 QualType ArgTyA = A.get()->getType();
4468 // return type is the same as the input type
4469 TheCall->setType(ArgTyA);
4470 break;
4471 }
4472 case Builtin::BI__builtin_hlsl_wave_active_all_equal: {
4473 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4474 return true;
4475
4476 // Ensure input expr type is a scalar/vector
4477 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4478 return true;
4479
4480 QualType InputTy = TheCall->getArg(Arg: 0)->getType();
4481 ASTContext &Ctx = getASTContext();
4482
4483 QualType RetTy;
4484
4485 // If vector, construct bool vector of same size
4486 if (const auto *VecTy = InputTy->getAs<ExtVectorType>()) {
4487 unsigned NumElts = VecTy->getNumElements();
4488 RetTy = Ctx.getExtVectorType(VectorType: Ctx.BoolTy, NumElts);
4489 } else {
4490 // Scalar case
4491 RetTy = Ctx.BoolTy;
4492 }
4493
4494 TheCall->setType(RetTy);
4495 break;
4496 }
4497 case Builtin::BI__builtin_hlsl_wave_active_max:
4498 case Builtin::BI__builtin_hlsl_wave_active_min:
4499 case Builtin::BI__builtin_hlsl_wave_active_sum:
4500 case Builtin::BI__builtin_hlsl_wave_active_product: {
4501 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4502 return true;
4503
4504 // Ensure input expr type is a scalar/vector and the same as the return type
4505 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4506 return true;
4507 if (CheckWaveActive(S: &SemaRef, TheCall))
4508 return true;
4509 ExprResult Expr = TheCall->getArg(Arg: 0);
4510 QualType ArgTyExpr = Expr.get()->getType();
4511 TheCall->setType(ArgTyExpr);
4512 break;
4513 }
4514 case Builtin::BI__builtin_hlsl_wave_active_bit_or:
4515 case Builtin::BI__builtin_hlsl_wave_active_bit_xor:
4516 case Builtin::BI__builtin_hlsl_wave_active_bit_and: {
4517 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4518 return true;
4519
4520 // Ensure input expr type is a scalar/vector
4521 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4522 return true;
4523
4524 if (CheckWaveActive(S: &SemaRef, TheCall))
4525 return true;
4526
4527 // Ensure the expr type is interpretable as a uint or vector<uint>
4528 ExprResult Expr = TheCall->getArg(Arg: 0);
4529 QualType ArgTyExpr = Expr.get()->getType();
4530 auto *VTy = ArgTyExpr->getAs<VectorType>();
4531 if (!(ArgTyExpr->isIntegerType() ||
4532 (VTy && VTy->getElementType()->isIntegerType()))) {
4533 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4534 DiagID: diag::err_builtin_invalid_arg_type)
4535 << ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4536 return true;
4537 }
4538
4539 // Ensure input expr type is the same as the return type
4540 TheCall->setType(ArgTyExpr);
4541 break;
4542 }
4543 case Builtin::BI__builtin_hlsl_interlocked_add: {
4544 // The builtin's prototype in Builtins.td is `void (...)`, so direct calls
4545 // to `__builtin_hlsl_interlocked_add` bypass argument checking entirely.
4546 // When reached via the synthesized `InterlockedAdd` overload set in
4547 // HLSLExternalSemaSource, overload resolution has already enforced the
4548 // argument count, integer-type matching, and the address-space requirement
4549 // on `dest`. The checks below are a safety net for callers that invoke the
4550 // builtin by its mangled name and would otherwise reach CodeGen unchecked.
4551 if (TheCall->getNumArgs() < 2) {
4552 SemaRef.Diag(Loc: TheCall->getEndLoc(),
4553 DiagID: diag::err_typecheck_call_too_few_args_at_least)
4554 << /*callee_type=*/0 << /*min_arg_count=*/2 << TheCall->getNumArgs()
4555 << /*is_non_object=*/0 << TheCall->getSourceRange();
4556 return true;
4557 }
4558 if (SemaRef.checkArgCountAtMost(Call: TheCall, MaxArgCount: 3))
4559 return true;
4560
4561 QualType DestTy = TheCall->getArg(Arg: 0)->getType().getUnqualifiedType();
4562 if (!DestTy->isIntegerType()) {
4563 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4564 DiagID: diag::err_builtin_invalid_arg_type)
4565 << /*ordinal=*/1 << /*scalar*/ 1 << /*integer*/ 1 << /*no float*/ 0
4566 << DestTy;
4567 return true;
4568 }
4569
4570 if (CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 0))
4571 return true;
4572
4573 if (CheckArgAddrSpaceOneOf(S: &SemaRef, TheCall, ArgIndex: 0,
4574 AllowedSpaces: {LangAS::hlsl_groupshared, LangAS::hlsl_device}))
4575 return true;
4576
4577 if (CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1), ExpectedType: DestTy))
4578 return true;
4579
4580 if (TheCall->getNumArgs() == 3) {
4581 if (CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2), ExpectedType: DestTy))
4582 return true;
4583 if (CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
4584 return true;
4585 }
4586
4587 TheCall->setType(SemaRef.Context.VoidTy);
4588 break;
4589 }
4590 // Note these are llvm builtins that we want to catch invalid intrinsic
4591 // generation. Normal handling of these builtins will occur elsewhere.
4592 case Builtin::BI__builtin_elementwise_bitreverse: {
4593 // does not include a check for number of arguments
4594 // because that is done previously
4595 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4596 Check: CheckUnsignedIntRepresentation))
4597 return true;
4598 break;
4599 }
4600 case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: {
4601 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4602 return true;
4603
4604 QualType ArgType = TheCall->getArg(Arg: 0)->getType();
4605
4606 if (!(ArgType->isScalarType())) {
4607 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4608 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
4609 << ArgType << 0;
4610 return true;
4611 }
4612
4613 if (!(ArgType->isBooleanType())) {
4614 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4615 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
4616 << ArgType << 0;
4617 return true;
4618 }
4619
4620 break;
4621 }
4622 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
4623 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
4624 return true;
4625
4626 // Ensure index parameter type can be interpreted as a uint
4627 ExprResult Index = TheCall->getArg(Arg: 1);
4628 QualType ArgTyIndex = Index.get()->getType();
4629 if (!ArgTyIndex->isIntegerType()) {
4630 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
4631 DiagID: diag::err_typecheck_convert_incompatible)
4632 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4633 return true;
4634 }
4635
4636 // Ensure input expr type is a scalar/vector and the same as the return type
4637 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4638 return true;
4639
4640 ExprResult Expr = TheCall->getArg(Arg: 0);
4641 QualType ArgTyExpr = Expr.get()->getType();
4642 TheCall->setType(ArgTyExpr);
4643 break;
4644 }
4645 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
4646 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 0))
4647 return true;
4648 break;
4649 }
4650 case Builtin::BI__builtin_hlsl_wave_prefix_sum:
4651 case Builtin::BI__builtin_hlsl_wave_prefix_product: {
4652 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4653 return true;
4654
4655 // Ensure input expr type is a scalar/vector and the same as the return type
4656 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4657 return true;
4658 if (CheckWavePrefix(S: &SemaRef, TheCall))
4659 return true;
4660 ExprResult Expr = TheCall->getArg(Arg: 0);
4661 QualType ArgTyExpr = Expr.get()->getType();
4662 TheCall->setType(ArgTyExpr);
4663 break;
4664 }
4665 case Builtin::BI__builtin_hlsl_quad_read_across_x:
4666 case Builtin::BI__builtin_hlsl_quad_read_across_y:
4667 case Builtin::BI__builtin_hlsl_quad_read_across_diagonal: {
4668 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4669 return true;
4670
4671 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4672 return true;
4673 if (CheckNotBoolScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4674 return true;
4675 ExprResult Expr = TheCall->getArg(Arg: 0);
4676 QualType ArgTyExpr = Expr.get()->getType();
4677 TheCall->setType(ArgTyExpr);
4678 break;
4679 }
4680 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
4681 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
4682 return true;
4683
4684 if (CheckScalarOrVectorOrMatrix(S: &SemaRef, TheCall, Scalar: SemaRef.Context.DoubleTy,
4685 ArgIndex: 0) ||
4686 CheckScalarOrVectorOrMatrix(S: &SemaRef, TheCall,
4687 Scalar: SemaRef.Context.UnsignedIntTy, ArgIndex: 1) ||
4688 CheckScalarOrVectorOrMatrix(S: &SemaRef, TheCall,
4689 Scalar: SemaRef.Context.UnsignedIntTy, ArgIndex: 2))
4690 return true;
4691
4692 if (CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 1) ||
4693 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
4694 return true;
4695 break;
4696 }
4697 case Builtin::BI__builtin_hlsl_elementwise_clip: {
4698 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4699 return true;
4700
4701 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.FloatTy, ArgIndex: 0))
4702 return true;
4703 break;
4704 }
4705 case Builtin::BI__builtin_elementwise_acos:
4706 case Builtin::BI__builtin_elementwise_asin:
4707 case Builtin::BI__builtin_elementwise_atan:
4708 case Builtin::BI__builtin_elementwise_atan2:
4709 case Builtin::BI__builtin_elementwise_ceil:
4710 case Builtin::BI__builtin_elementwise_cos:
4711 case Builtin::BI__builtin_elementwise_cosh:
4712 case Builtin::BI__builtin_elementwise_exp:
4713 case Builtin::BI__builtin_elementwise_exp2:
4714 case Builtin::BI__builtin_elementwise_exp10:
4715 case Builtin::BI__builtin_elementwise_floor:
4716 case Builtin::BI__builtin_elementwise_fmod:
4717 case Builtin::BI__builtin_elementwise_log:
4718 case Builtin::BI__builtin_elementwise_log2:
4719 case Builtin::BI__builtin_elementwise_log10:
4720 case Builtin::BI__builtin_elementwise_pow:
4721 case Builtin::BI__builtin_elementwise_roundeven:
4722 case Builtin::BI__builtin_elementwise_sin:
4723 case Builtin::BI__builtin_elementwise_sinh:
4724 case Builtin::BI__builtin_elementwise_sqrt:
4725 case Builtin::BI__builtin_elementwise_tan:
4726 case Builtin::BI__builtin_elementwise_tanh:
4727 case Builtin::BI__builtin_elementwise_trunc: {
4728 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4729 Check: CheckFloatOrHalfRepresentation))
4730 return true;
4731 break;
4732 }
4733 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
4734 assert(TheCall->getNumArgs() == 2 && "expected 2 args");
4735 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
4736 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
4737 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
4738 };
4739 if (CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0, Check: checkResTy))
4740 return true;
4741 Expr *OffsetExpr = TheCall->getArg(Arg: 1);
4742 std::optional<llvm::APSInt> Offset =
4743 OffsetExpr->getIntegerConstantExpr(Ctx: SemaRef.getASTContext());
4744 if (!Offset.has_value() || std::abs(i: Offset->getExtValue()) != 1) {
4745 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
4746 DiagID: diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
4747 << 1;
4748 return true;
4749 }
4750 break;
4751 }
4752 case Builtin::BI__builtin_hlsl_elementwise_f16tof32: {
4753 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4754 return true;
4755 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4756 Check: CheckUnsignedIntRepresentation))
4757 return true;
4758 // ensure arg integers are 32 bits
4759 if (CheckExpectedBitWidth(S: &SemaRef, TheCall, ArgOrdinal: 0, Width: 32))
4760 return true;
4761 // check it wasn't a bool type
4762 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
4763 if (auto *VTy = ArgTy->getAs<VectorType>())
4764 ArgTy = VTy->getElementType();
4765 if (ArgTy->isBooleanType()) {
4766 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4767 DiagID: diag::err_builtin_invalid_arg_type)
4768 << 1 << /* scalar or vector of */ 5 << /* unsigned int */ 3
4769 << /* no fp */ 0 << TheCall->getArg(Arg: 0)->getType();
4770 return true;
4771 }
4772
4773 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().FloatTy);
4774 break;
4775 }
4776 case Builtin::BI__builtin_hlsl_elementwise_f32tof16: {
4777 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4778 return true;
4779 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckFloatRepresentation))
4780 return true;
4781 SetElementTypeAsReturnType(S: &SemaRef, TheCall,
4782 ReturnType: getASTContext().UnsignedIntTy);
4783 break;
4784 }
4785 }
4786 return false;
4787}
4788
4789static void BuildFlattenedTypeList(QualType BaseTy,
4790 llvm::SmallVectorImpl<QualType> &List) {
4791 llvm::SmallVector<QualType, 16> WorkList;
4792 WorkList.push_back(Elt: BaseTy);
4793 while (!WorkList.empty()) {
4794 QualType T = WorkList.pop_back_val();
4795 T = T.getCanonicalType().getUnqualifiedType();
4796 if (const auto *AT = dyn_cast<ConstantArrayType>(Val&: T)) {
4797 llvm::SmallVector<QualType, 16> ElementFields;
4798 // Generally I've avoided recursion in this algorithm, but arrays of
4799 // structs could be time-consuming to flatten and churn through on the
4800 // work list. Hopefully nesting arrays of structs containing arrays
4801 // of structs too many levels deep is unlikely.
4802 BuildFlattenedTypeList(BaseTy: AT->getElementType(), List&: ElementFields);
4803 // Repeat the element's field list n times.
4804 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
4805 llvm::append_range(C&: List, R&: ElementFields);
4806 continue;
4807 }
4808 // Vectors can only have element types that are builtin types, so this can
4809 // add directly to the list instead of to the WorkList.
4810 if (const auto *VT = dyn_cast<VectorType>(Val&: T)) {
4811 List.insert(I: List.end(), NumToInsert: VT->getNumElements(), Elt: VT->getElementType());
4812 continue;
4813 }
4814 if (const auto *MT = dyn_cast<ConstantMatrixType>(Val&: T)) {
4815 List.insert(I: List.end(), NumToInsert: MT->getNumElementsFlattened(),
4816 Elt: MT->getElementType());
4817 continue;
4818 }
4819 if (const auto *RD = T->getAsCXXRecordDecl()) {
4820 if (RD->isStandardLayout())
4821 RD = RD->getStandardLayoutBaseWithFields();
4822
4823 // For types that we shouldn't decompose (unions and non-aggregates), just
4824 // add the type itself to the list.
4825 if (RD->isUnion() || !RD->isAggregate()) {
4826 List.push_back(Elt: T);
4827 continue;
4828 }
4829
4830 llvm::SmallVector<QualType, 16> FieldTypes;
4831 for (const auto *FD : RD->fields())
4832 if (!FD->isUnnamedBitField())
4833 FieldTypes.push_back(Elt: FD->getType());
4834 // Reverse the newly added sub-range.
4835 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
4836 llvm::append_range(C&: WorkList, R&: FieldTypes);
4837
4838 // If this wasn't a standard layout type we may also have some base
4839 // classes to deal with.
4840 if (!RD->isStandardLayout()) {
4841 FieldTypes.clear();
4842 for (const auto &Base : RD->bases())
4843 FieldTypes.push_back(Elt: Base.getType());
4844 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
4845 llvm::append_range(C&: WorkList, R&: FieldTypes);
4846 }
4847 continue;
4848 }
4849 List.push_back(Elt: T);
4850 }
4851}
4852
4853bool SemaHLSL::IsConstantBufferElementCompatible(clang::QualType QT) {
4854 if (QT.isNull())
4855 return false;
4856
4857 // Must be a class/struct.
4858 const auto *RD = QT->getAsCXXRecordDecl();
4859 if (!RD || RD->isUnion())
4860 return false;
4861
4862 // Cannot be a resource type or contain one.
4863 return !QT->isHLSLIntangibleType();
4864}
4865
4866bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) {
4867 // null and array types are not allowed.
4868 if (QT.isNull() || QT->isArrayType())
4869 return false;
4870
4871 // UDT types are not allowed
4872 if (QT->isRecordType())
4873 return false;
4874
4875 if (QT->isBooleanType() || QT->isEnumeralType())
4876 return false;
4877
4878 // the only other valid builtin types are scalars or vectors
4879 if (QT->isArithmeticType()) {
4880 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
4881 return false;
4882 return true;
4883 }
4884
4885 if (const VectorType *VT = QT->getAs<VectorType>()) {
4886 int ArraySize = VT->getNumElements();
4887
4888 if (ArraySize > 4)
4889 return false;
4890
4891 QualType ElTy = VT->getElementType();
4892 if (ElTy->isBooleanType())
4893 return false;
4894
4895 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
4896 return false;
4897 return true;
4898 }
4899
4900 return false;
4901}
4902
4903bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const {
4904 if (T1.isNull() || T2.isNull())
4905 return false;
4906
4907 T1 = T1.getCanonicalType().getUnqualifiedType();
4908 T2 = T2.getCanonicalType().getUnqualifiedType();
4909
4910 // If both types are the same canonical type, they're obviously compatible.
4911 if (SemaRef.getASTContext().hasSameType(T1, T2))
4912 return true;
4913
4914 llvm::SmallVector<QualType, 16> T1Types;
4915 BuildFlattenedTypeList(BaseTy: T1, List&: T1Types);
4916 llvm::SmallVector<QualType, 16> T2Types;
4917 BuildFlattenedTypeList(BaseTy: T2, List&: T2Types);
4918
4919 // Check the flattened type list
4920 return llvm::equal(LRange&: T1Types, RRange&: T2Types,
4921 P: [this](QualType LHS, QualType RHS) -> bool {
4922 return SemaRef.IsLayoutCompatible(T1: LHS, T2: RHS);
4923 });
4924}
4925
4926bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
4927 FunctionDecl *Old) {
4928 if (New->getNumParams() != Old->getNumParams())
4929 return true;
4930
4931 bool HadError = false;
4932
4933 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
4934 ParmVarDecl *NewParam = New->getParamDecl(i);
4935 ParmVarDecl *OldParam = Old->getParamDecl(i);
4936
4937 // HLSL parameter declarations for inout and out must match between
4938 // declarations. In HLSL inout and out are ambiguous at the call site,
4939 // but have different calling behavior, so you cannot overload a
4940 // method based on a difference between inout and out annotations.
4941 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
4942 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
4943 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
4944 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
4945
4946 if (NSpellingIdx != OSpellingIdx) {
4947 SemaRef.Diag(Loc: NewParam->getLocation(),
4948 DiagID: diag::err_hlsl_param_qualifier_mismatch)
4949 << NDAttr << NewParam;
4950 SemaRef.Diag(Loc: OldParam->getLocation(), DiagID: diag::note_previous_declaration_as)
4951 << ODAttr;
4952 HadError = true;
4953 }
4954 }
4955 return HadError;
4956}
4957
4958// Generally follows PerformScalarCast, with cases reordered for
4959// clarity of what types are supported
4960bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
4961
4962 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
4963 return false;
4964
4965 if (SemaRef.getASTContext().hasSameUnqualifiedType(T1: SrcTy, T2: DestTy))
4966 return true;
4967
4968 switch (SrcTy->getScalarTypeKind()) {
4969 case Type::STK_Bool: // casting from bool is like casting from an integer
4970 case Type::STK_Integral:
4971 switch (DestTy->getScalarTypeKind()) {
4972 case Type::STK_Bool:
4973 case Type::STK_Integral:
4974 case Type::STK_Floating:
4975 return true;
4976 case Type::STK_CPointer:
4977 case Type::STK_ObjCObjectPointer:
4978 case Type::STK_BlockPointer:
4979 case Type::STK_MemberPointer:
4980 llvm_unreachable("HLSL doesn't support pointers.");
4981 case Type::STK_IntegralComplex:
4982 case Type::STK_FloatingComplex:
4983 llvm_unreachable("HLSL doesn't support complex types.");
4984 case Type::STK_FixedPoint:
4985 llvm_unreachable("HLSL doesn't support fixed point types.");
4986 }
4987 llvm_unreachable("Should have returned before this");
4988
4989 case Type::STK_Floating:
4990 switch (DestTy->getScalarTypeKind()) {
4991 case Type::STK_Floating:
4992 case Type::STK_Bool:
4993 case Type::STK_Integral:
4994 return true;
4995 case Type::STK_FloatingComplex:
4996 case Type::STK_IntegralComplex:
4997 llvm_unreachable("HLSL doesn't support complex types.");
4998 case Type::STK_FixedPoint:
4999 llvm_unreachable("HLSL doesn't support fixed point types.");
5000 case Type::STK_CPointer:
5001 case Type::STK_ObjCObjectPointer:
5002 case Type::STK_BlockPointer:
5003 case Type::STK_MemberPointer:
5004 llvm_unreachable("HLSL doesn't support pointers.");
5005 }
5006 llvm_unreachable("Should have returned before this");
5007
5008 case Type::STK_MemberPointer:
5009 case Type::STK_CPointer:
5010 case Type::STK_BlockPointer:
5011 case Type::STK_ObjCObjectPointer:
5012 llvm_unreachable("HLSL doesn't support pointers.");
5013
5014 case Type::STK_FixedPoint:
5015 llvm_unreachable("HLSL doesn't support fixed point types.");
5016
5017 case Type::STK_FloatingComplex:
5018 case Type::STK_IntegralComplex:
5019 llvm_unreachable("HLSL doesn't support complex types.");
5020 }
5021
5022 llvm_unreachable("Unhandled scalar cast");
5023}
5024
5025// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
5026// Src is a scalar, a vector of length 1, or a 1x1 matrix
5027// Or if Dest is a vector and Src is a vector of length 1 or a 1x1 matrix
5028bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
5029
5030 QualType SrcTy = Src->getType();
5031 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
5032 // going to be a vector splat from a scalar.
5033 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
5034 DestTy->isScalarType())
5035 return false;
5036
5037 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
5038 const ConstantMatrixType *SrcMatTy = SrcTy->getAs<ConstantMatrixType>();
5039
5040 // Src isn't a scalar, a vector of length 1, or a 1x1 matrix
5041 if (!SrcTy->isScalarType() &&
5042 !(SrcVecTy && SrcVecTy->getNumElements() == 1) &&
5043 !(SrcMatTy && SrcMatTy->getNumElementsFlattened() == 1))
5044 return false;
5045
5046 if (SrcVecTy)
5047 SrcTy = SrcVecTy->getElementType();
5048 else if (SrcMatTy)
5049 SrcTy = SrcMatTy->getElementType();
5050
5051 llvm::SmallVector<QualType> DestTypes;
5052 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
5053
5054 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
5055 if (DestTypes[I]->isUnionType())
5056 return false;
5057 if (!CanPerformScalarCast(SrcTy, DestTy: DestTypes[I]))
5058 return false;
5059 }
5060 return true;
5061}
5062
5063// Can we perform an HLSL Elementwise cast?
5064bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
5065
5066 // Don't handle casts where LHS and RHS are any combination of scalar/vector
5067 // There must be an aggregate somewhere
5068 QualType SrcTy = Src->getType();
5069 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
5070 return false;
5071
5072 if (SrcTy->isVectorType() &&
5073 (DestTy->isScalarType() || DestTy->isVectorType()))
5074 return false;
5075
5076 if (SrcTy->isConstantMatrixType() &&
5077 (DestTy->isScalarType() || DestTy->isConstantMatrixType()))
5078 return false;
5079
5080 llvm::SmallVector<QualType> DestTypes;
5081 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
5082 llvm::SmallVector<QualType> SrcTypes;
5083 BuildFlattenedTypeList(BaseTy: SrcTy, List&: SrcTypes);
5084
5085 // Usually the size of SrcTypes must be greater than or equal to the size of
5086 // DestTypes.
5087 if (SrcTypes.size() < DestTypes.size())
5088 return false;
5089
5090 unsigned SrcSize = SrcTypes.size();
5091 unsigned DstSize = DestTypes.size();
5092 unsigned I;
5093 for (I = 0; I < DstSize && I < SrcSize; I++) {
5094 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
5095 return false;
5096 if (!CanPerformScalarCast(SrcTy: SrcTypes[I], DestTy: DestTypes[I])) {
5097 return false;
5098 }
5099 }
5100
5101 // check the rest of the source type for unions.
5102 for (; I < SrcSize; I++) {
5103 if (SrcTypes[I]->isUnionType())
5104 return false;
5105 }
5106 return true;
5107}
5108
5109ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) {
5110 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
5111 "We should not get here without a parameter modifier expression");
5112 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
5113 if (Attr->getABI() == ParameterABI::Ordinary)
5114 return ExprResult(Arg);
5115
5116 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
5117 if (!Arg->isLValue()) {
5118 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_lvalue)
5119 << Arg << (IsInOut ? 1 : 0);
5120 return ExprError();
5121 }
5122
5123 ASTContext &Ctx = SemaRef.getASTContext();
5124
5125 QualType Ty = Param->getType().getNonLValueExprType(Context: Ctx);
5126
5127 // HLSL allows implicit conversions from scalars to vectors, but not the
5128 // inverse, so we need to disallow `inout` with scalar->vector or
5129 // scalar->matrix conversions.
5130 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
5131 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_scalar_extension)
5132 << Arg << (IsInOut ? 1 : 0);
5133 return ExprError();
5134 }
5135
5136 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
5137 VK_LValue, OK_Ordinary, Arg);
5138
5139 // Parameters are initialized via copy initialization. This allows for
5140 // overload resolution of argument constructors.
5141 InitializedEntity Entity =
5142 InitializedEntity::InitializeParameter(Context&: Ctx, Type: Ty, Consumed: false);
5143 ExprResult Res =
5144 SemaRef.PerformCopyInitialization(Entity, EqualLoc: Param->getBeginLoc(), Init: ArgOpV);
5145 if (Res.isInvalid())
5146 return ExprError();
5147 Expr *Base = Res.get();
5148 // After the cast, drop the reference type when creating the exprs.
5149 Ty = Ty.getNonLValueExprType(Context: Ctx);
5150 auto *OpV = new (Ctx)
5151 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
5152
5153 // Writebacks are performed with `=` binary operator, which allows for
5154 // overload resolution on writeback result expressions.
5155 Res = SemaRef.ActOnBinOp(S: SemaRef.getCurScope(), TokLoc: Param->getBeginLoc(),
5156 Kind: tok::equal, LHSExpr: ArgOpV, RHSExpr: OpV);
5157
5158 if (Res.isInvalid())
5159 return ExprError();
5160 Expr *Writeback = Res.get();
5161 auto *OutExpr =
5162 HLSLOutArgExpr::Create(C: Ctx, Ty, Base: ArgOpV, OpV, WB: Writeback, IsInOut);
5163
5164 return ExprResult(OutExpr);
5165}
5166
5167QualType SemaHLSL::getInoutParameterType(QualType Ty) {
5168 // If HLSL gains support for references, all the cites that use this will need
5169 // to be updated with semantic checking to produce errors for
5170 // pointers/references.
5171 assert(!Ty->isReferenceType() &&
5172 "Pointer and reference types cannot be inout or out parameters");
5173 Ty = SemaRef.getASTContext().getLValueReferenceType(T: Ty);
5174 Ty.addRestrict();
5175 return Ty;
5176}
5177
5178// Returns true if the type has a non-empty constant buffer layout (if it is
5179// scalar, vector or matrix, or if it contains any of these.
5180static bool hasConstantBufferLayout(QualType QT) {
5181 const Type *Ty = QT->getUnqualifiedDesugaredType();
5182 if (Ty->isScalarType() || Ty->isVectorType() || Ty->isMatrixType())
5183 return true;
5184
5185 if (Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray())
5186 return false;
5187
5188 if (const auto *RD = Ty->getAsCXXRecordDecl()) {
5189 for (const auto *FD : RD->fields()) {
5190 if (hasConstantBufferLayout(QT: FD->getType()))
5191 return true;
5192 }
5193 assert(RD->getNumBases() <= 1 &&
5194 "HLSL doesn't support multiple inheritance");
5195 return RD->getNumBases()
5196 ? hasConstantBufferLayout(QT: RD->bases_begin()->getType())
5197 : false;
5198 }
5199
5200 if (const auto *AT = dyn_cast<ArrayType>(Val: Ty)) {
5201 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: AT))
5202 if (isZeroSizedArray(CAT))
5203 return false;
5204 return hasConstantBufferLayout(QT: AT->getElementType());
5205 }
5206
5207 return false;
5208}
5209
5210static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD) {
5211 bool IsVulkan =
5212 Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::Vulkan;
5213 bool IsVKPushConstant = IsVulkan && VD->hasAttr<HLSLVkPushConstantAttr>();
5214 QualType QT = VD->getType();
5215 return VD->getDeclContext()->isTranslationUnit() &&
5216 QT.getAddressSpace() == LangAS::Default &&
5217 VD->getStorageClass() != SC_Static &&
5218 !VD->hasAttr<HLSLVkConstantIdAttr>() && !IsVKPushConstant &&
5219 hasConstantBufferLayout(QT);
5220}
5221
5222void SemaHLSL::deduceAddressSpace(VarDecl *Decl) {
5223 // The variable already has an address space (groupshared for ex).
5224 if (Decl->getType().hasAddressSpace())
5225 return;
5226
5227 if (Decl->getType()->isDependentType())
5228 return;
5229
5230 QualType Type = Decl->getType();
5231
5232 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
5233 LangAS ImplAS = LangAS::hlsl_input;
5234 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
5235 Decl->setType(Type);
5236 return;
5237 }
5238
5239 if (Decl->hasAttr<HLSLVkExtBuiltinOutputAttr>()) {
5240 LangAS ImplAS = LangAS::hlsl_output;
5241 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
5242 Decl->setType(Type);
5243
5244 // HLSL uses `static` differently than C++. For BuiltIn output, the static
5245 // does not imply private to the module scope.
5246 // Marking it as external to reflect the semantic this attribute brings.
5247 // See https://github.com/microsoft/hlsl-specs/issues/350
5248 Decl->setStorageClass(SC_Extern);
5249 return;
5250 }
5251
5252 bool IsVulkan = getASTContext().getTargetInfo().getTriple().getOS() ==
5253 llvm::Triple::Vulkan;
5254 if (IsVulkan && Decl->hasAttr<HLSLVkPushConstantAttr>()) {
5255 if (HasDeclaredAPushConstant)
5256 SemaRef.Diag(Loc: Decl->getLocation(), DiagID: diag::err_hlsl_push_constant_unique);
5257
5258 LangAS ImplAS = LangAS::hlsl_push_constant;
5259 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
5260 Decl->setType(Type);
5261 HasDeclaredAPushConstant = true;
5262 return;
5263 }
5264
5265 if (Type->isSamplerT() || Type->isVoidType())
5266 return;
5267
5268 // Resource handles.
5269 if (Type->isHLSLResourceRecord() || Type->isHLSLResourceRecordArray())
5270 return;
5271
5272 // Only static globals belong to the Private address space.
5273 // Non-static globals belongs to the cbuffer.
5274 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
5275 return;
5276
5277 LangAS ImplAS = LangAS::hlsl_private;
5278 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
5279 Decl->setType(Type);
5280}
5281
5282namespace {
5283
5284// Helper class for assigning bindings to resources declared within a struct.
5285// It keeps track of all binding attributes declared on a struct instance, and
5286// the offsets for each register type that have been assigned so far.
5287// Handles both explicit and implicit bindings.
5288class StructBindingContext {
5289 // Bindings and offsets per register type. We only need to support four
5290 // register types - SRV (u), UAV (t), CBuffer (c), and Sampler (s).
5291 HLSLResourceBindingAttr *RegBindingsAttrs[4];
5292 unsigned RegBindingOffset[4];
5293
5294 // Make sure the RegisterType values are what we expect
5295 static_assert(static_cast<unsigned>(RegisterType::SRV) == 0 &&
5296 static_cast<unsigned>(RegisterType::UAV) == 1 &&
5297 static_cast<unsigned>(RegisterType::CBuffer) == 2 &&
5298 static_cast<unsigned>(RegisterType::Sampler) == 3,
5299 "unexpected register type values");
5300
5301 // Vulkan binding attribute does not vary by register type.
5302 HLSLVkBindingAttr *VkBindingAttr;
5303 unsigned VkBindingOffset;
5304
5305public:
5306 // Constructor: gather all binding attributes on a struct instance and
5307 // initialize offsets.
5308 StructBindingContext(VarDecl *VD) {
5309 for (unsigned i = 0; i < 4; ++i) {
5310 RegBindingsAttrs[i] = nullptr;
5311 RegBindingOffset[i] = 0;
5312 }
5313 VkBindingAttr = nullptr;
5314 VkBindingOffset = 0;
5315
5316 ASTContext &AST = VD->getASTContext();
5317 bool IsSpirv = AST.getTargetInfo().getTriple().isSPIRV();
5318
5319 for (Attr *A : VD->attrs()) {
5320 if (auto *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: A)) {
5321 RegisterType RegType = RBA->getRegisterType();
5322 unsigned RegTypeIdx = static_cast<unsigned>(RegType);
5323 // Ignore unsupported register annotations, such as 'c' or 'i'.
5324 if (RegTypeIdx < 4)
5325 RegBindingsAttrs[RegTypeIdx] = RBA;
5326 continue;
5327 }
5328 // Gather the Vulkan binding attributes only if the target is SPIR-V.
5329 if (IsSpirv) {
5330 if (auto *VBA = dyn_cast<HLSLVkBindingAttr>(Val: A))
5331 VkBindingAttr = VBA;
5332 }
5333 }
5334 }
5335
5336 // Creates a binding attribute for a resource based on the gathered attributes
5337 // and the required register type and range.
5338 Attr *createBindingAttr(SemaHLSL &S, ASTContext &AST, RegisterType RegType,
5339 unsigned Range, bool HasCounter) {
5340 assert(static_cast<unsigned>(RegType) < 4 && "unexpected register type");
5341
5342 if (VkBindingAttr) {
5343 unsigned Offset = VkBindingOffset;
5344 VkBindingOffset += Range;
5345 return HLSLVkBindingAttr::CreateImplicit(
5346 Ctx&: AST, Binding: VkBindingAttr->getBinding() + Offset, Set: VkBindingAttr->getSet(),
5347 Range: VkBindingAttr->getRange());
5348 }
5349
5350 HLSLResourceBindingAttr *RBA =
5351 RegBindingsAttrs[static_cast<unsigned>(RegType)];
5352 HLSLResourceBindingAttr *NewAttr = nullptr;
5353
5354 if (RBA && RBA->hasRegisterSlot()) {
5355 // Explicit binding - create a new attribute with offseted slot number
5356 // based on the required register type.
5357 unsigned Offset = RegBindingOffset[static_cast<unsigned>(RegType)];
5358 RegBindingOffset[static_cast<unsigned>(RegType)] += Range;
5359
5360 unsigned NewSlotNumber = RBA->getSlotNumber() + Offset;
5361 StringRef NewSlotNumberStr =
5362 createRegisterString(AST, RegType: RBA->getRegisterType(), N: NewSlotNumber);
5363 NewAttr = HLSLResourceBindingAttr::CreateImplicit(
5364 Ctx&: AST, Slot: NewSlotNumberStr, Space: RBA->getSpace(), Range: RBA->getRange());
5365 NewAttr->setBinding(RT: RegType, SlotNum: NewSlotNumber, SpaceNum: RBA->getSpaceNumber());
5366 } else {
5367 // No binding attribute or space-only binding - create a binding
5368 // attribute for implicit binding.
5369 NewAttr = HLSLResourceBindingAttr::CreateImplicit(Ctx&: AST, Slot: "", Space: "0", Range: {});
5370 NewAttr->setBinding(RT: RegType, SlotNum: std::nullopt,
5371 SpaceNum: RBA ? RBA->getSpaceNumber() : 0);
5372 NewAttr->setImplicitBindingOrderID(S.getNextImplicitBindingOrderID());
5373 }
5374 if (HasCounter)
5375 NewAttr->setImplicitCounterBindingOrderID(
5376 S.getNextImplicitBindingOrderID());
5377 return NewAttr;
5378 }
5379};
5380
5381// Creates a global variable declaration for a resource field embedded in a
5382// struct, assigns it a binding, initializes it, and associates it with the
5383// struct declaration via an HLSLAssociatedResourceDeclAttr.
5384static void createGlobalResourceDeclForStruct(
5385 Sema &S, VarDecl *ParentVD, SourceLocation Loc, IdentifierInfo *Id,
5386 QualType ResTy, StructBindingContext &BindingCtx) {
5387 assert(isResourceRecordTypeOrArrayOf(ResTy) &&
5388 "expected resource type or array of resources");
5389
5390 DeclContext *DC = ParentVD->getNonTransparentDeclContext();
5391 assert(DC->isTranslationUnit() && "expected translation unit decl context");
5392
5393 ASTContext &AST = S.getASTContext();
5394 VarDecl *ResDecl =
5395 VarDecl::Create(C&: AST, DC, StartLoc: Loc, IdLoc: Loc, Id, T: ResTy, TInfo: nullptr, S: SC_None);
5396
5397 unsigned Range = 1;
5398 const Type *SingleResTy = ResTy.getTypePtr()->getUnqualifiedDesugaredType();
5399 while (const auto *AT = dyn_cast<ArrayType>(Val: SingleResTy)) {
5400 const auto *CAT = dyn_cast<ConstantArrayType>(Val: AT);
5401 Range = CAT ? (Range * CAT->getSize().getZExtValue()) : 0;
5402 SingleResTy =
5403 AT->getArrayElementTypeNoTypeQual()->getUnqualifiedDesugaredType();
5404 }
5405 const HLSLAttributedResourceType *ResHandleTy =
5406 HLSLAttributedResourceType::findHandleTypeOnResource(RT: SingleResTy);
5407
5408 // Add a binding attribute to the global resource declaration.
5409 bool HasCounter = hasCounterHandle(RD: SingleResTy->getAsCXXRecordDecl());
5410 Attr *BindingAttr = BindingCtx.createBindingAttr(
5411 S&: S.HLSL(), AST, RegType: getRegisterType(ResTy: ResHandleTy), Range, HasCounter);
5412 ResDecl->addAttr(A: BindingAttr);
5413 ResDecl->addAttr(A: InternalLinkageAttr::CreateImplicit(Ctx&: AST));
5414 ResDecl->setImplicit();
5415
5416 if (Range == 1)
5417 S.HLSL().initGlobalResourceDecl(VD: ResDecl);
5418 else
5419 S.HLSL().initGlobalResourceArrayDecl(VD: ResDecl);
5420
5421 ParentVD->addAttr(
5422 A: HLSLAssociatedResourceDeclAttr::CreateImplicit(Ctx&: AST, ResDecl));
5423 DC->addDecl(D: ResDecl);
5424
5425 DeclGroupRef DG(ResDecl);
5426 S.Consumer.HandleTopLevelDecl(D: DG);
5427}
5428
5429static void handleArrayOfStructWithResources(
5430 Sema &S, VarDecl *ParentVD, const ConstantArrayType *CAT,
5431 EmbeddedResourceNameBuilder &NameBuilder, StructBindingContext &BindingCtx);
5432
5433// Scans base and all fields of a struct/class type to find all embedded
5434// resources or resource arrays. Creates a global variable for each resource
5435// found.
5436static void handleStructWithResources(Sema &S, VarDecl *ParentVD,
5437 const CXXRecordDecl *RD,
5438 EmbeddedResourceNameBuilder &NameBuilder,
5439 StructBindingContext &BindingCtx) {
5440
5441 // Scan the base classes.
5442 assert(RD->getNumBases() <= 1 && "HLSL doesn't support multiple inheritance");
5443 const auto *BasesIt = RD->bases_begin();
5444 if (BasesIt != RD->bases_end()) {
5445 QualType QT = BasesIt->getType();
5446 if (QT->isHLSLIntangibleType()) {
5447 CXXRecordDecl *BaseRD = QT->getAsCXXRecordDecl();
5448 NameBuilder.pushBaseName(N: BaseRD->getName());
5449 handleStructWithResources(S, ParentVD, RD: BaseRD, NameBuilder, BindingCtx);
5450 NameBuilder.pop();
5451 }
5452 }
5453 // Process this class fields.
5454 for (const FieldDecl *FD : RD->fields()) {
5455 QualType FDTy = FD->getType().getCanonicalType();
5456 if (!FDTy->isHLSLIntangibleType())
5457 continue;
5458
5459 NameBuilder.pushName(N: FD->getName());
5460
5461 if (isResourceRecordTypeOrArrayOf(Ty: FDTy)) {
5462 IdentifierInfo *II = NameBuilder.getNameAsIdentifier(AST&: S.getASTContext());
5463 createGlobalResourceDeclForStruct(S, ParentVD, Loc: FD->getLocation(), Id: II,
5464 ResTy: FDTy, BindingCtx);
5465 } else if (const auto *RD = FDTy->getAsCXXRecordDecl()) {
5466 handleStructWithResources(S, ParentVD, RD, NameBuilder, BindingCtx);
5467
5468 } else if (const auto *ArrayTy = dyn_cast<ConstantArrayType>(Val&: FDTy)) {
5469 assert(!FDTy->isHLSLResourceRecordArray() &&
5470 "resource arrays should have been already handled");
5471 handleArrayOfStructWithResources(S, ParentVD, CAT: ArrayTy, NameBuilder,
5472 BindingCtx);
5473 }
5474 NameBuilder.pop();
5475 }
5476}
5477
5478// Processes array of structs with resources.
5479static void
5480handleArrayOfStructWithResources(Sema &S, VarDecl *ParentVD,
5481 const ConstantArrayType *CAT,
5482 EmbeddedResourceNameBuilder &NameBuilder,
5483 StructBindingContext &BindingCtx) {
5484
5485 QualType ElementTy = CAT->getElementType().getCanonicalType();
5486 assert(ElementTy->isHLSLIntangibleType() && "Expected HLSL intangible type");
5487
5488 const ConstantArrayType *SubCAT = dyn_cast<ConstantArrayType>(Val&: ElementTy);
5489 const CXXRecordDecl *ElementRD = ElementTy->getAsCXXRecordDecl();
5490
5491 if (!SubCAT && !ElementRD)
5492 return;
5493
5494 for (unsigned I = 0, E = CAT->getSize().getZExtValue(); I < E; ++I) {
5495 NameBuilder.pushArrayIndex(Index: I);
5496 if (ElementRD)
5497 handleStructWithResources(S, ParentVD, RD: ElementRD, NameBuilder,
5498 BindingCtx);
5499 else
5500 handleArrayOfStructWithResources(S, ParentVD, CAT: SubCAT, NameBuilder,
5501 BindingCtx);
5502 NameBuilder.pop();
5503 }
5504}
5505
5506} // namespace
5507
5508// Scans all fields of a user-defined struct (or array of structs)
5509// to find all embedded resources or resource arrays. For each resource
5510// a global variable of the resource type is created and associated
5511// with the parent declaration (VD) through a HLSLAssociatedResourceDeclAttr
5512// attribute.
5513void SemaHLSL::handleGlobalStructOrArrayOfWithResources(VarDecl *VD) {
5514 EmbeddedResourceNameBuilder NameBuilder(VD->getName());
5515 StructBindingContext BindingCtx(VD);
5516
5517 const Type *VDTy = VD->getType().getTypePtr();
5518 assert(VDTy->isHLSLIntangibleType() && !isResourceRecordTypeOrArrayOf(VD) &&
5519 "Expected non-resource struct or array type");
5520
5521 if (const CXXRecordDecl *RD = VDTy->getAsCXXRecordDecl()) {
5522 handleStructWithResources(S&: SemaRef, ParentVD: VD, RD, NameBuilder, BindingCtx);
5523 return;
5524 }
5525
5526 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: VDTy)) {
5527 handleArrayOfStructWithResources(S&: SemaRef, ParentVD: VD, CAT, NameBuilder, BindingCtx);
5528 return;
5529 }
5530}
5531
5532void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
5533 if (VD->hasGlobalStorage()) {
5534 // make sure the declaration has a complete type
5535 if (SemaRef.RequireCompleteType(
5536 Loc: VD->getLocation(),
5537 T: SemaRef.getASTContext().getBaseElementType(QT: VD->getType()),
5538 DiagID: diag::err_typecheck_decl_incomplete_type)) {
5539 VD->setInvalidDecl();
5540 deduceAddressSpace(Decl: VD);
5541 return;
5542 }
5543
5544 // Global variables outside a cbuffer block that are not a resource, static,
5545 // groupshared, or an empty array or struct belong to the default constant
5546 // buffer $Globals (to be created at the end of the translation unit).
5547 if (IsDefaultBufferConstantDecl(Ctx: getASTContext(), VD)) {
5548 // update address space to hlsl_constant
5549 QualType NewTy = getASTContext().getAddrSpaceQualType(
5550 T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
5551 VD->setType(NewTy);
5552 DefaultCBufferDecls.push_back(Elt: VD);
5553 }
5554
5555 // find all resources bindings on decl
5556 if (VD->getType()->isHLSLIntangibleType())
5557 collectResourceBindingsOnVarDecl(D: VD);
5558
5559 if (VD->hasAttr<HLSLVkConstantIdAttr>())
5560 VD->setStorageClass(StorageClass::SC_Static);
5561
5562 if (isResourceRecordTypeOrArrayOf(VD) &&
5563 VD->getStorageClass() != SC_Static) {
5564 // Add internal linkage attribute to non-static resource variables. The
5565 // global externally visible storage is accessed through the handle, which
5566 // is a member. The variable itself is not externally visible.
5567 VD->addAttr(A: InternalLinkageAttr::CreateImplicit(Ctx&: getASTContext()));
5568 }
5569
5570 // process explicit bindings
5571 processExplicitBindingsOnDecl(D: VD);
5572
5573 // Add implicit binding attribute to non-static resource arrays.
5574 if (VD->getType()->isHLSLResourceRecordArray() &&
5575 VD->getStorageClass() != SC_Static) {
5576 // If the resource array does not have an explicit binding attribute,
5577 // create an implicit one. It will be used to transfer implicit binding
5578 // order_ID to codegen.
5579 ResourceBindingAttrs Binding(VD);
5580 if (!Binding.isExplicit()) {
5581 uint32_t OrderID = getNextImplicitBindingOrderID();
5582 if (Binding.hasBinding())
5583 Binding.setImplicitOrderID(OrderID);
5584 else {
5585 addImplicitBindingAttrToDecl(
5586 S&: SemaRef, D: VD, RT: getRegisterType(ResTy: getResourceArrayHandleType(VD)),
5587 ImplicitBindingOrderID: OrderID);
5588 // Re-create the binding object to pick up the new attribute.
5589 Binding = ResourceBindingAttrs(VD);
5590 }
5591 }
5592
5593 // Get to the base type of a potentially multi-dimensional array.
5594 QualType Ty = getASTContext().getBaseElementType(QT: VD->getType());
5595
5596 const CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
5597 if (hasCounterHandle(RD)) {
5598 if (!Binding.hasCounterImplicitOrderID()) {
5599 uint32_t OrderID = getNextImplicitBindingOrderID();
5600 Binding.setCounterImplicitOrderID(OrderID);
5601 }
5602 }
5603 }
5604
5605 // Process resources in user-defined structs, or arrays of such structs.
5606 const Type *VDTy = VD->getType().getTypePtr();
5607 if (VD->getStorageClass() != SC_Static && VDTy->isHLSLIntangibleType() &&
5608 !isResourceRecordTypeOrArrayOf(VD))
5609 handleGlobalStructOrArrayOfWithResources(VD);
5610
5611 // Mark groupshared variables as extern so they will have
5612 // external storage and won't be default initialized
5613 if (VD->hasAttr<HLSLGroupSharedAddressSpaceAttr>())
5614 VD->setStorageClass(StorageClass::SC_Extern);
5615 }
5616
5617 deduceAddressSpace(Decl: VD);
5618}
5619
5620bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
5621 assert(VD->getType()->isHLSLResourceRecord() &&
5622 "expected resource record type");
5623
5624 ASTContext &AST = SemaRef.getASTContext();
5625 uint64_t UIntTySize = AST.getTypeSize(T: AST.UnsignedIntTy);
5626 uint64_t IntTySize = AST.getTypeSize(T: AST.IntTy);
5627
5628 // Gather resource binding attributes.
5629 ResourceBindingAttrs Binding(VD);
5630
5631 // Find correct initialization method and create its arguments.
5632 QualType ResourceTy = VD->getType();
5633 CXXRecordDecl *ResourceDecl = ResourceTy->getAsCXXRecordDecl();
5634 CXXMethodDecl *CreateMethod = nullptr;
5635 llvm::SmallVector<Expr *> Args;
5636
5637 bool HasCounter = hasCounterHandle(RD: ResourceDecl);
5638 const char *CreateMethodName;
5639 if (Binding.isExplicit())
5640 CreateMethodName = HasCounter ? "__createFromBindingWithImplicitCounter"
5641 : "__createFromBinding";
5642 else
5643 CreateMethodName = HasCounter
5644 ? "__createFromImplicitBindingWithImplicitCounter"
5645 : "__createFromImplicitBinding";
5646
5647 CreateMethod =
5648 lookupMethod(S&: SemaRef, RecordDecl: ResourceDecl, Name: CreateMethodName, Loc: VD->getLocation());
5649
5650 if (!CreateMethod) {
5651 // This can happen if someone creates a struct that looks like an HLSL
5652 // resource record but does not have the required static create method.
5653 // No binding will be generated for it.
5654 assert(!ResourceDecl->isImplicit() &&
5655 "create method lookup should always succeed for built-in resource "
5656 "records");
5657 return false;
5658 }
5659
5660 if (Binding.isExplicit()) {
5661 IntegerLiteral *RegSlot =
5662 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, Binding.getSlot()),
5663 type: AST.UnsignedIntTy, l: SourceLocation());
5664 Args.push_back(Elt: RegSlot);
5665 } else {
5666 uint32_t OrderID = (Binding.hasImplicitOrderID())
5667 ? Binding.getImplicitOrderID()
5668 : getNextImplicitBindingOrderID();
5669 IntegerLiteral *OrderId =
5670 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, OrderID),
5671 type: AST.UnsignedIntTy, l: SourceLocation());
5672 Args.push_back(Elt: OrderId);
5673 }
5674
5675 IntegerLiteral *Space =
5676 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, Binding.getSpace()),
5677 type: AST.UnsignedIntTy, l: SourceLocation());
5678 Args.push_back(Elt: Space);
5679
5680 IntegerLiteral *RangeSize = IntegerLiteral::Create(
5681 C: AST, V: llvm::APInt(IntTySize, 1), type: AST.IntTy, l: SourceLocation());
5682 Args.push_back(Elt: RangeSize);
5683
5684 IntegerLiteral *Index = IntegerLiteral::Create(
5685 C: AST, V: llvm::APInt(UIntTySize, 0), type: AST.UnsignedIntTy, l: SourceLocation());
5686 Args.push_back(Elt: Index);
5687
5688 StringRef VarName = VD->getName();
5689 StringLiteral *Name = StringLiteral::Create(
5690 Ctx: AST, Str: VarName, Kind: StringLiteralKind::Ordinary, Pascal: false,
5691 Ty: AST.getStringLiteralArrayType(EltTy: AST.CharTy.withConst(), Length: VarName.size()),
5692 Locs: SourceLocation());
5693 ImplicitCastExpr *NameCast = ImplicitCastExpr::Create(
5694 Context: AST, T: AST.getPointerType(T: AST.CharTy.withConst()), Kind: CK_ArrayToPointerDecay,
5695 Operand: Name, BasePath: nullptr, Cat: VK_PRValue, FPO: FPOptionsOverride());
5696 Args.push_back(Elt: NameCast);
5697
5698 if (HasCounter) {
5699 // Will this be in the correct order?
5700 uint32_t CounterOrderID = getNextImplicitBindingOrderID();
5701 IntegerLiteral *CounterId =
5702 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, CounterOrderID),
5703 type: AST.UnsignedIntTy, l: SourceLocation());
5704 Args.push_back(Elt: CounterId);
5705 }
5706
5707 // Make sure the create method template is instantiated and emitted.
5708 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
5709 SemaRef.InstantiateFunctionDefinition(PointOfInstantiation: VD->getLocation(), Function: CreateMethod,
5710 Recursive: true);
5711
5712 // Create CallExpr with a call to the static method and set it as the decl
5713 // initialization.
5714 DeclRefExpr *DRE = DeclRefExpr::Create(
5715 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: CreateMethod, RefersToEnclosingVariableOrCapture: false,
5716 NameInfo: CreateMethod->getNameInfo(), T: CreateMethod->getType(), VK: VK_PRValue);
5717
5718 auto *ImpCast = ImplicitCastExpr::Create(
5719 Context: AST, T: AST.getPointerType(T: CreateMethod->getType()),
5720 Kind: CK_FunctionToPointerDecay, Operand: DRE, BasePath: nullptr, Cat: VK_PRValue, FPO: FPOptionsOverride());
5721
5722 CallExpr *InitExpr =
5723 CallExpr::Create(Ctx: AST, Fn: ImpCast, Args, Ty: ResourceTy, VK: VK_PRValue,
5724 RParenLoc: SourceLocation(), FPFeatures: FPOptionsOverride());
5725 VD->setInit(InitExpr);
5726 VD->setInitStyle(VarDecl::CallInit);
5727 SemaRef.CheckCompleteVariableDeclaration(VD);
5728 return true;
5729}
5730
5731bool SemaHLSL::initGlobalResourceArrayDecl(VarDecl *VD) {
5732 assert(VD->getType()->isHLSLResourceRecordArray() &&
5733 "expected array of resource records");
5734
5735 // Individual resources in a resource array are not initialized here. They
5736 // are initialized later on during codegen when the individual resources are
5737 // accessed. Codegen will emit a call to the resource initialization method
5738 // with the specified array index. We need to make sure though that the method
5739 // for the specific resource type is instantiated, so codegen can emit a call
5740 // to it when the array element is accessed.
5741
5742 // Find correct initialization method based on the resource binding
5743 // information.
5744 ASTContext &AST = SemaRef.getASTContext();
5745 QualType ResElementTy = AST.getBaseElementType(QT: VD->getType());
5746 CXXRecordDecl *ResourceDecl = ResElementTy->getAsCXXRecordDecl();
5747 CXXMethodDecl *CreateMethod = nullptr;
5748
5749 bool HasCounter = hasCounterHandle(RD: ResourceDecl);
5750 ResourceBindingAttrs ResourceAttrs(VD);
5751 if (ResourceAttrs.isExplicit())
5752 // Resource has explicit binding.
5753 CreateMethod =
5754 lookupMethod(S&: SemaRef, RecordDecl: ResourceDecl,
5755 Name: HasCounter ? "__createFromBindingWithImplicitCounter"
5756 : "__createFromBinding",
5757 Loc: VD->getLocation());
5758 else
5759 // Resource has implicit binding.
5760 CreateMethod = lookupMethod(
5761 S&: SemaRef, RecordDecl: ResourceDecl,
5762 Name: HasCounter ? "__createFromImplicitBindingWithImplicitCounter"
5763 : "__createFromImplicitBinding",
5764 Loc: VD->getLocation());
5765
5766 if (!CreateMethod)
5767 return false;
5768
5769 // Make sure the create method template is instantiated and emitted.
5770 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
5771 SemaRef.InstantiateFunctionDefinition(PointOfInstantiation: VD->getLocation(), Function: CreateMethod,
5772 Recursive: true);
5773 return true;
5774}
5775
5776// Returns true if the initialization has been handled.
5777// Returns false to use default initialization.
5778bool SemaHLSL::ActOnUninitializedVarDecl(VarDecl *VD) {
5779 // Objects in the hlsl_constant address space are initialized
5780 // externally, so don't synthesize an implicit initializer.
5781 if (VD->getType().getAddressSpace() == LangAS::hlsl_constant)
5782 return true;
5783
5784 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5785 const Type *Ty = VD->getType().getTypePtr();
5786 if (Ty->isHLSLResourceRecord() && initGlobalResourceDecl(VD))
5787 return true;
5788 if (Ty->isHLSLResourceRecordArray() && initGlobalResourceArrayDecl(VD))
5789 return true;
5790 }
5791
5792 // User-defined structs/classes do not have constructors.
5793 // When declared at a global scope, they are part of the constant buffer
5794 // and should not be initialized by the compiler.
5795 // When declared at a local scope, they are not initialized.
5796 // Also applies to arrays of user-defined structs/classes.
5797 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5798 while (Ty->isArrayType())
5799 Ty = Ty->getArrayElementTypeNoTypeQual()->getUnqualifiedDesugaredType();
5800 if (CXXRecordDecl *RD = Ty->getAsCXXRecordDecl())
5801 return !RD->isHLSLBuiltinRecord();
5802
5803 return false;
5804}
5805
5806std::optional<const DeclBindingInfo *> SemaHLSL::inferGlobalBinding(Expr *E) {
5807 if (auto *Ternary = dyn_cast<ConditionalOperator>(Val: E)) {
5808 auto TrueInfo = inferGlobalBinding(E: Ternary->getTrueExpr());
5809 auto FalseInfo = inferGlobalBinding(E: Ternary->getFalseExpr());
5810 if (!TrueInfo || !FalseInfo)
5811 return std::nullopt;
5812 if (*TrueInfo != *FalseInfo)
5813 return std::nullopt;
5814 return TrueInfo;
5815 }
5816
5817 if (auto *ASE = dyn_cast<ArraySubscriptExpr>(Val: E))
5818 E = ASE->getBase()->IgnoreParenImpCasts();
5819
5820 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Val: E->IgnoreParens()))
5821 if (VarDecl *VD = dyn_cast<VarDecl>(Val: DRE->getDecl())) {
5822 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5823 if (Ty->isArrayType())
5824 Ty = Ty->getArrayElementTypeNoTypeQual();
5825
5826 if (const auto *AttrResType =
5827 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
5828 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
5829 return Bindings.getDeclBindingInfo(VD, ResClass: RC);
5830 }
5831 }
5832
5833 return nullptr;
5834}
5835
5836void SemaHLSL::trackLocalResource(VarDecl *VD, Expr *E) {
5837 std::optional<const DeclBindingInfo *> ExprBinding = inferGlobalBinding(E);
5838 if (!ExprBinding) {
5839 SemaRef.Diag(Loc: E->getBeginLoc(),
5840 DiagID: diag::warn_hlsl_assigning_local_resource_is_not_unique)
5841 << E << VD;
5842 return; // Expr use multiple resources
5843 }
5844
5845 if (*ExprBinding == nullptr)
5846 return; // No binding could be inferred to track, return without error
5847
5848 auto PrevBinding = Assigns.find(Val: VD);
5849 if (PrevBinding == Assigns.end()) {
5850 // No previous binding recorded, simply record the new assignment
5851 Assigns.insert(KV: {VD, *ExprBinding});
5852 return;
5853 }
5854
5855 // Otherwise, warn if the assignment implies different resource bindings
5856 if (*ExprBinding != PrevBinding->second) {
5857 SemaRef.Diag(Loc: E->getBeginLoc(),
5858 DiagID: diag::warn_hlsl_assigning_local_resource_is_not_unique)
5859 << E << VD;
5860 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::note_var_declared_here) << VD;
5861 return;
5862 }
5863
5864 return;
5865}
5866
5867bool SemaHLSL::CheckResourceBinOp(BinaryOperatorKind Opc, Expr *LHSExpr,
5868 Expr *RHSExpr, SourceLocation Loc) {
5869 assert((LHSExpr->getType()->isHLSLResourceRecord() ||
5870 LHSExpr->getType()->isHLSLResourceRecordArray()) &&
5871 "expected LHS to be a resource record or array of resource records");
5872 if (Opc != BO_Assign)
5873 return true;
5874
5875 // If LHS is an array subscript, get the underlying declaration.
5876 Expr *E = LHSExpr;
5877 while (auto *ASE = dyn_cast<ArraySubscriptExpr>(Val: E))
5878 E = ASE->getBase()->IgnoreParenImpCasts();
5879
5880 // Report error if LHS is a non-static resource declared at a global scope.
5881 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Val: E->IgnoreParens())) {
5882 if (VarDecl *VD = dyn_cast<VarDecl>(Val: DRE->getDecl())) {
5883 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5884 // assignment to global resource is not allowed
5885 SemaRef.Diag(Loc, DiagID: diag::err_hlsl_assign_to_global_resource) << VD;
5886 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::note_var_declared_here) << VD;
5887 return false;
5888 }
5889
5890 trackLocalResource(VD, E: RHSExpr);
5891 }
5892 }
5893 return true;
5894}
5895
5896// Returns true if the given type can have an overload of the given
5897// binary operator.
5898bool SemaHLSL::canHaveOverloadedBinOp(QualType LHSTy, BinaryOperatorKind Opc) {
5899 CXXRecordDecl *RD = LHSTy->getAsCXXRecordDecl();
5900 if (!RD)
5901 return true;
5902 return RD->isHLSLBuiltinRecord() || Opc != BO_Assign;
5903}
5904
5905// Walks though the global variable declaration, collects all resource binding
5906// requirements and adds them to Bindings
5907void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
5908 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
5909 "expected global variable that contains HLSL resource");
5910
5911 // Cbuffers and Tbuffers are HLSLBufferDecl types
5912 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: VD)) {
5913 Bindings.addDeclBindingInfo(VD, ResClass: CBufferOrTBuffer->isCBuffer()
5914 ? ResourceClass::CBuffer
5915 : ResourceClass::SRV);
5916 return;
5917 }
5918
5919 // Unwrap arrays
5920 // FIXME: Calculate array size while unwrapping
5921 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5922 while (Ty->isArrayType()) {
5923 const ArrayType *AT = cast<ArrayType>(Val: Ty);
5924 Ty = AT->getElementType()->getUnqualifiedDesugaredType();
5925 }
5926
5927 // Resource (or array of resources)
5928 if (const HLSLAttributedResourceType *AttrResType =
5929 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
5930 Bindings.addDeclBindingInfo(VD, ResClass: AttrResType->getAttrs().ResourceClass);
5931 return;
5932 }
5933
5934 // User defined record type
5935 if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty))
5936 collectResourceBindingsOnUserRecordDecl(VD, RT);
5937}
5938
5939// Walks though the explicit resource binding attributes on the declaration,
5940// and makes sure there is a resource that matched the binding and updates
5941// DeclBindingInfoLists
5942void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
5943 assert(VD->hasGlobalStorage() && "expected global variable");
5944
5945 bool HasBinding = false;
5946 for (Attr *A : VD->attrs()) {
5947 if (isa<HLSLVkBindingAttr>(Val: A)) {
5948 HasBinding = true;
5949 if (auto PA = VD->getAttr<HLSLVkPushConstantAttr>())
5950 Diag(Loc: PA->getLoc(), DiagID: diag::err_hlsl_attr_incompatible) << A << PA;
5951 }
5952
5953 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: A);
5954 if (!RBA || !RBA->hasRegisterSlot())
5955 continue;
5956 HasBinding = true;
5957
5958 RegisterType RT = RBA->getRegisterType();
5959 assert(RT != RegisterType::I && "invalid or obsolete register type should "
5960 "never have an attribute created");
5961
5962 if (RT == RegisterType::C) {
5963 if (Bindings.hasBindingInfoForDecl(VD))
5964 SemaRef.Diag(Loc: VD->getLocation(),
5965 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
5966 << static_cast<int>(RT);
5967 continue;
5968 }
5969
5970 // Find DeclBindingInfo for this binding and update it, or report error
5971 // if it does not exist (user type does to contain resources with the
5972 // expected resource class).
5973 ResourceClass RC = getResourceClass(RT);
5974 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, ResClass: RC)) {
5975 // update binding info
5976 BI->setBindingAttribute(A: RBA, BT: BindingType::Explicit);
5977 } else {
5978 SemaRef.Diag(Loc: VD->getLocation(),
5979 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
5980 << static_cast<int>(RT);
5981 }
5982 }
5983
5984 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
5985 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
5986}
5987namespace {
5988class InitListTransformer {
5989 Sema &S;
5990 ASTContext &Ctx;
5991 QualType InitTy;
5992 QualType *DstIt = nullptr;
5993 Expr **ArgIt = nullptr;
5994 // Is wrapping the destination type iterator required? This is only used for
5995 // incomplete array types where we loop over the destination type since we
5996 // don't know the full number of elements from the declaration.
5997 bool Wrap;
5998
5999 bool castInitializer(Expr *E) {
6000 assert(DstIt && "This should always be something!");
6001 if (DstIt == DestTypes.end()) {
6002 if (!Wrap) {
6003 ArgExprs.push_back(Elt: E);
6004 // This is odd, but it isn't technically a failure due to conversion, we
6005 // handle mismatched counts of arguments differently.
6006 return true;
6007 }
6008 DstIt = DestTypes.begin();
6009 }
6010 InitializedEntity Entity = InitializedEntity::InitializeParameter(
6011 Context&: Ctx, Type: *DstIt, /* Consumed (ObjC) */ Consumed: false);
6012 ExprResult Res = S.PerformCopyInitialization(Entity, EqualLoc: E->getBeginLoc(), Init: E);
6013 if (Res.isInvalid())
6014 return false;
6015 Expr *Init = Res.get();
6016 ArgExprs.push_back(Elt: Init);
6017 DstIt++;
6018 return true;
6019 }
6020
6021 bool buildInitializerListImpl(Expr *E) {
6022 // If this is an initialization list, traverse the sub initializers.
6023 if (auto *Init = dyn_cast<InitListExpr>(Val: E)) {
6024 for (auto *SubInit : Init->inits())
6025 if (!buildInitializerListImpl(E: SubInit))
6026 return false;
6027 return true;
6028 }
6029
6030 // If this is a scalar type, just enqueue the expression.
6031 QualType Ty = E->getType().getDesugaredType(Context: Ctx);
6032
6033 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()) ||
6034 Ty->isHLSLAttributedResourceType())
6035 return castInitializer(E);
6036
6037 // If this is an aggregate type and a prvalue, create an xvalue temporary
6038 // so the member accesses will be xvalues. Wrap it in OpaqueExpr to make
6039 // sure codegen will not generate duplicate copies.
6040 if (E->isPRValue() && Ty->isAggregateType()) {
6041 ExprResult TmpExpr = S.TemporaryMaterializationConversion(E);
6042 if (TmpExpr.isInvalid())
6043 return false;
6044 E = TmpExpr.get();
6045 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), E->getType(),
6046 E->getValueKind(), E->getObjectKind(), E);
6047 }
6048
6049 if (auto *VecTy = Ty->getAs<VectorType>()) {
6050 uint64_t Size = VecTy->getNumElements();
6051
6052 QualType SizeTy = Ctx.getSizeType();
6053 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
6054 for (uint64_t I = 0; I < Size; ++I) {
6055 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
6056 type: SizeTy, l: SourceLocation());
6057
6058 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
6059 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
6060 if (ElExpr.isInvalid())
6061 return false;
6062 if (!castInitializer(E: ElExpr.get()))
6063 return false;
6064 }
6065 return true;
6066 }
6067 if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
6068 unsigned Rows = MTy->getNumRows();
6069 unsigned Cols = MTy->getNumColumns();
6070 QualType ElemTy = MTy->getElementType();
6071
6072 for (unsigned R = 0; R < Rows; ++R) {
6073 for (unsigned C = 0; C < Cols; ++C) {
6074 // row index literal
6075 Expr *RowIdx = IntegerLiteral::Create(
6076 C: Ctx, V: llvm::APInt(Ctx.getIntWidth(T: Ctx.IntTy), R), type: Ctx.IntTy,
6077 l: E->getBeginLoc());
6078 // column index literal
6079 Expr *ColIdx = IntegerLiteral::Create(
6080 C: Ctx, V: llvm::APInt(Ctx.getIntWidth(T: Ctx.IntTy), C), type: Ctx.IntTy,
6081 l: E->getBeginLoc());
6082 ExprResult ElExpr = S.CreateBuiltinMatrixSubscriptExpr(
6083 Base: E, RowIdx, ColumnIdx: ColIdx, RBLoc: E->getEndLoc());
6084 if (ElExpr.isInvalid())
6085 return false;
6086 if (!castInitializer(E: ElExpr.get()))
6087 return false;
6088 ElExpr.get()->setType(ElemTy);
6089 }
6090 }
6091 return true;
6092 }
6093
6094 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Val: Ty.getTypePtr())) {
6095 uint64_t Size = ArrTy->getZExtSize();
6096 QualType SizeTy = Ctx.getSizeType();
6097 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
6098 for (uint64_t I = 0; I < Size; ++I) {
6099 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
6100 type: SizeTy, l: SourceLocation());
6101 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
6102 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
6103 if (ElExpr.isInvalid())
6104 return false;
6105 if (!buildInitializerListImpl(E: ElExpr.get()))
6106 return false;
6107 }
6108 return true;
6109 }
6110
6111 if (auto *RD = Ty->getAsCXXRecordDecl()) {
6112 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
6113 RecordDecls.push_back(Elt: RD);
6114 while (RecordDecls.back()->getNumBases()) {
6115 CXXRecordDecl *D = RecordDecls.back();
6116 assert(D->getNumBases() == 1 &&
6117 "HLSL doesn't support multiple inheritance");
6118 RecordDecls.push_back(
6119 Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
6120 }
6121 while (!RecordDecls.empty()) {
6122 CXXRecordDecl *RD = RecordDecls.pop_back_val();
6123 for (auto *FD : RD->fields()) {
6124 if (FD->isUnnamedBitField())
6125 continue;
6126 DeclAccessPair Found = DeclAccessPair::make(D: FD, AS: FD->getAccess());
6127 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
6128 ExprResult Res = S.BuildFieldReferenceExpr(
6129 BaseExpr: E, IsArrow: false, OpLoc: E->getBeginLoc(), SS: CXXScopeSpec(), Field: FD, FoundDecl: Found, MemberNameInfo: NameInfo);
6130 if (Res.isInvalid())
6131 return false;
6132 if (!buildInitializerListImpl(E: Res.get()))
6133 return false;
6134 }
6135 }
6136 }
6137 return true;
6138 }
6139
6140 Expr *generateInitListsImpl(QualType Ty) {
6141 Ty = Ty.getDesugaredType(Context: Ctx);
6142 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
6143 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()) ||
6144 Ty->isHLSLAttributedResourceType())
6145 return *(ArgIt++);
6146
6147 llvm::SmallVector<Expr *> Inits;
6148 if (Ty->isVectorType() || Ty->isConstantArrayType() ||
6149 Ty->isConstantMatrixType()) {
6150 QualType ElTy;
6151 uint64_t Size = 0;
6152 if (auto *ATy = Ty->getAs<VectorType>()) {
6153 ElTy = ATy->getElementType();
6154 Size = ATy->getNumElements();
6155 } else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
6156 ElTy = CMTy->getElementType();
6157 Size = CMTy->getNumElementsFlattened();
6158 } else {
6159 auto *VTy = cast<ConstantArrayType>(Val: Ty.getTypePtr());
6160 ElTy = VTy->getElementType();
6161 Size = VTy->getZExtSize();
6162 }
6163 for (uint64_t I = 0; I < Size; ++I)
6164 Inits.push_back(Elt: generateInitListsImpl(Ty: ElTy));
6165 }
6166 if (auto *RD = Ty->getAsCXXRecordDecl()) {
6167 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
6168 RecordDecls.push_back(Elt: RD);
6169 while (RecordDecls.back()->getNumBases()) {
6170 CXXRecordDecl *D = RecordDecls.back();
6171 assert(D->getNumBases() == 1 &&
6172 "HLSL doesn't support multiple inheritance");
6173 RecordDecls.push_back(
6174 Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
6175 }
6176 while (!RecordDecls.empty()) {
6177 CXXRecordDecl *RD = RecordDecls.pop_back_val();
6178 for (auto *FD : RD->fields())
6179 if (!FD->isUnnamedBitField())
6180 Inits.push_back(Elt: generateInitListsImpl(Ty: FD->getType()));
6181 }
6182 }
6183 auto *NewInit =
6184 new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(), Inits,
6185 Inits.back()->getEndLoc(), /*isExplicit=*/false);
6186 NewInit->setType(Ty);
6187 return NewInit;
6188 }
6189
6190public:
6191 llvm::SmallVector<QualType, 16> DestTypes;
6192 llvm::SmallVector<Expr *, 16> ArgExprs;
6193 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
6194 : S(SemaRef), Ctx(SemaRef.getASTContext()),
6195 Wrap(Entity.getType()->isIncompleteArrayType()) {
6196 InitTy = Entity.getType().getNonReferenceType();
6197 // When we're generating initializer lists for incomplete array types we
6198 // need to wrap around both when building the initializers and when
6199 // generating the final initializer lists.
6200 if (Wrap) {
6201 assert(InitTy->isIncompleteArrayType());
6202 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(T: InitTy);
6203 InitTy = IAT->getElementType();
6204 }
6205 BuildFlattenedTypeList(BaseTy: InitTy, List&: DestTypes);
6206 DstIt = DestTypes.begin();
6207 }
6208
6209 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
6210
6211 Expr *generateInitLists() {
6212 assert(!ArgExprs.empty() &&
6213 "Call buildInitializerList to generate argument expressions.");
6214 ArgIt = ArgExprs.begin();
6215 if (!Wrap)
6216 return generateInitListsImpl(Ty: InitTy);
6217 llvm::SmallVector<Expr *> Inits;
6218 while (ArgIt != ArgExprs.end())
6219 Inits.push_back(Elt: generateInitListsImpl(Ty: InitTy));
6220
6221 auto *NewInit =
6222 new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(), Inits,
6223 Inits.back()->getEndLoc(), /*isExplicit=*/false);
6224 llvm::APInt ArySize(64, Inits.size());
6225 NewInit->setType(Ctx.getConstantArrayType(EltTy: InitTy, ArySize, SizeExpr: nullptr,
6226 ASM: ArraySizeModifier::Normal, IndexTypeQuals: 0));
6227 return NewInit;
6228 }
6229};
6230} // namespace
6231
6232// Recursively detect any incomplete array anywhere in the type graph,
6233// including arrays, struct fields, and base classes.
6234static bool containsIncompleteArrayType(QualType Ty) {
6235 Ty = Ty.getCanonicalType();
6236
6237 // Array types
6238 if (const ArrayType *AT = dyn_cast<ArrayType>(Val&: Ty)) {
6239 if (isa<IncompleteArrayType>(Val: AT))
6240 return true;
6241 return containsIncompleteArrayType(Ty: AT->getElementType());
6242 }
6243
6244 // Record (struct/class) types
6245 if (const auto *RT = Ty->getAs<RecordType>()) {
6246 const RecordDecl *RD = RT->getDecl();
6247
6248 // Walk base classes (for C++ / HLSL structs with inheritance)
6249 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
6250 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
6251 if (containsIncompleteArrayType(Ty: Base.getType()))
6252 return true;
6253 }
6254 }
6255
6256 // Walk fields
6257 for (const FieldDecl *F : RD->fields()) {
6258 if (containsIncompleteArrayType(Ty: F->getType()))
6259 return true;
6260 }
6261 }
6262
6263 return false;
6264}
6265
6266bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
6267 InitListExpr *Init) {
6268 // If the initializer is a scalar, just return it.
6269 if (Init->getType()->isScalarType())
6270 return true;
6271 ASTContext &Ctx = SemaRef.getASTContext();
6272 InitListTransformer ILT(SemaRef, Entity);
6273
6274 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
6275 Expr *E = Init->getInit(Init: I);
6276 if (E->HasSideEffects(Ctx)) {
6277 QualType Ty = E->getType();
6278 if (Ty->isRecordType())
6279 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
6280 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
6281 E->getObjectKind(), E);
6282 Init->setInit(Init: I, expr: E);
6283 }
6284 if (!ILT.buildInitializerList(E))
6285 return false;
6286 }
6287 size_t ExpectedSize = ILT.DestTypes.size();
6288 size_t ActualSize = ILT.ArgExprs.size();
6289 if (ExpectedSize == 0 && ActualSize == 0)
6290 return true;
6291
6292 // Reject empty initializer if *any* incomplete array exists structurally
6293 if (ActualSize == 0 && containsIncompleteArrayType(Ty: Entity.getType())) {
6294 QualType InitTy = Entity.getType().getNonReferenceType();
6295 if (InitTy.hasAddressSpace())
6296 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
6297
6298 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
6299 << /*TooManyOrFew=*/(int)(ExpectedSize < ActualSize) << InitTy
6300 << /*ExpectedSize=*/ExpectedSize << /*ActualSize=*/ActualSize;
6301 return false;
6302 }
6303
6304 // We infer size after validating legality.
6305 // For incomplete arrays it is completely arbitrary to choose whether we think
6306 // the user intended fewer or more elements. This implementation assumes that
6307 // the user intended more, and errors that there are too few initializers to
6308 // complete the final element.
6309 if (Entity.getType()->isIncompleteArrayType()) {
6310 assert(ExpectedSize > 0 &&
6311 "The expected size of an incomplete array type must be at least 1.");
6312 ExpectedSize =
6313 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
6314 }
6315
6316 // An initializer list might be attempting to initialize a reference or
6317 // rvalue-reference. When checking the initializer we should look through
6318 // the reference.
6319 QualType InitTy = Entity.getType().getNonReferenceType();
6320 if (InitTy.hasAddressSpace())
6321 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
6322 if (ExpectedSize != ActualSize) {
6323 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
6324 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
6325 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
6326 return false;
6327 }
6328
6329 // generateInitListsImpl will always return an InitListExpr here, because the
6330 // scalar case is handled above.
6331 auto *NewInit = cast<InitListExpr>(Val: ILT.generateInitLists());
6332 Init->resizeInits(Context: Ctx, NumInits: NewInit->getNumInits());
6333 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
6334 Init->updateInit(C: Ctx, Init: I, expr: NewInit->getInit(Init: I));
6335 return true;
6336}
6337
6338static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name,
6339 StringRef Expected,
6340 SourceLocation OpLoc,
6341 SourceLocation CompLoc) {
6342 S.Diag(Loc: OpLoc, DiagID: diag::err_builtin_matrix_invalid_member)
6343 << Name << Expected << SourceRange(CompLoc);
6344 return QualType();
6345}
6346
6347QualType SemaHLSL::checkMatrixComponent(Sema &S, QualType baseType,
6348 ExprValueKind &VK, SourceLocation OpLoc,
6349 const IdentifierInfo *CompName,
6350 SourceLocation CompLoc) {
6351 const auto *MT = baseType->castAs<ConstantMatrixType>();
6352 StringRef AccessorName = CompName->getName();
6353 assert(!AccessorName.empty() && "Matrix Accessor must have a name");
6354
6355 unsigned Rows = MT->getNumRows();
6356 unsigned Cols = MT->getNumColumns();
6357 bool IsZeroBasedAccessor = false;
6358 unsigned ChunkLen = 0;
6359 if (AccessorName.size() < 2)
6360 return ReportMatrixInvalidMember(S, Name: AccessorName,
6361 Expected: "length 4 for zero based: \'_mRC\' or "
6362 "length 3 for one-based: \'_RC\' accessor",
6363 OpLoc, CompLoc);
6364
6365 if (AccessorName[0] == '_') {
6366 if (AccessorName[1] == 'm') {
6367 IsZeroBasedAccessor = true;
6368 ChunkLen = 4; // zero-based: "_mRC"
6369 } else {
6370 ChunkLen = 3; // one-based: "_RC"
6371 }
6372 } else
6373 return ReportMatrixInvalidMember(
6374 S, Name: AccessorName, Expected: "zero based: \'_mRC\' or one-based: \'_RC\' accessor",
6375 OpLoc, CompLoc);
6376
6377 if (AccessorName.size() % ChunkLen != 0) {
6378 const llvm::StringRef Expected = IsZeroBasedAccessor
6379 ? "zero based: '_mRC' accessor"
6380 : "one-based: '_RC' accessor";
6381
6382 return ReportMatrixInvalidMember(S, Name: AccessorName, Expected, OpLoc, CompLoc);
6383 }
6384
6385 auto isDigit = [](char c) { return c >= '0' && c <= '9'; };
6386 auto isZeroBasedIndex = [](unsigned i) { return i <= 3; };
6387 auto isOneBasedIndex = [](unsigned i) { return i >= 1 && i <= 4; };
6388
6389 bool HasRepeated = false;
6390 SmallVector<bool, 16> Seen(Rows * Cols, false);
6391 unsigned NumComponents = 0;
6392 const char *Begin = AccessorName.data();
6393
6394 for (unsigned I = 0, E = AccessorName.size(); I < E; I += ChunkLen) {
6395 const char *Chunk = Begin + I;
6396 char RowChar = 0, ColChar = 0;
6397 if (IsZeroBasedAccessor) {
6398 // Zero-based: "_mRC"
6399 if (Chunk[0] != '_' || Chunk[1] != 'm') {
6400 char Bad = (Chunk[0] != '_') ? Chunk[0] : Chunk[1];
6401 return ReportMatrixInvalidMember(
6402 S, Name: StringRef(&Bad, 1), Expected: "\'_m\' prefix",
6403 OpLoc: OpLoc.getLocWithOffset(Offset: I + (Bad == Chunk[0] ? 1 : 2)), CompLoc);
6404 }
6405 RowChar = Chunk[2];
6406 ColChar = Chunk[3];
6407 } else {
6408 // One-based: "_RC"
6409 if (Chunk[0] != '_')
6410 return ReportMatrixInvalidMember(
6411 S, Name: StringRef(&Chunk[0], 1), Expected: "\'_\' prefix",
6412 OpLoc: OpLoc.getLocWithOffset(Offset: I + 1), CompLoc);
6413 RowChar = Chunk[1];
6414 ColChar = Chunk[2];
6415 }
6416
6417 // Must be digits.
6418 bool IsDigitsError = false;
6419 if (!isDigit(RowChar)) {
6420 unsigned BadPos = IsZeroBasedAccessor ? 2 : 1;
6421 ReportMatrixInvalidMember(S, Name: StringRef(&RowChar, 1), Expected: "row as integer",
6422 OpLoc: OpLoc.getLocWithOffset(Offset: I + BadPos + 1),
6423 CompLoc);
6424 IsDigitsError = true;
6425 }
6426
6427 if (!isDigit(ColChar)) {
6428 unsigned BadPos = IsZeroBasedAccessor ? 3 : 2;
6429 ReportMatrixInvalidMember(S, Name: StringRef(&ColChar, 1), Expected: "column as integer",
6430 OpLoc: OpLoc.getLocWithOffset(Offset: I + BadPos + 1),
6431 CompLoc);
6432 IsDigitsError = true;
6433 }
6434 if (IsDigitsError)
6435 return QualType();
6436
6437 unsigned Row = RowChar - '0';
6438 unsigned Col = ColChar - '0';
6439
6440 bool HasIndexingError = false;
6441 if (IsZeroBasedAccessor) {
6442 // 0-based [0..3]
6443 if (!isZeroBasedIndex(Row)) {
6444 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
6445 << /*row*/ 0 << /*zero-based*/ 0 << SourceRange(CompLoc);
6446 HasIndexingError = true;
6447 }
6448 if (!isZeroBasedIndex(Col)) {
6449 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
6450 << /*col*/ 1 << /*zero-based*/ 0 << SourceRange(CompLoc);
6451 HasIndexingError = true;
6452 }
6453 } else {
6454 // 1-based [1..4]
6455 if (!isOneBasedIndex(Row)) {
6456 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
6457 << /*row*/ 0 << /*one-based*/ 1 << SourceRange(CompLoc);
6458 HasIndexingError = true;
6459 }
6460 if (!isOneBasedIndex(Col)) {
6461 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
6462 << /*col*/ 1 << /*one-based*/ 1 << SourceRange(CompLoc);
6463 HasIndexingError = true;
6464 }
6465 // Convert to 0-based after range checking.
6466 --Row;
6467 --Col;
6468 }
6469
6470 if (HasIndexingError)
6471 return QualType();
6472
6473 // Note: matrix swizzle index is hard coded. That means Row and Col can
6474 // potentially be larger than Rows and Cols if matrix size is less than
6475 // the max index size.
6476 bool HasBoundsError = false;
6477 if (Row >= Rows) {
6478 Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_index_out_of_bounds)
6479 << /*Row*/ 0 << Row << Rows << SourceRange(CompLoc);
6480 HasBoundsError = true;
6481 }
6482 if (Col >= Cols) {
6483 Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_index_out_of_bounds)
6484 << /*Col*/ 1 << Col << Cols << SourceRange(CompLoc);
6485 HasBoundsError = true;
6486 }
6487 if (HasBoundsError)
6488 return QualType();
6489
6490 unsigned FlatIndex = Row * Cols + Col;
6491 if (Seen[FlatIndex])
6492 HasRepeated = true;
6493 Seen[FlatIndex] = true;
6494 ++NumComponents;
6495 }
6496 if (NumComponents == 0 || NumComponents > 4) {
6497 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_swizzle_invalid_length)
6498 << NumComponents << SourceRange(CompLoc);
6499 return QualType();
6500 }
6501
6502 QualType ElemTy = MT->getElementType();
6503 if (NumComponents == 1)
6504 return ElemTy;
6505 QualType VT = S.Context.getExtVectorType(VectorType: ElemTy, NumElts: NumComponents);
6506 if (HasRepeated)
6507 VK = VK_PRValue;
6508
6509 for (Sema::ExtVectorDeclsType::iterator
6510 I = S.ExtVectorDecls.begin(source: S.getExternalSource()),
6511 E = S.ExtVectorDecls.end();
6512 I != E; ++I) {
6513 if ((*I)->getUnderlyingType() == VT)
6514 return S.Context.getTypedefType(Keyword: ElaboratedTypeKeyword::None,
6515 /*Qualifier=*/std::nullopt, Decl: *I);
6516 }
6517
6518 return VT;
6519}
6520
6521bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
6522 // If initializing a local resource, track the resource binding it is using
6523 if (VDecl->getType()->isHLSLResourceRecord() && !VDecl->hasGlobalStorage())
6524 trackLocalResource(VD: VDecl, E: Init);
6525
6526 const HLSLVkConstantIdAttr *ConstIdAttr =
6527 VDecl->getAttr<HLSLVkConstantIdAttr>();
6528 if (!ConstIdAttr)
6529 return true;
6530
6531 ASTContext &Context = SemaRef.getASTContext();
6532
6533 APValue InitValue;
6534 if (!Init->isCXX11ConstantExpr(Ctx: Context, Result: &InitValue)) {
6535 Diag(Loc: VDecl->getLocation(), DiagID: diag::err_specialization_const);
6536 VDecl->setInvalidDecl();
6537 return false;
6538 }
6539
6540 Builtin::ID BID =
6541 getSpecConstBuiltinId(Type: VDecl->getType()->getUnqualifiedDesugaredType());
6542
6543 // Argument 1: The ID from the attribute
6544 int ConstantID = ConstIdAttr->getId();
6545 llvm::APInt IDVal(Context.getIntWidth(T: Context.IntTy), ConstantID);
6546 Expr *IdExpr = IntegerLiteral::Create(C: Context, V: IDVal, type: Context.IntTy,
6547 l: ConstIdAttr->getLocation());
6548
6549 SmallVector<Expr *, 2> Args = {IdExpr, Init};
6550 Expr *C = SemaRef.BuildBuiltinCallExpr(Loc: Init->getExprLoc(), Id: BID, CallArgs: Args);
6551 if (C->getType()->getCanonicalTypeUnqualified() !=
6552 VDecl->getType()->getCanonicalTypeUnqualified()) {
6553 C = SemaRef
6554 .BuildCStyleCastExpr(LParenLoc: SourceLocation(),
6555 Ty: Context.getTrivialTypeSourceInfo(
6556 T: Init->getType(), Loc: Init->getExprLoc()),
6557 RParenLoc: SourceLocation(), Op: C)
6558 .get();
6559 }
6560 Init = C;
6561 return true;
6562}
6563
6564QualType SemaHLSL::ActOnTemplateShorthand(TemplateDecl *Template,
6565 SourceLocation NameLoc) {
6566 if (!Template)
6567 return QualType();
6568
6569 DeclContext *DC = Template->getDeclContext();
6570 if (!DC->isNamespace() || !cast<NamespaceDecl>(Val: DC)->getIdentifier() ||
6571 cast<NamespaceDecl>(Val: DC)->getName() != "hlsl")
6572 return QualType();
6573
6574 TemplateParameterList *Params = Template->getTemplateParameters();
6575 if (!Params || Params->size() != 1)
6576 return QualType();
6577
6578 if (!Template->isImplicit())
6579 return QualType();
6580
6581 // We manually extract default arguments here instead of letting
6582 // CheckTemplateIdType handle it. This ensures that for resource types that
6583 // lack a default argument (like Buffer), we return a null QualType, which
6584 // triggers the "requires template arguments" error rather than a less
6585 // descriptive "too few template arguments" error.
6586 TemplateArgumentListInfo TemplateArgs(NameLoc, NameLoc);
6587 for (NamedDecl *P : *Params) {
6588 if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(Val: P)) {
6589 if (TTP->hasDefaultArgument()) {
6590 TemplateArgs.addArgument(Loc: TTP->getDefaultArgument());
6591 continue;
6592 }
6593 } else if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(Val: P)) {
6594 if (NTTP->hasDefaultArgument()) {
6595 TemplateArgs.addArgument(Loc: NTTP->getDefaultArgument());
6596 continue;
6597 }
6598 } else if (auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(Val: P)) {
6599 if (TTPD->hasDefaultArgument()) {
6600 TemplateArgs.addArgument(Loc: TTPD->getDefaultArgument());
6601 continue;
6602 }
6603 }
6604 return QualType();
6605 }
6606
6607 return SemaRef.CheckTemplateIdType(
6608 Keyword: ElaboratedTypeKeyword::None, Template: TemplateName(Template), TemplateLoc: NameLoc,
6609 TemplateArgs, Scope: nullptr, /*ForNestedNameSpecifier=*/false);
6610}
6611