1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
12#include "clang/AST/ASTConsumer.h"
13#include "clang/AST/ASTContext.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Decl.h"
16#include "clang/AST/DeclBase.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/DeclarationName.h"
19#include "clang/AST/DynamicRecursiveASTVisitor.h"
20#include "clang/AST/Expr.h"
21#include "clang/AST/HLSLResource.h"
22#include "clang/AST/Type.h"
23#include "clang/AST/TypeBase.h"
24#include "clang/AST/TypeLoc.h"
25#include "clang/Basic/Builtins.h"
26#include "clang/Basic/DiagnosticSema.h"
27#include "clang/Basic/IdentifierTable.h"
28#include "clang/Basic/LLVM.h"
29#include "clang/Basic/SourceLocation.h"
30#include "clang/Basic/Specifiers.h"
31#include "clang/Basic/TargetInfo.h"
32#include "clang/Sema/Initialization.h"
33#include "clang/Sema/Lookup.h"
34#include "clang/Sema/ParsedAttr.h"
35#include "clang/Sema/Sema.h"
36#include "clang/Sema/Template.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/StringExtras.h"
41#include "llvm/ADT/StringRef.h"
42#include "llvm/ADT/Twine.h"
43#include "llvm/Frontend/HLSL/HLSLBinding.h"
44#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
45#include "llvm/Support/Casting.h"
46#include "llvm/Support/DXILABI.h"
47#include "llvm/Support/ErrorHandling.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/TargetParser/Triple.h"
50#include <cmath>
51#include <cstddef>
52#include <iterator>
53#include <utility>
54
55using namespace clang;
56using namespace clang::hlsl;
57using RegisterType = HLSLResourceBindingAttr::RegisterType;
58
59static CXXRecordDecl *createHostLayoutStruct(Sema &S,
60 CXXRecordDecl *StructDecl);
61
62static RegisterType getRegisterType(ResourceClass RC) {
63 switch (RC) {
64 case ResourceClass::SRV:
65 return RegisterType::SRV;
66 case ResourceClass::UAV:
67 return RegisterType::UAV;
68 case ResourceClass::CBuffer:
69 return RegisterType::CBuffer;
70 case ResourceClass::Sampler:
71 return RegisterType::Sampler;
72 }
73 llvm_unreachable("unexpected ResourceClass value");
74}
75
76static RegisterType getRegisterType(const HLSLAttributedResourceType *ResTy) {
77 return getRegisterType(RC: ResTy->getAttrs().ResourceClass);
78}
79
80// Converts the first letter of string Slot to RegisterType.
81// Returns false if the letter does not correspond to a valid register type.
82static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
83 assert(RT != nullptr);
84 switch (Slot[0]) {
85 case 't':
86 case 'T':
87 *RT = RegisterType::SRV;
88 return true;
89 case 'u':
90 case 'U':
91 *RT = RegisterType::UAV;
92 return true;
93 case 'b':
94 case 'B':
95 *RT = RegisterType::CBuffer;
96 return true;
97 case 's':
98 case 'S':
99 *RT = RegisterType::Sampler;
100 return true;
101 case 'c':
102 case 'C':
103 *RT = RegisterType::C;
104 return true;
105 case 'i':
106 case 'I':
107 *RT = RegisterType::I;
108 return true;
109 default:
110 return false;
111 }
112}
113
114static char getRegisterTypeChar(RegisterType RT) {
115 switch (RT) {
116 case RegisterType::SRV:
117 return 't';
118 case RegisterType::UAV:
119 return 'u';
120 case RegisterType::CBuffer:
121 return 'b';
122 case RegisterType::Sampler:
123 return 's';
124 case RegisterType::C:
125 return 'c';
126 case RegisterType::I:
127 return 'i';
128 }
129 llvm_unreachable("unexpected RegisterType value");
130}
131
132static ResourceClass getResourceClass(RegisterType RT) {
133 switch (RT) {
134 case RegisterType::SRV:
135 return ResourceClass::SRV;
136 case RegisterType::UAV:
137 return ResourceClass::UAV;
138 case RegisterType::CBuffer:
139 return ResourceClass::CBuffer;
140 case RegisterType::Sampler:
141 return ResourceClass::Sampler;
142 case RegisterType::C:
143 case RegisterType::I:
144 // Deliberately falling through to the unreachable below.
145 break;
146 }
147 llvm_unreachable("unexpected RegisterType value");
148}
149
150static Builtin::ID getSpecConstBuiltinId(const Type *Type) {
151 const auto *BT = dyn_cast<BuiltinType>(Val: Type);
152 if (!BT) {
153 if (!Type->isEnumeralType())
154 return Builtin::NotBuiltin;
155 return Builtin::BI__builtin_get_spirv_spec_constant_int;
156 }
157
158 switch (BT->getKind()) {
159 case BuiltinType::Bool:
160 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
161 case BuiltinType::Short:
162 return Builtin::BI__builtin_get_spirv_spec_constant_short;
163 case BuiltinType::Int:
164 return Builtin::BI__builtin_get_spirv_spec_constant_int;
165 case BuiltinType::LongLong:
166 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
167 case BuiltinType::UShort:
168 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
169 case BuiltinType::UInt:
170 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
171 case BuiltinType::ULongLong:
172 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
173 case BuiltinType::Half:
174 return Builtin::BI__builtin_get_spirv_spec_constant_half;
175 case BuiltinType::Float:
176 return Builtin::BI__builtin_get_spirv_spec_constant_float;
177 case BuiltinType::Double:
178 return Builtin::BI__builtin_get_spirv_spec_constant_double;
179 default:
180 return Builtin::NotBuiltin;
181 }
182}
183
184static StringRef createRegisterString(ASTContext &AST, RegisterType RegType,
185 unsigned N) {
186 llvm::SmallString<16> Buffer;
187 llvm::raw_svector_ostream OS(Buffer);
188 OS << getRegisterTypeChar(RT: RegType);
189 OS << N;
190 return AST.backupStr(S: OS.str());
191}
192
193DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
194 ResourceClass ResClass) {
195 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
196 "DeclBindingInfo already added");
197 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
198 // VarDecl may have multiple entries for different resource classes.
199 // DeclToBindingListIndex stores the index of the first binding we saw
200 // for this decl. If there are any additional ones then that index
201 // shouldn't be updated.
202 DeclToBindingListIndex.try_emplace(Key: VD, Args: BindingsList.size());
203 return &BindingsList.emplace_back(Args&: VD, Args&: ResClass);
204}
205
206DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
207 ResourceClass ResClass) {
208 auto Entry = DeclToBindingListIndex.find(Val: VD);
209 if (Entry != DeclToBindingListIndex.end()) {
210 for (unsigned Index = Entry->getSecond();
211 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
212 ++Index) {
213 if (BindingsList[Index].ResClass == ResClass)
214 return &BindingsList[Index];
215 }
216 }
217 return nullptr;
218}
219
220bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
221 return DeclToBindingListIndex.contains(Val: VD);
222}
223
224SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
225
226Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
227 SourceLocation KwLoc, IdentifierInfo *Ident,
228 SourceLocation IdentLoc,
229 SourceLocation LBrace) {
230 // For anonymous namespace, take the location of the left brace.
231 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
232 HLSLBufferDecl *Result = HLSLBufferDecl::Create(
233 C&: getASTContext(), LexicalParent, CBuffer, KwLoc, ID: Ident, IDLoc: IdentLoc, LBrace);
234
235 // if CBuffer is false, then it's a TBuffer
236 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
237 : llvm::hlsl::ResourceClass::SRV;
238 Result->addAttr(A: HLSLResourceClassAttr::CreateImplicit(Ctx&: getASTContext(), ResourceClass: RC));
239
240 SemaRef.PushOnScopeChains(D: Result, S: BufferScope);
241 SemaRef.PushDeclContext(S: BufferScope, DC: Result);
242
243 return Result;
244}
245
246static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
247 QualType T) {
248 // Arrays, Matrices, and Structs are always aligned to new buffer rows
249 if (T->isArrayType() || T->isStructureType() || T->isConstantMatrixType())
250 return 16;
251
252 // Vectors are aligned to the type they contain
253 if (const VectorType *VT = T->getAs<VectorType>())
254 return calculateLegacyCbufferFieldAlign(Context, T: VT->getElementType());
255
256 assert(Context.getTypeSize(T) <= 64 &&
257 "Scalar bit widths larger than 64 not supported");
258
259 // Scalar types are aligned to their byte width
260 return Context.getTypeSize(T) / 8;
261}
262
263// Calculate the size of a legacy cbuffer type in bytes based on
264// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
265static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
266 QualType T) {
267 constexpr unsigned CBufferAlign = 16;
268 if (const auto *RD = T->getAsRecordDecl()) {
269 unsigned Size = 0;
270 for (const FieldDecl *Field : RD->fields()) {
271 QualType Ty = Field->getType();
272 unsigned FieldSize = calculateLegacyCbufferSize(Context, T: Ty);
273 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, T: Ty);
274
275 // If the field crosses the row boundary after alignment it drops to the
276 // next row
277 unsigned AlignSize = llvm::alignTo(Value: Size, Align: FieldAlign);
278 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
279 FieldAlign = CBufferAlign;
280 }
281
282 Size = llvm::alignTo(Value: Size, Align: FieldAlign);
283 Size += FieldSize;
284 }
285 return Size;
286 }
287
288 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
289 unsigned ElementCount = AT->getSize().getZExtValue();
290 if (ElementCount == 0)
291 return 0;
292
293 unsigned ElementSize =
294 calculateLegacyCbufferSize(Context, T: AT->getElementType());
295 unsigned AlignedElementSize = llvm::alignTo(Value: ElementSize, Align: CBufferAlign);
296 return AlignedElementSize * (ElementCount - 1) + ElementSize;
297 }
298
299 if (const VectorType *VT = T->getAs<VectorType>()) {
300 unsigned ElementCount = VT->getNumElements();
301 unsigned ElementSize =
302 calculateLegacyCbufferSize(Context, T: VT->getElementType());
303 return ElementSize * ElementCount;
304 }
305
306 return Context.getTypeSize(T) / 8;
307}
308
309// Validate packoffset:
310// - if packoffset it used it must be set on all declarations inside the buffer
311// - packoffset ranges must not overlap
312static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
313 llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
314
315 // Make sure the packoffset annotations are either on all declarations
316 // or on none.
317 bool HasPackOffset = false;
318 bool HasNonPackOffset = false;
319 for (auto *Field : BufDecl->buffer_decls()) {
320 VarDecl *Var = dyn_cast<VarDecl>(Val: Field);
321 if (!Var)
322 continue;
323 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
324 PackOffsetVec.emplace_back(Args&: Var, Args: Field->getAttr<HLSLPackOffsetAttr>());
325 HasPackOffset = true;
326 } else {
327 HasNonPackOffset = true;
328 }
329 }
330
331 if (!HasPackOffset)
332 return;
333
334 if (HasNonPackOffset)
335 S.Diag(Loc: BufDecl->getLocation(), DiagID: diag::warn_hlsl_packoffset_mix);
336
337 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
338 // and compare adjacent values.
339 bool IsValid = true;
340 ASTContext &Context = S.getASTContext();
341 std::sort(first: PackOffsetVec.begin(), last: PackOffsetVec.end(),
342 comp: [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
343 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
344 return LHS.second->getOffsetInBytes() <
345 RHS.second->getOffsetInBytes();
346 });
347 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
348 VarDecl *Var = PackOffsetVec[i].first;
349 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
350 unsigned Size = calculateLegacyCbufferSize(Context, T: Var->getType());
351 unsigned Begin = Attr->getOffsetInBytes();
352 unsigned End = Begin + Size;
353 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
354 if (End > NextBegin) {
355 VarDecl *NextVar = PackOffsetVec[i + 1].first;
356 S.Diag(Loc: NextVar->getLocation(), DiagID: diag::err_hlsl_packoffset_overlap)
357 << NextVar << Var;
358 IsValid = false;
359 }
360 }
361 BufDecl->setHasValidPackoffset(IsValid);
362}
363
364// Returns true if the array has a zero size = if any of the dimensions is 0
365static bool isZeroSizedArray(const ConstantArrayType *CAT) {
366 while (CAT && !CAT->isZeroSize())
367 CAT = dyn_cast<ConstantArrayType>(
368 Val: CAT->getElementType()->getUnqualifiedDesugaredType());
369 return CAT != nullptr;
370}
371
372static bool isResourceRecordTypeOrArrayOf(QualType Ty) {
373 return Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray();
374}
375
376static bool isResourceRecordTypeOrArrayOf(VarDecl *VD) {
377 return isResourceRecordTypeOrArrayOf(Ty: VD->getType());
378}
379
380static const HLSLAttributedResourceType *
381getResourceArrayHandleType(QualType QT) {
382 assert(QT->isHLSLResourceRecordArray() &&
383 "expected array of resource records");
384 const Type *Ty = QT->getUnqualifiedDesugaredType();
385 while (const ArrayType *AT = dyn_cast<ArrayType>(Val: Ty))
386 Ty = AT->getArrayElementTypeNoTypeQual()->getUnqualifiedDesugaredType();
387 return HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty);
388}
389
390static const HLSLAttributedResourceType *
391getResourceArrayHandleType(VarDecl *VD) {
392 return getResourceArrayHandleType(QT: VD->getType());
393}
394
395// Returns true if the type is a leaf element type that is not valid to be
396// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
397// array, or a builtin intangible type. Returns false it is a valid leaf element
398// type or if it is a record type that needs to be inspected further.
399static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
400 Ty = Ty->getUnqualifiedDesugaredType();
401 if (Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray())
402 return true;
403 if (const auto *RD = Ty->getAsCXXRecordDecl())
404 return RD->isEmpty();
405 if (Ty->isConstantArrayType() &&
406 isZeroSizedArray(CAT: cast<ConstantArrayType>(Val: Ty)))
407 return true;
408 if (Ty->isHLSLBuiltinIntangibleType() || Ty->isHLSLAttributedResourceType())
409 return true;
410 return false;
411}
412
413// Returns true if the struct contains at least one element that prevents it
414// from being included inside HLSL Buffer as is, such as an intangible type,
415// empty struct, or zero-sized array. If it does, a new implicit layout struct
416// needs to be created for HLSL Buffer use that will exclude these unwanted
417// declarations (see createHostLayoutStruct function).
418static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
419 if (RD->isHLSLIntangible() || RD->isEmpty())
420 return true;
421 // check fields
422 for (const FieldDecl *Field : RD->fields()) {
423 QualType Ty = Field->getType();
424 if (isInvalidConstantBufferLeafElementType(Ty: Ty.getTypePtr()))
425 return true;
426 if (const auto *RD = Ty->getAsCXXRecordDecl();
427 RD && requiresImplicitBufferLayoutStructure(RD))
428 return true;
429 }
430 // check bases
431 for (const CXXBaseSpecifier &Base : RD->bases())
432 if (requiresImplicitBufferLayoutStructure(
433 RD: Base.getType()->castAsCXXRecordDecl()))
434 return true;
435 return false;
436}
437
438static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II,
439 DeclContext *DC) {
440 CXXRecordDecl *RD = nullptr;
441 for (NamedDecl *Decl :
442 DC->getNonTransparentContext()->lookup(Name: DeclarationName(II))) {
443 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Val: Decl)) {
444 assert(RD == nullptr &&
445 "there should be at most 1 record by a given name in a scope");
446 RD = FoundRD;
447 }
448 }
449 return RD;
450}
451
452// Creates a name for buffer layout struct using the provide name base.
453// If the name must be unique (not previously defined), a suffix is added
454// until a unique name is found.
455static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl,
456 bool MustBeUnique) {
457 ASTContext &AST = S.getASTContext();
458
459 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
460 llvm::SmallString<64> Name("__cblayout_");
461 if (NameBaseII) {
462 Name.append(RHS: NameBaseII->getName());
463 } else {
464 // anonymous struct
465 Name.append(RHS: "anon");
466 MustBeUnique = true;
467 }
468
469 size_t NameLength = Name.size();
470 IdentifierInfo *II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
471 if (!MustBeUnique)
472 return II;
473
474 unsigned suffix = 0;
475 while (true) {
476 if (suffix != 0) {
477 Name.append(RHS: "_");
478 Name.append(RHS: llvm::Twine(suffix).str());
479 II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
480 }
481 if (!findRecordDeclInContext(II, DC: BaseDecl->getDeclContext()))
482 return II;
483 // declaration with that name already exists - increment suffix and try
484 // again until unique name is found
485 suffix++;
486 Name.truncate(N: NameLength);
487 };
488}
489
490static const Type *createHostLayoutType(Sema &S, const Type *Ty) {
491 ASTContext &AST = S.getASTContext();
492 if (auto *RD = Ty->getAsCXXRecordDecl()) {
493 if (!requiresImplicitBufferLayoutStructure(RD))
494 return Ty;
495 RD = createHostLayoutStruct(S, StructDecl: RD);
496 if (!RD)
497 return nullptr;
498 return AST.getCanonicalTagType(TD: RD)->getTypePtr();
499 }
500
501 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: Ty)) {
502 const Type *ElementTy = createHostLayoutType(
503 S, Ty: CAT->getElementType()->getUnqualifiedDesugaredType());
504 if (!ElementTy)
505 return nullptr;
506 return AST
507 .getConstantArrayType(EltTy: QualType(ElementTy, 0), ArySize: CAT->getSize(), SizeExpr: nullptr,
508 ASM: CAT->getSizeModifier(),
509 IndexTypeQuals: CAT->getIndexTypeCVRQualifiers())
510 .getTypePtr();
511 }
512 return Ty;
513}
514
515// Creates a field declaration of given name and type for HLSL buffer layout
516// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
517static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
518 IdentifierInfo *II,
519 CXXRecordDecl *LayoutStruct) {
520 if (isInvalidConstantBufferLeafElementType(Ty))
521 return nullptr;
522
523 Ty = createHostLayoutType(S, Ty);
524 if (!Ty)
525 return nullptr;
526
527 QualType QT = QualType(Ty, 0);
528 ASTContext &AST = S.getASTContext();
529 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(T: QT, Loc: SourceLocation());
530 auto *Field = FieldDecl::Create(C: AST, DC: LayoutStruct, StartLoc: SourceLocation(),
531 IdLoc: SourceLocation(), Id: II, T: QT, TInfo: TSI, BW: nullptr, Mutable: false,
532 InitStyle: InClassInitStyle::ICIS_NoInit);
533 Field->setAccess(AccessSpecifier::AS_public);
534 return Field;
535}
536
537// Creates host layout struct for a struct included in HLSL Buffer.
538// The layout struct will include only fields that are allowed in HLSL buffer.
539// These fields will be filtered out:
540// - resource classes
541// - empty structs
542// - zero-sized arrays
543// Returns nullptr if the resulting layout struct would be empty.
544static CXXRecordDecl *createHostLayoutStruct(Sema &S,
545 CXXRecordDecl *StructDecl) {
546 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
547 "struct is already HLSL buffer compatible");
548
549 ASTContext &AST = S.getASTContext();
550 DeclContext *DC = StructDecl->getDeclContext();
551 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: StructDecl, MustBeUnique: false);
552
553 // reuse existing if the layout struct if it already exists
554 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
555 return RD;
556
557 CXXRecordDecl *LS =
558 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC, StartLoc: SourceLocation(),
559 IdLoc: SourceLocation(), Id: II);
560 LS->setImplicit(true);
561 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
562 LS->startDefinition();
563
564 // copy base struct, create HLSL Buffer compatible version if needed
565 if (unsigned NumBases = StructDecl->getNumBases()) {
566 assert(NumBases == 1 && "HLSL supports only one base type");
567 (void)NumBases;
568 CXXBaseSpecifier Base = *StructDecl->bases_begin();
569 CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
570 if (requiresImplicitBufferLayoutStructure(RD: BaseDecl)) {
571 BaseDecl = createHostLayoutStruct(S, StructDecl: BaseDecl);
572 if (BaseDecl) {
573 TypeSourceInfo *TSI =
574 AST.getTrivialTypeSourceInfo(T: AST.getCanonicalTagType(TD: BaseDecl));
575 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
576 AS_none, TSI, SourceLocation());
577 }
578 }
579 if (BaseDecl) {
580 const CXXBaseSpecifier *BasesArray[1] = {&Base};
581 LS->setBases(Bases: BasesArray, NumBases: 1);
582 }
583 }
584
585 // filter struct fields
586 for (const FieldDecl *FD : StructDecl->fields()) {
587 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
588 if (FieldDecl *NewFD =
589 createFieldForHostLayoutStruct(S, Ty, II: FD->getIdentifier(), LayoutStruct: LS))
590 LS->addDecl(D: NewFD);
591 }
592 LS->completeDefinition();
593
594 if (LS->field_empty() && LS->getNumBases() == 0)
595 return nullptr;
596
597 DC->addDecl(D: LS);
598 return LS;
599}
600
601// Creates host layout struct for HLSL Buffer. The struct will include only
602// fields of types that are allowed in HLSL buffer and it will filter out:
603// - static or groupshared variable declarations
604// - resource classes
605// - empty structs
606// - zero-sized arrays
607// - non-variable declarations
608// The layout struct will be added to the HLSLBufferDecl declarations.
609static void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
610 ASTContext &AST = S.getASTContext();
611 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: BufDecl, MustBeUnique: true);
612
613 CXXRecordDecl *LS =
614 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC: BufDecl,
615 StartLoc: SourceLocation(), IdLoc: SourceLocation(), Id: II);
616 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
617 LS->setImplicit(true);
618 LS->startDefinition();
619
620 for (Decl *D : BufDecl->buffer_decls()) {
621 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
622 if (!VD || VD->getStorageClass() == SC_Static ||
623 VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
624 continue;
625 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
626
627 FieldDecl *FD =
628 createFieldForHostLayoutStruct(S, Ty, II: VD->getIdentifier(), LayoutStruct: LS);
629 // Declarations collected for the default $Globals constant buffer have
630 // already been checked to have non-empty cbuffer layout, so
631 // createFieldForHostLayoutStruct should always succeed. These declarations
632 // already have their address space set to hlsl_constant.
633 // For declarations in a named cbuffer block
634 // createFieldForHostLayoutStruct can still return nullptr if the type
635 // is empty (does not have a cbuffer layout).
636 assert((FD || VD->getType().getAddressSpace() != LangAS::hlsl_constant) &&
637 "host layout field for $Globals decl failed to be created");
638 if (FD) {
639 // Add the field decl to the layout struct.
640 LS->addDecl(D: FD);
641 if (VD->getType().getAddressSpace() != LangAS::hlsl_constant) {
642 // Update address space of the original decl to hlsl_constant.
643 QualType NewTy =
644 AST.getAddrSpaceQualType(T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
645 VD->setType(NewTy);
646 }
647 }
648 }
649 LS->completeDefinition();
650 BufDecl->addLayoutStruct(LS);
651}
652
653static void addImplicitBindingAttrToDecl(Sema &S, Decl *D, RegisterType RT,
654 uint32_t ImplicitBindingOrderID) {
655 auto *Attr =
656 HLSLResourceBindingAttr::CreateImplicit(Ctx&: S.getASTContext(), Slot: "", Space: "0", Range: {});
657 Attr->setBinding(RT, SlotNum: std::nullopt, SpaceNum: 0);
658 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
659 D->addAttr(A: Attr);
660}
661
662// Handle end of cbuffer/tbuffer declaration
663void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
664 auto *BufDecl = cast<HLSLBufferDecl>(Val: Dcl);
665 BufDecl->setRBraceLoc(RBrace);
666
667 validatePackoffset(S&: SemaRef, BufDecl);
668
669 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl);
670
671 // Handle implicit binding if needed.
672 ResourceBindingAttrs ResourceAttrs(Dcl);
673 if (!ResourceAttrs.isExplicit()) {
674 SemaRef.Diag(Loc: Dcl->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
675 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
676 // to codegen. If it does not exist, create an implicit attribute.
677 uint32_t OrderID = getNextImplicitBindingOrderID();
678 if (ResourceAttrs.hasBinding())
679 ResourceAttrs.setImplicitOrderID(OrderID);
680 else
681 addImplicitBindingAttrToDecl(S&: SemaRef, D: BufDecl,
682 RT: BufDecl->isCBuffer() ? RegisterType::CBuffer
683 : RegisterType::SRV,
684 ImplicitBindingOrderID: OrderID);
685 }
686
687 SemaRef.PopDeclContext();
688}
689
690HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
691 const AttributeCommonInfo &AL,
692 int X, int Y, int Z) {
693 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
694 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
695 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
696 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
697 }
698 return nullptr;
699 }
700 return ::new (getASTContext())
701 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
702}
703
704HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
705 const AttributeCommonInfo &AL,
706 int Min, int Max, int Preferred,
707 int SpelledArgsCount) {
708 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
709 if (WS->getMin() != Min || WS->getMax() != Max ||
710 WS->getPreferred() != Preferred ||
711 WS->getSpelledArgsCount() != SpelledArgsCount) {
712 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
713 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
714 }
715 return nullptr;
716 }
717 HLSLWaveSizeAttr *Result = ::new (getASTContext())
718 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
719 Result->setSpelledArgsCount(SpelledArgsCount);
720 return Result;
721}
722
723HLSLVkConstantIdAttr *
724SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
725 int Id) {
726
727 auto &TargetInfo = getASTContext().getTargetInfo();
728 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
729 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_ignored) << AL;
730 return nullptr;
731 }
732
733 auto *VD = cast<VarDecl>(Val: D);
734
735 if (getSpecConstBuiltinId(Type: VD->getType()->getUnqualifiedDesugaredType()) ==
736 Builtin::NotBuiltin) {
737 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
738 return nullptr;
739 }
740
741 if (!VD->getType().isConstQualified()) {
742 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
743 return nullptr;
744 }
745
746 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
747 if (CI->getId() != Id) {
748 Diag(Loc: CI->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
749 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
750 }
751 return nullptr;
752 }
753
754 HLSLVkConstantIdAttr *Result =
755 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
756 return Result;
757}
758
759HLSLShaderAttr *
760SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
761 llvm::Triple::EnvironmentType ShaderType) {
762 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
763 if (NT->getType() != ShaderType) {
764 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
765 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
766 }
767 return nullptr;
768 }
769 return HLSLShaderAttr::Create(Ctx&: getASTContext(), Type: ShaderType, CommonInfo: AL);
770}
771
772HLSLParamModifierAttr *
773SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
774 HLSLParamModifierAttr::Spelling Spelling) {
775 // We can only merge an `in` attribute with an `out` attribute. All other
776 // combinations of duplicated attributes are ill-formed.
777 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
778 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
779 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
780 D->dropAttr<HLSLParamModifierAttr>();
781 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
782 return HLSLParamModifierAttr::Create(
783 Ctx&: getASTContext(), /*MergedSpelling=*/true, Range: AdjustedRange,
784 S: HLSLParamModifierAttr::Keyword_inout);
785 }
786 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_duplicate_parameter_modifier) << AL;
787 Diag(Loc: PA->getLocation(), DiagID: diag::note_conflicting_attribute);
788 return nullptr;
789 }
790 return HLSLParamModifierAttr::Create(Ctx&: getASTContext(), CommonInfo: AL);
791}
792
793void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
794 auto &TargetInfo = getASTContext().getTargetInfo();
795
796 if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
797 return;
798
799 // If we have specified a root signature to override the entry function then
800 // attach it now
801 HLSLRootSignatureDecl *SignatureDecl =
802 lookupRootSignatureOverrideDecl(DC: FD->getDeclContext());
803 if (SignatureDecl) {
804 FD->dropAttr<RootSignatureAttr>();
805 // We could look up the SourceRange of the macro here as well
806 AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(),
807 SourceRange(), ParsedAttr::Form::Microsoft());
808 FD->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
809 getASTContext(), AL, RootSigOverrideIdent, SignatureDecl));
810 }
811
812 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
813 if (HLSLShaderAttr::isValidShaderType(ShaderType: Env) && Env != llvm::Triple::Library) {
814 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
815 // The entry point is already annotated - check that it matches the
816 // triple.
817 if (Shader->getType() != Env) {
818 Diag(Loc: Shader->getLocation(), DiagID: diag::err_hlsl_entry_shader_attr_mismatch)
819 << Shader;
820 FD->setInvalidDecl();
821 }
822 } else {
823 // Implicitly add the shader attribute if the entry function isn't
824 // explicitly annotated.
825 FD->addAttr(A: HLSLShaderAttr::CreateImplicit(Ctx&: getASTContext(), Type: Env,
826 Range: FD->getBeginLoc()));
827 }
828 } else {
829 switch (Env) {
830 case llvm::Triple::UnknownEnvironment:
831 case llvm::Triple::Library:
832 break;
833 case llvm::Triple::RootSignature:
834 llvm_unreachable("rootsig environment has no functions");
835 default:
836 llvm_unreachable("Unhandled environment in triple");
837 }
838 }
839}
840
841static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD,
842 HLSLAppliedSemanticAttr *Semantic,
843 bool IsInput) {
844 if (AstContext.getTargetInfo().getTriple().getOS() != llvm::Triple::Vulkan)
845 return false;
846
847 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
848 assert(ShaderAttr && "Entry point has no shader attribute");
849 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
850 auto SemanticName = Semantic->getSemanticName().upper();
851
852 // The SV_Position semantic is lowered to:
853 // - Position built-in for vertex output.
854 // - FragCoord built-in for fragment input.
855 if (SemanticName == "SV_POSITION") {
856 return (ST == llvm::Triple::Vertex && !IsInput) ||
857 (ST == llvm::Triple::Pixel && IsInput);
858 }
859 if (SemanticName == "SV_VERTEXID")
860 return true;
861
862 return false;
863}
864
865bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
866 DeclaratorDecl *OutputDecl,
867 DeclaratorDecl *D,
868 SemanticInfo &ActiveSemantic,
869 SemaHLSL::SemanticContext &SC) {
870 if (ActiveSemantic.Semantic == nullptr) {
871 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
872 if (ActiveSemantic.Semantic)
873 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
874 }
875
876 if (!ActiveSemantic.Semantic) {
877 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_missing_semantic_annotation);
878 return false;
879 }
880
881 auto *A = ::new (getASTContext())
882 HLSLAppliedSemanticAttr(getASTContext(), *ActiveSemantic.Semantic,
883 ActiveSemantic.Semantic->getAttrName()->getName(),
884 ActiveSemantic.Index.value_or(u: 0));
885 if (!A)
886 return false;
887
888 checkSemanticAnnotation(EntryPoint: FD, Param: D, SemanticAttr: A, SC);
889 OutputDecl->addAttr(A);
890
891 unsigned Location = ActiveSemantic.Index.value_or(u: 0);
892
893 if (!isVkPipelineBuiltin(AstContext: getASTContext(), FD, Semantic: A,
894 IsInput: SC.CurrentIOType & IOType::In)) {
895 bool HasVkLocation = false;
896 if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
897 HasVkLocation = true;
898 Location = A->getLocation();
899 }
900
901 if (SC.UsesExplicitVkLocations.value_or(u&: HasVkLocation) != HasVkLocation) {
902 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_semantic_partial_explicit_indexing);
903 return false;
904 }
905 SC.UsesExplicitVkLocations = HasVkLocation;
906 }
907
908 const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(Val: D->getType());
909 unsigned ElementCount = AT ? AT->getZExtSize() : 1;
910 ActiveSemantic.Index = Location + ElementCount;
911
912 Twine BaseName = Twine(ActiveSemantic.Semantic->getAttrName()->getName());
913 for (unsigned I = 0; I < ElementCount; ++I) {
914 Twine VariableName = BaseName.concat(Suffix: Twine(Location + I));
915
916 auto [_, Inserted] = SC.ActiveSemantics.insert(key: VariableName.str());
917 if (!Inserted) {
918 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_semantic_index_overlap)
919 << VariableName.str();
920 return false;
921 }
922 }
923
924 return true;
925}
926
927bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
928 DeclaratorDecl *OutputDecl,
929 DeclaratorDecl *D,
930 SemanticInfo &ActiveSemantic,
931 SemaHLSL::SemanticContext &SC) {
932 if (ActiveSemantic.Semantic == nullptr) {
933 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
934 if (ActiveSemantic.Semantic)
935 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
936 }
937
938 const Type *T = D == FD ? &*FD->getReturnType() : &*D->getType();
939 T = T->getUnqualifiedDesugaredType();
940
941 const RecordType *RT = dyn_cast<RecordType>(Val: T);
942 if (!RT)
943 return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
944 SC);
945
946 const RecordDecl *RD = RT->getDecl();
947 for (FieldDecl *Field : RD->fields()) {
948 SemanticInfo Info = ActiveSemantic;
949 if (!determineActiveSemantic(FD, OutputDecl, D: Field, ActiveSemantic&: Info, SC)) {
950 Diag(Loc: Field->getLocation(), DiagID: diag::note_hlsl_semantic_used_here) << Field;
951 return false;
952 }
953 if (ActiveSemantic.Semantic)
954 ActiveSemantic = Info;
955 }
956
957 return true;
958}
959
960void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
961 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
962 assert(ShaderAttr && "Entry point has no shader attribute");
963 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
964 auto &TargetInfo = getASTContext().getTargetInfo();
965 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
966 switch (ST) {
967 case llvm::Triple::Pixel:
968 case llvm::Triple::Vertex:
969 case llvm::Triple::Geometry:
970 case llvm::Triple::Hull:
971 case llvm::Triple::Domain:
972 case llvm::Triple::RayGeneration:
973 case llvm::Triple::Intersection:
974 case llvm::Triple::AnyHit:
975 case llvm::Triple::ClosestHit:
976 case llvm::Triple::Miss:
977 case llvm::Triple::Callable:
978 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
979 diagnoseAttrStageMismatch(A: NT, Stage: ST,
980 AllowedStages: {llvm::Triple::Compute,
981 llvm::Triple::Amplification,
982 llvm::Triple::Mesh});
983 FD->setInvalidDecl();
984 }
985 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
986 diagnoseAttrStageMismatch(A: WS, Stage: ST,
987 AllowedStages: {llvm::Triple::Compute,
988 llvm::Triple::Amplification,
989 llvm::Triple::Mesh});
990 FD->setInvalidDecl();
991 }
992 break;
993
994 case llvm::Triple::Compute:
995 case llvm::Triple::Amplification:
996 case llvm::Triple::Mesh:
997 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
998 Diag(Loc: FD->getLocation(), DiagID: diag::err_hlsl_missing_numthreads)
999 << llvm::Triple::getEnvironmentTypeName(Kind: ST);
1000 FD->setInvalidDecl();
1001 }
1002 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
1003 if (Ver < VersionTuple(6, 6)) {
1004 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_in_wrong_shader_model)
1005 << WS << "6.6";
1006 FD->setInvalidDecl();
1007 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
1008 Diag(
1009 Loc: WS->getLocation(),
1010 DiagID: diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
1011 << WS << WS->getSpelledArgsCount() << "6.8";
1012 FD->setInvalidDecl();
1013 }
1014 }
1015 break;
1016 case llvm::Triple::RootSignature:
1017 llvm_unreachable("rootsig environment has no function entry point");
1018 default:
1019 llvm_unreachable("Unhandled environment in triple");
1020 }
1021
1022 SemaHLSL::SemanticContext InputSC = {};
1023 InputSC.CurrentIOType = IOType::In;
1024
1025 for (ParmVarDecl *Param : FD->parameters()) {
1026 SemanticInfo ActiveSemantic;
1027 ActiveSemantic.Semantic = Param->getAttr<HLSLParsedSemanticAttr>();
1028 if (ActiveSemantic.Semantic)
1029 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
1030
1031 // FIXME: Verify output semantics in parameters.
1032 if (!determineActiveSemantic(FD, OutputDecl: Param, D: Param, ActiveSemantic, SC&: InputSC)) {
1033 Diag(Loc: Param->getLocation(), DiagID: diag::note_previous_decl) << Param;
1034 FD->setInvalidDecl();
1035 }
1036 }
1037
1038 SemanticInfo ActiveSemantic;
1039 SemaHLSL::SemanticContext OutputSC = {};
1040 OutputSC.CurrentIOType = IOType::Out;
1041 ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
1042 if (ActiveSemantic.Semantic)
1043 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
1044 if (!FD->getReturnType()->isVoidType())
1045 determineActiveSemantic(FD, OutputDecl: FD, D: FD, ActiveSemantic, SC&: OutputSC);
1046}
1047
1048void SemaHLSL::checkSemanticAnnotation(
1049 FunctionDecl *EntryPoint, const Decl *Param,
1050 const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC) {
1051 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
1052 assert(ShaderAttr && "Entry point has no shader attribute");
1053 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
1054
1055 auto SemanticName = SemanticAttr->getSemanticName().upper();
1056 if (SemanticName == "SV_DISPATCHTHREADID" ||
1057 SemanticName == "SV_GROUPINDEX" || SemanticName == "SV_GROUPTHREADID" ||
1058 SemanticName == "SV_GROUPID") {
1059
1060 if (ST != llvm::Triple::Compute)
1061 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1062 AllowedStages: {{.Stage: llvm::Triple::Compute, .AllowedIOTypesMask: IOType::In}});
1063
1064 if (SemanticAttr->getSemanticIndex() != 0) {
1065 std::string PrettyName =
1066 "'" + SemanticAttr->getSemanticName().str() + "'";
1067 Diag(Loc: SemanticAttr->getLoc(),
1068 DiagID: diag::err_hlsl_semantic_indexing_not_supported)
1069 << PrettyName;
1070 }
1071 return;
1072 }
1073
1074 if (SemanticName == "SV_POSITION") {
1075 // SV_Position can be an input or output in vertex shaders,
1076 // but only an input in pixel shaders.
1077 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1078 AllowedStages: {{.Stage: llvm::Triple::Vertex, .AllowedIOTypesMask: IOType::InOut},
1079 {.Stage: llvm::Triple::Pixel, .AllowedIOTypesMask: IOType::In}});
1080 return;
1081 }
1082 if (SemanticName == "SV_VERTEXID") {
1083 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1084 AllowedStages: {{.Stage: llvm::Triple::Vertex, .AllowedIOTypesMask: IOType::In}});
1085 return;
1086 }
1087
1088 if (SemanticName == "SV_TARGET") {
1089 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1090 AllowedStages: {{.Stage: llvm::Triple::Pixel, .AllowedIOTypesMask: IOType::Out}});
1091 return;
1092 }
1093
1094 // FIXME: catch-all for non-implemented system semantics reaching this
1095 // location.
1096 if (SemanticAttr->getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
1097 llvm_unreachable("Unknown SemanticAttr");
1098}
1099
1100void SemaHLSL::diagnoseAttrStageMismatch(
1101 const Attr *A, llvm::Triple::EnvironmentType Stage,
1102 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
1103 SmallVector<StringRef, 8> StageStrings;
1104 llvm::transform(Range&: AllowedStages, d_first: std::back_inserter(x&: StageStrings),
1105 F: [](llvm::Triple::EnvironmentType ST) {
1106 return StringRef(
1107 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: ST));
1108 });
1109 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
1110 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
1111 << (AllowedStages.size() != 1) << join(R&: StageStrings, Separator: ", ");
1112}
1113
1114void SemaHLSL::diagnoseSemanticStageMismatch(
1115 const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
1116 std::initializer_list<SemanticStageInfo> Allowed) {
1117
1118 for (auto &Case : Allowed) {
1119 if (Case.Stage != Stage)
1120 continue;
1121
1122 if (CurrentIOType & Case.AllowedIOTypesMask)
1123 return;
1124
1125 SmallVector<std::string, 8> ValidCases;
1126 llvm::transform(
1127 Range&: Allowed, d_first: std::back_inserter(x&: ValidCases), F: [](SemanticStageInfo Case) {
1128 SmallVector<std::string, 2> ValidType;
1129 if (Case.AllowedIOTypesMask & IOType::In)
1130 ValidType.push_back(Elt: "input");
1131 if (Case.AllowedIOTypesMask & IOType::Out)
1132 ValidType.push_back(Elt: "output");
1133 return std::string(
1134 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: Case.Stage)) +
1135 " " + join(R&: ValidType, Separator: "/");
1136 });
1137 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1138 << A->getAttrName() << (CurrentIOType & IOType::In ? "input" : "output")
1139 << llvm::Triple::getEnvironmentTypeName(Kind: Case.Stage)
1140 << join(R&: ValidCases, Separator: ", ");
1141 return;
1142 }
1143
1144 SmallVector<StringRef, 8> StageStrings;
1145 llvm::transform(
1146 Range&: Allowed, d_first: std::back_inserter(x&: StageStrings), F: [](SemanticStageInfo Case) {
1147 return StringRef(
1148 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: Case.Stage));
1149 });
1150
1151 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
1152 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
1153 << (Allowed.size() != 1) << join(R&: StageStrings, Separator: ", ");
1154}
1155
1156template <CastKind Kind>
1157static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
1158 if (const auto *VTy = Ty->getAs<VectorType>())
1159 Ty = VTy->getElementType();
1160 Ty = S.getASTContext().getExtVectorType(VectorType: Ty, NumElts: Sz);
1161 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
1162}
1163
1164template <CastKind Kind>
1165static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
1166 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
1167 return Ty;
1168}
1169
1170static QualType handleFloatVectorBinOpConversion(
1171 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1172 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1173 bool LHSFloat = LElTy->isRealFloatingType();
1174 bool RHSFloat = RElTy->isRealFloatingType();
1175
1176 if (LHSFloat && RHSFloat) {
1177 if (IsCompAssign ||
1178 SemaRef.getASTContext().getFloatingTypeOrder(LHS: LElTy, RHS: RElTy) > 0)
1179 return castElement<CK_FloatingCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1180
1181 return castElement<CK_FloatingCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1182 }
1183
1184 if (LHSFloat)
1185 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: RHS, Ty: LHSType);
1186
1187 assert(RHSFloat);
1188 if (IsCompAssign)
1189 return castElement<clang::CK_FloatingToIntegral>(S&: SemaRef, E&: RHS, Ty: LHSType);
1190
1191 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: LHS, Ty: RHSType);
1192}
1193
1194static QualType handleIntegerVectorBinOpConversion(
1195 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1196 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1197
1198 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LHS: LElTy, RHS: RElTy);
1199 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
1200 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
1201 auto &Ctx = SemaRef.getASTContext();
1202
1203 // If both types have the same signedness, use the higher ranked type.
1204 if (LHSSigned == RHSSigned) {
1205 if (IsCompAssign || IntOrder >= 0)
1206 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1207
1208 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1209 }
1210
1211 // If the unsigned type has greater than or equal rank of the signed type, use
1212 // the unsigned type.
1213 if (IntOrder != (LHSSigned ? 1 : -1)) {
1214 if (IsCompAssign || RHSSigned)
1215 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1216 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1217 }
1218
1219 // At this point the signed type has higher rank than the unsigned type, which
1220 // means it will be the same size or bigger. If the signed type is bigger, it
1221 // can represent all the values of the unsigned type, so select it.
1222 if (Ctx.getIntWidth(T: LElTy) != Ctx.getIntWidth(T: RElTy)) {
1223 if (IsCompAssign || LHSSigned)
1224 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1225 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1226 }
1227
1228 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
1229 // to C/C++ leaking through. The place this happens today is long vs long
1230 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
1231 // the long long has higher rank than long even though they are the same size.
1232
1233 // If this is a compound assignment cast the right hand side to the left hand
1234 // side's type.
1235 if (IsCompAssign)
1236 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1237
1238 // If this isn't a compound assignment we convert to unsigned long long.
1239 QualType ElTy = Ctx.getCorrespondingUnsignedType(T: LHSSigned ? LElTy : RElTy);
1240 QualType NewTy = Ctx.getExtVectorType(
1241 VectorType: ElTy, NumElts: RHSType->castAs<VectorType>()->getNumElements());
1242 (void)castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: NewTy);
1243
1244 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: NewTy);
1245}
1246
1247static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
1248 QualType SrcTy) {
1249 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
1250 return CK_FloatingCast;
1251 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
1252 return CK_IntegralCast;
1253 if (DestTy->isRealFloatingType())
1254 return CK_IntegralToFloating;
1255 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
1256 return CK_FloatingToIntegral;
1257}
1258
1259QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
1260 QualType LHSType,
1261 QualType RHSType,
1262 bool IsCompAssign) {
1263 const auto *LVecTy = LHSType->getAs<VectorType>();
1264 const auto *RVecTy = RHSType->getAs<VectorType>();
1265 auto &Ctx = getASTContext();
1266
1267 // If the LHS is not a vector and this is a compound assignment, we truncate
1268 // the argument to a scalar then convert it to the LHS's type.
1269 if (!LVecTy && IsCompAssign) {
1270 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1271 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: RElTy, CK: CK_HLSLVectorTruncation);
1272 RHSType = RHS.get()->getType();
1273 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1274 return LHSType;
1275 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: LHSType,
1276 CK: getScalarCastKind(Ctx, DestTy: LHSType, SrcTy: RHSType));
1277 return LHSType;
1278 }
1279
1280 unsigned EndSz = std::numeric_limits<unsigned>::max();
1281 unsigned LSz = 0;
1282 if (LVecTy)
1283 LSz = EndSz = LVecTy->getNumElements();
1284 if (RVecTy)
1285 EndSz = std::min(a: RVecTy->getNumElements(), b: EndSz);
1286 assert(EndSz != std::numeric_limits<unsigned>::max() &&
1287 "one of the above should have had a value");
1288
1289 // In a compound assignment, the left operand does not change type, the right
1290 // operand is converted to the type of the left operand.
1291 if (IsCompAssign && LSz != EndSz) {
1292 Diag(Loc: LHS.get()->getBeginLoc(),
1293 DiagID: diag::err_hlsl_vector_compound_assignment_truncation)
1294 << LHSType << RHSType;
1295 return QualType();
1296 }
1297
1298 if (RVecTy && RVecTy->getNumElements() > EndSz)
1299 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1300 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1301 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1302
1303 if (!RVecTy)
1304 castVector<CK_VectorSplat>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1305 if (!IsCompAssign && !LVecTy)
1306 castVector<CK_VectorSplat>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1307
1308 // If we're at the same type after resizing we can stop here.
1309 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1310 return Ctx.getCommonSugaredType(X: LHSType, Y: RHSType);
1311
1312 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1313 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1314
1315 // Handle conversion for floating point vectors.
1316 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1317 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1318 LElTy, RElTy, IsCompAssign);
1319
1320 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1321 "HLSL Vectors can only contain integer or floating point types");
1322 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1323 LElTy, RElTy, IsCompAssign);
1324}
1325
1326void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
1327 BinaryOperatorKind Opc) {
1328 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1329 "Called with non-logical operator");
1330 llvm::SmallVector<char, 256> Buff;
1331 llvm::raw_svector_ostream OS(Buff);
1332 PrintingPolicy PP(SemaRef.getLangOpts());
1333 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1334 OS << NewFnName << "(";
1335 LHS->printPretty(OS, Helper: nullptr, Policy: PP);
1336 OS << ", ";
1337 RHS->printPretty(OS, Helper: nullptr, Policy: PP);
1338 OS << ")";
1339 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1340 SemaRef.Diag(Loc: LHS->getBeginLoc(), DiagID: diag::note_function_suggestion)
1341 << NewFnName << FixItHint::CreateReplacement(RemoveRange: FullRange, Code: OS.str());
1342}
1343
1344std::pair<IdentifierInfo *, bool>
1345SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) {
1346 llvm::hash_code Hash = llvm::hash_value(S: Signature);
1347 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(val: Hash);
1348 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(Name: IdStr));
1349
1350 // Check if we have already found a decl of the same name.
1351 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1352 Sema::LookupOrdinaryName);
1353 bool Found = SemaRef.LookupQualifiedName(R, LookupCtx: SemaRef.CurContext);
1354 return {DeclIdent, Found};
1355}
1356
1357void SemaHLSL::ActOnFinishRootSignatureDecl(
1358 SourceLocation Loc, IdentifierInfo *DeclIdent,
1359 ArrayRef<hlsl::RootSignatureElement> RootElements) {
1360
1361 if (handleRootSignatureElements(Elements: RootElements))
1362 return;
1363
1364 SmallVector<llvm::hlsl::rootsig::RootElement> Elements;
1365 for (auto &RootSigElement : RootElements)
1366 Elements.push_back(Elt: RootSigElement.getElement());
1367
1368 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1369 C&: SemaRef.getASTContext(), /*DeclContext=*/DC: SemaRef.CurContext, Loc,
1370 ID: DeclIdent, Version: SemaRef.getLangOpts().HLSLRootSigVer, RootElements: Elements);
1371
1372 SignatureDecl->setImplicit();
1373 SemaRef.PushOnScopeChains(D: SignatureDecl, S: SemaRef.getCurScope());
1374}
1375
1376HLSLRootSignatureDecl *
1377SemaHLSL::lookupRootSignatureOverrideDecl(DeclContext *DC) const {
1378 if (RootSigOverrideIdent) {
1379 LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(),
1380 Sema::LookupOrdinaryName);
1381 if (SemaRef.LookupQualifiedName(R, LookupCtx: DC))
1382 return dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl());
1383 }
1384
1385 return nullptr;
1386}
1387
1388namespace {
1389
1390struct PerVisibilityBindingChecker {
1391 SemaHLSL *S;
1392 // We need one builder per `llvm::dxbc::ShaderVisibility` value.
1393 std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
1394
1395 struct ElemInfo {
1396 const hlsl::RootSignatureElement *Elem;
1397 llvm::dxbc::ShaderVisibility Vis;
1398 bool Diagnosed;
1399 };
1400 llvm::SmallVector<ElemInfo> ElemInfoMap;
1401
1402 PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
1403
1404 void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
1405 llvm::dxil::ResourceClass RC, uint32_t Space,
1406 uint32_t LowerBound, uint32_t UpperBound,
1407 const hlsl::RootSignatureElement *Elem) {
1408 uint32_t BuilderIndex = llvm::to_underlying(E: Visibility);
1409 assert(BuilderIndex < Builders.size() &&
1410 "Not enough builders for visibility type");
1411 Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
1412 Cookie: static_cast<const void *>(Elem));
1413
1414 static_assert(llvm::to_underlying(E: llvm::dxbc::ShaderVisibility::All) == 0,
1415 "'All' visibility must come first");
1416 if (Visibility == llvm::dxbc::ShaderVisibility::All)
1417 for (size_t I = 1, E = Builders.size(); I < E; ++I)
1418 Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
1419 Cookie: static_cast<const void *>(Elem));
1420
1421 ElemInfoMap.push_back(Elt: {.Elem: Elem, .Vis: Visibility, .Diagnosed: false});
1422 }
1423
1424 ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
1425 auto It = llvm::lower_bound(
1426 Range&: ElemInfoMap, Value&: Elem,
1427 C: [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
1428 assert(It->Elem == Elem && "Element not in map");
1429 return *It;
1430 }
1431
1432 bool checkOverlap() {
1433 llvm::sort(C&: ElemInfoMap, Comp: [](const auto &LHS, const auto &RHS) {
1434 return LHS.Elem < RHS.Elem;
1435 });
1436
1437 bool HadOverlap = false;
1438
1439 using llvm::hlsl::BindingInfoBuilder;
1440 auto ReportOverlap = [this,
1441 &HadOverlap](const BindingInfoBuilder &Builder,
1442 const llvm::hlsl::Binding &Reported) {
1443 HadOverlap = true;
1444
1445 const auto *Elem =
1446 static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
1447 const llvm::hlsl::Binding &Previous = Builder.findOverlapping(ReportedBinding: Reported);
1448 const auto *PrevElem =
1449 static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
1450
1451 ElemInfo &Info = getInfo(Elem);
1452 // We will have already diagnosed this binding if there's overlap in the
1453 // "All" visibility as well as any particular visibility.
1454 if (Info.Diagnosed)
1455 return;
1456 Info.Diagnosed = true;
1457
1458 ElemInfo &PrevInfo = getInfo(Elem: PrevElem);
1459 llvm::dxbc::ShaderVisibility CommonVis =
1460 Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
1461 : Info.Vis;
1462
1463 this->S->Diag(Loc: Elem->getLocation(), DiagID: diag::err_hlsl_resource_range_overlap)
1464 << llvm::to_underlying(E: Reported.RC) << Reported.LowerBound
1465 << Reported.isUnbounded() << Reported.UpperBound
1466 << llvm::to_underlying(E: Previous.RC) << Previous.LowerBound
1467 << Previous.isUnbounded() << Previous.UpperBound << Reported.Space
1468 << CommonVis;
1469
1470 this->S->Diag(Loc: PrevElem->getLocation(),
1471 DiagID: diag::note_hlsl_resource_range_here);
1472 };
1473
1474 for (BindingInfoBuilder &Builder : Builders)
1475 Builder.calculateBindingInfo(ReportOverlap);
1476
1477 return HadOverlap;
1478 }
1479};
1480
1481static CXXMethodDecl *lookupMethod(Sema &S, CXXRecordDecl *RecordDecl,
1482 StringRef Name, SourceLocation Loc) {
1483 DeclarationName DeclName(&S.getASTContext().Idents.get(Name));
1484 LookupResult Result(S, DeclName, Loc, Sema::LookupMemberName);
1485 if (!S.LookupQualifiedName(R&: Result, LookupCtx: static_cast<DeclContext *>(RecordDecl)))
1486 return nullptr;
1487 return cast<CXXMethodDecl>(Val: Result.getFoundDecl());
1488}
1489
1490} // end anonymous namespace
1491
1492static bool hasCounterHandle(const CXXRecordDecl *RD) {
1493 if (RD->field_empty())
1494 return false;
1495 auto It = std::next(x: RD->field_begin());
1496 if (It == RD->field_end())
1497 return false;
1498 const FieldDecl *SecondField = *It;
1499 if (const auto *ResTy =
1500 SecondField->getType()->getAs<HLSLAttributedResourceType>()) {
1501 return ResTy->getAttrs().IsCounter;
1502 }
1503 return false;
1504}
1505
1506bool SemaHLSL::handleRootSignatureElements(
1507 ArrayRef<hlsl::RootSignatureElement> Elements) {
1508 // Define some common error handling functions
1509 bool HadError = false;
1510 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1511 uint32_t UpperBound) {
1512 HadError = true;
1513 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1514 << LowerBound << UpperBound;
1515 };
1516
1517 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1518 float LowerBound,
1519 float UpperBound) {
1520 HadError = true;
1521 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1522 << llvm::formatv(Fmt: "{0:f}", Vals&: LowerBound).sstr<6>()
1523 << llvm::formatv(Fmt: "{0:f}", Vals&: UpperBound).sstr<6>();
1524 };
1525
1526 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1527 if (!llvm::hlsl::rootsig::verifyRegisterValue(RegisterValue: Register))
1528 ReportError(Loc, 0, 0xfffffffe);
1529 };
1530
1531 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1532 if (!llvm::hlsl::rootsig::verifyRegisterSpace(RegisterSpace: Space))
1533 ReportError(Loc, 0, 0xffffffef);
1534 };
1535
1536 const uint32_t Version =
1537 llvm::to_underlying(E: SemaRef.getLangOpts().HLSLRootSigVer);
1538 const uint32_t VersionEnum = Version - 1;
1539 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1540 HadError = true;
1541 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_flag)
1542 << /*version minor*/ VersionEnum;
1543 };
1544
1545 // Iterate through the elements and do basic validations
1546 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1547 SourceLocation Loc = RootSigElem.getLocation();
1548 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1549 if (const auto *Descriptor =
1550 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1551 VerifyRegister(Loc, Descriptor->Reg.Number);
1552 VerifySpace(Loc, Descriptor->Space);
1553
1554 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
1555 Flags: Descriptor->Flags))
1556 ReportFlagError(Loc);
1557 } else if (const auto *Constants =
1558 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1559 VerifyRegister(Loc, Constants->Reg.Number);
1560 VerifySpace(Loc, Constants->Space);
1561 } else if (const auto *Sampler =
1562 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1563 VerifyRegister(Loc, Sampler->Reg.Number);
1564 VerifySpace(Loc, Sampler->Space);
1565
1566 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1567 "By construction, parseFloatParam can't produce a NaN from a "
1568 "float_literal token");
1569
1570 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(MaxAnisotropy: Sampler->MaxAnisotropy))
1571 ReportError(Loc, 0, 16);
1572 if (!llvm::hlsl::rootsig::verifyMipLODBias(MipLODBias: Sampler->MipLODBias))
1573 ReportFloatError(Loc, -16.f, 15.99f);
1574 } else if (const auto *Clause =
1575 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1576 ptr: &Elem)) {
1577 VerifyRegister(Loc, Clause->Reg.Number);
1578 VerifySpace(Loc, Clause->Space);
1579
1580 if (!llvm::hlsl::rootsig::verifyNumDescriptors(NumDescriptors: Clause->NumDescriptors)) {
1581 // NumDescriptor could techincally be ~0u but that is reserved for
1582 // unbounded, so the diagnostic will not report that as a valid int
1583 // value
1584 ReportError(Loc, 1, 0xfffffffe);
1585 }
1586
1587 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Type: Clause->Type,
1588 Flags: Clause->Flags))
1589 ReportFlagError(Loc);
1590 }
1591 }
1592
1593 PerVisibilityBindingChecker BindingChecker(this);
1594 SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
1595 const hlsl::RootSignatureElement *>>
1596 UnboundClauses;
1597
1598 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1599 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1600 if (const auto *Descriptor =
1601 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1602 uint32_t LowerBound(Descriptor->Reg.Number);
1603 uint32_t UpperBound(LowerBound); // inclusive range
1604
1605 BindingChecker.trackBinding(
1606 Visibility: Descriptor->Visibility,
1607 RC: static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
1608 Space: Descriptor->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1609 } else if (const auto *Constants =
1610 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1611 uint32_t LowerBound(Constants->Reg.Number);
1612 uint32_t UpperBound(LowerBound); // inclusive range
1613
1614 BindingChecker.trackBinding(
1615 Visibility: Constants->Visibility, RC: llvm::dxil::ResourceClass::CBuffer,
1616 Space: Constants->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1617 } else if (const auto *Sampler =
1618 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1619 uint32_t LowerBound(Sampler->Reg.Number);
1620 uint32_t UpperBound(LowerBound); // inclusive range
1621
1622 BindingChecker.trackBinding(
1623 Visibility: Sampler->Visibility, RC: llvm::dxil::ResourceClass::Sampler,
1624 Space: Sampler->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1625 } else if (const auto *Clause =
1626 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1627 ptr: &Elem)) {
1628 // We'll process these once we see the table element.
1629 UnboundClauses.emplace_back(Args&: Clause, Args: &RootSigElem);
1630 } else if (const auto *Table =
1631 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(ptr: &Elem)) {
1632 assert(UnboundClauses.size() == Table->NumClauses &&
1633 "Number of unbound elements must match the number of clauses");
1634 bool HasAnySampler = false;
1635 bool HasAnyNonSampler = false;
1636 uint64_t Offset = 0;
1637 bool IsPrevUnbound = false;
1638 for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1639 SourceLocation Loc = ClauseElem->getLocation();
1640 if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
1641 HasAnySampler = true;
1642 else
1643 HasAnyNonSampler = true;
1644
1645 if (HasAnySampler && HasAnyNonSampler)
1646 Diag(Loc, DiagID: diag::err_hlsl_invalid_mixed_resources);
1647
1648 // Relevant error will have already been reported above and needs to be
1649 // fixed before we can conduct further analysis, so shortcut error
1650 // return
1651 if (Clause->NumDescriptors == 0)
1652 return true;
1653
1654 bool IsAppending =
1655 Clause->Offset == llvm::hlsl::rootsig::DescriptorTableOffsetAppend;
1656 if (!IsAppending)
1657 Offset = Clause->Offset;
1658
1659 uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
1660 Offset, Size: Clause->NumDescriptors);
1661
1662 if (IsPrevUnbound && IsAppending)
1663 Diag(Loc, DiagID: diag::err_hlsl_appending_onto_unbound);
1664 else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(Offset: RangeBound))
1665 Diag(Loc, DiagID: diag::err_hlsl_offset_overflow) << Offset << RangeBound;
1666
1667 // Update offset to be 1 past this range's bound
1668 Offset = RangeBound + 1;
1669 IsPrevUnbound = Clause->NumDescriptors ==
1670 llvm::hlsl::rootsig::NumDescriptorsUnbounded;
1671
1672 // Compute the register bounds and track resource binding
1673 uint32_t LowerBound(Clause->Reg.Number);
1674 uint32_t UpperBound = llvm::hlsl::rootsig::computeRangeBound(
1675 Offset: LowerBound, Size: Clause->NumDescriptors);
1676
1677 BindingChecker.trackBinding(
1678 Visibility: Table->Visibility,
1679 RC: static_cast<llvm::dxil::ResourceClass>(Clause->Type), Space: Clause->Space,
1680 LowerBound, UpperBound, Elem: ClauseElem);
1681 }
1682 UnboundClauses.clear();
1683 }
1684 }
1685
1686 return BindingChecker.checkOverlap();
1687}
1688
1689void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
1690 if (AL.getNumArgs() != 1) {
1691 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
1692 return;
1693 }
1694
1695 IdentifierInfo *Ident = AL.getArgAsIdent(Arg: 0)->getIdentifierInfo();
1696 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1697 if (RS->getSignatureIdent() != Ident) {
1698 Diag(Loc: AL.getLoc(), DiagID: diag::err_disallowed_duplicate_attribute) << RS;
1699 return;
1700 }
1701
1702 Diag(Loc: AL.getLoc(), DiagID: diag::warn_duplicate_attribute_exact) << RS;
1703 return;
1704 }
1705
1706 LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName);
1707 if (SemaRef.LookupQualifiedName(R, LookupCtx: D->getDeclContext()))
1708 if (auto *SignatureDecl =
1709 dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl())) {
1710 D->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
1711 getASTContext(), AL, Ident, SignatureDecl));
1712 }
1713}
1714
1715void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
1716 llvm::VersionTuple SMVersion =
1717 getASTContext().getTargetInfo().getTriple().getOSVersion();
1718 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1719 llvm::Triple::dxil;
1720
1721 uint32_t ZMax = 1024;
1722 uint32_t ThreadMax = 1024;
1723 if (IsDXIL && SMVersion.getMajor() <= 4) {
1724 ZMax = 1;
1725 ThreadMax = 768;
1726 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1727 ZMax = 64;
1728 ThreadMax = 1024;
1729 }
1730
1731 uint32_t X;
1732 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: X))
1733 return;
1734 if (X > 1024) {
1735 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1736 DiagID: diag::err_hlsl_numthreads_argument_oor)
1737 << 0 << 1024;
1738 return;
1739 }
1740 uint32_t Y;
1741 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Y))
1742 return;
1743 if (Y > 1024) {
1744 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1745 DiagID: diag::err_hlsl_numthreads_argument_oor)
1746 << 1 << 1024;
1747 return;
1748 }
1749 uint32_t Z;
1750 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Z))
1751 return;
1752 if (Z > ZMax) {
1753 SemaRef.Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1754 DiagID: diag::err_hlsl_numthreads_argument_oor)
1755 << 2 << ZMax;
1756 return;
1757 }
1758
1759 if (X * Y * Z > ThreadMax) {
1760 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_numthreads_invalid) << ThreadMax;
1761 return;
1762 }
1763
1764 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1765 if (NewAttr)
1766 D->addAttr(A: NewAttr);
1767}
1768
1769static bool isValidWaveSizeValue(unsigned Value) {
1770 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1771}
1772
1773void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
1774 // validate that the wavesize argument is a power of 2 between 4 and 128
1775 // inclusive
1776 unsigned SpelledArgsCount = AL.getNumArgs();
1777 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1778 return;
1779
1780 uint32_t Min;
1781 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Min))
1782 return;
1783
1784 uint32_t Max = 0;
1785 if (SpelledArgsCount > 1 &&
1786 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Max))
1787 return;
1788
1789 uint32_t Preferred = 0;
1790 if (SpelledArgsCount > 2 &&
1791 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Preferred))
1792 return;
1793
1794 if (SpelledArgsCount > 2) {
1795 if (!isValidWaveSizeValue(Value: Preferred)) {
1796 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1797 DiagID: diag::err_attribute_power_of_two_in_range)
1798 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1799 << Preferred;
1800 return;
1801 }
1802 // Preferred not in range.
1803 if (Preferred < Min || Preferred > Max) {
1804 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1805 DiagID: diag::err_attribute_power_of_two_in_range)
1806 << AL << Min << Max << Preferred;
1807 return;
1808 }
1809 } else if (SpelledArgsCount > 1) {
1810 if (!isValidWaveSizeValue(Value: Max)) {
1811 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1812 DiagID: diag::err_attribute_power_of_two_in_range)
1813 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1814 return;
1815 }
1816 if (Max < Min) {
1817 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_invalid) << AL << 1;
1818 return;
1819 } else if (Max == Min) {
1820 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attr_min_eq_max) << AL;
1821 }
1822 } else {
1823 if (!isValidWaveSizeValue(Value: Min)) {
1824 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1825 DiagID: diag::err_attribute_power_of_two_in_range)
1826 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1827 return;
1828 }
1829 }
1830
1831 HLSLWaveSizeAttr *NewAttr =
1832 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1833 if (NewAttr)
1834 D->addAttr(A: NewAttr);
1835}
1836
1837void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
1838 uint32_t ID;
1839 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: ID))
1840 return;
1841 D->addAttr(A: ::new (getASTContext())
1842 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1843}
1844
1845void SemaHLSL::handleVkExtBuiltinOutputAttr(Decl *D, const ParsedAttr &AL) {
1846 uint32_t ID;
1847 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: ID))
1848 return;
1849 D->addAttr(A: ::new (getASTContext())
1850 HLSLVkExtBuiltinOutputAttr(getASTContext(), AL, ID));
1851}
1852
1853void SemaHLSL::handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL) {
1854 D->addAttr(A: ::new (getASTContext())
1855 HLSLVkPushConstantAttr(getASTContext(), AL));
1856}
1857
1858void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1859 uint32_t Id;
1860 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Id))
1861 return;
1862 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1863 if (NewAttr)
1864 D->addAttr(A: NewAttr);
1865}
1866
1867void SemaHLSL::handleVkBindingAttr(Decl *D, const ParsedAttr &AL) {
1868 uint32_t Binding = 0;
1869 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Binding))
1870 return;
1871 uint32_t Set = 0;
1872 if (AL.getNumArgs() > 1 &&
1873 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Set))
1874 return;
1875
1876 D->addAttr(A: ::new (getASTContext())
1877 HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
1878}
1879
1880void SemaHLSL::handleVkLocationAttr(Decl *D, const ParsedAttr &AL) {
1881 uint32_t Location;
1882 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Location))
1883 return;
1884
1885 D->addAttr(A: ::new (getASTContext())
1886 HLSLVkLocationAttr(getASTContext(), AL, Location));
1887}
1888
1889bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
1890 const auto *VT = T->getAs<VectorType>();
1891
1892 if (!T->hasUnsignedIntegerRepresentation() ||
1893 (VT && VT->getNumElements() > 3)) {
1894 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1895 << AL << "uint/uint2/uint3";
1896 return false;
1897 }
1898
1899 return true;
1900}
1901
1902bool SemaHLSL::diagnosePositionType(QualType T, const ParsedAttr &AL) {
1903 const auto *VT = T->getAs<VectorType>();
1904 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1905 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1906 << AL << "float/float1/float2/float3/float4";
1907 return false;
1908 }
1909
1910 return true;
1911}
1912
1913void SemaHLSL::diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL,
1914 std::optional<unsigned> Index) {
1915 std::string SemanticName = AL.getAttrName()->getName().upper();
1916
1917 auto *VD = cast<ValueDecl>(Val: D);
1918 QualType ValueType = VD->getType();
1919 if (auto *FD = dyn_cast<FunctionDecl>(Val: D))
1920 ValueType = FD->getReturnType();
1921
1922 bool IsOutput = false;
1923 if (HLSLParamModifierAttr *MA = D->getAttr<HLSLParamModifierAttr>()) {
1924 if (MA->isOut()) {
1925 IsOutput = true;
1926 ValueType = cast<ReferenceType>(Val&: ValueType)->getPointeeType();
1927 }
1928 }
1929
1930 if (SemanticName == "SV_DISPATCHTHREADID") {
1931 diagnoseInputIDType(T: ValueType, AL);
1932 if (IsOutput)
1933 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1934 if (Index.has_value())
1935 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1936 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1937 return;
1938 }
1939
1940 if (SemanticName == "SV_GROUPINDEX") {
1941 if (IsOutput)
1942 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1943 if (Index.has_value())
1944 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1945 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1946 return;
1947 }
1948
1949 if (SemanticName == "SV_GROUPTHREADID") {
1950 diagnoseInputIDType(T: ValueType, AL);
1951 if (IsOutput)
1952 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1953 if (Index.has_value())
1954 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1955 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1956 return;
1957 }
1958
1959 if (SemanticName == "SV_GROUPID") {
1960 diagnoseInputIDType(T: ValueType, AL);
1961 if (IsOutput)
1962 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1963 if (Index.has_value())
1964 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1965 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1966 return;
1967 }
1968
1969 if (SemanticName == "SV_POSITION") {
1970 const auto *VT = ValueType->getAs<VectorType>();
1971 if (!ValueType->hasFloatingRepresentation() ||
1972 (VT && VT->getNumElements() > 4))
1973 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1974 << AL << "float/float1/float2/float3/float4";
1975 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1976 return;
1977 }
1978
1979 if (SemanticName == "SV_VERTEXID") {
1980 uint64_t SizeInBits = SemaRef.Context.getTypeSize(T: ValueType);
1981 if (!ValueType->isUnsignedIntegerType() || SizeInBits != 32)
1982 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type) << AL << "uint";
1983 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1984 return;
1985 }
1986
1987 if (SemanticName == "SV_TARGET") {
1988 const auto *VT = ValueType->getAs<VectorType>();
1989 if (!ValueType->hasFloatingRepresentation() ||
1990 (VT && VT->getNumElements() > 4))
1991 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1992 << AL << "float/float1/float2/float3/float4";
1993 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1994 return;
1995 }
1996
1997 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_unknown_semantic) << AL;
1998}
1999
2000void SemaHLSL::handleSemanticAttr(Decl *D, const ParsedAttr &AL) {
2001 uint32_t IndexValue(0), ExplicitIndex(0);
2002 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: IndexValue) ||
2003 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: ExplicitIndex)) {
2004 assert(0 && "HLSLUnparsedSemantic is expected to have 2 int arguments.");
2005 }
2006 assert(IndexValue > 0 ? ExplicitIndex : true);
2007 std::optional<unsigned> Index =
2008 ExplicitIndex ? std::optional<unsigned>(IndexValue) : std::nullopt;
2009
2010 if (AL.getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
2011 diagnoseSystemSemanticAttr(D, AL, Index);
2012 else
2013 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
2014}
2015
2016void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
2017 if (!isa<VarDecl>(Val: D) || !isa<HLSLBufferDecl>(Val: D->getDeclContext())) {
2018 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_ast_node)
2019 << AL << "shader constant in a constant buffer";
2020 return;
2021 }
2022
2023 uint32_t SubComponent;
2024 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: SubComponent))
2025 return;
2026 uint32_t Component;
2027 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Component))
2028 return;
2029
2030 QualType T = cast<VarDecl>(Val: D)->getType().getCanonicalType();
2031 // Check if T is an array or struct type.
2032 // TODO: mark matrix type as aggregate type.
2033 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
2034
2035 // Check Component is valid for T.
2036 if (Component) {
2037 unsigned Size = getASTContext().getTypeSize(T);
2038 if (IsAggregateTy) {
2039 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_invalid_register_or_packoffset);
2040 return;
2041 } else {
2042 // Make sure Component + sizeof(T) <= 4.
2043 if ((Component * 32 + Size) > 128) {
2044 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_cross_reg_boundary);
2045 return;
2046 }
2047 QualType EltTy = T;
2048 if (const auto *VT = T->getAs<VectorType>())
2049 EltTy = VT->getElementType();
2050 unsigned Align = getASTContext().getTypeAlign(T: EltTy);
2051 if (Align > 32 && Component == 1) {
2052 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
2053 // So we only need to check Component 1 here.
2054 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_alignment_mismatch)
2055 << Align << EltTy;
2056 return;
2057 }
2058 }
2059 }
2060
2061 D->addAttr(A: ::new (getASTContext()) HLSLPackOffsetAttr(
2062 getASTContext(), AL, SubComponent, Component));
2063}
2064
2065void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
2066 StringRef Str;
2067 SourceLocation ArgLoc;
2068 if (!SemaRef.checkStringLiteralArgumentAttr(Attr: AL, ArgNum: 0, Str, ArgLocation: &ArgLoc))
2069 return;
2070
2071 llvm::Triple::EnvironmentType ShaderType;
2072 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Val: Str, Out&: ShaderType)) {
2073 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_type_not_supported)
2074 << AL << Str << ArgLoc;
2075 return;
2076 }
2077
2078 // FIXME: check function match the shader stage.
2079
2080 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
2081 if (NewAttr)
2082 D->addAttr(A: NewAttr);
2083}
2084
2085bool clang::CreateHLSLAttributedResourceType(
2086 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
2087 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
2088 assert(AttrList.size() && "expected list of resource attributes");
2089
2090 QualType ContainedTy = QualType();
2091 TypeSourceInfo *ContainedTyInfo = nullptr;
2092 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
2093 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
2094
2095 HLSLAttributedResourceType::Attributes ResAttrs;
2096
2097 bool HasResourceClass = false;
2098 bool HasResourceDimension = false;
2099 for (const Attr *A : AttrList) {
2100 if (!A)
2101 continue;
2102 LocEnd = A->getRange().getEnd();
2103 switch (A->getKind()) {
2104 case attr::HLSLResourceClass: {
2105 ResourceClass RC = cast<HLSLResourceClassAttr>(Val: A)->getResourceClass();
2106 if (HasResourceClass) {
2107 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceClass == RC
2108 ? diag::warn_duplicate_attribute_exact
2109 : diag::warn_duplicate_attribute)
2110 << A;
2111 return false;
2112 }
2113 ResAttrs.ResourceClass = RC;
2114 HasResourceClass = true;
2115 break;
2116 }
2117 case attr::HLSLResourceDimension: {
2118 llvm::dxil::ResourceDimension RD =
2119 cast<HLSLResourceDimensionAttr>(Val: A)->getDimension();
2120 if (HasResourceDimension) {
2121 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceDimension == RD
2122 ? diag::warn_duplicate_attribute_exact
2123 : diag::warn_duplicate_attribute)
2124 << A;
2125 return false;
2126 }
2127 ResAttrs.ResourceDimension = RD;
2128 HasResourceDimension = true;
2129 break;
2130 }
2131 case attr::HLSLROV:
2132 if (ResAttrs.IsROV) {
2133 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2134 return false;
2135 }
2136 ResAttrs.IsROV = true;
2137 break;
2138 case attr::HLSLRawBuffer:
2139 if (ResAttrs.RawBuffer) {
2140 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2141 return false;
2142 }
2143 ResAttrs.RawBuffer = true;
2144 break;
2145 case attr::HLSLIsCounter:
2146 if (ResAttrs.IsCounter) {
2147 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2148 return false;
2149 }
2150 ResAttrs.IsCounter = true;
2151 break;
2152 case attr::HLSLContainedType: {
2153 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(Val: A);
2154 QualType Ty = CTAttr->getType();
2155 if (!ContainedTy.isNull()) {
2156 S.Diag(Loc: A->getLocation(), DiagID: ContainedTy == Ty
2157 ? diag::warn_duplicate_attribute_exact
2158 : diag::warn_duplicate_attribute)
2159 << A;
2160 return false;
2161 }
2162 ContainedTy = Ty;
2163 ContainedTyInfo = CTAttr->getTypeLoc();
2164 break;
2165 }
2166 default:
2167 llvm_unreachable("unhandled resource attribute type");
2168 }
2169 }
2170
2171 if (!HasResourceClass) {
2172 S.Diag(Loc: AttrList.back()->getRange().getEnd(),
2173 DiagID: diag::err_hlsl_missing_resource_class);
2174 return false;
2175 }
2176
2177 ResType = S.getASTContext().getHLSLAttributedResourceType(
2178 Wrapped, Contained: ContainedTy, Attrs: ResAttrs);
2179
2180 if (LocInfo && ContainedTyInfo) {
2181 LocInfo->Range = SourceRange(LocBegin, LocEnd);
2182 LocInfo->ContainedTyInfo = ContainedTyInfo;
2183 }
2184 return true;
2185}
2186
2187// Validates and creates an HLSL attribute that is applied as type attribute on
2188// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
2189// the end of the declaration they are applied to the declaration type by
2190// wrapping it in HLSLAttributedResourceType.
2191bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) {
2192 // only allow resource type attributes on intangible types
2193 if (!T->isHLSLResourceType()) {
2194 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attribute_needs_intangible_type)
2195 << AL << getASTContext().HLSLResourceTy;
2196 return false;
2197 }
2198
2199 // validate number of arguments
2200 if (!AL.checkExactlyNumArgs(S&: SemaRef, Num: AL.getMinArgs()))
2201 return false;
2202
2203 Attr *A = nullptr;
2204
2205 AttributeCommonInfo ACI(
2206 AL.getLoc(), AttributeScopeInfo(AL.getScopeName(), AL.getScopeLoc()),
2207 AttributeCommonInfo::NoSemaHandlerAttribute,
2208 {
2209 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
2210 false /*IsRegularKeywordAttribute*/
2211 });
2212
2213 switch (AL.getKind()) {
2214 case ParsedAttr::AT_HLSLResourceClass: {
2215 if (!AL.isArgIdent(Arg: 0)) {
2216 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2217 << AL << AANT_ArgumentIdentifier;
2218 return false;
2219 }
2220
2221 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
2222 StringRef Identifier = Loc->getIdentifierInfo()->getName();
2223 SourceLocation ArgLoc = Loc->getLoc();
2224
2225 // Validate resource class value
2226 ResourceClass RC;
2227 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Val: Identifier, Out&: RC)) {
2228 Diag(Loc: ArgLoc, DiagID: diag::warn_attribute_type_not_supported)
2229 << "ResourceClass" << Identifier;
2230 return false;
2231 }
2232 A = HLSLResourceClassAttr::Create(Ctx&: getASTContext(), ResourceClass: RC, CommonInfo: ACI);
2233 break;
2234 }
2235
2236 case ParsedAttr::AT_HLSLResourceDimension: {
2237 StringRef Identifier;
2238 SourceLocation ArgLoc;
2239 if (!SemaRef.checkStringLiteralArgumentAttr(Attr: AL, ArgNum: 0, Str&: Identifier, ArgLocation: &ArgLoc))
2240 return false;
2241
2242 // Validate resource dimension value
2243 llvm::dxil::ResourceDimension RD;
2244 if (!HLSLResourceDimensionAttr::ConvertStrToResourceDimension(Val: Identifier,
2245 Out&: RD)) {
2246 Diag(Loc: ArgLoc, DiagID: diag::warn_attribute_type_not_supported)
2247 << "ResourceDimension" << Identifier;
2248 return false;
2249 }
2250 A = HLSLResourceDimensionAttr::Create(Ctx&: getASTContext(), Dimension: RD, CommonInfo: ACI);
2251 break;
2252 }
2253
2254 case ParsedAttr::AT_HLSLROV:
2255 A = HLSLROVAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2256 break;
2257
2258 case ParsedAttr::AT_HLSLRawBuffer:
2259 A = HLSLRawBufferAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2260 break;
2261
2262 case ParsedAttr::AT_HLSLIsCounter:
2263 A = HLSLIsCounterAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2264 break;
2265
2266 case ParsedAttr::AT_HLSLContainedType: {
2267 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
2268 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
2269 return false;
2270 }
2271
2272 TypeSourceInfo *TSI = nullptr;
2273 QualType QT = SemaRef.GetTypeFromParser(Ty: AL.getTypeArg(), TInfo: &TSI);
2274 assert(TSI && "no type source info for attribute argument");
2275 if (SemaRef.RequireCompleteType(Loc: TSI->getTypeLoc().getBeginLoc(), T: QT,
2276 DiagID: diag::err_incomplete_type))
2277 return false;
2278 A = HLSLContainedTypeAttr::Create(Ctx&: getASTContext(), Type: TSI, CommonInfo: ACI);
2279 break;
2280 }
2281
2282 default:
2283 llvm_unreachable("unhandled HLSL attribute");
2284 }
2285
2286 HLSLResourcesTypeAttrs.emplace_back(Args&: A);
2287 return true;
2288}
2289
2290// Combines all resource type attributes and creates HLSLAttributedResourceType.
2291QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) {
2292 if (!HLSLResourcesTypeAttrs.size())
2293 return CurrentType;
2294
2295 QualType QT = CurrentType;
2296 HLSLAttributedResourceLocInfo LocInfo;
2297 if (CreateHLSLAttributedResourceType(S&: SemaRef, Wrapped: CurrentType,
2298 AttrList: HLSLResourcesTypeAttrs, ResType&: QT, LocInfo: &LocInfo)) {
2299 const HLSLAttributedResourceType *RT =
2300 cast<HLSLAttributedResourceType>(Val: QT.getTypePtr());
2301
2302 // Temporarily store TypeLoc information for the new type.
2303 // It will be transferred to HLSLAttributesResourceTypeLoc
2304 // shortly after the type is created by TypeSpecLocFiller which
2305 // will call the TakeLocForHLSLAttribute method below.
2306 LocsForHLSLAttributedResources.insert(KV: std::pair(RT, LocInfo));
2307 }
2308 HLSLResourcesTypeAttrs.clear();
2309 return QT;
2310}
2311
2312// Returns source location for the HLSLAttributedResourceType
2313HLSLAttributedResourceLocInfo
2314SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
2315 HLSLAttributedResourceLocInfo LocInfo = {};
2316 auto I = LocsForHLSLAttributedResources.find(Val: RT);
2317 if (I != LocsForHLSLAttributedResources.end()) {
2318 LocInfo = I->second;
2319 LocsForHLSLAttributedResources.erase(I);
2320 return LocInfo;
2321 }
2322 LocInfo.Range = SourceRange();
2323 return LocInfo;
2324}
2325
2326// Walks though the global variable declaration, collects all resource binding
2327// requirements and adds them to Bindings
2328void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
2329 const RecordType *RT) {
2330 const RecordDecl *RD = RT->getDecl()->getDefinitionOrSelf();
2331 for (FieldDecl *FD : RD->fields()) {
2332 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
2333
2334 // Unwrap arrays
2335 // FIXME: Calculate array size while unwrapping
2336 assert(!Ty->isIncompleteArrayType() &&
2337 "incomplete arrays inside user defined types are not supported");
2338 while (Ty->isConstantArrayType()) {
2339 const ConstantArrayType *CAT = cast<ConstantArrayType>(Val: Ty);
2340 Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
2341 }
2342
2343 if (!Ty->isRecordType())
2344 continue;
2345
2346 if (const HLSLAttributedResourceType *AttrResType =
2347 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
2348 // Add a new DeclBindingInfo to Bindings if it does not already exist
2349 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
2350 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, ResClass: RC);
2351 if (!DBI)
2352 Bindings.addDeclBindingInfo(VD, ResClass: RC);
2353 } else if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty)) {
2354 // Recursively scan embedded struct or class; it would be nice to do this
2355 // without recursion, but tricky to correctly calculate the size of the
2356 // binding, which is something we are probably going to need to do later
2357 // on. Hopefully nesting of structs in structs too many levels is
2358 // unlikely.
2359 collectResourceBindingsOnUserRecordDecl(VD, RT);
2360 }
2361 }
2362}
2363
2364// Diagnose localized register binding errors for a single binding; does not
2365// diagnose resource binding on user record types, that will be done later
2366// in processResourceBindingOnDecl based on the information collected in
2367// collectResourceBindingsOnVarDecl.
2368// Returns false if the register binding is not valid.
2369static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
2370 Decl *D, RegisterType RegType,
2371 bool SpecifiedSpace) {
2372 int RegTypeNum = static_cast<int>(RegType);
2373
2374 // check if the decl type is groupshared
2375 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
2376 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2377 return false;
2378 }
2379
2380 // Cbuffers and Tbuffers are HLSLBufferDecl types
2381 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: D)) {
2382 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
2383 : ResourceClass::SRV;
2384 if (RegType == getRegisterType(RC))
2385 return true;
2386
2387 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
2388 << RegTypeNum;
2389 return false;
2390 }
2391
2392 // Samplers, UAVs, and SRVs are VarDecl types
2393 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
2394 VarDecl *VD = cast<VarDecl>(Val: D);
2395
2396 // Resource
2397 if (const HLSLAttributedResourceType *AttrResType =
2398 HLSLAttributedResourceType::findHandleTypeOnResource(
2399 RT: VD->getType().getTypePtr())) {
2400 if (RegType == getRegisterType(ResTy: AttrResType))
2401 return true;
2402
2403 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
2404 << RegTypeNum;
2405 return false;
2406 }
2407
2408 const clang::Type *Ty = VD->getType().getTypePtr();
2409 while (Ty->isArrayType())
2410 Ty = Ty->getArrayElementTypeNoTypeQual();
2411
2412 // Basic types
2413 if (Ty->isArithmeticType() || Ty->isVectorType()) {
2414 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(Val: D->getDeclContext());
2415 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
2416 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_space_on_global_constant);
2417
2418 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(Ctx: S.getASTContext()) ||
2419 Ty->isFloatingType() || Ty->isVectorType())) {
2420 // Register annotation on default constant buffer declaration ($Globals)
2421 if (RegType == RegisterType::CBuffer)
2422 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_deprecated_register_type_b);
2423 else if (RegType != RegisterType::C)
2424 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2425 else
2426 return true;
2427 } else {
2428 if (RegType == RegisterType::C)
2429 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_register_type_c_packoffset);
2430 else
2431 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2432 }
2433 return false;
2434 }
2435 if (Ty->isRecordType())
2436 // RecordTypes will be diagnosed in processResourceBindingOnDecl
2437 // that is called from ActOnVariableDeclarator
2438 return true;
2439
2440 // Anything else is an error
2441 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2442 return false;
2443}
2444
2445static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
2446 RegisterType regType) {
2447 // make sure that there are no two register annotations
2448 // applied to the decl with the same register type
2449 bool RegisterTypesDetected[5] = {false};
2450 RegisterTypesDetected[static_cast<int>(regType)] = true;
2451
2452 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
2453 if (HLSLResourceBindingAttr *attr =
2454 dyn_cast<HLSLResourceBindingAttr>(Val: *it)) {
2455
2456 RegisterType otherRegType = attr->getRegisterType();
2457 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
2458 int otherRegTypeNum = static_cast<int>(otherRegType);
2459 S.Diag(Loc: TheDecl->getLocation(),
2460 DiagID: diag::err_hlsl_duplicate_register_annotation)
2461 << otherRegTypeNum;
2462 return false;
2463 }
2464 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
2465 }
2466 }
2467 return true;
2468}
2469
2470static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
2471 Decl *D, RegisterType RegType,
2472 bool SpecifiedSpace) {
2473
2474 // exactly one of these two types should be set
2475 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
2476 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
2477 "expecting VarDecl or HLSLBufferDecl");
2478
2479 // check if the declaration contains resource matching the register type
2480 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
2481 return false;
2482
2483 // next, if multiple register annotations exist, check that none conflict.
2484 return ValidateMultipleRegisterAnnotations(S, TheDecl: D, regType: RegType);
2485}
2486
2487// return false if the slot count exceeds the limit, true otherwise
2488static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot,
2489 const uint64_t &Limit,
2490 const ResourceClass ResClass,
2491 ASTContext &Ctx,
2492 uint64_t ArrayCount = 1) {
2493 Ty = Ty.getCanonicalType();
2494 const Type *T = Ty.getTypePtr();
2495
2496 // Early exit if already overflowed
2497 if (StartSlot > Limit)
2498 return false;
2499
2500 // Case 1: array type
2501 if (const auto *AT = dyn_cast<ArrayType>(Val: T)) {
2502 uint64_t Count = 1;
2503
2504 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: AT))
2505 Count = CAT->getSize().getZExtValue();
2506
2507 QualType ElemTy = AT->getElementType();
2508 return AccumulateHLSLResourceSlots(Ty: ElemTy, StartSlot, Limit, ResClass, Ctx,
2509 ArrayCount: ArrayCount * Count);
2510 }
2511
2512 // Case 2: resource leaf
2513 if (auto ResTy = dyn_cast<HLSLAttributedResourceType>(Val: T)) {
2514 // First ensure this resource counts towards the corresponding
2515 // register type limit.
2516 if (ResTy->getAttrs().ResourceClass != ResClass)
2517 return true;
2518
2519 // Validate highest slot used
2520 uint64_t EndSlot = StartSlot + ArrayCount - 1;
2521 if (EndSlot > Limit)
2522 return false;
2523
2524 // Advance SlotCount past the consumed range
2525 StartSlot = EndSlot + 1;
2526 return true;
2527 }
2528
2529 // Case 3: struct / record
2530 if (const auto *RT = dyn_cast<RecordType>(Val: T)) {
2531 const RecordDecl *RD = RT->getDecl();
2532
2533 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
2534 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
2535 if (!AccumulateHLSLResourceSlots(Ty: Base.getType(), StartSlot, Limit,
2536 ResClass, Ctx, ArrayCount))
2537 return false;
2538 }
2539 }
2540
2541 for (const FieldDecl *Field : RD->fields()) {
2542 if (!AccumulateHLSLResourceSlots(Ty: Field->getType(), StartSlot, Limit,
2543 ResClass, Ctx, ArrayCount))
2544 return false;
2545 }
2546
2547 return true;
2548 }
2549
2550 // Case 4: everything else
2551 return true;
2552}
2553
2554// return true if there is something invalid, false otherwise
2555static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl,
2556 ASTContext &Ctx, RegisterType RegTy) {
2557 const uint64_t Limit = UINT32_MAX;
2558 if (SlotNum > Limit)
2559 return true;
2560
2561 // after verifying the number doesn't exceed uint32max, we don't need
2562 // to look further into c or i register types
2563 if (RegTy == RegisterType::C || RegTy == RegisterType::I)
2564 return false;
2565
2566 if (VarDecl *VD = dyn_cast<VarDecl>(Val: TheDecl)) {
2567 uint64_t BaseSlot = SlotNum;
2568
2569 if (!AccumulateHLSLResourceSlots(Ty: VD->getType(), StartSlot&: SlotNum, Limit,
2570 ResClass: getResourceClass(RT: RegTy), Ctx))
2571 return true;
2572
2573 // After AccumulateHLSLResourceSlots runs, SlotNum is now
2574 // the first free slot; last used was SlotNum - 1
2575 return (BaseSlot > Limit);
2576 }
2577 // handle the cbuffer/tbuffer case
2578 if (isa<HLSLBufferDecl>(Val: TheDecl))
2579 // resources cannot be put within a cbuffer, so no need
2580 // to analyze the structure since the register number
2581 // won't be pushed any higher.
2582 return (SlotNum > Limit);
2583
2584 // we don't expect any other decl type, so fail
2585 llvm_unreachable("unexpected decl type");
2586}
2587
2588void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
2589 if (VarDecl *VD = dyn_cast<VarDecl>(Val: TheDecl)) {
2590 QualType Ty = VD->getType();
2591 if (const auto *IAT = dyn_cast<IncompleteArrayType>(Val&: Ty))
2592 Ty = IAT->getElementType();
2593 if (SemaRef.RequireCompleteType(Loc: TheDecl->getBeginLoc(), T: Ty,
2594 DiagID: diag::err_incomplete_type))
2595 return;
2596 }
2597
2598 StringRef Slot = "";
2599 StringRef Space = "";
2600 SourceLocation SlotLoc, SpaceLoc;
2601
2602 if (!AL.isArgIdent(Arg: 0)) {
2603 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2604 << AL << AANT_ArgumentIdentifier;
2605 return;
2606 }
2607 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
2608
2609 if (AL.getNumArgs() == 2) {
2610 Slot = Loc->getIdentifierInfo()->getName();
2611 SlotLoc = Loc->getLoc();
2612 if (!AL.isArgIdent(Arg: 1)) {
2613 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2614 << AL << AANT_ArgumentIdentifier;
2615 return;
2616 }
2617 Loc = AL.getArgAsIdent(Arg: 1);
2618 Space = Loc->getIdentifierInfo()->getName();
2619 SpaceLoc = Loc->getLoc();
2620 } else {
2621 StringRef Str = Loc->getIdentifierInfo()->getName();
2622 if (Str.starts_with(Prefix: "space")) {
2623 Space = Str;
2624 SpaceLoc = Loc->getLoc();
2625 } else {
2626 Slot = Str;
2627 SlotLoc = Loc->getLoc();
2628 Space = "space0";
2629 }
2630 }
2631
2632 RegisterType RegType = RegisterType::SRV;
2633 std::optional<unsigned> SlotNum;
2634 unsigned SpaceNum = 0;
2635
2636 // Validate slot
2637 if (!Slot.empty()) {
2638 if (!convertToRegisterType(Slot, RT: &RegType)) {
2639 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_binding_type_invalid) << Slot.substr(Start: 0, N: 1);
2640 return;
2641 }
2642 if (RegType == RegisterType::I) {
2643 Diag(Loc: SlotLoc, DiagID: diag::warn_hlsl_deprecated_register_type_i);
2644 return;
2645 }
2646 const StringRef SlotNumStr = Slot.substr(Start: 1);
2647
2648 uint64_t N;
2649
2650 // validate that the slot number is a non-empty number
2651 if (SlotNumStr.getAsInteger(Radix: 10, Result&: N)) {
2652 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_unsupported_register_number);
2653 return;
2654 }
2655
2656 // Validate register number. It should not exceed UINT32_MAX,
2657 // including if the resource type is an array that starts
2658 // before UINT32_MAX, but ends afterwards.
2659 if (ValidateRegisterNumber(SlotNum: N, TheDecl, Ctx&: getASTContext(), RegTy: RegType)) {
2660 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_register_number_too_large);
2661 return;
2662 }
2663
2664 // the slot number has been validated and does not exceed UINT32_MAX
2665 SlotNum = (unsigned)N;
2666 }
2667
2668 // Validate space
2669 if (!Space.starts_with(Prefix: "space")) {
2670 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2671 return;
2672 }
2673 StringRef SpaceNumStr = Space.substr(Start: 5);
2674 if (SpaceNumStr.getAsInteger(Radix: 10, Result&: SpaceNum)) {
2675 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2676 return;
2677 }
2678
2679 // If we have slot, diagnose it is the right register type for the decl
2680 if (SlotNum.has_value())
2681 if (!DiagnoseHLSLRegisterAttribute(S&: SemaRef, ArgLoc&: SlotLoc, D: TheDecl, RegType,
2682 SpecifiedSpace: !SpaceLoc.isInvalid()))
2683 return;
2684
2685 HLSLResourceBindingAttr *NewAttr =
2686 HLSLResourceBindingAttr::Create(Ctx&: getASTContext(), Slot, Space, CommonInfo: AL);
2687 if (NewAttr) {
2688 NewAttr->setBinding(RT: RegType, SlotNum, SpaceNum);
2689 TheDecl->addAttr(A: NewAttr);
2690 }
2691}
2692
2693void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
2694 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2695 D, AL,
2696 Spelling: static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2697 if (NewAttr)
2698 D->addAttr(A: NewAttr);
2699}
2700
2701namespace {
2702
2703/// This class implements HLSL availability diagnostics for default
2704/// and relaxed mode
2705///
2706/// The goal of this diagnostic is to emit an error or warning when an
2707/// unavailable API is found in code that is reachable from the shader
2708/// entry function or from an exported function (when compiling a shader
2709/// library).
2710///
2711/// This is done by traversing the AST of all shader entry point functions
2712/// and of all exported functions, and any functions that are referenced
2713/// from this AST. In other words, any functions that are reachable from
2714/// the entry points.
2715class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2716 Sema &SemaRef;
2717
2718 // Stack of functions to be scaned
2719 llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;
2720
2721 // Tracks which environments functions have been scanned in.
2722 //
2723 // Maps FunctionDecl to an unsigned number that represents the set of shader
2724 // environments the function has been scanned for.
2725 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2726 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2727 // (verified by static_asserts in Triple.cpp), we can use it to index
2728 // individual bits in the set, as long as we shift the values to start with 0
2729 // by subtracting the value of llvm::Triple::Pixel first.
2730 //
2731 // The N'th bit in the set will be set if the function has been scanned
2732 // in shader environment whose llvm::Triple::EnvironmentType integer value
2733 // equals (llvm::Triple::Pixel + N).
2734 //
2735 // For example, if a function has been scanned in compute and pixel stage
2736 // environment, the value will be 0x21 (100001 binary) because:
2737 //
2738 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2739 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2740 //
2741 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2742 // been scanned in any environment.
2743 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2744
2745 // Do not access these directly, use the get/set methods below to make
2746 // sure the values are in sync
2747 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2748 unsigned CurrentShaderStageBit;
2749
2750 // True if scanning a function that was already scanned in a different
2751 // shader stage context, and therefore we should not report issues that
2752 // depend only on shader model version because they would be duplicate.
2753 bool ReportOnlyShaderStageIssues;
2754
2755 // Helper methods for dealing with current stage context / environment
2756 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2757 static_assert(sizeof(unsigned) >= 4);
2758 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2759 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2760 "ShaderType is too big for this bitmap"); // 31 is reserved for
2761 // "unknown"
2762
2763 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2764 CurrentShaderEnvironment = ShaderType;
2765 CurrentShaderStageBit = (1 << bitmapIndex);
2766 }
2767
2768 void SetUnknownShaderStageContext() {
2769 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2770 CurrentShaderStageBit = (1 << 31);
2771 }
2772
2773 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2774 return CurrentShaderEnvironment;
2775 }
2776
2777 bool InUnknownShaderStageContext() const {
2778 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2779 }
2780
2781 // Helper methods for dealing with shader stage bitmap
2782 void AddToScannedFunctions(const FunctionDecl *FD) {
2783 unsigned &ScannedStages = ScannedDecls[FD];
2784 ScannedStages |= CurrentShaderStageBit;
2785 }
2786
2787 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2788
2789 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2790 return WasAlreadyScannedInCurrentStage(ScannerStages: GetScannedStages(FD));
2791 }
2792
2793 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2794 return ScannerStages & CurrentShaderStageBit;
2795 }
2796
2797 static bool NeverBeenScanned(unsigned ScannedStages) {
2798 return ScannedStages == 0;
2799 }
2800
2801 // Scanning methods
2802 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2803 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2804 SourceRange Range);
2805 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2806 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2807
2808public:
2809 DiagnoseHLSLAvailability(Sema &SemaRef)
2810 : SemaRef(SemaRef),
2811 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2812 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2813
2814 // AST traversal methods
2815 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2816 void RunOnFunction(const FunctionDecl *FD);
2817
2818 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2819 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: DRE->getDecl());
2820 if (FD)
2821 HandleFunctionOrMethodRef(FD, RefExpr: DRE);
2822 return true;
2823 }
2824
2825 bool VisitMemberExpr(MemberExpr *ME) override {
2826 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: ME->getMemberDecl());
2827 if (FD)
2828 HandleFunctionOrMethodRef(FD, RefExpr: ME);
2829 return true;
2830 }
2831};
2832
2833void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2834 Expr *RefExpr) {
2835 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2836 "expected DeclRefExpr or MemberExpr");
2837
2838 // has a definition -> add to stack to be scanned
2839 const FunctionDecl *FDWithBody = nullptr;
2840 if (FD->hasBody(Definition&: FDWithBody)) {
2841 if (!WasAlreadyScannedInCurrentStage(FD: FDWithBody))
2842 DeclsToScan.push_back(Elt: FDWithBody);
2843 return;
2844 }
2845
2846 // no body -> diagnose availability
2847 const AvailabilityAttr *AA = FindAvailabilityAttr(D: FD);
2848 if (AA)
2849 CheckDeclAvailability(
2850 D: FD, AA, Range: SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2851}
2852
2853void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2854 const TranslationUnitDecl *TU) {
2855
2856 // Iterate over all shader entry functions and library exports, and for those
2857 // that have a body (definiton), run diag scan on each, setting appropriate
2858 // shader environment context based on whether it is a shader entry function
2859 // or an exported function. Exported functions can be in namespaces and in
2860 // export declarations so we need to scan those declaration contexts as well.
2861 llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;
2862 DeclContextsToScan.push_back(Elt: TU);
2863
2864 while (!DeclContextsToScan.empty()) {
2865 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2866 for (auto &D : DC->decls()) {
2867 // do not scan implicit declaration generated by the implementation
2868 if (D->isImplicit())
2869 continue;
2870
2871 // for namespace or export declaration add the context to the list to be
2872 // scanned later
2873 if (llvm::dyn_cast<NamespaceDecl>(Val: D) || llvm::dyn_cast<ExportDecl>(Val: D)) {
2874 DeclContextsToScan.push_back(Elt: llvm::dyn_cast<DeclContext>(Val: D));
2875 continue;
2876 }
2877
2878 // skip over other decls or function decls without body
2879 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: D);
2880 if (!FD || !FD->isThisDeclarationADefinition())
2881 continue;
2882
2883 // shader entry point
2884 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2885 SetShaderStageContext(ShaderAttr->getType());
2886 RunOnFunction(FD);
2887 continue;
2888 }
2889 // exported library function
2890 // FIXME: replace this loop with external linkage check once issue #92071
2891 // is resolved
2892 bool isExport = FD->isInExportDeclContext();
2893 if (!isExport) {
2894 for (const auto *Redecl : FD->redecls()) {
2895 if (Redecl->isInExportDeclContext()) {
2896 isExport = true;
2897 break;
2898 }
2899 }
2900 }
2901 if (isExport) {
2902 SetUnknownShaderStageContext();
2903 RunOnFunction(FD);
2904 continue;
2905 }
2906 }
2907 }
2908}
2909
2910void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2911 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2912 DeclsToScan.push_back(Elt: FD);
2913
2914 while (!DeclsToScan.empty()) {
2915 // Take one decl from the stack and check it by traversing its AST.
2916 // For any CallExpr found during the traversal add it's callee to the top of
2917 // the stack to be processed next. Functions already processed are stored in
2918 // ScannedDecls.
2919 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2920
2921 // Decl was already scanned
2922 const unsigned ScannedStages = GetScannedStages(FD);
2923 if (WasAlreadyScannedInCurrentStage(ScannerStages: ScannedStages))
2924 continue;
2925
2926 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2927
2928 AddToScannedFunctions(FD);
2929 TraverseStmt(S: FD->getBody());
2930 }
2931}
2932
2933bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2934 const AvailabilityAttr *AA) {
2935 const IdentifierInfo *IIEnvironment = AA->getEnvironment();
2936 if (!IIEnvironment)
2937 return true;
2938
2939 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2940 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2941 return false;
2942
2943 llvm::Triple::EnvironmentType AttrEnv =
2944 AvailabilityAttr::getEnvironmentType(Environment: IIEnvironment->getName());
2945
2946 return CurrentEnv == AttrEnv;
2947}
2948
2949const AvailabilityAttr *
2950DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2951 AvailabilityAttr const *PartialMatch = nullptr;
2952 // Check each AvailabilityAttr to find the one for this platform.
2953 // For multiple attributes with the same platform try to find one for this
2954 // environment.
2955 for (const auto *A : D->attrs()) {
2956 if (const auto *Avail = dyn_cast<AvailabilityAttr>(Val: A)) {
2957 StringRef AttrPlatform = Avail->getPlatform()->getName();
2958 StringRef TargetPlatform =
2959 SemaRef.getASTContext().getTargetInfo().getPlatformName();
2960
2961 // Match the platform name.
2962 if (AttrPlatform == TargetPlatform) {
2963 // Find the best matching attribute for this environment
2964 if (HasMatchingEnvironmentOrNone(AA: Avail))
2965 return Avail;
2966 PartialMatch = Avail;
2967 }
2968 }
2969 }
2970 return PartialMatch;
2971}
2972
2973// Check availability against target shader model version and current shader
2974// stage and emit diagnostic
2975void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2976 const AvailabilityAttr *AA,
2977 SourceRange Range) {
2978
2979 const IdentifierInfo *IIEnv = AA->getEnvironment();
2980
2981 if (!IIEnv) {
2982 // The availability attribute does not have environment -> it depends only
2983 // on shader model version and not on specific the shader stage.
2984
2985 // Skip emitting the diagnostics if the diagnostic mode is set to
2986 // strict (-fhlsl-strict-availability) because all relevant diagnostics
2987 // were already emitted in the DiagnoseUnguardedAvailability scan
2988 // (SemaAvailability.cpp).
2989 if (SemaRef.getLangOpts().HLSLStrictAvailability)
2990 return;
2991
2992 // Do not report shader-stage-independent issues if scanning a function
2993 // that was already scanned in a different shader stage context (they would
2994 // be duplicate)
2995 if (ReportOnlyShaderStageIssues)
2996 return;
2997
2998 } else {
2999 // The availability attribute has environment -> we need to know
3000 // the current stage context to property diagnose it.
3001 if (InUnknownShaderStageContext())
3002 return;
3003 }
3004
3005 // Check introduced version and if environment matches
3006 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
3007 VersionTuple Introduced = AA->getIntroduced();
3008 VersionTuple TargetVersion =
3009 SemaRef.Context.getTargetInfo().getPlatformMinVersion();
3010
3011 if (TargetVersion >= Introduced && EnvironmentMatches)
3012 return;
3013
3014 // Emit diagnostic message
3015 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
3016 llvm::StringRef PlatformName(
3017 AvailabilityAttr::getPrettyPlatformName(Platform: TI.getPlatformName()));
3018
3019 llvm::StringRef CurrentEnvStr =
3020 llvm::Triple::getEnvironmentTypeName(Kind: GetCurrentShaderEnvironment());
3021
3022 llvm::StringRef AttrEnvStr =
3023 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
3024 bool UseEnvironment = !AttrEnvStr.empty();
3025
3026 if (EnvironmentMatches) {
3027 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability)
3028 << Range << D << PlatformName << Introduced.getAsString()
3029 << UseEnvironment << CurrentEnvStr;
3030 } else {
3031 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability_unavailable)
3032 << Range << D;
3033 }
3034
3035 SemaRef.Diag(Loc: D->getLocation(), DiagID: diag::note_partial_availability_specified_here)
3036 << D << PlatformName << Introduced.getAsString()
3037 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
3038 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
3039}
3040
3041} // namespace
3042
3043void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) {
3044 // process default CBuffer - create buffer layout struct and invoke codegenCGH
3045 if (!DefaultCBufferDecls.empty()) {
3046 HLSLBufferDecl *DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer(
3047 C&: SemaRef.getASTContext(), LexicalParent: SemaRef.getCurLexicalContext(),
3048 DefaultCBufferDecls);
3049 addImplicitBindingAttrToDecl(S&: SemaRef, D: DefaultCBuffer, RT: RegisterType::CBuffer,
3050 ImplicitBindingOrderID: getNextImplicitBindingOrderID());
3051 SemaRef.getCurLexicalContext()->addDecl(D: DefaultCBuffer);
3052 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl: DefaultCBuffer);
3053
3054 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
3055 for (const Decl *VD : DefaultCBufferDecls) {
3056 const HLSLResourceBindingAttr *RBA =
3057 VD->getAttr<HLSLResourceBindingAttr>();
3058 if (RBA && RBA->hasRegisterSlot() &&
3059 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
3060 DefaultCBuffer->setHasValidPackoffset(true);
3061 break;
3062 }
3063 }
3064
3065 DeclGroupRef DG(DefaultCBuffer);
3066 SemaRef.Consumer.HandleTopLevelDecl(D: DG);
3067 }
3068 diagnoseAvailabilityViolations(TU);
3069}
3070
3071void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
3072 // Skip running the diagnostics scan if the diagnostic mode is
3073 // strict (-fhlsl-strict-availability) and the target shader stage is known
3074 // because all relevant diagnostics were already emitted in the
3075 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
3076 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
3077 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
3078 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
3079 return;
3080
3081 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
3082}
3083
3084static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
3085 assert(TheCall->getNumArgs() > 1);
3086 QualType ArgTy0 = TheCall->getArg(Arg: 0)->getType();
3087
3088 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
3089 if (!S->getASTContext().hasSameUnqualifiedType(
3090 T1: ArgTy0, T2: TheCall->getArg(Arg: I)->getType())) {
3091 S->Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vec_builtin_incompatible_vector)
3092 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
3093 << SourceRange(TheCall->getArg(Arg: 0)->getBeginLoc(),
3094 TheCall->getArg(Arg: N - 1)->getEndLoc());
3095 return true;
3096 }
3097 }
3098 return false;
3099}
3100
3101static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
3102 QualType ArgType = Arg->getType();
3103 if (!S->getASTContext().hasSameUnqualifiedType(T1: ArgType, T2: ExpectedType)) {
3104 S->Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_typecheck_convert_incompatible)
3105 << ArgType << ExpectedType << 1 << 0 << 0;
3106 return true;
3107 }
3108 return false;
3109}
3110
3111static bool CheckAllArgTypesAreCorrect(
3112 Sema *S, CallExpr *TheCall,
3113 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
3114 clang::QualType PassedType)>
3115 Check) {
3116 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
3117 Expr *Arg = TheCall->getArg(Arg: I);
3118 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
3119 return true;
3120 }
3121 return false;
3122}
3123
3124static bool CheckFloatRepresentation(Sema *S, SourceLocation Loc,
3125 int ArgOrdinal,
3126 clang::QualType PassedType) {
3127 clang::QualType BaseType =
3128 PassedType->isVectorType()
3129 ? PassedType->castAs<clang::VectorType>()->getElementType()
3130 : PassedType;
3131 if (!BaseType->isFloat32Type())
3132 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3133 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3134 << /* float */ 1 << PassedType;
3135 return false;
3136}
3137
3138static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
3139 int ArgOrdinal,
3140 clang::QualType PassedType) {
3141 clang::QualType BaseType =
3142 PassedType->isVectorType()
3143 ? PassedType->castAs<clang::VectorType>()->getElementType()
3144 : PassedType;
3145 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
3146 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3147 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3148 << /* half or float */ 2 << PassedType;
3149 return false;
3150}
3151
3152static bool CheckAnyDoubleRepresentation(Sema *S, SourceLocation Loc,
3153 int ArgOrdinal,
3154 clang::QualType PassedType) {
3155 clang::QualType BaseType =
3156 PassedType->isVectorType()
3157 ? PassedType->castAs<clang::VectorType>()->getElementType()
3158 : PassedType->isMatrixType()
3159 ? PassedType->castAs<clang::MatrixType>()->getElementType()
3160 : PassedType;
3161 if (!BaseType->isDoubleType()) {
3162 // FIXME: adopt standard `err_builtin_invalid_arg_type` instead of using
3163 // this custom error.
3164 return S->Diag(Loc, DiagID: diag::err_builtin_requires_double_type)
3165 << ArgOrdinal << PassedType;
3166 }
3167
3168 return false;
3169}
3170
3171static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
3172 unsigned ArgIndex) {
3173 auto *Arg = TheCall->getArg(Arg: ArgIndex);
3174 SourceLocation OrigLoc = Arg->getExprLoc();
3175 if (Arg->IgnoreCasts()->isModifiableLvalue(Ctx&: S->Context, Loc: &OrigLoc) ==
3176 Expr::MLV_Valid)
3177 return false;
3178 S->Diag(Loc: OrigLoc, DiagID: diag::error_hlsl_inout_lvalue) << Arg << 0;
3179 return true;
3180}
3181
3182static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
3183 clang::QualType PassedType) {
3184 const auto *VecTy = PassedType->getAs<VectorType>();
3185 if (!VecTy)
3186 return false;
3187
3188 if (VecTy->getElementType()->isDoubleType())
3189 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3190 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
3191 << PassedType;
3192 return false;
3193}
3194
3195static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
3196 int ArgOrdinal,
3197 clang::QualType PassedType) {
3198 if (!PassedType->hasIntegerRepresentation() &&
3199 !PassedType->hasFloatingRepresentation())
3200 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3201 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
3202 << /* fp */ 1 << PassedType;
3203 return false;
3204}
3205
3206static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
3207 int ArgOrdinal,
3208 clang::QualType PassedType) {
3209 if (auto *VecTy = PassedType->getAs<VectorType>())
3210 if (VecTy->getElementType()->isUnsignedIntegerType())
3211 return false;
3212
3213 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3214 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
3215 << PassedType;
3216}
3217
3218// checks for unsigned ints of all sizes
3219static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
3220 int ArgOrdinal,
3221 clang::QualType PassedType) {
3222 if (!PassedType->hasUnsignedIntegerRepresentation())
3223 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3224 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
3225 << /* no fp */ 0 << PassedType;
3226 return false;
3227}
3228
3229static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall,
3230 unsigned ArgOrdinal, unsigned Width) {
3231 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
3232 if (auto *VTy = ArgTy->getAs<VectorType>())
3233 ArgTy = VTy->getElementType();
3234 // ensure arg type has expected bit width
3235 uint64_t ElementBitCount =
3236 S->getASTContext().getTypeSizeInChars(T: ArgTy).getQuantity() * 8;
3237 if (ElementBitCount != Width) {
3238 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3239 DiagID: diag::err_integer_incorrect_bit_count)
3240 << Width << ElementBitCount;
3241 return true;
3242 }
3243 return false;
3244}
3245
3246static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
3247 QualType ReturnType) {
3248 auto *VecTyA = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
3249 if (VecTyA)
3250 ReturnType =
3251 S->Context.getExtVectorType(VectorType: ReturnType, NumElts: VecTyA->getNumElements());
3252
3253 TheCall->setType(ReturnType);
3254}
3255
3256static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
3257 unsigned ArgIndex) {
3258 assert(TheCall->getNumArgs() >= ArgIndex);
3259 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3260 auto *VTy = ArgType->getAs<VectorType>();
3261 // not the scalar or vector<scalar>
3262 if (!(S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar) ||
3263 (VTy &&
3264 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar)))) {
3265 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3266 DiagID: diag::err_typecheck_expect_scalar_or_vector)
3267 << ArgType << Scalar;
3268 return true;
3269 }
3270 return false;
3271}
3272
3273static bool CheckScalarOrVectorOrMatrix(Sema *S, CallExpr *TheCall,
3274 QualType Scalar, unsigned ArgIndex) {
3275 assert(TheCall->getNumArgs() > ArgIndex);
3276
3277 Expr *Arg = TheCall->getArg(Arg: ArgIndex);
3278 QualType ArgType = Arg->getType();
3279
3280 // Scalar: T
3281 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar))
3282 return false;
3283
3284 // Vector: vector<T>
3285 if (const auto *VTy = ArgType->getAs<VectorType>()) {
3286 if (S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar))
3287 return false;
3288 }
3289
3290 // Matrix: ConstantMatrixType with element type T
3291 if (const auto *MTy = ArgType->getAs<ConstantMatrixType>()) {
3292 if (S->Context.hasSameUnqualifiedType(T1: MTy->getElementType(), T2: Scalar))
3293 return false;
3294 }
3295
3296 // Not a scalar/vector/matrix-of-scalar
3297 S->Diag(Loc: Arg->getBeginLoc(),
3298 DiagID: diag::err_typecheck_expect_scalar_or_vector_or_matrix)
3299 << ArgType << Scalar;
3300 return true;
3301}
3302
3303static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
3304 unsigned ArgIndex) {
3305 assert(TheCall->getNumArgs() >= ArgIndex);
3306 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3307 auto *VTy = ArgType->getAs<VectorType>();
3308 // not the scalar or vector<scalar>
3309 if (!(ArgType->isScalarType() ||
3310 (VTy && VTy->getElementType()->isScalarType()))) {
3311 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3312 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3313 << ArgType << 1;
3314 return true;
3315 }
3316 return false;
3317}
3318
3319// Check that the argument is not a bool or vector<bool>
3320// Returns true on error
3321static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall,
3322 unsigned ArgIndex) {
3323 QualType BoolType = S->getASTContext().BoolTy;
3324 assert(ArgIndex < TheCall->getNumArgs());
3325 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3326 auto *VTy = ArgType->getAs<VectorType>();
3327 // is the bool or vector<bool>
3328 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: BoolType) ||
3329 (VTy &&
3330 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: BoolType))) {
3331 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3332 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3333 << ArgType << 0;
3334 return true;
3335 }
3336 return false;
3337}
3338
3339static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
3340 if (CheckNotBoolScalarOrVector(S, TheCall, ArgIndex: 0))
3341 return true;
3342 return false;
3343}
3344
3345static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
3346 if (CheckNotBoolScalarOrVector(S, TheCall, ArgIndex: 0))
3347 return true;
3348 return false;
3349}
3350
3351static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
3352 assert(TheCall->getNumArgs() == 3);
3353 Expr *Arg1 = TheCall->getArg(Arg: 1);
3354 Expr *Arg2 = TheCall->getArg(Arg: 2);
3355 if (!S->Context.hasSameUnqualifiedType(T1: Arg1->getType(), T2: Arg2->getType())) {
3356 S->Diag(Loc: TheCall->getBeginLoc(),
3357 DiagID: diag::err_typecheck_call_different_arg_types)
3358 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
3359 << Arg2->getSourceRange();
3360 return true;
3361 }
3362
3363 TheCall->setType(Arg1->getType());
3364 return false;
3365}
3366
3367static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
3368 assert(TheCall->getNumArgs() == 3);
3369 Expr *Arg1 = TheCall->getArg(Arg: 1);
3370 QualType Arg1Ty = Arg1->getType();
3371 Expr *Arg2 = TheCall->getArg(Arg: 2);
3372 QualType Arg2Ty = Arg2->getType();
3373
3374 QualType Arg1ScalarTy = Arg1Ty;
3375 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
3376 Arg1ScalarTy = VTy->getElementType();
3377
3378 QualType Arg2ScalarTy = Arg2Ty;
3379 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
3380 Arg2ScalarTy = VTy->getElementType();
3381
3382 if (!S->Context.hasSameUnqualifiedType(T1: Arg1ScalarTy, T2: Arg2ScalarTy))
3383 S->Diag(Loc: Arg1->getBeginLoc(), DiagID: diag::err_hlsl_builtin_scalar_vector_mismatch)
3384 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
3385
3386 QualType Arg0Ty = TheCall->getArg(Arg: 0)->getType();
3387 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
3388 unsigned Arg1Length = Arg1Ty->isVectorType()
3389 ? Arg1Ty->getAs<VectorType>()->getNumElements()
3390 : 0;
3391 unsigned Arg2Length = Arg2Ty->isVectorType()
3392 ? Arg2Ty->getAs<VectorType>()->getNumElements()
3393 : 0;
3394 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
3395 S->Diag(Loc: TheCall->getBeginLoc(),
3396 DiagID: diag::err_typecheck_vector_lengths_not_equal)
3397 << Arg0Ty << Arg1Ty << TheCall->getArg(Arg: 0)->getSourceRange()
3398 << Arg1->getSourceRange();
3399 return true;
3400 }
3401
3402 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
3403 S->Diag(Loc: TheCall->getBeginLoc(),
3404 DiagID: diag::err_typecheck_vector_lengths_not_equal)
3405 << Arg0Ty << Arg2Ty << TheCall->getArg(Arg: 0)->getSourceRange()
3406 << Arg2->getSourceRange();
3407 return true;
3408 }
3409
3410 TheCall->setType(
3411 S->getASTContext().getExtVectorType(VectorType: Arg1ScalarTy, NumElts: Arg0Length));
3412 return false;
3413}
3414
3415static bool CheckIndexType(Sema *S, CallExpr *TheCall, unsigned IndexArgIndex) {
3416 assert(TheCall->getNumArgs() > IndexArgIndex && "Index argument missing");
3417 QualType ArgType = TheCall->getArg(Arg: IndexArgIndex)->getType();
3418 QualType IndexTy = ArgType;
3419 unsigned int ActualDim = 1;
3420 if (const auto *VTy = IndexTy->getAs<VectorType>()) {
3421 ActualDim = VTy->getNumElements();
3422 IndexTy = VTy->getElementType();
3423 }
3424 if (!IndexTy->isIntegerType()) {
3425 S->Diag(Loc: TheCall->getArg(Arg: IndexArgIndex)->getBeginLoc(),
3426 DiagID: diag::err_typecheck_expect_int)
3427 << ArgType;
3428 return true;
3429 }
3430
3431 QualType ResourceArgTy = TheCall->getArg(Arg: 0)->getType();
3432 const HLSLAttributedResourceType *ResTy =
3433 ResourceArgTy.getTypePtr()->getAs<HLSLAttributedResourceType>();
3434 assert(ResTy && "Resource argument must be a resource");
3435 HLSLAttributedResourceType::Attributes ResAttrs = ResTy->getAttrs();
3436
3437 unsigned int ExpectedDim = 1;
3438 if (ResAttrs.ResourceDimension != llvm::dxil::ResourceDimension::Unknown)
3439 ExpectedDim = getResourceDimensions(Dim: ResAttrs.ResourceDimension);
3440
3441 if (ActualDim != ExpectedDim) {
3442 S->Diag(Loc: TheCall->getArg(Arg: IndexArgIndex)->getBeginLoc(),
3443 DiagID: diag::err_hlsl_builtin_resource_coordinate_dimension_mismatch)
3444 << cast<NamedDecl>(Val: TheCall->getCalleeDecl()) << ExpectedDim
3445 << ActualDim;
3446 return true;
3447 }
3448
3449 return false;
3450}
3451
3452static bool CheckResourceHandle(
3453 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
3454 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
3455 nullptr) {
3456 assert(TheCall->getNumArgs() >= ArgIndex);
3457 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3458 const HLSLAttributedResourceType *ResTy =
3459 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
3460 if (!ResTy) {
3461 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getBeginLoc(),
3462 DiagID: diag::err_typecheck_expect_hlsl_resource)
3463 << ArgType;
3464 return true;
3465 }
3466 if (Check && Check(ResTy)) {
3467 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getExprLoc(),
3468 DiagID: diag::err_invalid_hlsl_resource_type)
3469 << ArgType;
3470 return true;
3471 }
3472 return false;
3473}
3474
3475static bool CheckVectorElementCount(Sema *S, QualType PassedType,
3476 QualType BaseType, unsigned ExpectedCount,
3477 SourceLocation Loc) {
3478 unsigned PassedCount = 1;
3479 if (const auto *VecTy = PassedType->getAs<VectorType>())
3480 PassedCount = VecTy->getNumElements();
3481
3482 if (PassedCount != ExpectedCount) {
3483 QualType ExpectedType =
3484 S->Context.getExtVectorType(VectorType: BaseType, NumElts: ExpectedCount);
3485 S->Diag(Loc, DiagID: diag::err_typecheck_convert_incompatible)
3486 << PassedType << ExpectedType << 1 << 0 << 0;
3487 return true;
3488 }
3489 return false;
3490}
3491
3492enum class SampleKind { Sample, Bias, Grad, Level, Cmp, CmpLevelZero };
3493
3494static bool CheckTextureSamplerAndLocation(Sema &S, CallExpr *TheCall) {
3495 // Check the texture handle.
3496 if (CheckResourceHandle(S: &S, TheCall, ArgIndex: 0,
3497 Check: [](const HLSLAttributedResourceType *ResType) {
3498 return ResType->getAttrs().ResourceDimension ==
3499 llvm::dxil::ResourceDimension::Unknown;
3500 }))
3501 return true;
3502
3503 // Check the sampler handle.
3504 if (CheckResourceHandle(S: &S, TheCall, ArgIndex: 1,
3505 Check: [](const HLSLAttributedResourceType *ResType) {
3506 return ResType->getAttrs().ResourceClass !=
3507 llvm::hlsl::ResourceClass::Sampler;
3508 }))
3509 return true;
3510
3511 auto *ResourceTy =
3512 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3513
3514 // Check the location.
3515 unsigned ExpectedDim =
3516 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3517 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: 2)->getType(),
3518 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3519 Loc: TheCall->getBeginLoc()))
3520 return true;
3521
3522 return false;
3523}
3524
3525static bool CheckCalculateLodBuiltin(Sema &S, CallExpr *TheCall) {
3526 if (S.checkArgCount(Call: TheCall, DesiredArgCount: 3))
3527 return true;
3528
3529 if (CheckTextureSamplerAndLocation(S, TheCall))
3530 return true;
3531
3532 TheCall->setType(S.Context.FloatTy);
3533 return false;
3534}
3535
3536static bool CheckGatherBuiltin(Sema &S, CallExpr *TheCall, bool IsCmp) {
3537 if (S.checkArgCountRange(Call: TheCall, MinArgCount: IsCmp ? 5 : 4, MaxArgCount: IsCmp ? 6 : 5))
3538 return true;
3539
3540 if (CheckTextureSamplerAndLocation(S, TheCall))
3541 return true;
3542
3543 unsigned NextIdx = 3;
3544 if (IsCmp) {
3545 // Check the compare value.
3546 QualType CmpTy = TheCall->getArg(Arg: NextIdx)->getType();
3547 if (!CmpTy->isFloatingType() || CmpTy->isVectorType()) {
3548 S.Diag(Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc(),
3549 DiagID: diag::err_typecheck_convert_incompatible)
3550 << CmpTy << S.Context.FloatTy << 1 << 0 << 0;
3551 return true;
3552 }
3553 NextIdx++;
3554 }
3555
3556 // Check the component operand.
3557 Expr *ComponentArg = TheCall->getArg(Arg: NextIdx);
3558 QualType ComponentTy = ComponentArg->getType();
3559 if (!ComponentTy->isIntegerType() || ComponentTy->isVectorType()) {
3560 S.Diag(Loc: ComponentArg->getBeginLoc(),
3561 DiagID: diag::err_typecheck_convert_incompatible)
3562 << ComponentTy << S.Context.UnsignedIntTy << 1 << 0 << 0;
3563 return true;
3564 }
3565
3566 // GatherCmp operations on Vulkan target must use component 0 (Red).
3567 if (IsCmp && S.getASTContext().getTargetInfo().getTriple().isSPIRV()) {
3568 std::optional<llvm::APSInt> ComponentOpt =
3569 ComponentArg->getIntegerConstantExpr(Ctx: S.getASTContext());
3570 if (ComponentOpt) {
3571 int64_t ComponentVal = ComponentOpt->getSExtValue();
3572 if (ComponentVal != 0) {
3573 // Issue an error if the component is not 0 (Red).
3574 // 0 -> Red, 1 -> Green, 2 -> Blue, 3 -> Alpha
3575 assert(ComponentVal >= 0 && ComponentVal <= 3 &&
3576 "The component is not in the expected range.");
3577 S.Diag(Loc: ComponentArg->getBeginLoc(),
3578 DiagID: diag::err_hlsl_gathercmp_invalid_component)
3579 << ComponentVal;
3580 return true;
3581 }
3582 }
3583 }
3584
3585 NextIdx++;
3586
3587 // Check the offset operand.
3588 const HLSLAttributedResourceType *ResourceTy =
3589 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3590 if (TheCall->getNumArgs() > NextIdx) {
3591 unsigned ExpectedDim =
3592 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3593 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx)->getType(),
3594 BaseType: S.Context.IntTy, ExpectedCount: ExpectedDim,
3595 Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc()))
3596 return true;
3597 NextIdx++;
3598 }
3599
3600 assert(ResourceTy->hasContainedType() &&
3601 "Expecting a contained type for resource with a dimension "
3602 "attribute.");
3603 QualType ReturnType = ResourceTy->getContainedType();
3604
3605 if (IsCmp) {
3606 if (!ReturnType->hasFloatingRepresentation()) {
3607 S.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_hlsl_samplecmp_requires_float);
3608 return true;
3609 }
3610 }
3611
3612 if (const auto *VecTy = ReturnType->getAs<VectorType>())
3613 ReturnType = VecTy->getElementType();
3614 ReturnType = S.Context.getExtVectorType(VectorType: ReturnType, NumElts: 4);
3615
3616 TheCall->setType(ReturnType);
3617
3618 return false;
3619}
3620static bool CheckLoadLevelBuiltin(Sema &S, CallExpr *TheCall) {
3621 if (S.checkArgCountRange(Call: TheCall, MinArgCount: 2, MaxArgCount: 3))
3622 return true;
3623
3624 // Check the texture handle.
3625 if (CheckResourceHandle(S: &S, TheCall, ArgIndex: 0,
3626 Check: [](const HLSLAttributedResourceType *ResType) {
3627 return ResType->getAttrs().ResourceDimension ==
3628 llvm::dxil::ResourceDimension::Unknown;
3629 }))
3630 return true;
3631
3632 auto *ResourceTy =
3633 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3634
3635 // Check the location + lod (int3 for Texture2D).
3636 unsigned ExpectedDim =
3637 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3638 QualType CoordLODTy = TheCall->getArg(Arg: 1)->getType();
3639 if (CheckVectorElementCount(S: &S, PassedType: CoordLODTy, BaseType: S.Context.IntTy, ExpectedCount: ExpectedDim + 1,
3640 Loc: TheCall->getArg(Arg: 1)->getBeginLoc()))
3641 return true;
3642
3643 QualType EltTy = CoordLODTy;
3644 if (const auto *VTy = EltTy->getAs<VectorType>())
3645 EltTy = VTy->getElementType();
3646 if (!EltTy->isIntegerType()) {
3647 S.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(), DiagID: diag::err_typecheck_expect_int)
3648 << CoordLODTy;
3649 return true;
3650 }
3651
3652 // Check the offset operand.
3653 if (TheCall->getNumArgs() > 2) {
3654 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: 2)->getType(),
3655 BaseType: S.Context.IntTy, ExpectedCount: ExpectedDim,
3656 Loc: TheCall->getArg(Arg: 2)->getBeginLoc()))
3657 return true;
3658 }
3659
3660 TheCall->setType(ResourceTy->getContainedType());
3661 return false;
3662}
3663
3664static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind) {
3665 unsigned MinArgs, MaxArgs;
3666 if (Kind == SampleKind::Sample) {
3667 MinArgs = 3;
3668 MaxArgs = 5;
3669 } else if (Kind == SampleKind::Bias) {
3670 MinArgs = 4;
3671 MaxArgs = 6;
3672 } else if (Kind == SampleKind::Grad) {
3673 MinArgs = 5;
3674 MaxArgs = 7;
3675 } else if (Kind == SampleKind::Level) {
3676 MinArgs = 4;
3677 MaxArgs = 5;
3678 } else if (Kind == SampleKind::Cmp) {
3679 MinArgs = 4;
3680 MaxArgs = 6;
3681 } else {
3682 assert(Kind == SampleKind::CmpLevelZero);
3683 MinArgs = 4;
3684 MaxArgs = 5;
3685 }
3686
3687 if (S.checkArgCountRange(Call: TheCall, MinArgCount: MinArgs, MaxArgCount: MaxArgs))
3688 return true;
3689
3690 if (CheckTextureSamplerAndLocation(S, TheCall))
3691 return true;
3692
3693 const HLSLAttributedResourceType *ResourceTy =
3694 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3695 unsigned ExpectedDim =
3696 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3697
3698 unsigned NextIdx = 3;
3699 if (Kind == SampleKind::Bias || Kind == SampleKind::Level ||
3700 Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3701 // Check the bias, lod level, or compare value, depending on the kind.
3702 // All of them must be a scalar float value.
3703 QualType BiasOrLODOrCmpTy = TheCall->getArg(Arg: NextIdx)->getType();
3704 if (!BiasOrLODOrCmpTy->isFloatingType() ||
3705 BiasOrLODOrCmpTy->isVectorType()) {
3706 S.Diag(Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc(),
3707 DiagID: diag::err_typecheck_convert_incompatible)
3708 << BiasOrLODOrCmpTy << S.Context.FloatTy << 1 << 0 << 0;
3709 return true;
3710 }
3711 NextIdx++;
3712 } else if (Kind == SampleKind::Grad) {
3713 // Check the DDX operand.
3714 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx)->getType(),
3715 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3716 Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc()))
3717 return true;
3718
3719 // Check the DDY operand.
3720 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx + 1)->getType(),
3721 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3722 Loc: TheCall->getArg(Arg: NextIdx + 1)->getBeginLoc()))
3723 return true;
3724 NextIdx += 2;
3725 }
3726
3727 // Check the offset operand.
3728 if (TheCall->getNumArgs() > NextIdx) {
3729 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx)->getType(),
3730 BaseType: S.Context.IntTy, ExpectedCount: ExpectedDim,
3731 Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc()))
3732 return true;
3733 NextIdx++;
3734 }
3735
3736 // Check the clamp operand.
3737 if (Kind != SampleKind::Level && Kind != SampleKind::CmpLevelZero &&
3738 TheCall->getNumArgs() > NextIdx) {
3739 QualType ClampTy = TheCall->getArg(Arg: NextIdx)->getType();
3740 if (!ClampTy->isFloatingType() || ClampTy->isVectorType()) {
3741 S.Diag(Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc(),
3742 DiagID: diag::err_typecheck_convert_incompatible)
3743 << ClampTy << S.Context.FloatTy << 1 << 0 << 0;
3744 return true;
3745 }
3746 }
3747
3748 assert(ResourceTy->hasContainedType() &&
3749 "Expecting a contained type for resource with a dimension "
3750 "attribute.");
3751 QualType ReturnType = ResourceTy->getContainedType();
3752 if (Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3753 if (!ReturnType->hasFloatingRepresentation()) {
3754 S.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_hlsl_samplecmp_requires_float);
3755 return true;
3756 }
3757 ReturnType = S.Context.FloatTy;
3758 }
3759 TheCall->setType(ReturnType);
3760
3761 return false;
3762}
3763
3764// Note: returning true in this case results in CheckBuiltinFunctionCall
3765// returning an ExprError
3766bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
3767 switch (BuiltinID) {
3768 case Builtin::BI__builtin_hlsl_adduint64: {
3769 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3770 return true;
3771
3772 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3773 Check: CheckUnsignedIntVecRepresentation))
3774 return true;
3775
3776 // ensure arg integers are 32-bits
3777 if (CheckExpectedBitWidth(S: &SemaRef, TheCall, ArgOrdinal: 0, Width: 32))
3778 return true;
3779
3780 // ensure both args are vectors of total bit size of a multiple of 64
3781 auto *VTy = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
3782 int NumElementsArg = VTy->getNumElements();
3783 if (NumElementsArg != 2 && NumElementsArg != 4) {
3784 SemaRef.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vector_incorrect_bit_count)
3785 << 1 /*a multiple of*/ << 64 << NumElementsArg * 32;
3786 return true;
3787 }
3788
3789 // ensure first arg and second arg have the same type
3790 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3791 return true;
3792
3793 ExprResult A = TheCall->getArg(Arg: 0);
3794 QualType ArgTyA = A.get()->getType();
3795 // return type is the same as the input type
3796 TheCall->setType(ArgTyA);
3797 break;
3798 }
3799 case Builtin::BI__builtin_hlsl_resource_getpointer: {
3800 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2) ||
3801 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3802 CheckIndexType(S: &SemaRef, TheCall, IndexArgIndex: 1))
3803 return true;
3804
3805 auto *ResourceTy =
3806 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3807 QualType ContainedTy = ResourceTy->getContainedType();
3808 auto ReturnType =
3809 SemaRef.Context.getAddrSpaceQualType(T: ContainedTy, AddressSpace: LangAS::hlsl_device);
3810 ReturnType = SemaRef.Context.getPointerType(T: ReturnType);
3811 TheCall->setType(ReturnType);
3812 TheCall->setValueKind(VK_LValue);
3813
3814 break;
3815 }
3816 case Builtin::BI__builtin_hlsl_resource_getpointer_typed: {
3817 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3) ||
3818 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3819 CheckIndexType(S: &SemaRef, TheCall, IndexArgIndex: 1))
3820 return true;
3821
3822 QualType ElementTy = TheCall->getArg(Arg: 2)->getType();
3823 assert(ElementTy->isPointerType() &&
3824 "expected pointer type for second argument");
3825 ElementTy = ElementTy->getPointeeType();
3826
3827 // Reject array types
3828 if (ElementTy->isArrayType())
3829 return SemaRef.Diag(
3830 Loc: cast<FunctionDecl>(Val: SemaRef.CurContext)->getPointOfInstantiation(),
3831 DiagID: diag::err_invalid_use_of_array_type);
3832
3833 auto ReturnType =
3834 SemaRef.Context.getAddrSpaceQualType(T: ElementTy, AddressSpace: LangAS::hlsl_device);
3835 ReturnType = SemaRef.Context.getPointerType(T: ReturnType);
3836 TheCall->setType(ReturnType);
3837
3838 break;
3839 }
3840 case Builtin::BI__builtin_hlsl_resource_load_with_status: {
3841 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3) ||
3842 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3843 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3844 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3845 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2),
3846 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3847 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
3848 return true;
3849
3850 auto *ResourceTy =
3851 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3852 QualType ReturnType = ResourceTy->getContainedType();
3853 TheCall->setType(ReturnType);
3854
3855 break;
3856 }
3857 case Builtin::BI__builtin_hlsl_resource_load_with_status_typed: {
3858 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 4) ||
3859 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3860 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3861 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3862 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2),
3863 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3864 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
3865 return true;
3866
3867 QualType ReturnType = TheCall->getArg(Arg: 3)->getType();
3868 assert(ReturnType->isPointerType() &&
3869 "expected pointer type for second argument");
3870 ReturnType = ReturnType->getPointeeType();
3871
3872 // Reject array types
3873 if (ReturnType->isArrayType())
3874 return SemaRef.Diag(
3875 Loc: cast<FunctionDecl>(Val: SemaRef.CurContext)->getPointOfInstantiation(),
3876 DiagID: diag::err_invalid_use_of_array_type);
3877
3878 TheCall->setType(ReturnType);
3879
3880 break;
3881 }
3882 case Builtin::BI__builtin_hlsl_resource_load_level:
3883 return CheckLoadLevelBuiltin(S&: SemaRef, TheCall);
3884 case Builtin::BI__builtin_hlsl_resource_sample:
3885 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Sample);
3886 case Builtin::BI__builtin_hlsl_resource_sample_bias:
3887 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Bias);
3888 case Builtin::BI__builtin_hlsl_resource_sample_grad:
3889 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Grad);
3890 case Builtin::BI__builtin_hlsl_resource_sample_level:
3891 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Level);
3892 case Builtin::BI__builtin_hlsl_resource_sample_cmp:
3893 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Cmp);
3894 case Builtin::BI__builtin_hlsl_resource_sample_cmp_level_zero:
3895 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::CmpLevelZero);
3896 case Builtin::BI__builtin_hlsl_resource_calculate_lod:
3897 case Builtin::BI__builtin_hlsl_resource_calculate_lod_unclamped:
3898 return CheckCalculateLodBuiltin(S&: SemaRef, TheCall);
3899 case Builtin::BI__builtin_hlsl_resource_gather:
3900 return CheckGatherBuiltin(S&: SemaRef, TheCall, /*IsCmp=*/false);
3901 case Builtin::BI__builtin_hlsl_resource_gather_cmp:
3902 return CheckGatherBuiltin(S&: SemaRef, TheCall, /*IsCmp=*/true);
3903 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
3904 assert(TheCall->getNumArgs() == 1 && "expected 1 arg");
3905 // Update return type to be the attributed resource type from arg0.
3906 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3907 TheCall->setType(ResourceTy);
3908 break;
3909 }
3910 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
3911 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3912 // Update return type to be the attributed resource type from arg0.
3913 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3914 TheCall->setType(ResourceTy);
3915 break;
3916 }
3917 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
3918 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3919 // Update return type to be the attributed resource type from arg0.
3920 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3921 TheCall->setType(ResourceTy);
3922 break;
3923 }
3924 case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: {
3925 assert(TheCall->getNumArgs() == 3 && "expected 3 args");
3926 ASTContext &AST = SemaRef.getASTContext();
3927 QualType MainHandleTy = TheCall->getArg(Arg: 0)->getType();
3928 auto *MainResType = MainHandleTy->getAs<HLSLAttributedResourceType>();
3929 auto MainAttrs = MainResType->getAttrs();
3930 assert(!MainAttrs.IsCounter && "cannot create a counter from a counter");
3931 MainAttrs.IsCounter = true;
3932 QualType CounterHandleTy = AST.getHLSLAttributedResourceType(
3933 Wrapped: MainResType->getWrappedType(), Contained: MainResType->getContainedType(),
3934 Attrs: MainAttrs);
3935 // Update return type to be the attributed resource type from arg0
3936 // with added IsCounter flag.
3937 TheCall->setType(CounterHandleTy);
3938 break;
3939 }
3940 case Builtin::BI__builtin_hlsl_and:
3941 case Builtin::BI__builtin_hlsl_or: {
3942 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3943 return true;
3944 if (CheckScalarOrVectorOrMatrix(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy,
3945 ArgIndex: 0))
3946 return true;
3947 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3948 return true;
3949
3950 ExprResult A = TheCall->getArg(Arg: 0);
3951 QualType ArgTyA = A.get()->getType();
3952 // return type is the same as the input type
3953 TheCall->setType(ArgTyA);
3954 break;
3955 }
3956 case Builtin::BI__builtin_hlsl_all:
3957 case Builtin::BI__builtin_hlsl_any: {
3958 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3959 return true;
3960 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3961 return true;
3962 break;
3963 }
3964 case Builtin::BI__builtin_hlsl_asdouble: {
3965 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3966 return true;
3967 if (CheckScalarOrVector(
3968 S: &SemaRef, TheCall,
3969 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
3970 /* arg index */ ArgIndex: 0))
3971 return true;
3972 if (CheckScalarOrVector(
3973 S: &SemaRef, TheCall,
3974 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
3975 /* arg index */ ArgIndex: 1))
3976 return true;
3977 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3978 return true;
3979
3980 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().DoubleTy);
3981 break;
3982 }
3983 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
3984 if (SemaRef.BuiltinElementwiseTernaryMath(
3985 TheCall, /*ArgTyRestr=*/
3986 Sema::EltwiseBuiltinArgTyRestriction::None))
3987 return true;
3988 break;
3989 }
3990 case Builtin::BI__builtin_hlsl_dot: {
3991 // arg count is checked by BuiltinVectorToScalarMath
3992 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
3993 return true;
3994 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckNoDoubleVectors))
3995 return true;
3996 break;
3997 }
3998 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
3999 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
4000 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4001 return true;
4002
4003 const Expr *Arg = TheCall->getArg(Arg: 0);
4004 QualType ArgTy = Arg->getType();
4005 QualType EltTy = ArgTy;
4006
4007 QualType ResTy = SemaRef.Context.UnsignedIntTy;
4008
4009 if (auto *VecTy = EltTy->getAs<VectorType>()) {
4010 EltTy = VecTy->getElementType();
4011 ResTy = SemaRef.Context.getExtVectorType(VectorType: ResTy, NumElts: VecTy->getNumElements());
4012 }
4013
4014 if (!EltTy->isIntegerType()) {
4015 Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
4016 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
4017 << /* no fp */ 0 << ArgTy;
4018 return true;
4019 }
4020
4021 TheCall->setType(ResTy);
4022 break;
4023 }
4024 case Builtin::BI__builtin_hlsl_select: {
4025 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
4026 return true;
4027 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy, ArgIndex: 0))
4028 return true;
4029 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
4030 if (ArgTy->isBooleanType() && CheckBoolSelect(S: &SemaRef, TheCall))
4031 return true;
4032 auto *VTy = ArgTy->getAs<VectorType>();
4033 if (VTy && VTy->getElementType()->isBooleanType() &&
4034 CheckVectorSelect(S: &SemaRef, TheCall))
4035 return true;
4036 break;
4037 }
4038 case Builtin::BI__builtin_hlsl_elementwise_saturate:
4039 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
4040 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4041 return true;
4042 if (!TheCall->getArg(Arg: 0)
4043 ->getType()
4044 ->hasFloatingRepresentation()) // half or float or double
4045 return SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4046 DiagID: diag::err_builtin_invalid_arg_type)
4047 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
4048 << /* fp */ 1 << TheCall->getArg(Arg: 0)->getType();
4049 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4050 return true;
4051 break;
4052 }
4053 case Builtin::BI__builtin_hlsl_elementwise_degrees:
4054 case Builtin::BI__builtin_hlsl_elementwise_radians:
4055 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
4056 case Builtin::BI__builtin_hlsl_elementwise_frac:
4057 case Builtin::BI__builtin_hlsl_elementwise_ddx_coarse:
4058 case Builtin::BI__builtin_hlsl_elementwise_ddy_coarse:
4059 case Builtin::BI__builtin_hlsl_elementwise_ddx_fine:
4060 case Builtin::BI__builtin_hlsl_elementwise_ddy_fine: {
4061 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4062 return true;
4063 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4064 Check: CheckFloatOrHalfRepresentation))
4065 return true;
4066 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4067 return true;
4068 break;
4069 }
4070 case Builtin::BI__builtin_hlsl_elementwise_isinf:
4071 case Builtin::BI__builtin_hlsl_elementwise_isnan: {
4072 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4073 return true;
4074 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4075 Check: CheckFloatOrHalfRepresentation))
4076 return true;
4077 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4078 return true;
4079 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().BoolTy);
4080 break;
4081 }
4082 case Builtin::BI__builtin_hlsl_lerp: {
4083 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
4084 return true;
4085 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4086 Check: CheckFloatOrHalfRepresentation))
4087 return true;
4088 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
4089 return true;
4090 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
4091 return true;
4092 break;
4093 }
4094 case Builtin::BI__builtin_hlsl_mad: {
4095 if (SemaRef.BuiltinElementwiseTernaryMath(
4096 TheCall, /*ArgTyRestr=*/
4097 Sema::EltwiseBuiltinArgTyRestriction::None))
4098 return true;
4099 break;
4100 }
4101 case Builtin::BI__builtin_hlsl_mul: {
4102 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
4103 return true;
4104
4105 Expr *Arg0 = TheCall->getArg(Arg: 0);
4106 Expr *Arg1 = TheCall->getArg(Arg: 1);
4107 QualType Ty0 = Arg0->getType();
4108 QualType Ty1 = Arg1->getType();
4109
4110 auto getElemType = [](QualType T) -> QualType {
4111 if (const auto *VTy = T->getAs<VectorType>())
4112 return VTy->getElementType();
4113 if (const auto *MTy = T->getAs<ConstantMatrixType>())
4114 return MTy->getElementType();
4115 return T;
4116 };
4117
4118 QualType EltTy0 = getElemType(Ty0);
4119
4120 bool IsVec0 = Ty0->isVectorType();
4121 bool IsMat0 = Ty0->isConstantMatrixType();
4122 bool IsVec1 = Ty1->isVectorType();
4123 bool IsMat1 = Ty1->isConstantMatrixType();
4124
4125 QualType RetTy;
4126
4127 if (IsVec0 && IsMat1) {
4128 auto *MatTy = Ty1->castAs<ConstantMatrixType>();
4129 RetTy = getASTContext().getExtVectorType(VectorType: EltTy0, NumElts: MatTy->getNumColumns());
4130 } else if (IsMat0 && IsVec1) {
4131 auto *MatTy = Ty0->castAs<ConstantMatrixType>();
4132 RetTy = getASTContext().getExtVectorType(VectorType: EltTy0, NumElts: MatTy->getNumRows());
4133 } else {
4134 assert(IsMat0 && IsMat1);
4135 auto *MatTy0 = Ty0->castAs<ConstantMatrixType>();
4136 auto *MatTy1 = Ty1->castAs<ConstantMatrixType>();
4137 RetTy = getASTContext().getConstantMatrixType(
4138 ElementType: EltTy0, NumRows: MatTy0->getNumRows(), NumColumns: MatTy1->getNumColumns());
4139 }
4140
4141 TheCall->setType(RetTy);
4142 break;
4143 }
4144 case Builtin::BI__builtin_hlsl_normalize: {
4145 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4146 return true;
4147 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4148 Check: CheckFloatOrHalfRepresentation))
4149 return true;
4150 ExprResult A = TheCall->getArg(Arg: 0);
4151 QualType ArgTyA = A.get()->getType();
4152 // return type is the same as the input type
4153 TheCall->setType(ArgTyA);
4154 break;
4155 }
4156 case Builtin::BI__builtin_elementwise_fma: {
4157 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3) ||
4158 CheckAllArgsHaveSameType(S: &SemaRef, TheCall)) {
4159 return true;
4160 }
4161
4162 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4163 Check: CheckAnyDoubleRepresentation))
4164 return true;
4165
4166 ExprResult A = TheCall->getArg(Arg: 0);
4167 QualType ArgTyA = A.get()->getType();
4168 // return type is the same as input type
4169 TheCall->setType(ArgTyA);
4170 break;
4171 }
4172 case Builtin::BI__builtin_hlsl_transpose: {
4173 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4174 return true;
4175
4176 Expr *Arg = TheCall->getArg(Arg: 0);
4177 QualType ArgTy = Arg->getType();
4178
4179 const auto *MatTy = ArgTy->getAs<ConstantMatrixType>();
4180 if (!MatTy) {
4181 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
4182 << 1 << /* matrix */ 3 << /* no int */ 0 << /* no fp */ 0 << ArgTy;
4183 return true;
4184 }
4185
4186 QualType RetTy = getASTContext().getConstantMatrixType(
4187 ElementType: MatTy->getElementType(), NumRows: MatTy->getNumColumns(), NumColumns: MatTy->getNumRows());
4188 TheCall->setType(RetTy);
4189 break;
4190 }
4191 case Builtin::BI__builtin_hlsl_elementwise_sign: {
4192 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4193 return true;
4194 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4195 Check: CheckFloatingOrIntRepresentation))
4196 return true;
4197 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().IntTy);
4198 break;
4199 }
4200 case Builtin::BI__builtin_hlsl_step: {
4201 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
4202 return true;
4203 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4204 Check: CheckFloatOrHalfRepresentation))
4205 return true;
4206
4207 ExprResult A = TheCall->getArg(Arg: 0);
4208 QualType ArgTyA = A.get()->getType();
4209 // return type is the same as the input type
4210 TheCall->setType(ArgTyA);
4211 break;
4212 }
4213 case Builtin::BI__builtin_hlsl_wave_active_all_equal: {
4214 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4215 return true;
4216
4217 // Ensure input expr type is a scalar/vector
4218 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4219 return true;
4220
4221 QualType InputTy = TheCall->getArg(Arg: 0)->getType();
4222 ASTContext &Ctx = getASTContext();
4223
4224 QualType RetTy;
4225
4226 // If vector, construct bool vector of same size
4227 if (const auto *VecTy = InputTy->getAs<ExtVectorType>()) {
4228 unsigned NumElts = VecTy->getNumElements();
4229 RetTy = Ctx.getExtVectorType(VectorType: Ctx.BoolTy, NumElts);
4230 } else {
4231 // Scalar case
4232 RetTy = Ctx.BoolTy;
4233 }
4234
4235 TheCall->setType(RetTy);
4236 break;
4237 }
4238 case Builtin::BI__builtin_hlsl_wave_active_max:
4239 case Builtin::BI__builtin_hlsl_wave_active_min:
4240 case Builtin::BI__builtin_hlsl_wave_active_sum:
4241 case Builtin::BI__builtin_hlsl_wave_active_product: {
4242 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4243 return true;
4244
4245 // Ensure input expr type is a scalar/vector and the same as the return type
4246 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4247 return true;
4248 if (CheckWaveActive(S: &SemaRef, TheCall))
4249 return true;
4250 ExprResult Expr = TheCall->getArg(Arg: 0);
4251 QualType ArgTyExpr = Expr.get()->getType();
4252 TheCall->setType(ArgTyExpr);
4253 break;
4254 }
4255 case Builtin::BI__builtin_hlsl_wave_active_bit_or:
4256 case Builtin::BI__builtin_hlsl_wave_active_bit_xor:
4257 case Builtin::BI__builtin_hlsl_wave_active_bit_and: {
4258 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4259 return true;
4260
4261 // Ensure input expr type is a scalar/vector
4262 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4263 return true;
4264
4265 if (CheckWaveActive(S: &SemaRef, TheCall))
4266 return true;
4267
4268 // Ensure the expr type is interpretable as a uint or vector<uint>
4269 ExprResult Expr = TheCall->getArg(Arg: 0);
4270 QualType ArgTyExpr = Expr.get()->getType();
4271 auto *VTy = ArgTyExpr->getAs<VectorType>();
4272 if (!(ArgTyExpr->isIntegerType() ||
4273 (VTy && VTy->getElementType()->isIntegerType()))) {
4274 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4275 DiagID: diag::err_builtin_invalid_arg_type)
4276 << ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4277 return true;
4278 }
4279
4280 // Ensure input expr type is the same as the return type
4281 TheCall->setType(ArgTyExpr);
4282 break;
4283 }
4284 // Note these are llvm builtins that we want to catch invalid intrinsic
4285 // generation. Normal handling of these builtins will occur elsewhere.
4286 case Builtin::BI__builtin_elementwise_bitreverse: {
4287 // does not include a check for number of arguments
4288 // because that is done previously
4289 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4290 Check: CheckUnsignedIntRepresentation))
4291 return true;
4292 break;
4293 }
4294 case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: {
4295 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4296 return true;
4297
4298 QualType ArgType = TheCall->getArg(Arg: 0)->getType();
4299
4300 if (!(ArgType->isScalarType())) {
4301 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4302 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
4303 << ArgType << 0;
4304 return true;
4305 }
4306
4307 if (!(ArgType->isBooleanType())) {
4308 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4309 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
4310 << ArgType << 0;
4311 return true;
4312 }
4313
4314 break;
4315 }
4316 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
4317 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
4318 return true;
4319
4320 // Ensure index parameter type can be interpreted as a uint
4321 ExprResult Index = TheCall->getArg(Arg: 1);
4322 QualType ArgTyIndex = Index.get()->getType();
4323 if (!ArgTyIndex->isIntegerType()) {
4324 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
4325 DiagID: diag::err_typecheck_convert_incompatible)
4326 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4327 return true;
4328 }
4329
4330 // Ensure input expr type is a scalar/vector and the same as the return type
4331 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4332 return true;
4333
4334 ExprResult Expr = TheCall->getArg(Arg: 0);
4335 QualType ArgTyExpr = Expr.get()->getType();
4336 TheCall->setType(ArgTyExpr);
4337 break;
4338 }
4339 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
4340 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 0))
4341 return true;
4342 break;
4343 }
4344 case Builtin::BI__builtin_hlsl_wave_prefix_sum:
4345 case Builtin::BI__builtin_hlsl_wave_prefix_product: {
4346 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4347 return true;
4348
4349 // Ensure input expr type is a scalar/vector and the same as the return type
4350 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4351 return true;
4352 if (CheckWavePrefix(S: &SemaRef, TheCall))
4353 return true;
4354 ExprResult Expr = TheCall->getArg(Arg: 0);
4355 QualType ArgTyExpr = Expr.get()->getType();
4356 TheCall->setType(ArgTyExpr);
4357 break;
4358 }
4359 case Builtin::BI__builtin_hlsl_quad_read_across_x:
4360 case Builtin::BI__builtin_hlsl_quad_read_across_y: {
4361 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4362 return true;
4363
4364 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4365 return true;
4366 if (CheckNotBoolScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4367 return true;
4368 ExprResult Expr = TheCall->getArg(Arg: 0);
4369 QualType ArgTyExpr = Expr.get()->getType();
4370 TheCall->setType(ArgTyExpr);
4371 break;
4372 }
4373 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
4374 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
4375 return true;
4376
4377 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.DoubleTy, ArgIndex: 0) ||
4378 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
4379 ArgIndex: 1) ||
4380 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
4381 ArgIndex: 2))
4382 return true;
4383
4384 if (CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 1) ||
4385 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
4386 return true;
4387 break;
4388 }
4389 case Builtin::BI__builtin_hlsl_elementwise_clip: {
4390 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4391 return true;
4392
4393 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.FloatTy, ArgIndex: 0))
4394 return true;
4395 break;
4396 }
4397 case Builtin::BI__builtin_elementwise_acos:
4398 case Builtin::BI__builtin_elementwise_asin:
4399 case Builtin::BI__builtin_elementwise_atan:
4400 case Builtin::BI__builtin_elementwise_atan2:
4401 case Builtin::BI__builtin_elementwise_ceil:
4402 case Builtin::BI__builtin_elementwise_cos:
4403 case Builtin::BI__builtin_elementwise_cosh:
4404 case Builtin::BI__builtin_elementwise_exp:
4405 case Builtin::BI__builtin_elementwise_exp2:
4406 case Builtin::BI__builtin_elementwise_exp10:
4407 case Builtin::BI__builtin_elementwise_floor:
4408 case Builtin::BI__builtin_elementwise_fmod:
4409 case Builtin::BI__builtin_elementwise_log:
4410 case Builtin::BI__builtin_elementwise_log2:
4411 case Builtin::BI__builtin_elementwise_log10:
4412 case Builtin::BI__builtin_elementwise_pow:
4413 case Builtin::BI__builtin_elementwise_roundeven:
4414 case Builtin::BI__builtin_elementwise_sin:
4415 case Builtin::BI__builtin_elementwise_sinh:
4416 case Builtin::BI__builtin_elementwise_sqrt:
4417 case Builtin::BI__builtin_elementwise_tan:
4418 case Builtin::BI__builtin_elementwise_tanh:
4419 case Builtin::BI__builtin_elementwise_trunc: {
4420 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4421 Check: CheckFloatOrHalfRepresentation))
4422 return true;
4423 break;
4424 }
4425 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
4426 assert(TheCall->getNumArgs() == 2 && "expected 2 args");
4427 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
4428 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
4429 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
4430 };
4431 if (CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0, Check: checkResTy))
4432 return true;
4433 Expr *OffsetExpr = TheCall->getArg(Arg: 1);
4434 std::optional<llvm::APSInt> Offset =
4435 OffsetExpr->getIntegerConstantExpr(Ctx: SemaRef.getASTContext());
4436 if (!Offset.has_value() || std::abs(i: Offset->getExtValue()) != 1) {
4437 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
4438 DiagID: diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
4439 << 1;
4440 return true;
4441 }
4442 break;
4443 }
4444 case Builtin::BI__builtin_hlsl_elementwise_f16tof32: {
4445 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4446 return true;
4447 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4448 Check: CheckUnsignedIntRepresentation))
4449 return true;
4450 // ensure arg integers are 32 bits
4451 if (CheckExpectedBitWidth(S: &SemaRef, TheCall, ArgOrdinal: 0, Width: 32))
4452 return true;
4453 // check it wasn't a bool type
4454 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
4455 if (auto *VTy = ArgTy->getAs<VectorType>())
4456 ArgTy = VTy->getElementType();
4457 if (ArgTy->isBooleanType()) {
4458 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4459 DiagID: diag::err_builtin_invalid_arg_type)
4460 << 1 << /* scalar or vector of */ 5 << /* unsigned int */ 3
4461 << /* no fp */ 0 << TheCall->getArg(Arg: 0)->getType();
4462 return true;
4463 }
4464
4465 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().FloatTy);
4466 break;
4467 }
4468 case Builtin::BI__builtin_hlsl_elementwise_f32tof16: {
4469 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4470 return true;
4471 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckFloatRepresentation))
4472 return true;
4473 SetElementTypeAsReturnType(S: &SemaRef, TheCall,
4474 ReturnType: getASTContext().UnsignedIntTy);
4475 break;
4476 }
4477 }
4478 return false;
4479}
4480
4481static void BuildFlattenedTypeList(QualType BaseTy,
4482 llvm::SmallVectorImpl<QualType> &List) {
4483 llvm::SmallVector<QualType, 16> WorkList;
4484 WorkList.push_back(Elt: BaseTy);
4485 while (!WorkList.empty()) {
4486 QualType T = WorkList.pop_back_val();
4487 T = T.getCanonicalType().getUnqualifiedType();
4488 if (const auto *AT = dyn_cast<ConstantArrayType>(Val&: T)) {
4489 llvm::SmallVector<QualType, 16> ElementFields;
4490 // Generally I've avoided recursion in this algorithm, but arrays of
4491 // structs could be time-consuming to flatten and churn through on the
4492 // work list. Hopefully nesting arrays of structs containing arrays
4493 // of structs too many levels deep is unlikely.
4494 BuildFlattenedTypeList(BaseTy: AT->getElementType(), List&: ElementFields);
4495 // Repeat the element's field list n times.
4496 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
4497 llvm::append_range(C&: List, R&: ElementFields);
4498 continue;
4499 }
4500 // Vectors can only have element types that are builtin types, so this can
4501 // add directly to the list instead of to the WorkList.
4502 if (const auto *VT = dyn_cast<VectorType>(Val&: T)) {
4503 List.insert(I: List.end(), NumToInsert: VT->getNumElements(), Elt: VT->getElementType());
4504 continue;
4505 }
4506 if (const auto *MT = dyn_cast<ConstantMatrixType>(Val&: T)) {
4507 List.insert(I: List.end(), NumToInsert: MT->getNumElementsFlattened(),
4508 Elt: MT->getElementType());
4509 continue;
4510 }
4511 if (const auto *RD = T->getAsCXXRecordDecl()) {
4512 if (RD->isStandardLayout())
4513 RD = RD->getStandardLayoutBaseWithFields();
4514
4515 // For types that we shouldn't decompose (unions and non-aggregates), just
4516 // add the type itself to the list.
4517 if (RD->isUnion() || !RD->isAggregate()) {
4518 List.push_back(Elt: T);
4519 continue;
4520 }
4521
4522 llvm::SmallVector<QualType, 16> FieldTypes;
4523 for (const auto *FD : RD->fields())
4524 if (!FD->isUnnamedBitField())
4525 FieldTypes.push_back(Elt: FD->getType());
4526 // Reverse the newly added sub-range.
4527 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
4528 llvm::append_range(C&: WorkList, R&: FieldTypes);
4529
4530 // If this wasn't a standard layout type we may also have some base
4531 // classes to deal with.
4532 if (!RD->isStandardLayout()) {
4533 FieldTypes.clear();
4534 for (const auto &Base : RD->bases())
4535 FieldTypes.push_back(Elt: Base.getType());
4536 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
4537 llvm::append_range(C&: WorkList, R&: FieldTypes);
4538 }
4539 continue;
4540 }
4541 List.push_back(Elt: T);
4542 }
4543}
4544
4545bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) {
4546 // null and array types are not allowed.
4547 if (QT.isNull() || QT->isArrayType())
4548 return false;
4549
4550 // UDT types are not allowed
4551 if (QT->isRecordType())
4552 return false;
4553
4554 if (QT->isBooleanType() || QT->isEnumeralType())
4555 return false;
4556
4557 // the only other valid builtin types are scalars or vectors
4558 if (QT->isArithmeticType()) {
4559 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
4560 return false;
4561 return true;
4562 }
4563
4564 if (const VectorType *VT = QT->getAs<VectorType>()) {
4565 int ArraySize = VT->getNumElements();
4566
4567 if (ArraySize > 4)
4568 return false;
4569
4570 QualType ElTy = VT->getElementType();
4571 if (ElTy->isBooleanType())
4572 return false;
4573
4574 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
4575 return false;
4576 return true;
4577 }
4578
4579 return false;
4580}
4581
4582bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const {
4583 if (T1.isNull() || T2.isNull())
4584 return false;
4585
4586 T1 = T1.getCanonicalType().getUnqualifiedType();
4587 T2 = T2.getCanonicalType().getUnqualifiedType();
4588
4589 // If both types are the same canonical type, they're obviously compatible.
4590 if (SemaRef.getASTContext().hasSameType(T1, T2))
4591 return true;
4592
4593 llvm::SmallVector<QualType, 16> T1Types;
4594 BuildFlattenedTypeList(BaseTy: T1, List&: T1Types);
4595 llvm::SmallVector<QualType, 16> T2Types;
4596 BuildFlattenedTypeList(BaseTy: T2, List&: T2Types);
4597
4598 // Check the flattened type list
4599 return llvm::equal(LRange&: T1Types, RRange&: T2Types,
4600 P: [this](QualType LHS, QualType RHS) -> bool {
4601 return SemaRef.IsLayoutCompatible(T1: LHS, T2: RHS);
4602 });
4603}
4604
4605bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
4606 FunctionDecl *Old) {
4607 if (New->getNumParams() != Old->getNumParams())
4608 return true;
4609
4610 bool HadError = false;
4611
4612 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
4613 ParmVarDecl *NewParam = New->getParamDecl(i);
4614 ParmVarDecl *OldParam = Old->getParamDecl(i);
4615
4616 // HLSL parameter declarations for inout and out must match between
4617 // declarations. In HLSL inout and out are ambiguous at the call site,
4618 // but have different calling behavior, so you cannot overload a
4619 // method based on a difference between inout and out annotations.
4620 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
4621 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
4622 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
4623 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
4624
4625 if (NSpellingIdx != OSpellingIdx) {
4626 SemaRef.Diag(Loc: NewParam->getLocation(),
4627 DiagID: diag::err_hlsl_param_qualifier_mismatch)
4628 << NDAttr << NewParam;
4629 SemaRef.Diag(Loc: OldParam->getLocation(), DiagID: diag::note_previous_declaration_as)
4630 << ODAttr;
4631 HadError = true;
4632 }
4633 }
4634 return HadError;
4635}
4636
4637// Generally follows PerformScalarCast, with cases reordered for
4638// clarity of what types are supported
4639bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
4640
4641 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
4642 return false;
4643
4644 if (SemaRef.getASTContext().hasSameUnqualifiedType(T1: SrcTy, T2: DestTy))
4645 return true;
4646
4647 switch (SrcTy->getScalarTypeKind()) {
4648 case Type::STK_Bool: // casting from bool is like casting from an integer
4649 case Type::STK_Integral:
4650 switch (DestTy->getScalarTypeKind()) {
4651 case Type::STK_Bool:
4652 case Type::STK_Integral:
4653 case Type::STK_Floating:
4654 return true;
4655 case Type::STK_CPointer:
4656 case Type::STK_ObjCObjectPointer:
4657 case Type::STK_BlockPointer:
4658 case Type::STK_MemberPointer:
4659 llvm_unreachable("HLSL doesn't support pointers.");
4660 case Type::STK_IntegralComplex:
4661 case Type::STK_FloatingComplex:
4662 llvm_unreachable("HLSL doesn't support complex types.");
4663 case Type::STK_FixedPoint:
4664 llvm_unreachable("HLSL doesn't support fixed point types.");
4665 }
4666 llvm_unreachable("Should have returned before this");
4667
4668 case Type::STK_Floating:
4669 switch (DestTy->getScalarTypeKind()) {
4670 case Type::STK_Floating:
4671 case Type::STK_Bool:
4672 case Type::STK_Integral:
4673 return true;
4674 case Type::STK_FloatingComplex:
4675 case Type::STK_IntegralComplex:
4676 llvm_unreachable("HLSL doesn't support complex types.");
4677 case Type::STK_FixedPoint:
4678 llvm_unreachable("HLSL doesn't support fixed point types.");
4679 case Type::STK_CPointer:
4680 case Type::STK_ObjCObjectPointer:
4681 case Type::STK_BlockPointer:
4682 case Type::STK_MemberPointer:
4683 llvm_unreachable("HLSL doesn't support pointers.");
4684 }
4685 llvm_unreachable("Should have returned before this");
4686
4687 case Type::STK_MemberPointer:
4688 case Type::STK_CPointer:
4689 case Type::STK_BlockPointer:
4690 case Type::STK_ObjCObjectPointer:
4691 llvm_unreachable("HLSL doesn't support pointers.");
4692
4693 case Type::STK_FixedPoint:
4694 llvm_unreachable("HLSL doesn't support fixed point types.");
4695
4696 case Type::STK_FloatingComplex:
4697 case Type::STK_IntegralComplex:
4698 llvm_unreachable("HLSL doesn't support complex types.");
4699 }
4700
4701 llvm_unreachable("Unhandled scalar cast");
4702}
4703
4704// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
4705// Src is a scalar, a vector of length 1, or a 1x1 matrix
4706// Or if Dest is a vector and Src is a vector of length 1 or a 1x1 matrix
4707bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
4708
4709 QualType SrcTy = Src->getType();
4710 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
4711 // going to be a vector splat from a scalar.
4712 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
4713 DestTy->isScalarType())
4714 return false;
4715
4716 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
4717 const ConstantMatrixType *SrcMatTy = SrcTy->getAs<ConstantMatrixType>();
4718
4719 // Src isn't a scalar, a vector of length 1, or a 1x1 matrix
4720 if (!SrcTy->isScalarType() &&
4721 !(SrcVecTy && SrcVecTy->getNumElements() == 1) &&
4722 !(SrcMatTy && SrcMatTy->getNumElementsFlattened() == 1))
4723 return false;
4724
4725 if (SrcVecTy)
4726 SrcTy = SrcVecTy->getElementType();
4727 else if (SrcMatTy)
4728 SrcTy = SrcMatTy->getElementType();
4729
4730 llvm::SmallVector<QualType> DestTypes;
4731 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
4732
4733 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
4734 if (DestTypes[I]->isUnionType())
4735 return false;
4736 if (!CanPerformScalarCast(SrcTy, DestTy: DestTypes[I]))
4737 return false;
4738 }
4739 return true;
4740}
4741
4742// Can we perform an HLSL Elementwise cast?
4743bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
4744
4745 // Don't handle casts where LHS and RHS are any combination of scalar/vector
4746 // There must be an aggregate somewhere
4747 QualType SrcTy = Src->getType();
4748 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
4749 return false;
4750
4751 if (SrcTy->isVectorType() &&
4752 (DestTy->isScalarType() || DestTy->isVectorType()))
4753 return false;
4754
4755 if (SrcTy->isConstantMatrixType() &&
4756 (DestTy->isScalarType() || DestTy->isConstantMatrixType()))
4757 return false;
4758
4759 llvm::SmallVector<QualType> DestTypes;
4760 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
4761 llvm::SmallVector<QualType> SrcTypes;
4762 BuildFlattenedTypeList(BaseTy: SrcTy, List&: SrcTypes);
4763
4764 // Usually the size of SrcTypes must be greater than or equal to the size of
4765 // DestTypes.
4766 if (SrcTypes.size() < DestTypes.size())
4767 return false;
4768
4769 unsigned SrcSize = SrcTypes.size();
4770 unsigned DstSize = DestTypes.size();
4771 unsigned I;
4772 for (I = 0; I < DstSize && I < SrcSize; I++) {
4773 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
4774 return false;
4775 if (!CanPerformScalarCast(SrcTy: SrcTypes[I], DestTy: DestTypes[I])) {
4776 return false;
4777 }
4778 }
4779
4780 // check the rest of the source type for unions.
4781 for (; I < SrcSize; I++) {
4782 if (SrcTypes[I]->isUnionType())
4783 return false;
4784 }
4785 return true;
4786}
4787
4788ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) {
4789 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
4790 "We should not get here without a parameter modifier expression");
4791 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
4792 if (Attr->getABI() == ParameterABI::Ordinary)
4793 return ExprResult(Arg);
4794
4795 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
4796 if (!Arg->isLValue()) {
4797 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_lvalue)
4798 << Arg << (IsInOut ? 1 : 0);
4799 return ExprError();
4800 }
4801
4802 ASTContext &Ctx = SemaRef.getASTContext();
4803
4804 QualType Ty = Param->getType().getNonLValueExprType(Context: Ctx);
4805
4806 // HLSL allows implicit conversions from scalars to vectors, but not the
4807 // inverse, so we need to disallow `inout` with scalar->vector or
4808 // scalar->matrix conversions.
4809 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
4810 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_scalar_extension)
4811 << Arg << (IsInOut ? 1 : 0);
4812 return ExprError();
4813 }
4814
4815 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
4816 VK_LValue, OK_Ordinary, Arg);
4817
4818 // Parameters are initialized via copy initialization. This allows for
4819 // overload resolution of argument constructors.
4820 InitializedEntity Entity =
4821 InitializedEntity::InitializeParameter(Context&: Ctx, Type: Ty, Consumed: false);
4822 ExprResult Res =
4823 SemaRef.PerformCopyInitialization(Entity, EqualLoc: Param->getBeginLoc(), Init: ArgOpV);
4824 if (Res.isInvalid())
4825 return ExprError();
4826 Expr *Base = Res.get();
4827 // After the cast, drop the reference type when creating the exprs.
4828 Ty = Ty.getNonLValueExprType(Context: Ctx);
4829 auto *OpV = new (Ctx)
4830 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
4831
4832 // Writebacks are performed with `=` binary operator, which allows for
4833 // overload resolution on writeback result expressions.
4834 Res = SemaRef.ActOnBinOp(S: SemaRef.getCurScope(), TokLoc: Param->getBeginLoc(),
4835 Kind: tok::equal, LHSExpr: ArgOpV, RHSExpr: OpV);
4836
4837 if (Res.isInvalid())
4838 return ExprError();
4839 Expr *Writeback = Res.get();
4840 auto *OutExpr =
4841 HLSLOutArgExpr::Create(C: Ctx, Ty, Base: ArgOpV, OpV, WB: Writeback, IsInOut);
4842
4843 return ExprResult(OutExpr);
4844}
4845
4846QualType SemaHLSL::getInoutParameterType(QualType Ty) {
4847 // If HLSL gains support for references, all the cites that use this will need
4848 // to be updated with semantic checking to produce errors for
4849 // pointers/references.
4850 assert(!Ty->isReferenceType() &&
4851 "Pointer and reference types cannot be inout or out parameters");
4852 Ty = SemaRef.getASTContext().getLValueReferenceType(T: Ty);
4853 Ty.addRestrict();
4854 return Ty;
4855}
4856
4857// Returns true if the type has a non-empty constant buffer layout (if it is
4858// scalar, vector or matrix, or if it contains any of these.
4859static bool hasConstantBufferLayout(QualType QT) {
4860 const Type *Ty = QT->getUnqualifiedDesugaredType();
4861 if (Ty->isScalarType() || Ty->isVectorType() || Ty->isMatrixType())
4862 return true;
4863
4864 if (Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray())
4865 return false;
4866
4867 if (const auto *RD = Ty->getAsCXXRecordDecl()) {
4868 for (const auto *FD : RD->fields()) {
4869 if (hasConstantBufferLayout(QT: FD->getType()))
4870 return true;
4871 }
4872 assert(RD->getNumBases() <= 1 &&
4873 "HLSL doesn't support multiple inheritance");
4874 return RD->getNumBases()
4875 ? hasConstantBufferLayout(QT: RD->bases_begin()->getType())
4876 : false;
4877 }
4878
4879 if (const auto *AT = dyn_cast<ArrayType>(Val: Ty)) {
4880 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: AT))
4881 if (isZeroSizedArray(CAT))
4882 return false;
4883 return hasConstantBufferLayout(QT: AT->getElementType());
4884 }
4885
4886 return false;
4887}
4888
4889static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD) {
4890 bool IsVulkan =
4891 Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::Vulkan;
4892 bool IsVKPushConstant = IsVulkan && VD->hasAttr<HLSLVkPushConstantAttr>();
4893 QualType QT = VD->getType();
4894 return VD->getDeclContext()->isTranslationUnit() &&
4895 QT.getAddressSpace() == LangAS::Default &&
4896 VD->getStorageClass() != SC_Static &&
4897 !VD->hasAttr<HLSLVkConstantIdAttr>() && !IsVKPushConstant &&
4898 hasConstantBufferLayout(QT);
4899}
4900
4901void SemaHLSL::deduceAddressSpace(VarDecl *Decl) {
4902 // The variable already has an address space (groupshared for ex).
4903 if (Decl->getType().hasAddressSpace())
4904 return;
4905
4906 if (Decl->getType()->isDependentType())
4907 return;
4908
4909 QualType Type = Decl->getType();
4910
4911 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
4912 LangAS ImplAS = LangAS::hlsl_input;
4913 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4914 Decl->setType(Type);
4915 return;
4916 }
4917
4918 if (Decl->hasAttr<HLSLVkExtBuiltinOutputAttr>()) {
4919 LangAS ImplAS = LangAS::hlsl_output;
4920 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4921 Decl->setType(Type);
4922
4923 // HLSL uses `static` differently than C++. For BuiltIn output, the static
4924 // does not imply private to the module scope.
4925 // Marking it as external to reflect the semantic this attribute brings.
4926 // See https://github.com/microsoft/hlsl-specs/issues/350
4927 Decl->setStorageClass(SC_Extern);
4928 return;
4929 }
4930
4931 bool IsVulkan = getASTContext().getTargetInfo().getTriple().getOS() ==
4932 llvm::Triple::Vulkan;
4933 if (IsVulkan && Decl->hasAttr<HLSLVkPushConstantAttr>()) {
4934 if (HasDeclaredAPushConstant)
4935 SemaRef.Diag(Loc: Decl->getLocation(), DiagID: diag::err_hlsl_push_constant_unique);
4936
4937 LangAS ImplAS = LangAS::hlsl_push_constant;
4938 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4939 Decl->setType(Type);
4940 HasDeclaredAPushConstant = true;
4941 return;
4942 }
4943
4944 if (Type->isSamplerT() || Type->isVoidType())
4945 return;
4946
4947 // Resource handles.
4948 if (Type->isHLSLResourceRecord() || Type->isHLSLResourceRecordArray())
4949 return;
4950
4951 // Only static globals belong to the Private address space.
4952 // Non-static globals belongs to the cbuffer.
4953 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
4954 return;
4955
4956 LangAS ImplAS = LangAS::hlsl_private;
4957 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4958 Decl->setType(Type);
4959}
4960
4961namespace {
4962
4963// Helper class for assigning bindings to resources declared within a struct.
4964// It keeps track of all binding attributes declared on a struct instance, and
4965// the offsets for each register type that have been assigned so far.
4966// Handles both explicit and implicit bindings.
4967class StructBindingContext {
4968 // Bindings and offsets per register type. We only need to support four
4969 // register types - SRV (u), UAV (t), CBuffer (c), and Sampler (s).
4970 HLSLResourceBindingAttr *RegBindingsAttrs[4];
4971 unsigned RegBindingOffset[4];
4972
4973 // Make sure the RegisterType values are what we expect
4974 static_assert(static_cast<unsigned>(RegisterType::SRV) == 0 &&
4975 static_cast<unsigned>(RegisterType::UAV) == 1 &&
4976 static_cast<unsigned>(RegisterType::CBuffer) == 2 &&
4977 static_cast<unsigned>(RegisterType::Sampler) == 3,
4978 "unexpected register type values");
4979
4980 // Vulkan binding attribute does not vary by register type.
4981 HLSLVkBindingAttr *VkBindingAttr;
4982 unsigned VkBindingOffset;
4983
4984public:
4985 // Constructor: gather all binding attributes on a struct instance and
4986 // initialize offsets.
4987 StructBindingContext(VarDecl *VD) {
4988 for (unsigned i = 0; i < 4; ++i) {
4989 RegBindingsAttrs[i] = nullptr;
4990 RegBindingOffset[i] = 0;
4991 }
4992 VkBindingAttr = nullptr;
4993 VkBindingOffset = 0;
4994
4995 ASTContext &AST = VD->getASTContext();
4996 bool IsSpirv = AST.getTargetInfo().getTriple().isSPIRV();
4997
4998 for (Attr *A : VD->attrs()) {
4999 if (auto *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: A)) {
5000 RegisterType RegType = RBA->getRegisterType();
5001 unsigned RegTypeIdx = static_cast<unsigned>(RegType);
5002 // Ignore unsupported register annotations, such as 'c' or 'i'.
5003 if (RegTypeIdx < 4)
5004 RegBindingsAttrs[RegTypeIdx] = RBA;
5005 continue;
5006 }
5007 // Gather the Vulkan binding attributes only if the target is SPIR-V.
5008 if (IsSpirv) {
5009 if (auto *VBA = dyn_cast<HLSLVkBindingAttr>(Val: A))
5010 VkBindingAttr = VBA;
5011 }
5012 }
5013 }
5014
5015 // Creates a binding attribute for a resource based on the gathered attributes
5016 // and the required register type and range.
5017 Attr *createBindingAttr(SemaHLSL &S, ASTContext &AST, RegisterType RegType,
5018 unsigned Range) {
5019 assert(static_cast<unsigned>(RegType) < 4 && "unexpected register type");
5020
5021 if (VkBindingAttr) {
5022 unsigned Offset = VkBindingOffset;
5023 VkBindingOffset += Range;
5024 return HLSLVkBindingAttr::CreateImplicit(
5025 Ctx&: AST, Binding: VkBindingAttr->getBinding() + Offset, Set: VkBindingAttr->getSet(),
5026 Range: VkBindingAttr->getRange());
5027 }
5028
5029 HLSLResourceBindingAttr *RBA =
5030 RegBindingsAttrs[static_cast<unsigned>(RegType)];
5031 HLSLResourceBindingAttr *NewAttr = nullptr;
5032
5033 if (RBA && RBA->hasRegisterSlot()) {
5034 // Explicit binding - create a new attribute with offseted slot number
5035 // based on the required register type.
5036 unsigned Offset = RegBindingOffset[static_cast<unsigned>(RegType)];
5037 RegBindingOffset[static_cast<unsigned>(RegType)] += Range;
5038
5039 unsigned NewSlotNumber = RBA->getSlotNumber() + Offset;
5040 StringRef NewSlotNumberStr =
5041 createRegisterString(AST, RegType: RBA->getRegisterType(), N: NewSlotNumber);
5042 NewAttr = HLSLResourceBindingAttr::CreateImplicit(
5043 Ctx&: AST, Slot: NewSlotNumberStr, Space: RBA->getSpace(), Range: RBA->getRange());
5044 NewAttr->setBinding(RT: RegType, SlotNum: NewSlotNumber, SpaceNum: RBA->getSpaceNumber());
5045 } else {
5046 // No binding attribute or space-only binding - create a binding
5047 // attribute for implicit binding.
5048 NewAttr = HLSLResourceBindingAttr::CreateImplicit(Ctx&: AST, Slot: "", Space: "0", Range: {});
5049 NewAttr->setBinding(RT: RegType, SlotNum: std::nullopt,
5050 SpaceNum: RBA ? RBA->getSpaceNumber() : 0);
5051 NewAttr->setImplicitBindingOrderID(S.getNextImplicitBindingOrderID());
5052 }
5053 return NewAttr;
5054 }
5055};
5056
5057// Creates a global variable declaration for a resource field embedded in a
5058// struct, assigns it a binding, initializes it, and associates it with the
5059// struct declaration via an HLSLAssociatedResourceDeclAttr.
5060static void createGlobalResourceDeclForStruct(
5061 Sema &S, VarDecl *ParentVD, SourceLocation Loc, IdentifierInfo *Id,
5062 QualType ResTy, StructBindingContext &BindingCtx) {
5063 assert(isResourceRecordTypeOrArrayOf(ResTy) &&
5064 "expected resource type or array of resources");
5065
5066 DeclContext *DC = ParentVD->getNonTransparentDeclContext();
5067 assert(DC->isTranslationUnit() && "expected translation unit decl context");
5068
5069 ASTContext &AST = S.getASTContext();
5070 VarDecl *ResDecl =
5071 VarDecl::Create(C&: AST, DC, StartLoc: Loc, IdLoc: Loc, Id, T: ResTy, TInfo: nullptr, S: SC_None);
5072
5073 unsigned Range = 1;
5074 const HLSLAttributedResourceType *ResHandleTy = nullptr;
5075 if (const auto *AT = dyn_cast<ArrayType>(Val: ResTy.getTypePtr())) {
5076 const auto *CAT = dyn_cast<ConstantArrayType>(Val: AT);
5077 Range = CAT ? CAT->getSize().getZExtValue() : 0;
5078 ResHandleTy = getResourceArrayHandleType(QT: ResTy);
5079 } else {
5080 ResHandleTy = HLSLAttributedResourceType::findHandleTypeOnResource(
5081 RT: ResTy.getTypePtr());
5082 }
5083 // Add a binding attribute to the global resource declaration.
5084 Attr *BindingAttr = BindingCtx.createBindingAttr(
5085 S&: S.HLSL(), AST, RegType: getRegisterType(ResTy: ResHandleTy), Range);
5086 ResDecl->addAttr(A: BindingAttr);
5087 ResDecl->addAttr(A: InternalLinkageAttr::CreateImplicit(Ctx&: AST));
5088 ResDecl->setImplicit();
5089
5090 if (Range == 1)
5091 S.HLSL().initGlobalResourceDecl(VD: ResDecl);
5092 else
5093 S.HLSL().initGlobalResourceArrayDecl(VD: ResDecl);
5094
5095 ParentVD->addAttr(
5096 A: HLSLAssociatedResourceDeclAttr::CreateImplicit(Ctx&: AST, ResDecl));
5097 DC->addDecl(D: ResDecl);
5098
5099 DeclGroupRef DG(ResDecl);
5100 S.Consumer.HandleTopLevelDecl(D: DG);
5101}
5102
5103static void handleArrayOfStructWithResources(
5104 Sema &S, VarDecl *ParentVD, const ConstantArrayType *CAT,
5105 EmbeddedResourceNameBuilder &NameBuilder, StructBindingContext &BindingCtx);
5106
5107// Scans base and all fields of a struct/class type to find all embedded
5108// resources or resource arrays. Creates a global variable for each resource
5109// found.
5110static void handleStructWithResources(Sema &S, VarDecl *ParentVD,
5111 const CXXRecordDecl *RD,
5112 EmbeddedResourceNameBuilder &NameBuilder,
5113 StructBindingContext &BindingCtx) {
5114
5115 // Scan the base classes.
5116 assert(RD->getNumBases() <= 1 && "HLSL doesn't support multiple inheritance");
5117 const auto *BasesIt = RD->bases_begin();
5118 if (BasesIt != RD->bases_end()) {
5119 QualType QT = BasesIt->getType();
5120 if (QT->isHLSLIntangibleType()) {
5121 CXXRecordDecl *BaseRD = QT->getAsCXXRecordDecl();
5122 NameBuilder.pushBaseName(N: BaseRD->getName());
5123 handleStructWithResources(S, ParentVD, RD: BaseRD, NameBuilder, BindingCtx);
5124 NameBuilder.pop();
5125 }
5126 }
5127 // Process this class fields.
5128 for (const FieldDecl *FD : RD->fields()) {
5129 QualType FDTy = FD->getType().getCanonicalType();
5130 if (!FDTy->isHLSLIntangibleType())
5131 continue;
5132
5133 NameBuilder.pushName(N: FD->getName());
5134
5135 if (isResourceRecordTypeOrArrayOf(Ty: FDTy)) {
5136 IdentifierInfo *II = NameBuilder.getNameAsIdentifier(AST&: S.getASTContext());
5137 createGlobalResourceDeclForStruct(S, ParentVD, Loc: FD->getLocation(), Id: II,
5138 ResTy: FDTy, BindingCtx);
5139 } else if (const auto *RD = FDTy->getAsCXXRecordDecl()) {
5140 handleStructWithResources(S, ParentVD, RD, NameBuilder, BindingCtx);
5141
5142 } else if (const auto *ArrayTy = dyn_cast<ConstantArrayType>(Val&: FDTy)) {
5143 assert(!FDTy->isHLSLResourceRecordArray() &&
5144 "resource arrays should have been already handled");
5145 handleArrayOfStructWithResources(S, ParentVD, CAT: ArrayTy, NameBuilder,
5146 BindingCtx);
5147 }
5148 NameBuilder.pop();
5149 }
5150}
5151
5152// Processes array of structs with resources.
5153static void
5154handleArrayOfStructWithResources(Sema &S, VarDecl *ParentVD,
5155 const ConstantArrayType *CAT,
5156 EmbeddedResourceNameBuilder &NameBuilder,
5157 StructBindingContext &BindingCtx) {
5158
5159 QualType ElementTy = CAT->getElementType().getCanonicalType();
5160 assert(ElementTy->isHLSLIntangibleType() && "Expected HLSL intangible type");
5161
5162 const ConstantArrayType *SubCAT = dyn_cast<ConstantArrayType>(Val&: ElementTy);
5163 const CXXRecordDecl *ElementRD = ElementTy->getAsCXXRecordDecl();
5164
5165 if (!SubCAT && !ElementRD)
5166 return;
5167
5168 for (unsigned I = 0, E = CAT->getSize().getZExtValue(); I < E; ++I) {
5169 NameBuilder.pushArrayIndex(Index: I);
5170 if (ElementRD)
5171 handleStructWithResources(S, ParentVD, RD: ElementRD, NameBuilder,
5172 BindingCtx);
5173 else
5174 handleArrayOfStructWithResources(S, ParentVD, CAT: SubCAT, NameBuilder,
5175 BindingCtx);
5176 NameBuilder.pop();
5177 }
5178}
5179
5180} // namespace
5181
5182// Scans all fields of a user-defined struct (or array of structs)
5183// to find all embedded resources or resource arrays. For each resource
5184// a global variable of the resource type is created and associated
5185// with the parent declaration (VD) through a HLSLAssociatedResourceDeclAttr
5186// attribute.
5187void SemaHLSL::handleGlobalStructOrArrayOfWithResources(VarDecl *VD) {
5188 EmbeddedResourceNameBuilder NameBuilder(VD->getName());
5189 StructBindingContext BindingCtx(VD);
5190
5191 const Type *VDTy = VD->getType().getTypePtr();
5192 assert(VDTy->isHLSLIntangibleType() && !isResourceRecordTypeOrArrayOf(VD) &&
5193 "Expected non-resource struct or array type");
5194
5195 if (const CXXRecordDecl *RD = VDTy->getAsCXXRecordDecl()) {
5196 handleStructWithResources(S&: SemaRef, ParentVD: VD, RD, NameBuilder, BindingCtx);
5197 return;
5198 }
5199
5200 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: VDTy)) {
5201 handleArrayOfStructWithResources(S&: SemaRef, ParentVD: VD, CAT, NameBuilder, BindingCtx);
5202 return;
5203 }
5204}
5205
5206void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
5207 if (VD->hasGlobalStorage()) {
5208 // make sure the declaration has a complete type
5209 if (SemaRef.RequireCompleteType(
5210 Loc: VD->getLocation(),
5211 T: SemaRef.getASTContext().getBaseElementType(QT: VD->getType()),
5212 DiagID: diag::err_typecheck_decl_incomplete_type)) {
5213 VD->setInvalidDecl();
5214 deduceAddressSpace(Decl: VD);
5215 return;
5216 }
5217
5218 // Global variables outside a cbuffer block that are not a resource, static,
5219 // groupshared, or an empty array or struct belong to the default constant
5220 // buffer $Globals (to be created at the end of the translation unit).
5221 if (IsDefaultBufferConstantDecl(Ctx: getASTContext(), VD)) {
5222 // update address space to hlsl_constant
5223 QualType NewTy = getASTContext().getAddrSpaceQualType(
5224 T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
5225 VD->setType(NewTy);
5226 DefaultCBufferDecls.push_back(Elt: VD);
5227 }
5228
5229 // find all resources bindings on decl
5230 if (VD->getType()->isHLSLIntangibleType())
5231 collectResourceBindingsOnVarDecl(D: VD);
5232
5233 if (VD->hasAttr<HLSLVkConstantIdAttr>())
5234 VD->setStorageClass(StorageClass::SC_Static);
5235
5236 if (isResourceRecordTypeOrArrayOf(VD) &&
5237 VD->getStorageClass() != SC_Static) {
5238 // Add internal linkage attribute to non-static resource variables. The
5239 // global externally visible storage is accessed through the handle, which
5240 // is a member. The variable itself is not externally visible.
5241 VD->addAttr(A: InternalLinkageAttr::CreateImplicit(Ctx&: getASTContext()));
5242 }
5243
5244 // process explicit bindings
5245 processExplicitBindingsOnDecl(D: VD);
5246
5247 // Add implicit binding attribute to non-static resource arrays.
5248 if (VD->getType()->isHLSLResourceRecordArray() &&
5249 VD->getStorageClass() != SC_Static) {
5250 // If the resource array does not have an explicit binding attribute,
5251 // create an implicit one. It will be used to transfer implicit binding
5252 // order_ID to codegen.
5253 ResourceBindingAttrs Binding(VD);
5254 if (!Binding.isExplicit()) {
5255 uint32_t OrderID = getNextImplicitBindingOrderID();
5256 if (Binding.hasBinding())
5257 Binding.setImplicitOrderID(OrderID);
5258 else {
5259 addImplicitBindingAttrToDecl(
5260 S&: SemaRef, D: VD, RT: getRegisterType(ResTy: getResourceArrayHandleType(VD)),
5261 ImplicitBindingOrderID: OrderID);
5262 // Re-create the binding object to pick up the new attribute.
5263 Binding = ResourceBindingAttrs(VD);
5264 }
5265 }
5266
5267 // Get to the base type of a potentially multi-dimensional array.
5268 QualType Ty = getASTContext().getBaseElementType(QT: VD->getType());
5269
5270 const CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
5271 if (hasCounterHandle(RD)) {
5272 if (!Binding.hasCounterImplicitOrderID()) {
5273 uint32_t OrderID = getNextImplicitBindingOrderID();
5274 Binding.setCounterImplicitOrderID(OrderID);
5275 }
5276 }
5277 }
5278
5279 // Process resources in user-defined structs, or arrays of such structs.
5280 const Type *VDTy = VD->getType().getTypePtr();
5281 if (VD->getStorageClass() != SC_Static && VDTy->isHLSLIntangibleType() &&
5282 !isResourceRecordTypeOrArrayOf(VD))
5283 handleGlobalStructOrArrayOfWithResources(VD);
5284
5285 // Mark groupshared variables as extern so they will have
5286 // external storage and won't be default initialized
5287 if (VD->hasAttr<HLSLGroupSharedAddressSpaceAttr>())
5288 VD->setStorageClass(StorageClass::SC_Extern);
5289 }
5290
5291 deduceAddressSpace(Decl: VD);
5292}
5293
5294bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
5295 assert(VD->getType()->isHLSLResourceRecord() &&
5296 "expected resource record type");
5297
5298 ASTContext &AST = SemaRef.getASTContext();
5299 uint64_t UIntTySize = AST.getTypeSize(T: AST.UnsignedIntTy);
5300 uint64_t IntTySize = AST.getTypeSize(T: AST.IntTy);
5301
5302 // Gather resource binding attributes.
5303 ResourceBindingAttrs Binding(VD);
5304
5305 // Find correct initialization method and create its arguments.
5306 QualType ResourceTy = VD->getType();
5307 CXXRecordDecl *ResourceDecl = ResourceTy->getAsCXXRecordDecl();
5308 CXXMethodDecl *CreateMethod = nullptr;
5309 llvm::SmallVector<Expr *> Args;
5310
5311 bool HasCounter = hasCounterHandle(RD: ResourceDecl);
5312 const char *CreateMethodName;
5313 if (Binding.isExplicit())
5314 CreateMethodName = HasCounter ? "__createFromBindingWithImplicitCounter"
5315 : "__createFromBinding";
5316 else
5317 CreateMethodName = HasCounter
5318 ? "__createFromImplicitBindingWithImplicitCounter"
5319 : "__createFromImplicitBinding";
5320
5321 CreateMethod =
5322 lookupMethod(S&: SemaRef, RecordDecl: ResourceDecl, Name: CreateMethodName, Loc: VD->getLocation());
5323
5324 if (!CreateMethod)
5325 // This can happen if someone creates a struct that looks like an HLSL
5326 // resource record but does not have the required static create method.
5327 // No binding will be generated for it.
5328 return false;
5329
5330 if (Binding.isExplicit()) {
5331 IntegerLiteral *RegSlot =
5332 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, Binding.getSlot()),
5333 type: AST.UnsignedIntTy, l: SourceLocation());
5334 Args.push_back(Elt: RegSlot);
5335 } else {
5336 uint32_t OrderID = (Binding.hasImplicitOrderID())
5337 ? Binding.getImplicitOrderID()
5338 : getNextImplicitBindingOrderID();
5339 IntegerLiteral *OrderId =
5340 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, OrderID),
5341 type: AST.UnsignedIntTy, l: SourceLocation());
5342 Args.push_back(Elt: OrderId);
5343 }
5344
5345 IntegerLiteral *Space =
5346 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, Binding.getSpace()),
5347 type: AST.UnsignedIntTy, l: SourceLocation());
5348 Args.push_back(Elt: Space);
5349
5350 IntegerLiteral *RangeSize = IntegerLiteral::Create(
5351 C: AST, V: llvm::APInt(IntTySize, 1), type: AST.IntTy, l: SourceLocation());
5352 Args.push_back(Elt: RangeSize);
5353
5354 IntegerLiteral *Index = IntegerLiteral::Create(
5355 C: AST, V: llvm::APInt(UIntTySize, 0), type: AST.UnsignedIntTy, l: SourceLocation());
5356 Args.push_back(Elt: Index);
5357
5358 StringRef VarName = VD->getName();
5359 StringLiteral *Name = StringLiteral::Create(
5360 Ctx: AST, Str: VarName, Kind: StringLiteralKind::Ordinary, Pascal: false,
5361 Ty: AST.getStringLiteralArrayType(EltTy: AST.CharTy.withConst(), Length: VarName.size()),
5362 Locs: SourceLocation());
5363 ImplicitCastExpr *NameCast = ImplicitCastExpr::Create(
5364 Context: AST, T: AST.getPointerType(T: AST.CharTy.withConst()), Kind: CK_ArrayToPointerDecay,
5365 Operand: Name, BasePath: nullptr, Cat: VK_PRValue, FPO: FPOptionsOverride());
5366 Args.push_back(Elt: NameCast);
5367
5368 if (HasCounter) {
5369 // Will this be in the correct order?
5370 uint32_t CounterOrderID = getNextImplicitBindingOrderID();
5371 IntegerLiteral *CounterId =
5372 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, CounterOrderID),
5373 type: AST.UnsignedIntTy, l: SourceLocation());
5374 Args.push_back(Elt: CounterId);
5375 }
5376
5377 // Make sure the create method template is instantiated and emitted.
5378 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
5379 SemaRef.InstantiateFunctionDefinition(PointOfInstantiation: VD->getLocation(), Function: CreateMethod,
5380 Recursive: true);
5381
5382 // Create CallExpr with a call to the static method and set it as the decl
5383 // initialization.
5384 DeclRefExpr *DRE = DeclRefExpr::Create(
5385 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: CreateMethod, RefersToEnclosingVariableOrCapture: false,
5386 NameInfo: CreateMethod->getNameInfo(), T: CreateMethod->getType(), VK: VK_PRValue);
5387
5388 auto *ImpCast = ImplicitCastExpr::Create(
5389 Context: AST, T: AST.getPointerType(T: CreateMethod->getType()),
5390 Kind: CK_FunctionToPointerDecay, Operand: DRE, BasePath: nullptr, Cat: VK_PRValue, FPO: FPOptionsOverride());
5391
5392 CallExpr *InitExpr =
5393 CallExpr::Create(Ctx: AST, Fn: ImpCast, Args, Ty: ResourceTy, VK: VK_PRValue,
5394 RParenLoc: SourceLocation(), FPFeatures: FPOptionsOverride());
5395 VD->setInit(InitExpr);
5396 VD->setInitStyle(VarDecl::CallInit);
5397 SemaRef.CheckCompleteVariableDeclaration(VD);
5398 return true;
5399}
5400
5401bool SemaHLSL::initGlobalResourceArrayDecl(VarDecl *VD) {
5402 assert(VD->getType()->isHLSLResourceRecordArray() &&
5403 "expected array of resource records");
5404
5405 // Individual resources in a resource array are not initialized here. They
5406 // are initialized later on during codegen when the individual resources are
5407 // accessed. Codegen will emit a call to the resource initialization method
5408 // with the specified array index. We need to make sure though that the method
5409 // for the specific resource type is instantiated, so codegen can emit a call
5410 // to it when the array element is accessed.
5411
5412 // Find correct initialization method based on the resource binding
5413 // information.
5414 ASTContext &AST = SemaRef.getASTContext();
5415 QualType ResElementTy = AST.getBaseElementType(QT: VD->getType());
5416 CXXRecordDecl *ResourceDecl = ResElementTy->getAsCXXRecordDecl();
5417 CXXMethodDecl *CreateMethod = nullptr;
5418
5419 bool HasCounter = hasCounterHandle(RD: ResourceDecl);
5420 ResourceBindingAttrs ResourceAttrs(VD);
5421 if (ResourceAttrs.isExplicit())
5422 // Resource has explicit binding.
5423 CreateMethod =
5424 lookupMethod(S&: SemaRef, RecordDecl: ResourceDecl,
5425 Name: HasCounter ? "__createFromBindingWithImplicitCounter"
5426 : "__createFromBinding",
5427 Loc: VD->getLocation());
5428 else
5429 // Resource has implicit binding.
5430 CreateMethod = lookupMethod(
5431 S&: SemaRef, RecordDecl: ResourceDecl,
5432 Name: HasCounter ? "__createFromImplicitBindingWithImplicitCounter"
5433 : "__createFromImplicitBinding",
5434 Loc: VD->getLocation());
5435
5436 if (!CreateMethod)
5437 return false;
5438
5439 // Make sure the create method template is instantiated and emitted.
5440 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
5441 SemaRef.InstantiateFunctionDefinition(PointOfInstantiation: VD->getLocation(), Function: CreateMethod,
5442 Recursive: true);
5443 return true;
5444}
5445
5446// Returns true if the initialization has been handled.
5447// Returns false to use default initialization.
5448bool SemaHLSL::ActOnUninitializedVarDecl(VarDecl *VD) {
5449 // Objects in the hlsl_constant address space are initialized
5450 // externally, so don't synthesize an implicit initializer.
5451 if (VD->getType().getAddressSpace() == LangAS::hlsl_constant)
5452 return true;
5453
5454 // Initialize non-static resources at the global scope.
5455 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5456 const Type *Ty = VD->getType().getTypePtr();
5457 if (Ty->isHLSLResourceRecord())
5458 return initGlobalResourceDecl(VD);
5459 if (Ty->isHLSLResourceRecordArray())
5460 return initGlobalResourceArrayDecl(VD);
5461 }
5462 return false;
5463}
5464
5465std::optional<const DeclBindingInfo *> SemaHLSL::inferGlobalBinding(Expr *E) {
5466 if (auto *Ternary = dyn_cast<ConditionalOperator>(Val: E)) {
5467 auto TrueInfo = inferGlobalBinding(E: Ternary->getTrueExpr());
5468 auto FalseInfo = inferGlobalBinding(E: Ternary->getFalseExpr());
5469 if (!TrueInfo || !FalseInfo)
5470 return std::nullopt;
5471 if (*TrueInfo != *FalseInfo)
5472 return std::nullopt;
5473 return TrueInfo;
5474 }
5475
5476 if (auto *ASE = dyn_cast<ArraySubscriptExpr>(Val: E))
5477 E = ASE->getBase()->IgnoreParenImpCasts();
5478
5479 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Val: E->IgnoreParens()))
5480 if (VarDecl *VD = dyn_cast<VarDecl>(Val: DRE->getDecl())) {
5481 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5482 if (Ty->isArrayType())
5483 Ty = Ty->getArrayElementTypeNoTypeQual();
5484
5485 if (const auto *AttrResType =
5486 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
5487 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
5488 return Bindings.getDeclBindingInfo(VD, ResClass: RC);
5489 }
5490 }
5491
5492 return nullptr;
5493}
5494
5495void SemaHLSL::trackLocalResource(VarDecl *VD, Expr *E) {
5496 std::optional<const DeclBindingInfo *> ExprBinding = inferGlobalBinding(E);
5497 if (!ExprBinding) {
5498 SemaRef.Diag(Loc: E->getBeginLoc(),
5499 DiagID: diag::warn_hlsl_assigning_local_resource_is_not_unique)
5500 << E << VD;
5501 return; // Expr use multiple resources
5502 }
5503
5504 if (*ExprBinding == nullptr)
5505 return; // No binding could be inferred to track, return without error
5506
5507 auto PrevBinding = Assigns.find(Val: VD);
5508 if (PrevBinding == Assigns.end()) {
5509 // No previous binding recorded, simply record the new assignment
5510 Assigns.insert(KV: {VD, *ExprBinding});
5511 return;
5512 }
5513
5514 // Otherwise, warn if the assignment implies different resource bindings
5515 if (*ExprBinding != PrevBinding->second) {
5516 SemaRef.Diag(Loc: E->getBeginLoc(),
5517 DiagID: diag::warn_hlsl_assigning_local_resource_is_not_unique)
5518 << E << VD;
5519 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::note_var_declared_here) << VD;
5520 return;
5521 }
5522
5523 return;
5524}
5525
5526bool SemaHLSL::CheckResourceBinOp(BinaryOperatorKind Opc, Expr *LHSExpr,
5527 Expr *RHSExpr, SourceLocation Loc) {
5528 assert((LHSExpr->getType()->isHLSLResourceRecord() ||
5529 LHSExpr->getType()->isHLSLResourceRecordArray()) &&
5530 "expected LHS to be a resource record or array of resource records");
5531 if (Opc != BO_Assign)
5532 return true;
5533
5534 // If LHS is an array subscript, get the underlying declaration.
5535 Expr *E = LHSExpr;
5536 while (auto *ASE = dyn_cast<ArraySubscriptExpr>(Val: E))
5537 E = ASE->getBase()->IgnoreParenImpCasts();
5538
5539 // Report error if LHS is a non-static resource declared at a global scope.
5540 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Val: E->IgnoreParens())) {
5541 if (VarDecl *VD = dyn_cast<VarDecl>(Val: DRE->getDecl())) {
5542 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5543 // assignment to global resource is not allowed
5544 SemaRef.Diag(Loc, DiagID: diag::err_hlsl_assign_to_global_resource) << VD;
5545 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::note_var_declared_here) << VD;
5546 return false;
5547 }
5548
5549 trackLocalResource(VD, E: RHSExpr);
5550 }
5551 }
5552 return true;
5553}
5554
5555// Walks though the global variable declaration, collects all resource binding
5556// requirements and adds them to Bindings
5557void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
5558 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
5559 "expected global variable that contains HLSL resource");
5560
5561 // Cbuffers and Tbuffers are HLSLBufferDecl types
5562 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: VD)) {
5563 Bindings.addDeclBindingInfo(VD, ResClass: CBufferOrTBuffer->isCBuffer()
5564 ? ResourceClass::CBuffer
5565 : ResourceClass::SRV);
5566 return;
5567 }
5568
5569 // Unwrap arrays
5570 // FIXME: Calculate array size while unwrapping
5571 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5572 while (Ty->isArrayType()) {
5573 const ArrayType *AT = cast<ArrayType>(Val: Ty);
5574 Ty = AT->getElementType()->getUnqualifiedDesugaredType();
5575 }
5576
5577 // Resource (or array of resources)
5578 if (const HLSLAttributedResourceType *AttrResType =
5579 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
5580 Bindings.addDeclBindingInfo(VD, ResClass: AttrResType->getAttrs().ResourceClass);
5581 return;
5582 }
5583
5584 // User defined record type
5585 if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty))
5586 collectResourceBindingsOnUserRecordDecl(VD, RT);
5587}
5588
5589// Walks though the explicit resource binding attributes on the declaration,
5590// and makes sure there is a resource that matched the binding and updates
5591// DeclBindingInfoLists
5592void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
5593 assert(VD->hasGlobalStorage() && "expected global variable");
5594
5595 bool HasBinding = false;
5596 for (Attr *A : VD->attrs()) {
5597 if (isa<HLSLVkBindingAttr>(Val: A)) {
5598 HasBinding = true;
5599 if (auto PA = VD->getAttr<HLSLVkPushConstantAttr>())
5600 Diag(Loc: PA->getLoc(), DiagID: diag::err_hlsl_attr_incompatible) << A << PA;
5601 }
5602
5603 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: A);
5604 if (!RBA || !RBA->hasRegisterSlot())
5605 continue;
5606 HasBinding = true;
5607
5608 RegisterType RT = RBA->getRegisterType();
5609 assert(RT != RegisterType::I && "invalid or obsolete register type should "
5610 "never have an attribute created");
5611
5612 if (RT == RegisterType::C) {
5613 if (Bindings.hasBindingInfoForDecl(VD))
5614 SemaRef.Diag(Loc: VD->getLocation(),
5615 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
5616 << static_cast<int>(RT);
5617 continue;
5618 }
5619
5620 // Find DeclBindingInfo for this binding and update it, or report error
5621 // if it does not exist (user type does to contain resources with the
5622 // expected resource class).
5623 ResourceClass RC = getResourceClass(RT);
5624 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, ResClass: RC)) {
5625 // update binding info
5626 BI->setBindingAttribute(A: RBA, BT: BindingType::Explicit);
5627 } else {
5628 SemaRef.Diag(Loc: VD->getLocation(),
5629 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
5630 << static_cast<int>(RT);
5631 }
5632 }
5633
5634 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
5635 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
5636}
5637namespace {
5638class InitListTransformer {
5639 Sema &S;
5640 ASTContext &Ctx;
5641 QualType InitTy;
5642 QualType *DstIt = nullptr;
5643 Expr **ArgIt = nullptr;
5644 // Is wrapping the destination type iterator required? This is only used for
5645 // incomplete array types where we loop over the destination type since we
5646 // don't know the full number of elements from the declaration.
5647 bool Wrap;
5648
5649 bool castInitializer(Expr *E) {
5650 assert(DstIt && "This should always be something!");
5651 if (DstIt == DestTypes.end()) {
5652 if (!Wrap) {
5653 ArgExprs.push_back(Elt: E);
5654 // This is odd, but it isn't technically a failure due to conversion, we
5655 // handle mismatched counts of arguments differently.
5656 return true;
5657 }
5658 DstIt = DestTypes.begin();
5659 }
5660 InitializedEntity Entity = InitializedEntity::InitializeParameter(
5661 Context&: Ctx, Type: *DstIt, /* Consumed (ObjC) */ Consumed: false);
5662 ExprResult Res = S.PerformCopyInitialization(Entity, EqualLoc: E->getBeginLoc(), Init: E);
5663 if (Res.isInvalid())
5664 return false;
5665 Expr *Init = Res.get();
5666 ArgExprs.push_back(Elt: Init);
5667 DstIt++;
5668 return true;
5669 }
5670
5671 bool buildInitializerListImpl(Expr *E) {
5672 // If this is an initialization list, traverse the sub initializers.
5673 if (auto *Init = dyn_cast<InitListExpr>(Val: E)) {
5674 for (auto *SubInit : Init->inits())
5675 if (!buildInitializerListImpl(E: SubInit))
5676 return false;
5677 return true;
5678 }
5679
5680 // If this is a scalar type, just enqueue the expression.
5681 QualType Ty = E->getType().getDesugaredType(Context: Ctx);
5682
5683 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()) ||
5684 Ty->isHLSLAttributedResourceType())
5685 return castInitializer(E);
5686
5687 if (auto *VecTy = Ty->getAs<VectorType>()) {
5688 uint64_t Size = VecTy->getNumElements();
5689
5690 QualType SizeTy = Ctx.getSizeType();
5691 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
5692 for (uint64_t I = 0; I < Size; ++I) {
5693 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
5694 type: SizeTy, l: SourceLocation());
5695
5696 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
5697 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
5698 if (ElExpr.isInvalid())
5699 return false;
5700 if (!castInitializer(E: ElExpr.get()))
5701 return false;
5702 }
5703 return true;
5704 }
5705 if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
5706 unsigned Rows = MTy->getNumRows();
5707 unsigned Cols = MTy->getNumColumns();
5708 QualType ElemTy = MTy->getElementType();
5709
5710 for (unsigned R = 0; R < Rows; ++R) {
5711 for (unsigned C = 0; C < Cols; ++C) {
5712 // row index literal
5713 Expr *RowIdx = IntegerLiteral::Create(
5714 C: Ctx, V: llvm::APInt(Ctx.getIntWidth(T: Ctx.IntTy), R), type: Ctx.IntTy,
5715 l: E->getBeginLoc());
5716 // column index literal
5717 Expr *ColIdx = IntegerLiteral::Create(
5718 C: Ctx, V: llvm::APInt(Ctx.getIntWidth(T: Ctx.IntTy), C), type: Ctx.IntTy,
5719 l: E->getBeginLoc());
5720 ExprResult ElExpr = S.CreateBuiltinMatrixSubscriptExpr(
5721 Base: E, RowIdx, ColumnIdx: ColIdx, RBLoc: E->getEndLoc());
5722 if (ElExpr.isInvalid())
5723 return false;
5724 if (!castInitializer(E: ElExpr.get()))
5725 return false;
5726 ElExpr.get()->setType(ElemTy);
5727 }
5728 }
5729 return true;
5730 }
5731
5732 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Val: Ty.getTypePtr())) {
5733 uint64_t Size = ArrTy->getZExtSize();
5734 QualType SizeTy = Ctx.getSizeType();
5735 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
5736 for (uint64_t I = 0; I < Size; ++I) {
5737 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
5738 type: SizeTy, l: SourceLocation());
5739 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
5740 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
5741 if (ElExpr.isInvalid())
5742 return false;
5743 if (!buildInitializerListImpl(E: ElExpr.get()))
5744 return false;
5745 }
5746 return true;
5747 }
5748
5749 if (auto *RD = Ty->getAsCXXRecordDecl()) {
5750 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
5751 RecordDecls.push_back(Elt: RD);
5752 // If this is a prvalue create an xvalue so the member accesses
5753 // will be xvalues.
5754 if (E->isPRValue())
5755 E = new (Ctx)
5756 MaterializeTemporaryExpr(Ty, E, /*BoundToLvalueReference=*/false);
5757 while (RecordDecls.back()->getNumBases()) {
5758 CXXRecordDecl *D = RecordDecls.back();
5759 assert(D->getNumBases() == 1 &&
5760 "HLSL doesn't support multiple inheritance");
5761 RecordDecls.push_back(
5762 Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
5763 }
5764 while (!RecordDecls.empty()) {
5765 CXXRecordDecl *RD = RecordDecls.pop_back_val();
5766 for (auto *FD : RD->fields()) {
5767 if (FD->isUnnamedBitField())
5768 continue;
5769 DeclAccessPair Found = DeclAccessPair::make(D: FD, AS: FD->getAccess());
5770 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
5771 ExprResult Res = S.BuildFieldReferenceExpr(
5772 BaseExpr: E, IsArrow: false, OpLoc: E->getBeginLoc(), SS: CXXScopeSpec(), Field: FD, FoundDecl: Found, MemberNameInfo: NameInfo);
5773 if (Res.isInvalid())
5774 return false;
5775 if (!buildInitializerListImpl(E: Res.get()))
5776 return false;
5777 }
5778 }
5779 }
5780 return true;
5781 }
5782
5783 Expr *generateInitListsImpl(QualType Ty) {
5784 Ty = Ty.getDesugaredType(Context: Ctx);
5785 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
5786 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()) ||
5787 Ty->isHLSLAttributedResourceType())
5788 return *(ArgIt++);
5789
5790 llvm::SmallVector<Expr *> Inits;
5791 if (Ty->isVectorType() || Ty->isConstantArrayType() ||
5792 Ty->isConstantMatrixType()) {
5793 QualType ElTy;
5794 uint64_t Size = 0;
5795 if (auto *ATy = Ty->getAs<VectorType>()) {
5796 ElTy = ATy->getElementType();
5797 Size = ATy->getNumElements();
5798 } else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
5799 ElTy = CMTy->getElementType();
5800 Size = CMTy->getNumElementsFlattened();
5801 } else {
5802 auto *VTy = cast<ConstantArrayType>(Val: Ty.getTypePtr());
5803 ElTy = VTy->getElementType();
5804 Size = VTy->getZExtSize();
5805 }
5806 for (uint64_t I = 0; I < Size; ++I)
5807 Inits.push_back(Elt: generateInitListsImpl(Ty: ElTy));
5808 }
5809 if (auto *RD = Ty->getAsCXXRecordDecl()) {
5810 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
5811 RecordDecls.push_back(Elt: RD);
5812 while (RecordDecls.back()->getNumBases()) {
5813 CXXRecordDecl *D = RecordDecls.back();
5814 assert(D->getNumBases() == 1 &&
5815 "HLSL doesn't support multiple inheritance");
5816 RecordDecls.push_back(
5817 Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
5818 }
5819 while (!RecordDecls.empty()) {
5820 CXXRecordDecl *RD = RecordDecls.pop_back_val();
5821 for (auto *FD : RD->fields())
5822 if (!FD->isUnnamedBitField())
5823 Inits.push_back(Elt: generateInitListsImpl(Ty: FD->getType()));
5824 }
5825 }
5826 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
5827 Inits, Inits.back()->getEndLoc());
5828 NewInit->setType(Ty);
5829 return NewInit;
5830 }
5831
5832public:
5833 llvm::SmallVector<QualType, 16> DestTypes;
5834 llvm::SmallVector<Expr *, 16> ArgExprs;
5835 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
5836 : S(SemaRef), Ctx(SemaRef.getASTContext()),
5837 Wrap(Entity.getType()->isIncompleteArrayType()) {
5838 InitTy = Entity.getType().getNonReferenceType();
5839 // When we're generating initializer lists for incomplete array types we
5840 // need to wrap around both when building the initializers and when
5841 // generating the final initializer lists.
5842 if (Wrap) {
5843 assert(InitTy->isIncompleteArrayType());
5844 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(T: InitTy);
5845 InitTy = IAT->getElementType();
5846 }
5847 BuildFlattenedTypeList(BaseTy: InitTy, List&: DestTypes);
5848 DstIt = DestTypes.begin();
5849 }
5850
5851 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
5852
5853 Expr *generateInitLists() {
5854 assert(!ArgExprs.empty() &&
5855 "Call buildInitializerList to generate argument expressions.");
5856 ArgIt = ArgExprs.begin();
5857 if (!Wrap)
5858 return generateInitListsImpl(Ty: InitTy);
5859 llvm::SmallVector<Expr *> Inits;
5860 while (ArgIt != ArgExprs.end())
5861 Inits.push_back(Elt: generateInitListsImpl(Ty: InitTy));
5862
5863 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
5864 Inits, Inits.back()->getEndLoc());
5865 llvm::APInt ArySize(64, Inits.size());
5866 NewInit->setType(Ctx.getConstantArrayType(EltTy: InitTy, ArySize, SizeExpr: nullptr,
5867 ASM: ArraySizeModifier::Normal, IndexTypeQuals: 0));
5868 return NewInit;
5869 }
5870};
5871} // namespace
5872
5873// Recursively detect any incomplete array anywhere in the type graph,
5874// including arrays, struct fields, and base classes.
5875static bool containsIncompleteArrayType(QualType Ty) {
5876 Ty = Ty.getCanonicalType();
5877
5878 // Array types
5879 if (const ArrayType *AT = dyn_cast<ArrayType>(Val&: Ty)) {
5880 if (isa<IncompleteArrayType>(Val: AT))
5881 return true;
5882 return containsIncompleteArrayType(Ty: AT->getElementType());
5883 }
5884
5885 // Record (struct/class) types
5886 if (const auto *RT = Ty->getAs<RecordType>()) {
5887 const RecordDecl *RD = RT->getDecl();
5888
5889 // Walk base classes (for C++ / HLSL structs with inheritance)
5890 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
5891 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
5892 if (containsIncompleteArrayType(Ty: Base.getType()))
5893 return true;
5894 }
5895 }
5896
5897 // Walk fields
5898 for (const FieldDecl *F : RD->fields()) {
5899 if (containsIncompleteArrayType(Ty: F->getType()))
5900 return true;
5901 }
5902 }
5903
5904 return false;
5905}
5906
5907bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
5908 InitListExpr *Init) {
5909 // If the initializer is a scalar, just return it.
5910 if (Init->getType()->isScalarType())
5911 return true;
5912 ASTContext &Ctx = SemaRef.getASTContext();
5913 InitListTransformer ILT(SemaRef, Entity);
5914
5915 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
5916 Expr *E = Init->getInit(Init: I);
5917 if (E->HasSideEffects(Ctx)) {
5918 QualType Ty = E->getType();
5919 if (Ty->isRecordType())
5920 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
5921 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
5922 E->getObjectKind(), E);
5923 Init->setInit(Init: I, expr: E);
5924 }
5925 if (!ILT.buildInitializerList(E))
5926 return false;
5927 }
5928 size_t ExpectedSize = ILT.DestTypes.size();
5929 size_t ActualSize = ILT.ArgExprs.size();
5930 if (ExpectedSize == 0 && ActualSize == 0)
5931 return true;
5932
5933 // Reject empty initializer if *any* incomplete array exists structurally
5934 if (ActualSize == 0 && containsIncompleteArrayType(Ty: Entity.getType())) {
5935 QualType InitTy = Entity.getType().getNonReferenceType();
5936 if (InitTy.hasAddressSpace())
5937 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
5938
5939 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
5940 << /*TooManyOrFew=*/(int)(ExpectedSize < ActualSize) << InitTy
5941 << /*ExpectedSize=*/ExpectedSize << /*ActualSize=*/ActualSize;
5942 return false;
5943 }
5944
5945 // We infer size after validating legality.
5946 // For incomplete arrays it is completely arbitrary to choose whether we think
5947 // the user intended fewer or more elements. This implementation assumes that
5948 // the user intended more, and errors that there are too few initializers to
5949 // complete the final element.
5950 if (Entity.getType()->isIncompleteArrayType()) {
5951 assert(ExpectedSize > 0 &&
5952 "The expected size of an incomplete array type must be at least 1.");
5953 ExpectedSize =
5954 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
5955 }
5956
5957 // An initializer list might be attempting to initialize a reference or
5958 // rvalue-reference. When checking the initializer we should look through
5959 // the reference.
5960 QualType InitTy = Entity.getType().getNonReferenceType();
5961 if (InitTy.hasAddressSpace())
5962 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
5963 if (ExpectedSize != ActualSize) {
5964 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
5965 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
5966 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
5967 return false;
5968 }
5969
5970 // generateInitListsImpl will always return an InitListExpr here, because the
5971 // scalar case is handled above.
5972 auto *NewInit = cast<InitListExpr>(Val: ILT.generateInitLists());
5973 Init->resizeInits(Context: Ctx, NumInits: NewInit->getNumInits());
5974 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
5975 Init->updateInit(C: Ctx, Init: I, expr: NewInit->getInit(Init: I));
5976 return true;
5977}
5978
5979static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name,
5980 StringRef Expected,
5981 SourceLocation OpLoc,
5982 SourceLocation CompLoc) {
5983 S.Diag(Loc: OpLoc, DiagID: diag::err_builtin_matrix_invalid_member)
5984 << Name << Expected << SourceRange(CompLoc);
5985 return QualType();
5986}
5987
5988QualType SemaHLSL::checkMatrixComponent(Sema &S, QualType baseType,
5989 ExprValueKind &VK, SourceLocation OpLoc,
5990 const IdentifierInfo *CompName,
5991 SourceLocation CompLoc) {
5992 const auto *MT = baseType->castAs<ConstantMatrixType>();
5993 StringRef AccessorName = CompName->getName();
5994 assert(!AccessorName.empty() && "Matrix Accessor must have a name");
5995
5996 unsigned Rows = MT->getNumRows();
5997 unsigned Cols = MT->getNumColumns();
5998 bool IsZeroBasedAccessor = false;
5999 unsigned ChunkLen = 0;
6000 if (AccessorName.size() < 2)
6001 return ReportMatrixInvalidMember(S, Name: AccessorName,
6002 Expected: "length 4 for zero based: \'_mRC\' or "
6003 "length 3 for one-based: \'_RC\' accessor",
6004 OpLoc, CompLoc);
6005
6006 if (AccessorName[0] == '_') {
6007 if (AccessorName[1] == 'm') {
6008 IsZeroBasedAccessor = true;
6009 ChunkLen = 4; // zero-based: "_mRC"
6010 } else {
6011 ChunkLen = 3; // one-based: "_RC"
6012 }
6013 } else
6014 return ReportMatrixInvalidMember(
6015 S, Name: AccessorName, Expected: "zero based: \'_mRC\' or one-based: \'_RC\' accessor",
6016 OpLoc, CompLoc);
6017
6018 if (AccessorName.size() % ChunkLen != 0) {
6019 const llvm::StringRef Expected = IsZeroBasedAccessor
6020 ? "zero based: '_mRC' accessor"
6021 : "one-based: '_RC' accessor";
6022
6023 return ReportMatrixInvalidMember(S, Name: AccessorName, Expected, OpLoc, CompLoc);
6024 }
6025
6026 auto isDigit = [](char c) { return c >= '0' && c <= '9'; };
6027 auto isZeroBasedIndex = [](unsigned i) { return i <= 3; };
6028 auto isOneBasedIndex = [](unsigned i) { return i >= 1 && i <= 4; };
6029
6030 bool HasRepeated = false;
6031 SmallVector<bool, 16> Seen(Rows * Cols, false);
6032 unsigned NumComponents = 0;
6033 const char *Begin = AccessorName.data();
6034
6035 for (unsigned I = 0, E = AccessorName.size(); I < E; I += ChunkLen) {
6036 const char *Chunk = Begin + I;
6037 char RowChar = 0, ColChar = 0;
6038 if (IsZeroBasedAccessor) {
6039 // Zero-based: "_mRC"
6040 if (Chunk[0] != '_' || Chunk[1] != 'm') {
6041 char Bad = (Chunk[0] != '_') ? Chunk[0] : Chunk[1];
6042 return ReportMatrixInvalidMember(
6043 S, Name: StringRef(&Bad, 1), Expected: "\'_m\' prefix",
6044 OpLoc: OpLoc.getLocWithOffset(Offset: I + (Bad == Chunk[0] ? 1 : 2)), CompLoc);
6045 }
6046 RowChar = Chunk[2];
6047 ColChar = Chunk[3];
6048 } else {
6049 // One-based: "_RC"
6050 if (Chunk[0] != '_')
6051 return ReportMatrixInvalidMember(
6052 S, Name: StringRef(&Chunk[0], 1), Expected: "\'_\' prefix",
6053 OpLoc: OpLoc.getLocWithOffset(Offset: I + 1), CompLoc);
6054 RowChar = Chunk[1];
6055 ColChar = Chunk[2];
6056 }
6057
6058 // Must be digits.
6059 bool IsDigitsError = false;
6060 if (!isDigit(RowChar)) {
6061 unsigned BadPos = IsZeroBasedAccessor ? 2 : 1;
6062 ReportMatrixInvalidMember(S, Name: StringRef(&RowChar, 1), Expected: "row as integer",
6063 OpLoc: OpLoc.getLocWithOffset(Offset: I + BadPos + 1),
6064 CompLoc);
6065 IsDigitsError = true;
6066 }
6067
6068 if (!isDigit(ColChar)) {
6069 unsigned BadPos = IsZeroBasedAccessor ? 3 : 2;
6070 ReportMatrixInvalidMember(S, Name: StringRef(&ColChar, 1), Expected: "column as integer",
6071 OpLoc: OpLoc.getLocWithOffset(Offset: I + BadPos + 1),
6072 CompLoc);
6073 IsDigitsError = true;
6074 }
6075 if (IsDigitsError)
6076 return QualType();
6077
6078 unsigned Row = RowChar - '0';
6079 unsigned Col = ColChar - '0';
6080
6081 bool HasIndexingError = false;
6082 if (IsZeroBasedAccessor) {
6083 // 0-based [0..3]
6084 if (!isZeroBasedIndex(Row)) {
6085 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
6086 << /*row*/ 0 << /*zero-based*/ 0 << SourceRange(CompLoc);
6087 HasIndexingError = true;
6088 }
6089 if (!isZeroBasedIndex(Col)) {
6090 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
6091 << /*col*/ 1 << /*zero-based*/ 0 << SourceRange(CompLoc);
6092 HasIndexingError = true;
6093 }
6094 } else {
6095 // 1-based [1..4]
6096 if (!isOneBasedIndex(Row)) {
6097 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
6098 << /*row*/ 0 << /*one-based*/ 1 << SourceRange(CompLoc);
6099 HasIndexingError = true;
6100 }
6101 if (!isOneBasedIndex(Col)) {
6102 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
6103 << /*col*/ 1 << /*one-based*/ 1 << SourceRange(CompLoc);
6104 HasIndexingError = true;
6105 }
6106 // Convert to 0-based after range checking.
6107 --Row;
6108 --Col;
6109 }
6110
6111 if (HasIndexingError)
6112 return QualType();
6113
6114 // Note: matrix swizzle index is hard coded. That means Row and Col can
6115 // potentially be larger than Rows and Cols if matrix size is less than
6116 // the max index size.
6117 bool HasBoundsError = false;
6118 if (Row >= Rows) {
6119 Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_index_out_of_bounds)
6120 << /*Row*/ 0 << Row << Rows << SourceRange(CompLoc);
6121 HasBoundsError = true;
6122 }
6123 if (Col >= Cols) {
6124 Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_index_out_of_bounds)
6125 << /*Col*/ 1 << Col << Cols << SourceRange(CompLoc);
6126 HasBoundsError = true;
6127 }
6128 if (HasBoundsError)
6129 return QualType();
6130
6131 unsigned FlatIndex = Row * Cols + Col;
6132 if (Seen[FlatIndex])
6133 HasRepeated = true;
6134 Seen[FlatIndex] = true;
6135 ++NumComponents;
6136 }
6137 if (NumComponents == 0 || NumComponents > 4) {
6138 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_swizzle_invalid_length)
6139 << NumComponents << SourceRange(CompLoc);
6140 return QualType();
6141 }
6142
6143 QualType ElemTy = MT->getElementType();
6144 if (NumComponents == 1)
6145 return ElemTy;
6146 QualType VT = S.Context.getExtVectorType(VectorType: ElemTy, NumElts: NumComponents);
6147 if (HasRepeated)
6148 VK = VK_PRValue;
6149
6150 for (Sema::ExtVectorDeclsType::iterator
6151 I = S.ExtVectorDecls.begin(source: S.getExternalSource()),
6152 E = S.ExtVectorDecls.end();
6153 I != E; ++I) {
6154 if ((*I)->getUnderlyingType() == VT)
6155 return S.Context.getTypedefType(Keyword: ElaboratedTypeKeyword::None,
6156 /*Qualifier=*/std::nullopt, Decl: *I);
6157 }
6158
6159 return VT;
6160}
6161
6162bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
6163 // If initializing a local resource, track the resource binding it is using
6164 if (VDecl->getType()->isHLSLResourceRecord() && !VDecl->hasGlobalStorage())
6165 trackLocalResource(VD: VDecl, E: Init);
6166
6167 const HLSLVkConstantIdAttr *ConstIdAttr =
6168 VDecl->getAttr<HLSLVkConstantIdAttr>();
6169 if (!ConstIdAttr)
6170 return true;
6171
6172 ASTContext &Context = SemaRef.getASTContext();
6173
6174 APValue InitValue;
6175 if (!Init->isCXX11ConstantExpr(Ctx: Context, Result: &InitValue)) {
6176 Diag(Loc: VDecl->getLocation(), DiagID: diag::err_specialization_const);
6177 VDecl->setInvalidDecl();
6178 return false;
6179 }
6180
6181 Builtin::ID BID =
6182 getSpecConstBuiltinId(Type: VDecl->getType()->getUnqualifiedDesugaredType());
6183
6184 // Argument 1: The ID from the attribute
6185 int ConstantID = ConstIdAttr->getId();
6186 llvm::APInt IDVal(Context.getIntWidth(T: Context.IntTy), ConstantID);
6187 Expr *IdExpr = IntegerLiteral::Create(C: Context, V: IDVal, type: Context.IntTy,
6188 l: ConstIdAttr->getLocation());
6189
6190 SmallVector<Expr *, 2> Args = {IdExpr, Init};
6191 Expr *C = SemaRef.BuildBuiltinCallExpr(Loc: Init->getExprLoc(), Id: BID, CallArgs: Args);
6192 if (C->getType()->getCanonicalTypeUnqualified() !=
6193 VDecl->getType()->getCanonicalTypeUnqualified()) {
6194 C = SemaRef
6195 .BuildCStyleCastExpr(LParenLoc: SourceLocation(),
6196 Ty: Context.getTrivialTypeSourceInfo(
6197 T: Init->getType(), Loc: Init->getExprLoc()),
6198 RParenLoc: SourceLocation(), Op: C)
6199 .get();
6200 }
6201 Init = C;
6202 return true;
6203}
6204
6205QualType SemaHLSL::ActOnTemplateShorthand(TemplateDecl *Template,
6206 SourceLocation NameLoc) {
6207 if (!Template)
6208 return QualType();
6209
6210 DeclContext *DC = Template->getDeclContext();
6211 if (!DC->isNamespace() || !cast<NamespaceDecl>(Val: DC)->getIdentifier() ||
6212 cast<NamespaceDecl>(Val: DC)->getName() != "hlsl")
6213 return QualType();
6214
6215 TemplateParameterList *Params = Template->getTemplateParameters();
6216 if (!Params || Params->size() != 1)
6217 return QualType();
6218
6219 if (!Template->isImplicit())
6220 return QualType();
6221
6222 // We manually extract default arguments here instead of letting
6223 // CheckTemplateIdType handle it. This ensures that for resource types that
6224 // lack a default argument (like Buffer), we return a null QualType, which
6225 // triggers the "requires template arguments" error rather than a less
6226 // descriptive "too few template arguments" error.
6227 TemplateArgumentListInfo TemplateArgs(NameLoc, NameLoc);
6228 for (NamedDecl *P : *Params) {
6229 if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(Val: P)) {
6230 if (TTP->hasDefaultArgument()) {
6231 TemplateArgs.addArgument(Loc: TTP->getDefaultArgument());
6232 continue;
6233 }
6234 } else if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(Val: P)) {
6235 if (NTTP->hasDefaultArgument()) {
6236 TemplateArgs.addArgument(Loc: NTTP->getDefaultArgument());
6237 continue;
6238 }
6239 } else if (auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(Val: P)) {
6240 if (TTPD->hasDefaultArgument()) {
6241 TemplateArgs.addArgument(Loc: TTPD->getDefaultArgument());
6242 continue;
6243 }
6244 }
6245 return QualType();
6246 }
6247
6248 return SemaRef.CheckTemplateIdType(
6249 Keyword: ElaboratedTypeKeyword::None, Template: TemplateName(Template), TemplateLoc: NameLoc,
6250 TemplateArgs, Scope: nullptr, /*ForNestedNameSpecifier=*/false);
6251}
6252