1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
12#include "clang/AST/ASTConsumer.h"
13#include "clang/AST/ASTContext.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Decl.h"
16#include "clang/AST/DeclBase.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/DeclarationName.h"
19#include "clang/AST/DynamicRecursiveASTVisitor.h"
20#include "clang/AST/Expr.h"
21#include "clang/AST/HLSLResource.h"
22#include "clang/AST/Type.h"
23#include "clang/AST/TypeBase.h"
24#include "clang/AST/TypeLoc.h"
25#include "clang/Basic/Builtins.h"
26#include "clang/Basic/DiagnosticSema.h"
27#include "clang/Basic/IdentifierTable.h"
28#include "clang/Basic/LLVM.h"
29#include "clang/Basic/SourceLocation.h"
30#include "clang/Basic/Specifiers.h"
31#include "clang/Basic/TargetInfo.h"
32#include "clang/Sema/Initialization.h"
33#include "clang/Sema/Lookup.h"
34#include "clang/Sema/ParsedAttr.h"
35#include "clang/Sema/Sema.h"
36#include "clang/Sema/Template.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/StringExtras.h"
41#include "llvm/ADT/StringRef.h"
42#include "llvm/ADT/Twine.h"
43#include "llvm/Frontend/HLSL/HLSLBinding.h"
44#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
45#include "llvm/Support/Casting.h"
46#include "llvm/Support/DXILABI.h"
47#include "llvm/Support/ErrorHandling.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/TargetParser/Triple.h"
50#include <cmath>
51#include <cstddef>
52#include <iterator>
53#include <utility>
54
55using namespace clang;
56using namespace clang::hlsl;
57using RegisterType = HLSLResourceBindingAttr::RegisterType;
58
59static CXXRecordDecl *createHostLayoutStruct(Sema &S,
60 CXXRecordDecl *StructDecl);
61
62static RegisterType getRegisterType(ResourceClass RC) {
63 switch (RC) {
64 case ResourceClass::SRV:
65 return RegisterType::SRV;
66 case ResourceClass::UAV:
67 return RegisterType::UAV;
68 case ResourceClass::CBuffer:
69 return RegisterType::CBuffer;
70 case ResourceClass::Sampler:
71 return RegisterType::Sampler;
72 }
73 llvm_unreachable("unexpected ResourceClass value");
74}
75
76static RegisterType getRegisterType(const HLSLAttributedResourceType *ResTy) {
77 return getRegisterType(RC: ResTy->getAttrs().ResourceClass);
78}
79
80// Converts the first letter of string Slot to RegisterType.
81// Returns false if the letter does not correspond to a valid register type.
82static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
83 assert(RT != nullptr);
84 switch (Slot[0]) {
85 case 't':
86 case 'T':
87 *RT = RegisterType::SRV;
88 return true;
89 case 'u':
90 case 'U':
91 *RT = RegisterType::UAV;
92 return true;
93 case 'b':
94 case 'B':
95 *RT = RegisterType::CBuffer;
96 return true;
97 case 's':
98 case 'S':
99 *RT = RegisterType::Sampler;
100 return true;
101 case 'c':
102 case 'C':
103 *RT = RegisterType::C;
104 return true;
105 case 'i':
106 case 'I':
107 *RT = RegisterType::I;
108 return true;
109 default:
110 return false;
111 }
112}
113
114static ResourceClass getResourceClass(RegisterType RT) {
115 switch (RT) {
116 case RegisterType::SRV:
117 return ResourceClass::SRV;
118 case RegisterType::UAV:
119 return ResourceClass::UAV;
120 case RegisterType::CBuffer:
121 return ResourceClass::CBuffer;
122 case RegisterType::Sampler:
123 return ResourceClass::Sampler;
124 case RegisterType::C:
125 case RegisterType::I:
126 // Deliberately falling through to the unreachable below.
127 break;
128 }
129 llvm_unreachable("unexpected RegisterType value");
130}
131
132static Builtin::ID getSpecConstBuiltinId(const Type *Type) {
133 const auto *BT = dyn_cast<BuiltinType>(Val: Type);
134 if (!BT) {
135 if (!Type->isEnumeralType())
136 return Builtin::NotBuiltin;
137 return Builtin::BI__builtin_get_spirv_spec_constant_int;
138 }
139
140 switch (BT->getKind()) {
141 case BuiltinType::Bool:
142 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
143 case BuiltinType::Short:
144 return Builtin::BI__builtin_get_spirv_spec_constant_short;
145 case BuiltinType::Int:
146 return Builtin::BI__builtin_get_spirv_spec_constant_int;
147 case BuiltinType::LongLong:
148 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
149 case BuiltinType::UShort:
150 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
151 case BuiltinType::UInt:
152 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
153 case BuiltinType::ULongLong:
154 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
155 case BuiltinType::Half:
156 return Builtin::BI__builtin_get_spirv_spec_constant_half;
157 case BuiltinType::Float:
158 return Builtin::BI__builtin_get_spirv_spec_constant_float;
159 case BuiltinType::Double:
160 return Builtin::BI__builtin_get_spirv_spec_constant_double;
161 default:
162 return Builtin::NotBuiltin;
163 }
164}
165
166DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
167 ResourceClass ResClass) {
168 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
169 "DeclBindingInfo already added");
170 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
171 // VarDecl may have multiple entries for different resource classes.
172 // DeclToBindingListIndex stores the index of the first binding we saw
173 // for this decl. If there are any additional ones then that index
174 // shouldn't be updated.
175 DeclToBindingListIndex.try_emplace(Key: VD, Args: BindingsList.size());
176 return &BindingsList.emplace_back(Args&: VD, Args&: ResClass);
177}
178
179DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
180 ResourceClass ResClass) {
181 auto Entry = DeclToBindingListIndex.find(Val: VD);
182 if (Entry != DeclToBindingListIndex.end()) {
183 for (unsigned Index = Entry->getSecond();
184 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
185 ++Index) {
186 if (BindingsList[Index].ResClass == ResClass)
187 return &BindingsList[Index];
188 }
189 }
190 return nullptr;
191}
192
193bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
194 return DeclToBindingListIndex.contains(Val: VD);
195}
196
197SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
198
199Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
200 SourceLocation KwLoc, IdentifierInfo *Ident,
201 SourceLocation IdentLoc,
202 SourceLocation LBrace) {
203 // For anonymous namespace, take the location of the left brace.
204 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
205 HLSLBufferDecl *Result = HLSLBufferDecl::Create(
206 C&: getASTContext(), LexicalParent, CBuffer, KwLoc, ID: Ident, IDLoc: IdentLoc, LBrace);
207
208 // if CBuffer is false, then it's a TBuffer
209 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
210 : llvm::hlsl::ResourceClass::SRV;
211 Result->addAttr(A: HLSLResourceClassAttr::CreateImplicit(Ctx&: getASTContext(), ResourceClass: RC));
212
213 SemaRef.PushOnScopeChains(D: Result, S: BufferScope);
214 SemaRef.PushDeclContext(S: BufferScope, DC: Result);
215
216 return Result;
217}
218
219static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
220 QualType T) {
221 // Arrays, Matrices, and Structs are always aligned to new buffer rows
222 if (T->isArrayType() || T->isStructureType() || T->isConstantMatrixType())
223 return 16;
224
225 // Vectors are aligned to the type they contain
226 if (const VectorType *VT = T->getAs<VectorType>())
227 return calculateLegacyCbufferFieldAlign(Context, T: VT->getElementType());
228
229 assert(Context.getTypeSize(T) <= 64 &&
230 "Scalar bit widths larger than 64 not supported");
231
232 // Scalar types are aligned to their byte width
233 return Context.getTypeSize(T) / 8;
234}
235
236// Calculate the size of a legacy cbuffer type in bytes based on
237// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
238static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
239 QualType T) {
240 constexpr unsigned CBufferAlign = 16;
241 if (const auto *RD = T->getAsRecordDecl()) {
242 unsigned Size = 0;
243 for (const FieldDecl *Field : RD->fields()) {
244 QualType Ty = Field->getType();
245 unsigned FieldSize = calculateLegacyCbufferSize(Context, T: Ty);
246 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, T: Ty);
247
248 // If the field crosses the row boundary after alignment it drops to the
249 // next row
250 unsigned AlignSize = llvm::alignTo(Value: Size, Align: FieldAlign);
251 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
252 FieldAlign = CBufferAlign;
253 }
254
255 Size = llvm::alignTo(Value: Size, Align: FieldAlign);
256 Size += FieldSize;
257 }
258 return Size;
259 }
260
261 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
262 unsigned ElementCount = AT->getSize().getZExtValue();
263 if (ElementCount == 0)
264 return 0;
265
266 unsigned ElementSize =
267 calculateLegacyCbufferSize(Context, T: AT->getElementType());
268 unsigned AlignedElementSize = llvm::alignTo(Value: ElementSize, Align: CBufferAlign);
269 return AlignedElementSize * (ElementCount - 1) + ElementSize;
270 }
271
272 if (const VectorType *VT = T->getAs<VectorType>()) {
273 unsigned ElementCount = VT->getNumElements();
274 unsigned ElementSize =
275 calculateLegacyCbufferSize(Context, T: VT->getElementType());
276 return ElementSize * ElementCount;
277 }
278
279 return Context.getTypeSize(T) / 8;
280}
281
282// Validate packoffset:
283// - if packoffset it used it must be set on all declarations inside the buffer
284// - packoffset ranges must not overlap
285static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
286 llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
287
288 // Make sure the packoffset annotations are either on all declarations
289 // or on none.
290 bool HasPackOffset = false;
291 bool HasNonPackOffset = false;
292 for (auto *Field : BufDecl->buffer_decls()) {
293 VarDecl *Var = dyn_cast<VarDecl>(Val: Field);
294 if (!Var)
295 continue;
296 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
297 PackOffsetVec.emplace_back(Args&: Var, Args: Field->getAttr<HLSLPackOffsetAttr>());
298 HasPackOffset = true;
299 } else {
300 HasNonPackOffset = true;
301 }
302 }
303
304 if (!HasPackOffset)
305 return;
306
307 if (HasNonPackOffset)
308 S.Diag(Loc: BufDecl->getLocation(), DiagID: diag::warn_hlsl_packoffset_mix);
309
310 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
311 // and compare adjacent values.
312 bool IsValid = true;
313 ASTContext &Context = S.getASTContext();
314 std::sort(first: PackOffsetVec.begin(), last: PackOffsetVec.end(),
315 comp: [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
316 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
317 return LHS.second->getOffsetInBytes() <
318 RHS.second->getOffsetInBytes();
319 });
320 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
321 VarDecl *Var = PackOffsetVec[i].first;
322 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
323 unsigned Size = calculateLegacyCbufferSize(Context, T: Var->getType());
324 unsigned Begin = Attr->getOffsetInBytes();
325 unsigned End = Begin + Size;
326 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
327 if (End > NextBegin) {
328 VarDecl *NextVar = PackOffsetVec[i + 1].first;
329 S.Diag(Loc: NextVar->getLocation(), DiagID: diag::err_hlsl_packoffset_overlap)
330 << NextVar << Var;
331 IsValid = false;
332 }
333 }
334 BufDecl->setHasValidPackoffset(IsValid);
335}
336
337// Returns true if the array has a zero size = if any of the dimensions is 0
338static bool isZeroSizedArray(const ConstantArrayType *CAT) {
339 while (CAT && !CAT->isZeroSize())
340 CAT = dyn_cast<ConstantArrayType>(
341 Val: CAT->getElementType()->getUnqualifiedDesugaredType());
342 return CAT != nullptr;
343}
344
345static bool isResourceRecordTypeOrArrayOf(VarDecl *VD) {
346 const Type *Ty = VD->getType().getTypePtr();
347 return Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray();
348}
349
350static const HLSLAttributedResourceType *
351getResourceArrayHandleType(VarDecl *VD) {
352 assert(VD->getType()->isHLSLResourceRecordArray() &&
353 "expected array of resource records");
354 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
355 while (const ArrayType *AT = dyn_cast<ArrayType>(Val: Ty))
356 Ty = AT->getArrayElementTypeNoTypeQual()->getUnqualifiedDesugaredType();
357 return HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty);
358}
359
360// Returns true if the type is a leaf element type that is not valid to be
361// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
362// array, or a builtin intangible type. Returns false it is a valid leaf element
363// type or if it is a record type that needs to be inspected further.
364static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
365 Ty = Ty->getUnqualifiedDesugaredType();
366 if (Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray())
367 return true;
368 if (const auto *RD = Ty->getAsCXXRecordDecl())
369 return RD->isEmpty();
370 if (Ty->isConstantArrayType() &&
371 isZeroSizedArray(CAT: cast<ConstantArrayType>(Val: Ty)))
372 return true;
373 if (Ty->isHLSLBuiltinIntangibleType() || Ty->isHLSLAttributedResourceType())
374 return true;
375 return false;
376}
377
378// Returns true if the struct contains at least one element that prevents it
379// from being included inside HLSL Buffer as is, such as an intangible type,
380// empty struct, or zero-sized array. If it does, a new implicit layout struct
381// needs to be created for HLSL Buffer use that will exclude these unwanted
382// declarations (see createHostLayoutStruct function).
383static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
384 if (RD->isHLSLIntangible() || RD->isEmpty())
385 return true;
386 // check fields
387 for (const FieldDecl *Field : RD->fields()) {
388 QualType Ty = Field->getType();
389 if (isInvalidConstantBufferLeafElementType(Ty: Ty.getTypePtr()))
390 return true;
391 if (const auto *RD = Ty->getAsCXXRecordDecl();
392 RD && requiresImplicitBufferLayoutStructure(RD))
393 return true;
394 }
395 // check bases
396 for (const CXXBaseSpecifier &Base : RD->bases())
397 if (requiresImplicitBufferLayoutStructure(
398 RD: Base.getType()->castAsCXXRecordDecl()))
399 return true;
400 return false;
401}
402
403static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II,
404 DeclContext *DC) {
405 CXXRecordDecl *RD = nullptr;
406 for (NamedDecl *Decl :
407 DC->getNonTransparentContext()->lookup(Name: DeclarationName(II))) {
408 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Val: Decl)) {
409 assert(RD == nullptr &&
410 "there should be at most 1 record by a given name in a scope");
411 RD = FoundRD;
412 }
413 }
414 return RD;
415}
416
417// Creates a name for buffer layout struct using the provide name base.
418// If the name must be unique (not previously defined), a suffix is added
419// until a unique name is found.
420static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl,
421 bool MustBeUnique) {
422 ASTContext &AST = S.getASTContext();
423
424 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
425 llvm::SmallString<64> Name("__cblayout_");
426 if (NameBaseII) {
427 Name.append(RHS: NameBaseII->getName());
428 } else {
429 // anonymous struct
430 Name.append(RHS: "anon");
431 MustBeUnique = true;
432 }
433
434 size_t NameLength = Name.size();
435 IdentifierInfo *II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
436 if (!MustBeUnique)
437 return II;
438
439 unsigned suffix = 0;
440 while (true) {
441 if (suffix != 0) {
442 Name.append(RHS: "_");
443 Name.append(RHS: llvm::Twine(suffix).str());
444 II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
445 }
446 if (!findRecordDeclInContext(II, DC: BaseDecl->getDeclContext()))
447 return II;
448 // declaration with that name already exists - increment suffix and try
449 // again until unique name is found
450 suffix++;
451 Name.truncate(N: NameLength);
452 };
453}
454
455// Creates a field declaration of given name and type for HLSL buffer layout
456// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
457static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
458 IdentifierInfo *II,
459 CXXRecordDecl *LayoutStruct) {
460 if (isInvalidConstantBufferLeafElementType(Ty))
461 return nullptr;
462
463 if (auto *RD = Ty->getAsCXXRecordDecl()) {
464 if (requiresImplicitBufferLayoutStructure(RD)) {
465 RD = createHostLayoutStruct(S, StructDecl: RD);
466 if (!RD)
467 return nullptr;
468 Ty = S.Context.getCanonicalTagType(TD: RD)->getTypePtr();
469 }
470 }
471
472 QualType QT = QualType(Ty, 0);
473 ASTContext &AST = S.getASTContext();
474 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(T: QT, Loc: SourceLocation());
475 auto *Field = FieldDecl::Create(C: AST, DC: LayoutStruct, StartLoc: SourceLocation(),
476 IdLoc: SourceLocation(), Id: II, T: QT, TInfo: TSI, BW: nullptr, Mutable: false,
477 InitStyle: InClassInitStyle::ICIS_NoInit);
478 Field->setAccess(AccessSpecifier::AS_public);
479 return Field;
480}
481
482// Creates host layout struct for a struct included in HLSL Buffer.
483// The layout struct will include only fields that are allowed in HLSL buffer.
484// These fields will be filtered out:
485// - resource classes
486// - empty structs
487// - zero-sized arrays
488// Returns nullptr if the resulting layout struct would be empty.
489static CXXRecordDecl *createHostLayoutStruct(Sema &S,
490 CXXRecordDecl *StructDecl) {
491 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
492 "struct is already HLSL buffer compatible");
493
494 ASTContext &AST = S.getASTContext();
495 DeclContext *DC = StructDecl->getDeclContext();
496 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: StructDecl, MustBeUnique: false);
497
498 // reuse existing if the layout struct if it already exists
499 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
500 return RD;
501
502 CXXRecordDecl *LS =
503 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC, StartLoc: SourceLocation(),
504 IdLoc: SourceLocation(), Id: II);
505 LS->setImplicit(true);
506 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
507 LS->startDefinition();
508
509 // copy base struct, create HLSL Buffer compatible version if needed
510 if (unsigned NumBases = StructDecl->getNumBases()) {
511 assert(NumBases == 1 && "HLSL supports only one base type");
512 (void)NumBases;
513 CXXBaseSpecifier Base = *StructDecl->bases_begin();
514 CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
515 if (requiresImplicitBufferLayoutStructure(RD: BaseDecl)) {
516 BaseDecl = createHostLayoutStruct(S, StructDecl: BaseDecl);
517 if (BaseDecl) {
518 TypeSourceInfo *TSI =
519 AST.getTrivialTypeSourceInfo(T: AST.getCanonicalTagType(TD: BaseDecl));
520 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
521 AS_none, TSI, SourceLocation());
522 }
523 }
524 if (BaseDecl) {
525 const CXXBaseSpecifier *BasesArray[1] = {&Base};
526 LS->setBases(Bases: BasesArray, NumBases: 1);
527 }
528 }
529
530 // filter struct fields
531 for (const FieldDecl *FD : StructDecl->fields()) {
532 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
533 if (FieldDecl *NewFD =
534 createFieldForHostLayoutStruct(S, Ty, II: FD->getIdentifier(), LayoutStruct: LS))
535 LS->addDecl(D: NewFD);
536 }
537 LS->completeDefinition();
538
539 if (LS->field_empty() && LS->getNumBases() == 0)
540 return nullptr;
541
542 DC->addDecl(D: LS);
543 return LS;
544}
545
546// Creates host layout struct for HLSL Buffer. The struct will include only
547// fields of types that are allowed in HLSL buffer and it will filter out:
548// - static or groupshared variable declarations
549// - resource classes
550// - empty structs
551// - zero-sized arrays
552// - non-variable declarations
553// The layout struct will be added to the HLSLBufferDecl declarations.
554void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
555 ASTContext &AST = S.getASTContext();
556 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: BufDecl, MustBeUnique: true);
557
558 CXXRecordDecl *LS =
559 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC: BufDecl,
560 StartLoc: SourceLocation(), IdLoc: SourceLocation(), Id: II);
561 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
562 LS->setImplicit(true);
563 LS->startDefinition();
564
565 for (Decl *D : BufDecl->buffer_decls()) {
566 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
567 if (!VD || VD->getStorageClass() == SC_Static ||
568 VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
569 continue;
570 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
571 if (FieldDecl *FD =
572 createFieldForHostLayoutStruct(S, Ty, II: VD->getIdentifier(), LayoutStruct: LS)) {
573 // add the field decl to the layout struct
574 LS->addDecl(D: FD);
575 // update address space of the original decl to hlsl_constant
576 QualType NewTy =
577 AST.getAddrSpaceQualType(T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
578 VD->setType(NewTy);
579 }
580 }
581 LS->completeDefinition();
582 BufDecl->addLayoutStruct(LS);
583}
584
585static void addImplicitBindingAttrToDecl(Sema &S, Decl *D, RegisterType RT,
586 uint32_t ImplicitBindingOrderID) {
587 auto *Attr =
588 HLSLResourceBindingAttr::CreateImplicit(Ctx&: S.getASTContext(), Slot: "", Space: "0", Range: {});
589 Attr->setBinding(RT, SlotNum: std::nullopt, SpaceNum: 0);
590 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
591 D->addAttr(A: Attr);
592}
593
594// Handle end of cbuffer/tbuffer declaration
595void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
596 auto *BufDecl = cast<HLSLBufferDecl>(Val: Dcl);
597 BufDecl->setRBraceLoc(RBrace);
598
599 validatePackoffset(S&: SemaRef, BufDecl);
600
601 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl);
602
603 // Handle implicit binding if needed.
604 ResourceBindingAttrs ResourceAttrs(Dcl);
605 if (!ResourceAttrs.isExplicit()) {
606 SemaRef.Diag(Loc: Dcl->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
607 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
608 // to codegen. If it does not exist, create an implicit attribute.
609 uint32_t OrderID = getNextImplicitBindingOrderID();
610 if (ResourceAttrs.hasBinding())
611 ResourceAttrs.setImplicitOrderID(OrderID);
612 else
613 addImplicitBindingAttrToDecl(S&: SemaRef, D: BufDecl,
614 RT: BufDecl->isCBuffer() ? RegisterType::CBuffer
615 : RegisterType::SRV,
616 ImplicitBindingOrderID: OrderID);
617 }
618
619 SemaRef.PopDeclContext();
620}
621
622HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
623 const AttributeCommonInfo &AL,
624 int X, int Y, int Z) {
625 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
626 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
627 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
628 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
629 }
630 return nullptr;
631 }
632 return ::new (getASTContext())
633 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
634}
635
636HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
637 const AttributeCommonInfo &AL,
638 int Min, int Max, int Preferred,
639 int SpelledArgsCount) {
640 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
641 if (WS->getMin() != Min || WS->getMax() != Max ||
642 WS->getPreferred() != Preferred ||
643 WS->getSpelledArgsCount() != SpelledArgsCount) {
644 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
645 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
646 }
647 return nullptr;
648 }
649 HLSLWaveSizeAttr *Result = ::new (getASTContext())
650 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
651 Result->setSpelledArgsCount(SpelledArgsCount);
652 return Result;
653}
654
655HLSLVkConstantIdAttr *
656SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
657 int Id) {
658
659 auto &TargetInfo = getASTContext().getTargetInfo();
660 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
661 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_ignored) << AL;
662 return nullptr;
663 }
664
665 auto *VD = cast<VarDecl>(Val: D);
666
667 if (getSpecConstBuiltinId(Type: VD->getType()->getUnqualifiedDesugaredType()) ==
668 Builtin::NotBuiltin) {
669 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
670 return nullptr;
671 }
672
673 if (!VD->getType().isConstQualified()) {
674 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
675 return nullptr;
676 }
677
678 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
679 if (CI->getId() != Id) {
680 Diag(Loc: CI->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
681 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
682 }
683 return nullptr;
684 }
685
686 HLSLVkConstantIdAttr *Result =
687 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
688 return Result;
689}
690
691HLSLShaderAttr *
692SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
693 llvm::Triple::EnvironmentType ShaderType) {
694 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
695 if (NT->getType() != ShaderType) {
696 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
697 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
698 }
699 return nullptr;
700 }
701 return HLSLShaderAttr::Create(Ctx&: getASTContext(), Type: ShaderType, CommonInfo: AL);
702}
703
704HLSLParamModifierAttr *
705SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
706 HLSLParamModifierAttr::Spelling Spelling) {
707 // We can only merge an `in` attribute with an `out` attribute. All other
708 // combinations of duplicated attributes are ill-formed.
709 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
710 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
711 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
712 D->dropAttr<HLSLParamModifierAttr>();
713 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
714 return HLSLParamModifierAttr::Create(
715 Ctx&: getASTContext(), /*MergedSpelling=*/true, Range: AdjustedRange,
716 S: HLSLParamModifierAttr::Keyword_inout);
717 }
718 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_duplicate_parameter_modifier) << AL;
719 Diag(Loc: PA->getLocation(), DiagID: diag::note_conflicting_attribute);
720 return nullptr;
721 }
722 return HLSLParamModifierAttr::Create(Ctx&: getASTContext(), CommonInfo: AL);
723}
724
725void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
726 auto &TargetInfo = getASTContext().getTargetInfo();
727
728 if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
729 return;
730
731 // If we have specified a root signature to override the entry function then
732 // attach it now
733 HLSLRootSignatureDecl *SignatureDecl =
734 lookupRootSignatureOverrideDecl(DC: FD->getDeclContext());
735 if (SignatureDecl) {
736 FD->dropAttr<RootSignatureAttr>();
737 // We could look up the SourceRange of the macro here as well
738 AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(),
739 SourceRange(), ParsedAttr::Form::Microsoft());
740 FD->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
741 getASTContext(), AL, RootSigOverrideIdent, SignatureDecl));
742 }
743
744 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
745 if (HLSLShaderAttr::isValidShaderType(ShaderType: Env) && Env != llvm::Triple::Library) {
746 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
747 // The entry point is already annotated - check that it matches the
748 // triple.
749 if (Shader->getType() != Env) {
750 Diag(Loc: Shader->getLocation(), DiagID: diag::err_hlsl_entry_shader_attr_mismatch)
751 << Shader;
752 FD->setInvalidDecl();
753 }
754 } else {
755 // Implicitly add the shader attribute if the entry function isn't
756 // explicitly annotated.
757 FD->addAttr(A: HLSLShaderAttr::CreateImplicit(Ctx&: getASTContext(), Type: Env,
758 Range: FD->getBeginLoc()));
759 }
760 } else {
761 switch (Env) {
762 case llvm::Triple::UnknownEnvironment:
763 case llvm::Triple::Library:
764 break;
765 case llvm::Triple::RootSignature:
766 llvm_unreachable("rootsig environment has no functions");
767 default:
768 llvm_unreachable("Unhandled environment in triple");
769 }
770 }
771}
772
773static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD,
774 HLSLAppliedSemanticAttr *Semantic,
775 bool IsInput) {
776 if (AstContext.getTargetInfo().getTriple().getOS() != llvm::Triple::Vulkan)
777 return false;
778
779 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
780 assert(ShaderAttr && "Entry point has no shader attribute");
781 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
782 auto SemanticName = Semantic->getSemanticName().upper();
783
784 // The SV_Position semantic is lowered to:
785 // - Position built-in for vertex output.
786 // - FragCoord built-in for fragment input.
787 if (SemanticName == "SV_POSITION") {
788 return (ST == llvm::Triple::Vertex && !IsInput) ||
789 (ST == llvm::Triple::Pixel && IsInput);
790 }
791
792 return false;
793}
794
795bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
796 DeclaratorDecl *OutputDecl,
797 DeclaratorDecl *D,
798 SemanticInfo &ActiveSemantic,
799 SemaHLSL::SemanticContext &SC) {
800 if (ActiveSemantic.Semantic == nullptr) {
801 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
802 if (ActiveSemantic.Semantic)
803 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
804 }
805
806 if (!ActiveSemantic.Semantic) {
807 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_missing_semantic_annotation);
808 return false;
809 }
810
811 auto *A = ::new (getASTContext())
812 HLSLAppliedSemanticAttr(getASTContext(), *ActiveSemantic.Semantic,
813 ActiveSemantic.Semantic->getAttrName()->getName(),
814 ActiveSemantic.Index.value_or(u: 0));
815 if (!A)
816 return false;
817
818 checkSemanticAnnotation(EntryPoint: FD, Param: D, SemanticAttr: A, SC);
819 OutputDecl->addAttr(A);
820
821 unsigned Location = ActiveSemantic.Index.value_or(u: 0);
822
823 if (!isVkPipelineBuiltin(AstContext: getASTContext(), FD, Semantic: A,
824 IsInput: SC.CurrentIOType & IOType::In)) {
825 bool HasVkLocation = false;
826 if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
827 HasVkLocation = true;
828 Location = A->getLocation();
829 }
830
831 if (SC.UsesExplicitVkLocations.value_or(u&: HasVkLocation) != HasVkLocation) {
832 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_semantic_partial_explicit_indexing);
833 return false;
834 }
835 SC.UsesExplicitVkLocations = HasVkLocation;
836 }
837
838 const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(Val: D->getType());
839 unsigned ElementCount = AT ? AT->getZExtSize() : 1;
840 ActiveSemantic.Index = Location + ElementCount;
841
842 Twine BaseName = Twine(ActiveSemantic.Semantic->getAttrName()->getName());
843 for (unsigned I = 0; I < ElementCount; ++I) {
844 Twine VariableName = BaseName.concat(Suffix: Twine(Location + I));
845
846 auto [_, Inserted] = SC.ActiveSemantics.insert(key: VariableName.str());
847 if (!Inserted) {
848 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_semantic_index_overlap)
849 << VariableName.str();
850 return false;
851 }
852 }
853
854 return true;
855}
856
857bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
858 DeclaratorDecl *OutputDecl,
859 DeclaratorDecl *D,
860 SemanticInfo &ActiveSemantic,
861 SemaHLSL::SemanticContext &SC) {
862 if (ActiveSemantic.Semantic == nullptr) {
863 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
864 if (ActiveSemantic.Semantic)
865 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
866 }
867
868 const Type *T = D == FD ? &*FD->getReturnType() : &*D->getType();
869 T = T->getUnqualifiedDesugaredType();
870
871 const RecordType *RT = dyn_cast<RecordType>(Val: T);
872 if (!RT)
873 return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
874 SC);
875
876 const RecordDecl *RD = RT->getDecl();
877 for (FieldDecl *Field : RD->fields()) {
878 SemanticInfo Info = ActiveSemantic;
879 if (!determineActiveSemantic(FD, OutputDecl, D: Field, ActiveSemantic&: Info, SC)) {
880 Diag(Loc: Field->getLocation(), DiagID: diag::note_hlsl_semantic_used_here) << Field;
881 return false;
882 }
883 if (ActiveSemantic.Semantic)
884 ActiveSemantic = Info;
885 }
886
887 return true;
888}
889
890void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
891 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
892 assert(ShaderAttr && "Entry point has no shader attribute");
893 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
894 auto &TargetInfo = getASTContext().getTargetInfo();
895 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
896 switch (ST) {
897 case llvm::Triple::Pixel:
898 case llvm::Triple::Vertex:
899 case llvm::Triple::Geometry:
900 case llvm::Triple::Hull:
901 case llvm::Triple::Domain:
902 case llvm::Triple::RayGeneration:
903 case llvm::Triple::Intersection:
904 case llvm::Triple::AnyHit:
905 case llvm::Triple::ClosestHit:
906 case llvm::Triple::Miss:
907 case llvm::Triple::Callable:
908 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
909 diagnoseAttrStageMismatch(A: NT, Stage: ST,
910 AllowedStages: {llvm::Triple::Compute,
911 llvm::Triple::Amplification,
912 llvm::Triple::Mesh});
913 FD->setInvalidDecl();
914 }
915 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
916 diagnoseAttrStageMismatch(A: WS, Stage: ST,
917 AllowedStages: {llvm::Triple::Compute,
918 llvm::Triple::Amplification,
919 llvm::Triple::Mesh});
920 FD->setInvalidDecl();
921 }
922 break;
923
924 case llvm::Triple::Compute:
925 case llvm::Triple::Amplification:
926 case llvm::Triple::Mesh:
927 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
928 Diag(Loc: FD->getLocation(), DiagID: diag::err_hlsl_missing_numthreads)
929 << llvm::Triple::getEnvironmentTypeName(Kind: ST);
930 FD->setInvalidDecl();
931 }
932 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
933 if (Ver < VersionTuple(6, 6)) {
934 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_in_wrong_shader_model)
935 << WS << "6.6";
936 FD->setInvalidDecl();
937 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
938 Diag(
939 Loc: WS->getLocation(),
940 DiagID: diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
941 << WS << WS->getSpelledArgsCount() << "6.8";
942 FD->setInvalidDecl();
943 }
944 }
945 break;
946 case llvm::Triple::RootSignature:
947 llvm_unreachable("rootsig environment has no function entry point");
948 default:
949 llvm_unreachable("Unhandled environment in triple");
950 }
951
952 SemaHLSL::SemanticContext InputSC = {};
953 InputSC.CurrentIOType = IOType::In;
954
955 for (ParmVarDecl *Param : FD->parameters()) {
956 SemanticInfo ActiveSemantic;
957 ActiveSemantic.Semantic = Param->getAttr<HLSLParsedSemanticAttr>();
958 if (ActiveSemantic.Semantic)
959 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
960
961 // FIXME: Verify output semantics in parameters.
962 if (!determineActiveSemantic(FD, OutputDecl: Param, D: Param, ActiveSemantic, SC&: InputSC)) {
963 Diag(Loc: Param->getLocation(), DiagID: diag::note_previous_decl) << Param;
964 FD->setInvalidDecl();
965 }
966 }
967
968 SemanticInfo ActiveSemantic;
969 SemaHLSL::SemanticContext OutputSC = {};
970 OutputSC.CurrentIOType = IOType::Out;
971 ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
972 if (ActiveSemantic.Semantic)
973 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
974 if (!FD->getReturnType()->isVoidType())
975 determineActiveSemantic(FD, OutputDecl: FD, D: FD, ActiveSemantic, SC&: OutputSC);
976}
977
978void SemaHLSL::checkSemanticAnnotation(
979 FunctionDecl *EntryPoint, const Decl *Param,
980 const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC) {
981 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
982 assert(ShaderAttr && "Entry point has no shader attribute");
983 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
984
985 auto SemanticName = SemanticAttr->getSemanticName().upper();
986 if (SemanticName == "SV_DISPATCHTHREADID" ||
987 SemanticName == "SV_GROUPINDEX" || SemanticName == "SV_GROUPTHREADID" ||
988 SemanticName == "SV_GROUPID") {
989
990 if (ST != llvm::Triple::Compute)
991 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
992 AllowedStages: {{.Stage: llvm::Triple::Compute, .AllowedIOTypesMask: IOType::In}});
993
994 if (SemanticAttr->getSemanticIndex() != 0) {
995 std::string PrettyName =
996 "'" + SemanticAttr->getSemanticName().str() + "'";
997 Diag(Loc: SemanticAttr->getLoc(),
998 DiagID: diag::err_hlsl_semantic_indexing_not_supported)
999 << PrettyName;
1000 }
1001 return;
1002 }
1003
1004 if (SemanticName == "SV_POSITION") {
1005 // SV_Position can be an input or output in vertex shaders,
1006 // but only an input in pixel shaders.
1007 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1008 AllowedStages: {{.Stage: llvm::Triple::Vertex, .AllowedIOTypesMask: IOType::InOut},
1009 {.Stage: llvm::Triple::Pixel, .AllowedIOTypesMask: IOType::In}});
1010 return;
1011 }
1012
1013 if (SemanticName == "SV_TARGET") {
1014 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1015 AllowedStages: {{.Stage: llvm::Triple::Pixel, .AllowedIOTypesMask: IOType::Out}});
1016 return;
1017 }
1018
1019 // FIXME: catch-all for non-implemented system semantics reaching this
1020 // location.
1021 if (SemanticAttr->getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
1022 llvm_unreachable("Unknown SemanticAttr");
1023}
1024
1025void SemaHLSL::diagnoseAttrStageMismatch(
1026 const Attr *A, llvm::Triple::EnvironmentType Stage,
1027 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
1028 SmallVector<StringRef, 8> StageStrings;
1029 llvm::transform(Range&: AllowedStages, d_first: std::back_inserter(x&: StageStrings),
1030 F: [](llvm::Triple::EnvironmentType ST) {
1031 return StringRef(
1032 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: ST));
1033 });
1034 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
1035 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
1036 << (AllowedStages.size() != 1) << join(R&: StageStrings, Separator: ", ");
1037}
1038
1039void SemaHLSL::diagnoseSemanticStageMismatch(
1040 const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
1041 std::initializer_list<SemanticStageInfo> Allowed) {
1042
1043 for (auto &Case : Allowed) {
1044 if (Case.Stage != Stage)
1045 continue;
1046
1047 if (CurrentIOType & Case.AllowedIOTypesMask)
1048 return;
1049
1050 SmallVector<std::string, 8> ValidCases;
1051 llvm::transform(
1052 Range&: Allowed, d_first: std::back_inserter(x&: ValidCases), F: [](SemanticStageInfo Case) {
1053 SmallVector<std::string, 2> ValidType;
1054 if (Case.AllowedIOTypesMask & IOType::In)
1055 ValidType.push_back(Elt: "input");
1056 if (Case.AllowedIOTypesMask & IOType::Out)
1057 ValidType.push_back(Elt: "output");
1058 return std::string(
1059 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: Case.Stage)) +
1060 " " + join(R&: ValidType, Separator: "/");
1061 });
1062 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1063 << A->getAttrName() << (CurrentIOType & IOType::In ? "input" : "output")
1064 << llvm::Triple::getEnvironmentTypeName(Kind: Case.Stage)
1065 << join(R&: ValidCases, Separator: ", ");
1066 return;
1067 }
1068
1069 SmallVector<StringRef, 8> StageStrings;
1070 llvm::transform(
1071 Range&: Allowed, d_first: std::back_inserter(x&: StageStrings), F: [](SemanticStageInfo Case) {
1072 return StringRef(
1073 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: Case.Stage));
1074 });
1075
1076 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
1077 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
1078 << (Allowed.size() != 1) << join(R&: StageStrings, Separator: ", ");
1079}
1080
1081template <CastKind Kind>
1082static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
1083 if (const auto *VTy = Ty->getAs<VectorType>())
1084 Ty = VTy->getElementType();
1085 Ty = S.getASTContext().getExtVectorType(VectorType: Ty, NumElts: Sz);
1086 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
1087}
1088
1089template <CastKind Kind>
1090static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
1091 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
1092 return Ty;
1093}
1094
1095static QualType handleFloatVectorBinOpConversion(
1096 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1097 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1098 bool LHSFloat = LElTy->isRealFloatingType();
1099 bool RHSFloat = RElTy->isRealFloatingType();
1100
1101 if (LHSFloat && RHSFloat) {
1102 if (IsCompAssign ||
1103 SemaRef.getASTContext().getFloatingTypeOrder(LHS: LElTy, RHS: RElTy) > 0)
1104 return castElement<CK_FloatingCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1105
1106 return castElement<CK_FloatingCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1107 }
1108
1109 if (LHSFloat)
1110 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: RHS, Ty: LHSType);
1111
1112 assert(RHSFloat);
1113 if (IsCompAssign)
1114 return castElement<clang::CK_FloatingToIntegral>(S&: SemaRef, E&: RHS, Ty: LHSType);
1115
1116 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: LHS, Ty: RHSType);
1117}
1118
1119static QualType handleIntegerVectorBinOpConversion(
1120 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1121 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1122
1123 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LHS: LElTy, RHS: RElTy);
1124 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
1125 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
1126 auto &Ctx = SemaRef.getASTContext();
1127
1128 // If both types have the same signedness, use the higher ranked type.
1129 if (LHSSigned == RHSSigned) {
1130 if (IsCompAssign || IntOrder >= 0)
1131 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1132
1133 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1134 }
1135
1136 // If the unsigned type has greater than or equal rank of the signed type, use
1137 // the unsigned type.
1138 if (IntOrder != (LHSSigned ? 1 : -1)) {
1139 if (IsCompAssign || RHSSigned)
1140 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1141 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1142 }
1143
1144 // At this point the signed type has higher rank than the unsigned type, which
1145 // means it will be the same size or bigger. If the signed type is bigger, it
1146 // can represent all the values of the unsigned type, so select it.
1147 if (Ctx.getIntWidth(T: LElTy) != Ctx.getIntWidth(T: RElTy)) {
1148 if (IsCompAssign || LHSSigned)
1149 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1150 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1151 }
1152
1153 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
1154 // to C/C++ leaking through. The place this happens today is long vs long
1155 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
1156 // the long long has higher rank than long even though they are the same size.
1157
1158 // If this is a compound assignment cast the right hand side to the left hand
1159 // side's type.
1160 if (IsCompAssign)
1161 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1162
1163 // If this isn't a compound assignment we convert to unsigned long long.
1164 QualType ElTy = Ctx.getCorrespondingUnsignedType(T: LHSSigned ? LElTy : RElTy);
1165 QualType NewTy = Ctx.getExtVectorType(
1166 VectorType: ElTy, NumElts: RHSType->castAs<VectorType>()->getNumElements());
1167 (void)castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: NewTy);
1168
1169 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: NewTy);
1170}
1171
1172static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
1173 QualType SrcTy) {
1174 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
1175 return CK_FloatingCast;
1176 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
1177 return CK_IntegralCast;
1178 if (DestTy->isRealFloatingType())
1179 return CK_IntegralToFloating;
1180 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
1181 return CK_FloatingToIntegral;
1182}
1183
1184QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
1185 QualType LHSType,
1186 QualType RHSType,
1187 bool IsCompAssign) {
1188 const auto *LVecTy = LHSType->getAs<VectorType>();
1189 const auto *RVecTy = RHSType->getAs<VectorType>();
1190 auto &Ctx = getASTContext();
1191
1192 // If the LHS is not a vector and this is a compound assignment, we truncate
1193 // the argument to a scalar then convert it to the LHS's type.
1194 if (!LVecTy && IsCompAssign) {
1195 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1196 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: RElTy, CK: CK_HLSLVectorTruncation);
1197 RHSType = RHS.get()->getType();
1198 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1199 return LHSType;
1200 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: LHSType,
1201 CK: getScalarCastKind(Ctx, DestTy: LHSType, SrcTy: RHSType));
1202 return LHSType;
1203 }
1204
1205 unsigned EndSz = std::numeric_limits<unsigned>::max();
1206 unsigned LSz = 0;
1207 if (LVecTy)
1208 LSz = EndSz = LVecTy->getNumElements();
1209 if (RVecTy)
1210 EndSz = std::min(a: RVecTy->getNumElements(), b: EndSz);
1211 assert(EndSz != std::numeric_limits<unsigned>::max() &&
1212 "one of the above should have had a value");
1213
1214 // In a compound assignment, the left operand does not change type, the right
1215 // operand is converted to the type of the left operand.
1216 if (IsCompAssign && LSz != EndSz) {
1217 Diag(Loc: LHS.get()->getBeginLoc(),
1218 DiagID: diag::err_hlsl_vector_compound_assignment_truncation)
1219 << LHSType << RHSType;
1220 return QualType();
1221 }
1222
1223 if (RVecTy && RVecTy->getNumElements() > EndSz)
1224 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1225 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1226 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1227
1228 if (!RVecTy)
1229 castVector<CK_VectorSplat>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1230 if (!IsCompAssign && !LVecTy)
1231 castVector<CK_VectorSplat>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1232
1233 // If we're at the same type after resizing we can stop here.
1234 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1235 return Ctx.getCommonSugaredType(X: LHSType, Y: RHSType);
1236
1237 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1238 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1239
1240 // Handle conversion for floating point vectors.
1241 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1242 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1243 LElTy, RElTy, IsCompAssign);
1244
1245 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1246 "HLSL Vectors can only contain integer or floating point types");
1247 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1248 LElTy, RElTy, IsCompAssign);
1249}
1250
1251void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
1252 BinaryOperatorKind Opc) {
1253 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1254 "Called with non-logical operator");
1255 llvm::SmallVector<char, 256> Buff;
1256 llvm::raw_svector_ostream OS(Buff);
1257 PrintingPolicy PP(SemaRef.getLangOpts());
1258 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1259 OS << NewFnName << "(";
1260 LHS->printPretty(OS, Helper: nullptr, Policy: PP);
1261 OS << ", ";
1262 RHS->printPretty(OS, Helper: nullptr, Policy: PP);
1263 OS << ")";
1264 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1265 SemaRef.Diag(Loc: LHS->getBeginLoc(), DiagID: diag::note_function_suggestion)
1266 << NewFnName << FixItHint::CreateReplacement(RemoveRange: FullRange, Code: OS.str());
1267}
1268
1269std::pair<IdentifierInfo *, bool>
1270SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) {
1271 llvm::hash_code Hash = llvm::hash_value(S: Signature);
1272 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(val: Hash);
1273 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(Name: IdStr));
1274
1275 // Check if we have already found a decl of the same name.
1276 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1277 Sema::LookupOrdinaryName);
1278 bool Found = SemaRef.LookupQualifiedName(R, LookupCtx: SemaRef.CurContext);
1279 return {DeclIdent, Found};
1280}
1281
1282void SemaHLSL::ActOnFinishRootSignatureDecl(
1283 SourceLocation Loc, IdentifierInfo *DeclIdent,
1284 ArrayRef<hlsl::RootSignatureElement> RootElements) {
1285
1286 if (handleRootSignatureElements(Elements: RootElements))
1287 return;
1288
1289 SmallVector<llvm::hlsl::rootsig::RootElement> Elements;
1290 for (auto &RootSigElement : RootElements)
1291 Elements.push_back(Elt: RootSigElement.getElement());
1292
1293 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1294 C&: SemaRef.getASTContext(), /*DeclContext=*/DC: SemaRef.CurContext, Loc,
1295 ID: DeclIdent, Version: SemaRef.getLangOpts().HLSLRootSigVer, RootElements: Elements);
1296
1297 SignatureDecl->setImplicit();
1298 SemaRef.PushOnScopeChains(D: SignatureDecl, S: SemaRef.getCurScope());
1299}
1300
1301HLSLRootSignatureDecl *
1302SemaHLSL::lookupRootSignatureOverrideDecl(DeclContext *DC) const {
1303 if (RootSigOverrideIdent) {
1304 LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(),
1305 Sema::LookupOrdinaryName);
1306 if (SemaRef.LookupQualifiedName(R, LookupCtx: DC))
1307 return dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl());
1308 }
1309
1310 return nullptr;
1311}
1312
1313namespace {
1314
1315struct PerVisibilityBindingChecker {
1316 SemaHLSL *S;
1317 // We need one builder per `llvm::dxbc::ShaderVisibility` value.
1318 std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
1319
1320 struct ElemInfo {
1321 const hlsl::RootSignatureElement *Elem;
1322 llvm::dxbc::ShaderVisibility Vis;
1323 bool Diagnosed;
1324 };
1325 llvm::SmallVector<ElemInfo> ElemInfoMap;
1326
1327 PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
1328
1329 void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
1330 llvm::dxil::ResourceClass RC, uint32_t Space,
1331 uint32_t LowerBound, uint32_t UpperBound,
1332 const hlsl::RootSignatureElement *Elem) {
1333 uint32_t BuilderIndex = llvm::to_underlying(E: Visibility);
1334 assert(BuilderIndex < Builders.size() &&
1335 "Not enough builders for visibility type");
1336 Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
1337 Cookie: static_cast<const void *>(Elem));
1338
1339 static_assert(llvm::to_underlying(E: llvm::dxbc::ShaderVisibility::All) == 0,
1340 "'All' visibility must come first");
1341 if (Visibility == llvm::dxbc::ShaderVisibility::All)
1342 for (size_t I = 1, E = Builders.size(); I < E; ++I)
1343 Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
1344 Cookie: static_cast<const void *>(Elem));
1345
1346 ElemInfoMap.push_back(Elt: {.Elem: Elem, .Vis: Visibility, .Diagnosed: false});
1347 }
1348
1349 ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
1350 auto It = llvm::lower_bound(
1351 Range&: ElemInfoMap, Value&: Elem,
1352 C: [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
1353 assert(It->Elem == Elem && "Element not in map");
1354 return *It;
1355 }
1356
1357 bool checkOverlap() {
1358 llvm::sort(C&: ElemInfoMap, Comp: [](const auto &LHS, const auto &RHS) {
1359 return LHS.Elem < RHS.Elem;
1360 });
1361
1362 bool HadOverlap = false;
1363
1364 using llvm::hlsl::BindingInfoBuilder;
1365 auto ReportOverlap = [this,
1366 &HadOverlap](const BindingInfoBuilder &Builder,
1367 const llvm::hlsl::Binding &Reported) {
1368 HadOverlap = true;
1369
1370 const auto *Elem =
1371 static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
1372 const llvm::hlsl::Binding &Previous = Builder.findOverlapping(ReportedBinding: Reported);
1373 const auto *PrevElem =
1374 static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
1375
1376 ElemInfo &Info = getInfo(Elem);
1377 // We will have already diagnosed this binding if there's overlap in the
1378 // "All" visibility as well as any particular visibility.
1379 if (Info.Diagnosed)
1380 return;
1381 Info.Diagnosed = true;
1382
1383 ElemInfo &PrevInfo = getInfo(Elem: PrevElem);
1384 llvm::dxbc::ShaderVisibility CommonVis =
1385 Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
1386 : Info.Vis;
1387
1388 this->S->Diag(Loc: Elem->getLocation(), DiagID: diag::err_hlsl_resource_range_overlap)
1389 << llvm::to_underlying(E: Reported.RC) << Reported.LowerBound
1390 << Reported.isUnbounded() << Reported.UpperBound
1391 << llvm::to_underlying(E: Previous.RC) << Previous.LowerBound
1392 << Previous.isUnbounded() << Previous.UpperBound << Reported.Space
1393 << CommonVis;
1394
1395 this->S->Diag(Loc: PrevElem->getLocation(),
1396 DiagID: diag::note_hlsl_resource_range_here);
1397 };
1398
1399 for (BindingInfoBuilder &Builder : Builders)
1400 Builder.calculateBindingInfo(ReportOverlap);
1401
1402 return HadOverlap;
1403 }
1404};
1405
1406static CXXMethodDecl *lookupMethod(Sema &S, CXXRecordDecl *RecordDecl,
1407 StringRef Name, SourceLocation Loc) {
1408 DeclarationName DeclName(&S.getASTContext().Idents.get(Name));
1409 LookupResult Result(S, DeclName, Loc, Sema::LookupMemberName);
1410 if (!S.LookupQualifiedName(R&: Result, LookupCtx: static_cast<DeclContext *>(RecordDecl)))
1411 return nullptr;
1412 return cast<CXXMethodDecl>(Val: Result.getFoundDecl());
1413}
1414
1415} // end anonymous namespace
1416
1417static bool hasCounterHandle(const CXXRecordDecl *RD) {
1418 if (RD->field_empty())
1419 return false;
1420 auto It = std::next(x: RD->field_begin());
1421 if (It == RD->field_end())
1422 return false;
1423 const FieldDecl *SecondField = *It;
1424 if (const auto *ResTy =
1425 SecondField->getType()->getAs<HLSLAttributedResourceType>()) {
1426 return ResTy->getAttrs().IsCounter;
1427 }
1428 return false;
1429}
1430
1431bool SemaHLSL::handleRootSignatureElements(
1432 ArrayRef<hlsl::RootSignatureElement> Elements) {
1433 // Define some common error handling functions
1434 bool HadError = false;
1435 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1436 uint32_t UpperBound) {
1437 HadError = true;
1438 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1439 << LowerBound << UpperBound;
1440 };
1441
1442 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1443 float LowerBound,
1444 float UpperBound) {
1445 HadError = true;
1446 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1447 << llvm::formatv(Fmt: "{0:f}", Vals&: LowerBound).sstr<6>()
1448 << llvm::formatv(Fmt: "{0:f}", Vals&: UpperBound).sstr<6>();
1449 };
1450
1451 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1452 if (!llvm::hlsl::rootsig::verifyRegisterValue(RegisterValue: Register))
1453 ReportError(Loc, 0, 0xfffffffe);
1454 };
1455
1456 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1457 if (!llvm::hlsl::rootsig::verifyRegisterSpace(RegisterSpace: Space))
1458 ReportError(Loc, 0, 0xffffffef);
1459 };
1460
1461 const uint32_t Version =
1462 llvm::to_underlying(E: SemaRef.getLangOpts().HLSLRootSigVer);
1463 const uint32_t VersionEnum = Version - 1;
1464 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1465 HadError = true;
1466 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_flag)
1467 << /*version minor*/ VersionEnum;
1468 };
1469
1470 // Iterate through the elements and do basic validations
1471 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1472 SourceLocation Loc = RootSigElem.getLocation();
1473 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1474 if (const auto *Descriptor =
1475 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1476 VerifyRegister(Loc, Descriptor->Reg.Number);
1477 VerifySpace(Loc, Descriptor->Space);
1478
1479 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
1480 Flags: Descriptor->Flags))
1481 ReportFlagError(Loc);
1482 } else if (const auto *Constants =
1483 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1484 VerifyRegister(Loc, Constants->Reg.Number);
1485 VerifySpace(Loc, Constants->Space);
1486 } else if (const auto *Sampler =
1487 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1488 VerifyRegister(Loc, Sampler->Reg.Number);
1489 VerifySpace(Loc, Sampler->Space);
1490
1491 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1492 "By construction, parseFloatParam can't produce a NaN from a "
1493 "float_literal token");
1494
1495 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(MaxAnisotropy: Sampler->MaxAnisotropy))
1496 ReportError(Loc, 0, 16);
1497 if (!llvm::hlsl::rootsig::verifyMipLODBias(MipLODBias: Sampler->MipLODBias))
1498 ReportFloatError(Loc, -16.f, 15.99f);
1499 } else if (const auto *Clause =
1500 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1501 ptr: &Elem)) {
1502 VerifyRegister(Loc, Clause->Reg.Number);
1503 VerifySpace(Loc, Clause->Space);
1504
1505 if (!llvm::hlsl::rootsig::verifyNumDescriptors(NumDescriptors: Clause->NumDescriptors)) {
1506 // NumDescriptor could techincally be ~0u but that is reserved for
1507 // unbounded, so the diagnostic will not report that as a valid int
1508 // value
1509 ReportError(Loc, 1, 0xfffffffe);
1510 }
1511
1512 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Type: Clause->Type,
1513 Flags: Clause->Flags))
1514 ReportFlagError(Loc);
1515 }
1516 }
1517
1518 PerVisibilityBindingChecker BindingChecker(this);
1519 SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
1520 const hlsl::RootSignatureElement *>>
1521 UnboundClauses;
1522
1523 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1524 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1525 if (const auto *Descriptor =
1526 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1527 uint32_t LowerBound(Descriptor->Reg.Number);
1528 uint32_t UpperBound(LowerBound); // inclusive range
1529
1530 BindingChecker.trackBinding(
1531 Visibility: Descriptor->Visibility,
1532 RC: static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
1533 Space: Descriptor->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1534 } else if (const auto *Constants =
1535 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1536 uint32_t LowerBound(Constants->Reg.Number);
1537 uint32_t UpperBound(LowerBound); // inclusive range
1538
1539 BindingChecker.trackBinding(
1540 Visibility: Constants->Visibility, RC: llvm::dxil::ResourceClass::CBuffer,
1541 Space: Constants->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1542 } else if (const auto *Sampler =
1543 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1544 uint32_t LowerBound(Sampler->Reg.Number);
1545 uint32_t UpperBound(LowerBound); // inclusive range
1546
1547 BindingChecker.trackBinding(
1548 Visibility: Sampler->Visibility, RC: llvm::dxil::ResourceClass::Sampler,
1549 Space: Sampler->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1550 } else if (const auto *Clause =
1551 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1552 ptr: &Elem)) {
1553 // We'll process these once we see the table element.
1554 UnboundClauses.emplace_back(Args&: Clause, Args: &RootSigElem);
1555 } else if (const auto *Table =
1556 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(ptr: &Elem)) {
1557 assert(UnboundClauses.size() == Table->NumClauses &&
1558 "Number of unbound elements must match the number of clauses");
1559 bool HasAnySampler = false;
1560 bool HasAnyNonSampler = false;
1561 uint64_t Offset = 0;
1562 bool IsPrevUnbound = false;
1563 for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1564 SourceLocation Loc = ClauseElem->getLocation();
1565 if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
1566 HasAnySampler = true;
1567 else
1568 HasAnyNonSampler = true;
1569
1570 if (HasAnySampler && HasAnyNonSampler)
1571 Diag(Loc, DiagID: diag::err_hlsl_invalid_mixed_resources);
1572
1573 // Relevant error will have already been reported above and needs to be
1574 // fixed before we can conduct further analysis, so shortcut error
1575 // return
1576 if (Clause->NumDescriptors == 0)
1577 return true;
1578
1579 bool IsAppending =
1580 Clause->Offset == llvm::hlsl::rootsig::DescriptorTableOffsetAppend;
1581 if (!IsAppending)
1582 Offset = Clause->Offset;
1583
1584 uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
1585 Offset, Size: Clause->NumDescriptors);
1586
1587 if (IsPrevUnbound && IsAppending)
1588 Diag(Loc, DiagID: diag::err_hlsl_appending_onto_unbound);
1589 else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(Offset: RangeBound))
1590 Diag(Loc, DiagID: diag::err_hlsl_offset_overflow) << Offset << RangeBound;
1591
1592 // Update offset to be 1 past this range's bound
1593 Offset = RangeBound + 1;
1594 IsPrevUnbound = Clause->NumDescriptors ==
1595 llvm::hlsl::rootsig::NumDescriptorsUnbounded;
1596
1597 // Compute the register bounds and track resource binding
1598 uint32_t LowerBound(Clause->Reg.Number);
1599 uint32_t UpperBound = llvm::hlsl::rootsig::computeRangeBound(
1600 Offset: LowerBound, Size: Clause->NumDescriptors);
1601
1602 BindingChecker.trackBinding(
1603 Visibility: Table->Visibility,
1604 RC: static_cast<llvm::dxil::ResourceClass>(Clause->Type), Space: Clause->Space,
1605 LowerBound, UpperBound, Elem: ClauseElem);
1606 }
1607 UnboundClauses.clear();
1608 }
1609 }
1610
1611 return BindingChecker.checkOverlap();
1612}
1613
1614void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
1615 if (AL.getNumArgs() != 1) {
1616 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
1617 return;
1618 }
1619
1620 IdentifierInfo *Ident = AL.getArgAsIdent(Arg: 0)->getIdentifierInfo();
1621 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1622 if (RS->getSignatureIdent() != Ident) {
1623 Diag(Loc: AL.getLoc(), DiagID: diag::err_disallowed_duplicate_attribute) << RS;
1624 return;
1625 }
1626
1627 Diag(Loc: AL.getLoc(), DiagID: diag::warn_duplicate_attribute_exact) << RS;
1628 return;
1629 }
1630
1631 LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName);
1632 if (SemaRef.LookupQualifiedName(R, LookupCtx: D->getDeclContext()))
1633 if (auto *SignatureDecl =
1634 dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl())) {
1635 D->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
1636 getASTContext(), AL, Ident, SignatureDecl));
1637 }
1638}
1639
1640void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
1641 llvm::VersionTuple SMVersion =
1642 getASTContext().getTargetInfo().getTriple().getOSVersion();
1643 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1644 llvm::Triple::dxil;
1645
1646 uint32_t ZMax = 1024;
1647 uint32_t ThreadMax = 1024;
1648 if (IsDXIL && SMVersion.getMajor() <= 4) {
1649 ZMax = 1;
1650 ThreadMax = 768;
1651 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1652 ZMax = 64;
1653 ThreadMax = 1024;
1654 }
1655
1656 uint32_t X;
1657 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: X))
1658 return;
1659 if (X > 1024) {
1660 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1661 DiagID: diag::err_hlsl_numthreads_argument_oor)
1662 << 0 << 1024;
1663 return;
1664 }
1665 uint32_t Y;
1666 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Y))
1667 return;
1668 if (Y > 1024) {
1669 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1670 DiagID: diag::err_hlsl_numthreads_argument_oor)
1671 << 1 << 1024;
1672 return;
1673 }
1674 uint32_t Z;
1675 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Z))
1676 return;
1677 if (Z > ZMax) {
1678 SemaRef.Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1679 DiagID: diag::err_hlsl_numthreads_argument_oor)
1680 << 2 << ZMax;
1681 return;
1682 }
1683
1684 if (X * Y * Z > ThreadMax) {
1685 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_numthreads_invalid) << ThreadMax;
1686 return;
1687 }
1688
1689 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1690 if (NewAttr)
1691 D->addAttr(A: NewAttr);
1692}
1693
1694static bool isValidWaveSizeValue(unsigned Value) {
1695 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1696}
1697
1698void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
1699 // validate that the wavesize argument is a power of 2 between 4 and 128
1700 // inclusive
1701 unsigned SpelledArgsCount = AL.getNumArgs();
1702 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1703 return;
1704
1705 uint32_t Min;
1706 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Min))
1707 return;
1708
1709 uint32_t Max = 0;
1710 if (SpelledArgsCount > 1 &&
1711 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Max))
1712 return;
1713
1714 uint32_t Preferred = 0;
1715 if (SpelledArgsCount > 2 &&
1716 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Preferred))
1717 return;
1718
1719 if (SpelledArgsCount > 2) {
1720 if (!isValidWaveSizeValue(Value: Preferred)) {
1721 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1722 DiagID: diag::err_attribute_power_of_two_in_range)
1723 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1724 << Preferred;
1725 return;
1726 }
1727 // Preferred not in range.
1728 if (Preferred < Min || Preferred > Max) {
1729 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1730 DiagID: diag::err_attribute_power_of_two_in_range)
1731 << AL << Min << Max << Preferred;
1732 return;
1733 }
1734 } else if (SpelledArgsCount > 1) {
1735 if (!isValidWaveSizeValue(Value: Max)) {
1736 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1737 DiagID: diag::err_attribute_power_of_two_in_range)
1738 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1739 return;
1740 }
1741 if (Max < Min) {
1742 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_invalid) << AL << 1;
1743 return;
1744 } else if (Max == Min) {
1745 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attr_min_eq_max) << AL;
1746 }
1747 } else {
1748 if (!isValidWaveSizeValue(Value: Min)) {
1749 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1750 DiagID: diag::err_attribute_power_of_two_in_range)
1751 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1752 return;
1753 }
1754 }
1755
1756 HLSLWaveSizeAttr *NewAttr =
1757 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1758 if (NewAttr)
1759 D->addAttr(A: NewAttr);
1760}
1761
1762void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
1763 uint32_t ID;
1764 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: ID))
1765 return;
1766 D->addAttr(A: ::new (getASTContext())
1767 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1768}
1769
1770void SemaHLSL::handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL) {
1771 D->addAttr(A: ::new (getASTContext())
1772 HLSLVkPushConstantAttr(getASTContext(), AL));
1773}
1774
1775void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1776 uint32_t Id;
1777 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Id))
1778 return;
1779 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1780 if (NewAttr)
1781 D->addAttr(A: NewAttr);
1782}
1783
1784void SemaHLSL::handleVkBindingAttr(Decl *D, const ParsedAttr &AL) {
1785 uint32_t Binding = 0;
1786 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Binding))
1787 return;
1788 uint32_t Set = 0;
1789 if (AL.getNumArgs() > 1 &&
1790 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Set))
1791 return;
1792
1793 D->addAttr(A: ::new (getASTContext())
1794 HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
1795}
1796
1797void SemaHLSL::handleVkLocationAttr(Decl *D, const ParsedAttr &AL) {
1798 uint32_t Location;
1799 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Location))
1800 return;
1801
1802 D->addAttr(A: ::new (getASTContext())
1803 HLSLVkLocationAttr(getASTContext(), AL, Location));
1804}
1805
1806bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
1807 const auto *VT = T->getAs<VectorType>();
1808
1809 if (!T->hasUnsignedIntegerRepresentation() ||
1810 (VT && VT->getNumElements() > 3)) {
1811 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1812 << AL << "uint/uint2/uint3";
1813 return false;
1814 }
1815
1816 return true;
1817}
1818
1819bool SemaHLSL::diagnosePositionType(QualType T, const ParsedAttr &AL) {
1820 const auto *VT = T->getAs<VectorType>();
1821 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1822 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1823 << AL << "float/float1/float2/float3/float4";
1824 return false;
1825 }
1826
1827 return true;
1828}
1829
1830void SemaHLSL::diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL,
1831 std::optional<unsigned> Index) {
1832 std::string SemanticName = AL.getAttrName()->getName().upper();
1833
1834 auto *VD = cast<ValueDecl>(Val: D);
1835 QualType ValueType = VD->getType();
1836 if (auto *FD = dyn_cast<FunctionDecl>(Val: D))
1837 ValueType = FD->getReturnType();
1838
1839 bool IsOutput = false;
1840 if (HLSLParamModifierAttr *MA = D->getAttr<HLSLParamModifierAttr>()) {
1841 if (MA->isOut()) {
1842 IsOutput = true;
1843 ValueType = cast<ReferenceType>(Val&: ValueType)->getPointeeType();
1844 }
1845 }
1846
1847 if (SemanticName == "SV_DISPATCHTHREADID") {
1848 diagnoseInputIDType(T: ValueType, AL);
1849 if (IsOutput)
1850 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1851 if (Index.has_value())
1852 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1853 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1854 return;
1855 }
1856
1857 if (SemanticName == "SV_GROUPINDEX") {
1858 if (IsOutput)
1859 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1860 if (Index.has_value())
1861 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1862 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1863 return;
1864 }
1865
1866 if (SemanticName == "SV_GROUPTHREADID") {
1867 diagnoseInputIDType(T: ValueType, AL);
1868 if (IsOutput)
1869 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1870 if (Index.has_value())
1871 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1872 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1873 return;
1874 }
1875
1876 if (SemanticName == "SV_GROUPID") {
1877 diagnoseInputIDType(T: ValueType, AL);
1878 if (IsOutput)
1879 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1880 if (Index.has_value())
1881 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1882 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1883 return;
1884 }
1885
1886 if (SemanticName == "SV_POSITION") {
1887 const auto *VT = ValueType->getAs<VectorType>();
1888 if (!ValueType->hasFloatingRepresentation() ||
1889 (VT && VT->getNumElements() > 4))
1890 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1891 << AL << "float/float1/float2/float3/float4";
1892 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1893 return;
1894 }
1895
1896 if (SemanticName == "SV_TARGET") {
1897 const auto *VT = ValueType->getAs<VectorType>();
1898 if (!ValueType->hasFloatingRepresentation() ||
1899 (VT && VT->getNumElements() > 4))
1900 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1901 << AL << "float/float1/float2/float3/float4";
1902 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1903 return;
1904 }
1905
1906 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_unknown_semantic) << AL;
1907}
1908
1909void SemaHLSL::handleSemanticAttr(Decl *D, const ParsedAttr &AL) {
1910 uint32_t IndexValue(0), ExplicitIndex(0);
1911 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: IndexValue) ||
1912 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: ExplicitIndex)) {
1913 assert(0 && "HLSLUnparsedSemantic is expected to have 2 int arguments.");
1914 }
1915 assert(IndexValue > 0 ? ExplicitIndex : true);
1916 std::optional<unsigned> Index =
1917 ExplicitIndex ? std::optional<unsigned>(IndexValue) : std::nullopt;
1918
1919 if (AL.getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
1920 diagnoseSystemSemanticAttr(D, AL, Index);
1921 else
1922 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1923}
1924
1925void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
1926 if (!isa<VarDecl>(Val: D) || !isa<HLSLBufferDecl>(Val: D->getDeclContext())) {
1927 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_ast_node)
1928 << AL << "shader constant in a constant buffer";
1929 return;
1930 }
1931
1932 uint32_t SubComponent;
1933 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: SubComponent))
1934 return;
1935 uint32_t Component;
1936 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Component))
1937 return;
1938
1939 QualType T = cast<VarDecl>(Val: D)->getType().getCanonicalType();
1940 // Check if T is an array or struct type.
1941 // TODO: mark matrix type as aggregate type.
1942 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
1943
1944 // Check Component is valid for T.
1945 if (Component) {
1946 unsigned Size = getASTContext().getTypeSize(T);
1947 if (IsAggregateTy) {
1948 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_invalid_register_or_packoffset);
1949 return;
1950 } else {
1951 // Make sure Component + sizeof(T) <= 4.
1952 if ((Component * 32 + Size) > 128) {
1953 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_cross_reg_boundary);
1954 return;
1955 }
1956 QualType EltTy = T;
1957 if (const auto *VT = T->getAs<VectorType>())
1958 EltTy = VT->getElementType();
1959 unsigned Align = getASTContext().getTypeAlign(T: EltTy);
1960 if (Align > 32 && Component == 1) {
1961 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
1962 // So we only need to check Component 1 here.
1963 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_alignment_mismatch)
1964 << Align << EltTy;
1965 return;
1966 }
1967 }
1968 }
1969
1970 D->addAttr(A: ::new (getASTContext()) HLSLPackOffsetAttr(
1971 getASTContext(), AL, SubComponent, Component));
1972}
1973
1974void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
1975 StringRef Str;
1976 SourceLocation ArgLoc;
1977 if (!SemaRef.checkStringLiteralArgumentAttr(Attr: AL, ArgNum: 0, Str, ArgLocation: &ArgLoc))
1978 return;
1979
1980 llvm::Triple::EnvironmentType ShaderType;
1981 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Val: Str, Out&: ShaderType)) {
1982 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_type_not_supported)
1983 << AL << Str << ArgLoc;
1984 return;
1985 }
1986
1987 // FIXME: check function match the shader stage.
1988
1989 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
1990 if (NewAttr)
1991 D->addAttr(A: NewAttr);
1992}
1993
1994bool clang::CreateHLSLAttributedResourceType(
1995 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
1996 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
1997 assert(AttrList.size() && "expected list of resource attributes");
1998
1999 QualType ContainedTy = QualType();
2000 TypeSourceInfo *ContainedTyInfo = nullptr;
2001 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
2002 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
2003
2004 HLSLAttributedResourceType::Attributes ResAttrs;
2005
2006 bool HasResourceClass = false;
2007 bool HasResourceDimension = false;
2008 for (const Attr *A : AttrList) {
2009 if (!A)
2010 continue;
2011 LocEnd = A->getRange().getEnd();
2012 switch (A->getKind()) {
2013 case attr::HLSLResourceClass: {
2014 ResourceClass RC = cast<HLSLResourceClassAttr>(Val: A)->getResourceClass();
2015 if (HasResourceClass) {
2016 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceClass == RC
2017 ? diag::warn_duplicate_attribute_exact
2018 : diag::warn_duplicate_attribute)
2019 << A;
2020 return false;
2021 }
2022 ResAttrs.ResourceClass = RC;
2023 HasResourceClass = true;
2024 break;
2025 }
2026 case attr::HLSLResourceDimension: {
2027 llvm::dxil::ResourceDimension RD =
2028 cast<HLSLResourceDimensionAttr>(Val: A)->getDimension();
2029 if (HasResourceDimension) {
2030 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceDimension == RD
2031 ? diag::warn_duplicate_attribute_exact
2032 : diag::warn_duplicate_attribute)
2033 << A;
2034 return false;
2035 }
2036 ResAttrs.ResourceDimension = RD;
2037 HasResourceDimension = true;
2038 break;
2039 }
2040 case attr::HLSLROV:
2041 if (ResAttrs.IsROV) {
2042 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2043 return false;
2044 }
2045 ResAttrs.IsROV = true;
2046 break;
2047 case attr::HLSLRawBuffer:
2048 if (ResAttrs.RawBuffer) {
2049 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2050 return false;
2051 }
2052 ResAttrs.RawBuffer = true;
2053 break;
2054 case attr::HLSLIsCounter:
2055 if (ResAttrs.IsCounter) {
2056 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2057 return false;
2058 }
2059 ResAttrs.IsCounter = true;
2060 break;
2061 case attr::HLSLContainedType: {
2062 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(Val: A);
2063 QualType Ty = CTAttr->getType();
2064 if (!ContainedTy.isNull()) {
2065 S.Diag(Loc: A->getLocation(), DiagID: ContainedTy == Ty
2066 ? diag::warn_duplicate_attribute_exact
2067 : diag::warn_duplicate_attribute)
2068 << A;
2069 return false;
2070 }
2071 ContainedTy = Ty;
2072 ContainedTyInfo = CTAttr->getTypeLoc();
2073 break;
2074 }
2075 default:
2076 llvm_unreachable("unhandled resource attribute type");
2077 }
2078 }
2079
2080 if (!HasResourceClass) {
2081 S.Diag(Loc: AttrList.back()->getRange().getEnd(),
2082 DiagID: diag::err_hlsl_missing_resource_class);
2083 return false;
2084 }
2085
2086 ResType = S.getASTContext().getHLSLAttributedResourceType(
2087 Wrapped, Contained: ContainedTy, Attrs: ResAttrs);
2088
2089 if (LocInfo && ContainedTyInfo) {
2090 LocInfo->Range = SourceRange(LocBegin, LocEnd);
2091 LocInfo->ContainedTyInfo = ContainedTyInfo;
2092 }
2093 return true;
2094}
2095
2096// Validates and creates an HLSL attribute that is applied as type attribute on
2097// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
2098// the end of the declaration they are applied to the declaration type by
2099// wrapping it in HLSLAttributedResourceType.
2100bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) {
2101 // only allow resource type attributes on intangible types
2102 if (!T->isHLSLResourceType()) {
2103 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attribute_needs_intangible_type)
2104 << AL << getASTContext().HLSLResourceTy;
2105 return false;
2106 }
2107
2108 // validate number of arguments
2109 if (!AL.checkExactlyNumArgs(S&: SemaRef, Num: AL.getMinArgs()))
2110 return false;
2111
2112 Attr *A = nullptr;
2113
2114 AttributeCommonInfo ACI(
2115 AL.getLoc(), AttributeScopeInfo(AL.getScopeName(), AL.getScopeLoc()),
2116 AttributeCommonInfo::NoSemaHandlerAttribute,
2117 {
2118 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
2119 false /*IsRegularKeywordAttribute*/
2120 });
2121
2122 switch (AL.getKind()) {
2123 case ParsedAttr::AT_HLSLResourceClass: {
2124 if (!AL.isArgIdent(Arg: 0)) {
2125 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2126 << AL << AANT_ArgumentIdentifier;
2127 return false;
2128 }
2129
2130 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
2131 StringRef Identifier = Loc->getIdentifierInfo()->getName();
2132 SourceLocation ArgLoc = Loc->getLoc();
2133
2134 // Validate resource class value
2135 ResourceClass RC;
2136 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Val: Identifier, Out&: RC)) {
2137 Diag(Loc: ArgLoc, DiagID: diag::warn_attribute_type_not_supported)
2138 << "ResourceClass" << Identifier;
2139 return false;
2140 }
2141 A = HLSLResourceClassAttr::Create(Ctx&: getASTContext(), ResourceClass: RC, CommonInfo: ACI);
2142 break;
2143 }
2144
2145 case ParsedAttr::AT_HLSLROV:
2146 A = HLSLROVAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2147 break;
2148
2149 case ParsedAttr::AT_HLSLRawBuffer:
2150 A = HLSLRawBufferAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2151 break;
2152
2153 case ParsedAttr::AT_HLSLIsCounter:
2154 A = HLSLIsCounterAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2155 break;
2156
2157 case ParsedAttr::AT_HLSLContainedType: {
2158 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
2159 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
2160 return false;
2161 }
2162
2163 TypeSourceInfo *TSI = nullptr;
2164 QualType QT = SemaRef.GetTypeFromParser(Ty: AL.getTypeArg(), TInfo: &TSI);
2165 assert(TSI && "no type source info for attribute argument");
2166 if (SemaRef.RequireCompleteType(Loc: TSI->getTypeLoc().getBeginLoc(), T: QT,
2167 DiagID: diag::err_incomplete_type))
2168 return false;
2169 A = HLSLContainedTypeAttr::Create(Ctx&: getASTContext(), Type: TSI, CommonInfo: ACI);
2170 break;
2171 }
2172
2173 default:
2174 llvm_unreachable("unhandled HLSL attribute");
2175 }
2176
2177 HLSLResourcesTypeAttrs.emplace_back(Args&: A);
2178 return true;
2179}
2180
2181// Combines all resource type attributes and creates HLSLAttributedResourceType.
2182QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) {
2183 if (!HLSLResourcesTypeAttrs.size())
2184 return CurrentType;
2185
2186 QualType QT = CurrentType;
2187 HLSLAttributedResourceLocInfo LocInfo;
2188 if (CreateHLSLAttributedResourceType(S&: SemaRef, Wrapped: CurrentType,
2189 AttrList: HLSLResourcesTypeAttrs, ResType&: QT, LocInfo: &LocInfo)) {
2190 const HLSLAttributedResourceType *RT =
2191 cast<HLSLAttributedResourceType>(Val: QT.getTypePtr());
2192
2193 // Temporarily store TypeLoc information for the new type.
2194 // It will be transferred to HLSLAttributesResourceTypeLoc
2195 // shortly after the type is created by TypeSpecLocFiller which
2196 // will call the TakeLocForHLSLAttribute method below.
2197 LocsForHLSLAttributedResources.insert(KV: std::pair(RT, LocInfo));
2198 }
2199 HLSLResourcesTypeAttrs.clear();
2200 return QT;
2201}
2202
2203// Returns source location for the HLSLAttributedResourceType
2204HLSLAttributedResourceLocInfo
2205SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
2206 HLSLAttributedResourceLocInfo LocInfo = {};
2207 auto I = LocsForHLSLAttributedResources.find(Val: RT);
2208 if (I != LocsForHLSLAttributedResources.end()) {
2209 LocInfo = I->second;
2210 LocsForHLSLAttributedResources.erase(I);
2211 return LocInfo;
2212 }
2213 LocInfo.Range = SourceRange();
2214 return LocInfo;
2215}
2216
2217// Walks though the global variable declaration, collects all resource binding
2218// requirements and adds them to Bindings
2219void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
2220 const RecordType *RT) {
2221 const RecordDecl *RD = RT->getDecl()->getDefinitionOrSelf();
2222 for (FieldDecl *FD : RD->fields()) {
2223 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
2224
2225 // Unwrap arrays
2226 // FIXME: Calculate array size while unwrapping
2227 assert(!Ty->isIncompleteArrayType() &&
2228 "incomplete arrays inside user defined types are not supported");
2229 while (Ty->isConstantArrayType()) {
2230 const ConstantArrayType *CAT = cast<ConstantArrayType>(Val: Ty);
2231 Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
2232 }
2233
2234 if (!Ty->isRecordType())
2235 continue;
2236
2237 if (const HLSLAttributedResourceType *AttrResType =
2238 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
2239 // Add a new DeclBindingInfo to Bindings if it does not already exist
2240 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
2241 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, ResClass: RC);
2242 if (!DBI)
2243 Bindings.addDeclBindingInfo(VD, ResClass: RC);
2244 } else if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty)) {
2245 // Recursively scan embedded struct or class; it would be nice to do this
2246 // without recursion, but tricky to correctly calculate the size of the
2247 // binding, which is something we are probably going to need to do later
2248 // on. Hopefully nesting of structs in structs too many levels is
2249 // unlikely.
2250 collectResourceBindingsOnUserRecordDecl(VD, RT);
2251 }
2252 }
2253}
2254
2255// Diagnose localized register binding errors for a single binding; does not
2256// diagnose resource binding on user record types, that will be done later
2257// in processResourceBindingOnDecl based on the information collected in
2258// collectResourceBindingsOnVarDecl.
2259// Returns false if the register binding is not valid.
2260static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
2261 Decl *D, RegisterType RegType,
2262 bool SpecifiedSpace) {
2263 int RegTypeNum = static_cast<int>(RegType);
2264
2265 // check if the decl type is groupshared
2266 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
2267 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2268 return false;
2269 }
2270
2271 // Cbuffers and Tbuffers are HLSLBufferDecl types
2272 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: D)) {
2273 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
2274 : ResourceClass::SRV;
2275 if (RegType == getRegisterType(RC))
2276 return true;
2277
2278 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
2279 << RegTypeNum;
2280 return false;
2281 }
2282
2283 // Samplers, UAVs, and SRVs are VarDecl types
2284 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
2285 VarDecl *VD = cast<VarDecl>(Val: D);
2286
2287 // Resource
2288 if (const HLSLAttributedResourceType *AttrResType =
2289 HLSLAttributedResourceType::findHandleTypeOnResource(
2290 RT: VD->getType().getTypePtr())) {
2291 if (RegType == getRegisterType(ResTy: AttrResType))
2292 return true;
2293
2294 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
2295 << RegTypeNum;
2296 return false;
2297 }
2298
2299 const clang::Type *Ty = VD->getType().getTypePtr();
2300 while (Ty->isArrayType())
2301 Ty = Ty->getArrayElementTypeNoTypeQual();
2302
2303 // Basic types
2304 if (Ty->isArithmeticType() || Ty->isVectorType()) {
2305 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(Val: D->getDeclContext());
2306 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
2307 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_space_on_global_constant);
2308
2309 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(Ctx: S.getASTContext()) ||
2310 Ty->isFloatingType() || Ty->isVectorType())) {
2311 // Register annotation on default constant buffer declaration ($Globals)
2312 if (RegType == RegisterType::CBuffer)
2313 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_deprecated_register_type_b);
2314 else if (RegType != RegisterType::C)
2315 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2316 else
2317 return true;
2318 } else {
2319 if (RegType == RegisterType::C)
2320 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_register_type_c_packoffset);
2321 else
2322 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2323 }
2324 return false;
2325 }
2326 if (Ty->isRecordType())
2327 // RecordTypes will be diagnosed in processResourceBindingOnDecl
2328 // that is called from ActOnVariableDeclarator
2329 return true;
2330
2331 // Anything else is an error
2332 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2333 return false;
2334}
2335
2336static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
2337 RegisterType regType) {
2338 // make sure that there are no two register annotations
2339 // applied to the decl with the same register type
2340 bool RegisterTypesDetected[5] = {false};
2341 RegisterTypesDetected[static_cast<int>(regType)] = true;
2342
2343 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
2344 if (HLSLResourceBindingAttr *attr =
2345 dyn_cast<HLSLResourceBindingAttr>(Val: *it)) {
2346
2347 RegisterType otherRegType = attr->getRegisterType();
2348 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
2349 int otherRegTypeNum = static_cast<int>(otherRegType);
2350 S.Diag(Loc: TheDecl->getLocation(),
2351 DiagID: diag::err_hlsl_duplicate_register_annotation)
2352 << otherRegTypeNum;
2353 return false;
2354 }
2355 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
2356 }
2357 }
2358 return true;
2359}
2360
2361static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
2362 Decl *D, RegisterType RegType,
2363 bool SpecifiedSpace) {
2364
2365 // exactly one of these two types should be set
2366 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
2367 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
2368 "expecting VarDecl or HLSLBufferDecl");
2369
2370 // check if the declaration contains resource matching the register type
2371 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
2372 return false;
2373
2374 // next, if multiple register annotations exist, check that none conflict.
2375 return ValidateMultipleRegisterAnnotations(S, TheDecl: D, regType: RegType);
2376}
2377
2378// return false if the slot count exceeds the limit, true otherwise
2379static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot,
2380 const uint64_t &Limit,
2381 const ResourceClass ResClass,
2382 ASTContext &Ctx,
2383 uint64_t ArrayCount = 1) {
2384 Ty = Ty.getCanonicalType();
2385 const Type *T = Ty.getTypePtr();
2386
2387 // Early exit if already overflowed
2388 if (StartSlot > Limit)
2389 return false;
2390
2391 // Case 1: array type
2392 if (const auto *AT = dyn_cast<ArrayType>(Val: T)) {
2393 uint64_t Count = 1;
2394
2395 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: AT))
2396 Count = CAT->getSize().getZExtValue();
2397
2398 QualType ElemTy = AT->getElementType();
2399 return AccumulateHLSLResourceSlots(Ty: ElemTy, StartSlot, Limit, ResClass, Ctx,
2400 ArrayCount: ArrayCount * Count);
2401 }
2402
2403 // Case 2: resource leaf
2404 if (auto ResTy = dyn_cast<HLSLAttributedResourceType>(Val: T)) {
2405 // First ensure this resource counts towards the corresponding
2406 // register type limit.
2407 if (ResTy->getAttrs().ResourceClass != ResClass)
2408 return true;
2409
2410 // Validate highest slot used
2411 uint64_t EndSlot = StartSlot + ArrayCount - 1;
2412 if (EndSlot > Limit)
2413 return false;
2414
2415 // Advance SlotCount past the consumed range
2416 StartSlot = EndSlot + 1;
2417 return true;
2418 }
2419
2420 // Case 3: struct / record
2421 if (const auto *RT = dyn_cast<RecordType>(Val: T)) {
2422 const RecordDecl *RD = RT->getDecl();
2423
2424 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
2425 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
2426 if (!AccumulateHLSLResourceSlots(Ty: Base.getType(), StartSlot, Limit,
2427 ResClass, Ctx, ArrayCount))
2428 return false;
2429 }
2430 }
2431
2432 for (const FieldDecl *Field : RD->fields()) {
2433 if (!AccumulateHLSLResourceSlots(Ty: Field->getType(), StartSlot, Limit,
2434 ResClass, Ctx, ArrayCount))
2435 return false;
2436 }
2437
2438 return true;
2439 }
2440
2441 // Case 4: everything else
2442 return true;
2443}
2444
2445// return true if there is something invalid, false otherwise
2446static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl,
2447 ASTContext &Ctx, RegisterType RegTy) {
2448 const uint64_t Limit = UINT32_MAX;
2449 if (SlotNum > Limit)
2450 return true;
2451
2452 // after verifying the number doesn't exceed uint32max, we don't need
2453 // to look further into c or i register types
2454 if (RegTy == RegisterType::C || RegTy == RegisterType::I)
2455 return false;
2456
2457 if (VarDecl *VD = dyn_cast<VarDecl>(Val: TheDecl)) {
2458 uint64_t BaseSlot = SlotNum;
2459
2460 if (!AccumulateHLSLResourceSlots(Ty: VD->getType(), StartSlot&: SlotNum, Limit,
2461 ResClass: getResourceClass(RT: RegTy), Ctx))
2462 return true;
2463
2464 // After AccumulateHLSLResourceSlots runs, SlotNum is now
2465 // the first free slot; last used was SlotNum - 1
2466 return (BaseSlot > Limit);
2467 }
2468 // handle the cbuffer/tbuffer case
2469 if (isa<HLSLBufferDecl>(Val: TheDecl))
2470 // resources cannot be put within a cbuffer, so no need
2471 // to analyze the structure since the register number
2472 // won't be pushed any higher.
2473 return (SlotNum > Limit);
2474
2475 // we don't expect any other decl type, so fail
2476 llvm_unreachable("unexpected decl type");
2477}
2478
2479void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
2480 if (VarDecl *VD = dyn_cast<VarDecl>(Val: TheDecl)) {
2481 QualType Ty = VD->getType();
2482 if (const auto *IAT = dyn_cast<IncompleteArrayType>(Val&: Ty))
2483 Ty = IAT->getElementType();
2484 if (SemaRef.RequireCompleteType(Loc: TheDecl->getBeginLoc(), T: Ty,
2485 DiagID: diag::err_incomplete_type))
2486 return;
2487 }
2488
2489 StringRef Slot = "";
2490 StringRef Space = "";
2491 SourceLocation SlotLoc, SpaceLoc;
2492
2493 if (!AL.isArgIdent(Arg: 0)) {
2494 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2495 << AL << AANT_ArgumentIdentifier;
2496 return;
2497 }
2498 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
2499
2500 if (AL.getNumArgs() == 2) {
2501 Slot = Loc->getIdentifierInfo()->getName();
2502 SlotLoc = Loc->getLoc();
2503 if (!AL.isArgIdent(Arg: 1)) {
2504 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2505 << AL << AANT_ArgumentIdentifier;
2506 return;
2507 }
2508 Loc = AL.getArgAsIdent(Arg: 1);
2509 Space = Loc->getIdentifierInfo()->getName();
2510 SpaceLoc = Loc->getLoc();
2511 } else {
2512 StringRef Str = Loc->getIdentifierInfo()->getName();
2513 if (Str.starts_with(Prefix: "space")) {
2514 Space = Str;
2515 SpaceLoc = Loc->getLoc();
2516 } else {
2517 Slot = Str;
2518 SlotLoc = Loc->getLoc();
2519 Space = "space0";
2520 }
2521 }
2522
2523 RegisterType RegType = RegisterType::SRV;
2524 std::optional<unsigned> SlotNum;
2525 unsigned SpaceNum = 0;
2526
2527 // Validate slot
2528 if (!Slot.empty()) {
2529 if (!convertToRegisterType(Slot, RT: &RegType)) {
2530 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_binding_type_invalid) << Slot.substr(Start: 0, N: 1);
2531 return;
2532 }
2533 if (RegType == RegisterType::I) {
2534 Diag(Loc: SlotLoc, DiagID: diag::warn_hlsl_deprecated_register_type_i);
2535 return;
2536 }
2537 const StringRef SlotNumStr = Slot.substr(Start: 1);
2538
2539 uint64_t N;
2540
2541 // validate that the slot number is a non-empty number
2542 if (SlotNumStr.getAsInteger(Radix: 10, Result&: N)) {
2543 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_unsupported_register_number);
2544 return;
2545 }
2546
2547 // Validate register number. It should not exceed UINT32_MAX,
2548 // including if the resource type is an array that starts
2549 // before UINT32_MAX, but ends afterwards.
2550 if (ValidateRegisterNumber(SlotNum: N, TheDecl, Ctx&: getASTContext(), RegTy: RegType)) {
2551 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_register_number_too_large);
2552 return;
2553 }
2554
2555 // the slot number has been validated and does not exceed UINT32_MAX
2556 SlotNum = (unsigned)N;
2557 }
2558
2559 // Validate space
2560 if (!Space.starts_with(Prefix: "space")) {
2561 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2562 return;
2563 }
2564 StringRef SpaceNumStr = Space.substr(Start: 5);
2565 if (SpaceNumStr.getAsInteger(Radix: 10, Result&: SpaceNum)) {
2566 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2567 return;
2568 }
2569
2570 // If we have slot, diagnose it is the right register type for the decl
2571 if (SlotNum.has_value())
2572 if (!DiagnoseHLSLRegisterAttribute(S&: SemaRef, ArgLoc&: SlotLoc, D: TheDecl, RegType,
2573 SpecifiedSpace: !SpaceLoc.isInvalid()))
2574 return;
2575
2576 HLSLResourceBindingAttr *NewAttr =
2577 HLSLResourceBindingAttr::Create(Ctx&: getASTContext(), Slot, Space, CommonInfo: AL);
2578 if (NewAttr) {
2579 NewAttr->setBinding(RT: RegType, SlotNum, SpaceNum);
2580 TheDecl->addAttr(A: NewAttr);
2581 }
2582}
2583
2584void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
2585 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2586 D, AL,
2587 Spelling: static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2588 if (NewAttr)
2589 D->addAttr(A: NewAttr);
2590}
2591
2592namespace {
2593
2594/// This class implements HLSL availability diagnostics for default
2595/// and relaxed mode
2596///
2597/// The goal of this diagnostic is to emit an error or warning when an
2598/// unavailable API is found in code that is reachable from the shader
2599/// entry function or from an exported function (when compiling a shader
2600/// library).
2601///
2602/// This is done by traversing the AST of all shader entry point functions
2603/// and of all exported functions, and any functions that are referenced
2604/// from this AST. In other words, any functions that are reachable from
2605/// the entry points.
2606class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2607 Sema &SemaRef;
2608
2609 // Stack of functions to be scaned
2610 llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;
2611
2612 // Tracks which environments functions have been scanned in.
2613 //
2614 // Maps FunctionDecl to an unsigned number that represents the set of shader
2615 // environments the function has been scanned for.
2616 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2617 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2618 // (verified by static_asserts in Triple.cpp), we can use it to index
2619 // individual bits in the set, as long as we shift the values to start with 0
2620 // by subtracting the value of llvm::Triple::Pixel first.
2621 //
2622 // The N'th bit in the set will be set if the function has been scanned
2623 // in shader environment whose llvm::Triple::EnvironmentType integer value
2624 // equals (llvm::Triple::Pixel + N).
2625 //
2626 // For example, if a function has been scanned in compute and pixel stage
2627 // environment, the value will be 0x21 (100001 binary) because:
2628 //
2629 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2630 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2631 //
2632 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2633 // been scanned in any environment.
2634 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2635
2636 // Do not access these directly, use the get/set methods below to make
2637 // sure the values are in sync
2638 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2639 unsigned CurrentShaderStageBit;
2640
2641 // True if scanning a function that was already scanned in a different
2642 // shader stage context, and therefore we should not report issues that
2643 // depend only on shader model version because they would be duplicate.
2644 bool ReportOnlyShaderStageIssues;
2645
2646 // Helper methods for dealing with current stage context / environment
2647 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2648 static_assert(sizeof(unsigned) >= 4);
2649 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2650 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2651 "ShaderType is too big for this bitmap"); // 31 is reserved for
2652 // "unknown"
2653
2654 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2655 CurrentShaderEnvironment = ShaderType;
2656 CurrentShaderStageBit = (1 << bitmapIndex);
2657 }
2658
2659 void SetUnknownShaderStageContext() {
2660 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2661 CurrentShaderStageBit = (1 << 31);
2662 }
2663
2664 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2665 return CurrentShaderEnvironment;
2666 }
2667
2668 bool InUnknownShaderStageContext() const {
2669 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2670 }
2671
2672 // Helper methods for dealing with shader stage bitmap
2673 void AddToScannedFunctions(const FunctionDecl *FD) {
2674 unsigned &ScannedStages = ScannedDecls[FD];
2675 ScannedStages |= CurrentShaderStageBit;
2676 }
2677
2678 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2679
2680 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2681 return WasAlreadyScannedInCurrentStage(ScannerStages: GetScannedStages(FD));
2682 }
2683
2684 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2685 return ScannerStages & CurrentShaderStageBit;
2686 }
2687
2688 static bool NeverBeenScanned(unsigned ScannedStages) {
2689 return ScannedStages == 0;
2690 }
2691
2692 // Scanning methods
2693 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2694 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2695 SourceRange Range);
2696 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2697 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2698
2699public:
2700 DiagnoseHLSLAvailability(Sema &SemaRef)
2701 : SemaRef(SemaRef),
2702 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2703 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2704
2705 // AST traversal methods
2706 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2707 void RunOnFunction(const FunctionDecl *FD);
2708
2709 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2710 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: DRE->getDecl());
2711 if (FD)
2712 HandleFunctionOrMethodRef(FD, RefExpr: DRE);
2713 return true;
2714 }
2715
2716 bool VisitMemberExpr(MemberExpr *ME) override {
2717 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: ME->getMemberDecl());
2718 if (FD)
2719 HandleFunctionOrMethodRef(FD, RefExpr: ME);
2720 return true;
2721 }
2722};
2723
2724void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2725 Expr *RefExpr) {
2726 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2727 "expected DeclRefExpr or MemberExpr");
2728
2729 // has a definition -> add to stack to be scanned
2730 const FunctionDecl *FDWithBody = nullptr;
2731 if (FD->hasBody(Definition&: FDWithBody)) {
2732 if (!WasAlreadyScannedInCurrentStage(FD: FDWithBody))
2733 DeclsToScan.push_back(Elt: FDWithBody);
2734 return;
2735 }
2736
2737 // no body -> diagnose availability
2738 const AvailabilityAttr *AA = FindAvailabilityAttr(D: FD);
2739 if (AA)
2740 CheckDeclAvailability(
2741 D: FD, AA, Range: SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2742}
2743
2744void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2745 const TranslationUnitDecl *TU) {
2746
2747 // Iterate over all shader entry functions and library exports, and for those
2748 // that have a body (definiton), run diag scan on each, setting appropriate
2749 // shader environment context based on whether it is a shader entry function
2750 // or an exported function. Exported functions can be in namespaces and in
2751 // export declarations so we need to scan those declaration contexts as well.
2752 llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;
2753 DeclContextsToScan.push_back(Elt: TU);
2754
2755 while (!DeclContextsToScan.empty()) {
2756 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2757 for (auto &D : DC->decls()) {
2758 // do not scan implicit declaration generated by the implementation
2759 if (D->isImplicit())
2760 continue;
2761
2762 // for namespace or export declaration add the context to the list to be
2763 // scanned later
2764 if (llvm::dyn_cast<NamespaceDecl>(Val: D) || llvm::dyn_cast<ExportDecl>(Val: D)) {
2765 DeclContextsToScan.push_back(Elt: llvm::dyn_cast<DeclContext>(Val: D));
2766 continue;
2767 }
2768
2769 // skip over other decls or function decls without body
2770 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: D);
2771 if (!FD || !FD->isThisDeclarationADefinition())
2772 continue;
2773
2774 // shader entry point
2775 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2776 SetShaderStageContext(ShaderAttr->getType());
2777 RunOnFunction(FD);
2778 continue;
2779 }
2780 // exported library function
2781 // FIXME: replace this loop with external linkage check once issue #92071
2782 // is resolved
2783 bool isExport = FD->isInExportDeclContext();
2784 if (!isExport) {
2785 for (const auto *Redecl : FD->redecls()) {
2786 if (Redecl->isInExportDeclContext()) {
2787 isExport = true;
2788 break;
2789 }
2790 }
2791 }
2792 if (isExport) {
2793 SetUnknownShaderStageContext();
2794 RunOnFunction(FD);
2795 continue;
2796 }
2797 }
2798 }
2799}
2800
2801void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2802 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2803 DeclsToScan.push_back(Elt: FD);
2804
2805 while (!DeclsToScan.empty()) {
2806 // Take one decl from the stack and check it by traversing its AST.
2807 // For any CallExpr found during the traversal add it's callee to the top of
2808 // the stack to be processed next. Functions already processed are stored in
2809 // ScannedDecls.
2810 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2811
2812 // Decl was already scanned
2813 const unsigned ScannedStages = GetScannedStages(FD);
2814 if (WasAlreadyScannedInCurrentStage(ScannerStages: ScannedStages))
2815 continue;
2816
2817 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2818
2819 AddToScannedFunctions(FD);
2820 TraverseStmt(S: FD->getBody());
2821 }
2822}
2823
2824bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2825 const AvailabilityAttr *AA) {
2826 const IdentifierInfo *IIEnvironment = AA->getEnvironment();
2827 if (!IIEnvironment)
2828 return true;
2829
2830 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2831 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2832 return false;
2833
2834 llvm::Triple::EnvironmentType AttrEnv =
2835 AvailabilityAttr::getEnvironmentType(Environment: IIEnvironment->getName());
2836
2837 return CurrentEnv == AttrEnv;
2838}
2839
2840const AvailabilityAttr *
2841DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2842 AvailabilityAttr const *PartialMatch = nullptr;
2843 // Check each AvailabilityAttr to find the one for this platform.
2844 // For multiple attributes with the same platform try to find one for this
2845 // environment.
2846 for (const auto *A : D->attrs()) {
2847 if (const auto *Avail = dyn_cast<AvailabilityAttr>(Val: A)) {
2848 StringRef AttrPlatform = Avail->getPlatform()->getName();
2849 StringRef TargetPlatform =
2850 SemaRef.getASTContext().getTargetInfo().getPlatformName();
2851
2852 // Match the platform name.
2853 if (AttrPlatform == TargetPlatform) {
2854 // Find the best matching attribute for this environment
2855 if (HasMatchingEnvironmentOrNone(AA: Avail))
2856 return Avail;
2857 PartialMatch = Avail;
2858 }
2859 }
2860 }
2861 return PartialMatch;
2862}
2863
2864// Check availability against target shader model version and current shader
2865// stage and emit diagnostic
2866void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2867 const AvailabilityAttr *AA,
2868 SourceRange Range) {
2869
2870 const IdentifierInfo *IIEnv = AA->getEnvironment();
2871
2872 if (!IIEnv) {
2873 // The availability attribute does not have environment -> it depends only
2874 // on shader model version and not on specific the shader stage.
2875
2876 // Skip emitting the diagnostics if the diagnostic mode is set to
2877 // strict (-fhlsl-strict-availability) because all relevant diagnostics
2878 // were already emitted in the DiagnoseUnguardedAvailability scan
2879 // (SemaAvailability.cpp).
2880 if (SemaRef.getLangOpts().HLSLStrictAvailability)
2881 return;
2882
2883 // Do not report shader-stage-independent issues if scanning a function
2884 // that was already scanned in a different shader stage context (they would
2885 // be duplicate)
2886 if (ReportOnlyShaderStageIssues)
2887 return;
2888
2889 } else {
2890 // The availability attribute has environment -> we need to know
2891 // the current stage context to property diagnose it.
2892 if (InUnknownShaderStageContext())
2893 return;
2894 }
2895
2896 // Check introduced version and if environment matches
2897 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
2898 VersionTuple Introduced = AA->getIntroduced();
2899 VersionTuple TargetVersion =
2900 SemaRef.Context.getTargetInfo().getPlatformMinVersion();
2901
2902 if (TargetVersion >= Introduced && EnvironmentMatches)
2903 return;
2904
2905 // Emit diagnostic message
2906 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2907 llvm::StringRef PlatformName(
2908 AvailabilityAttr::getPrettyPlatformName(Platform: TI.getPlatformName()));
2909
2910 llvm::StringRef CurrentEnvStr =
2911 llvm::Triple::getEnvironmentTypeName(Kind: GetCurrentShaderEnvironment());
2912
2913 llvm::StringRef AttrEnvStr =
2914 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
2915 bool UseEnvironment = !AttrEnvStr.empty();
2916
2917 if (EnvironmentMatches) {
2918 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability)
2919 << Range << D << PlatformName << Introduced.getAsString()
2920 << UseEnvironment << CurrentEnvStr;
2921 } else {
2922 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability_unavailable)
2923 << Range << D;
2924 }
2925
2926 SemaRef.Diag(Loc: D->getLocation(), DiagID: diag::note_partial_availability_specified_here)
2927 << D << PlatformName << Introduced.getAsString()
2928 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
2929 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
2930}
2931
2932} // namespace
2933
2934void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) {
2935 // process default CBuffer - create buffer layout struct and invoke codegenCGH
2936 if (!DefaultCBufferDecls.empty()) {
2937 HLSLBufferDecl *DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer(
2938 C&: SemaRef.getASTContext(), LexicalParent: SemaRef.getCurLexicalContext(),
2939 DefaultCBufferDecls);
2940 addImplicitBindingAttrToDecl(S&: SemaRef, D: DefaultCBuffer, RT: RegisterType::CBuffer,
2941 ImplicitBindingOrderID: getNextImplicitBindingOrderID());
2942 SemaRef.getCurLexicalContext()->addDecl(D: DefaultCBuffer);
2943 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl: DefaultCBuffer);
2944
2945 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
2946 for (const Decl *VD : DefaultCBufferDecls) {
2947 const HLSLResourceBindingAttr *RBA =
2948 VD->getAttr<HLSLResourceBindingAttr>();
2949 if (RBA && RBA->hasRegisterSlot() &&
2950 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
2951 DefaultCBuffer->setHasValidPackoffset(true);
2952 break;
2953 }
2954 }
2955
2956 DeclGroupRef DG(DefaultCBuffer);
2957 SemaRef.Consumer.HandleTopLevelDecl(D: DG);
2958 }
2959 diagnoseAvailabilityViolations(TU);
2960}
2961
2962void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
2963 // Skip running the diagnostics scan if the diagnostic mode is
2964 // strict (-fhlsl-strict-availability) and the target shader stage is known
2965 // because all relevant diagnostics were already emitted in the
2966 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
2967 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2968 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
2969 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
2970 return;
2971
2972 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
2973}
2974
2975static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
2976 assert(TheCall->getNumArgs() > 1);
2977 QualType ArgTy0 = TheCall->getArg(Arg: 0)->getType();
2978
2979 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
2980 if (!S->getASTContext().hasSameUnqualifiedType(
2981 T1: ArgTy0, T2: TheCall->getArg(Arg: I)->getType())) {
2982 S->Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vec_builtin_incompatible_vector)
2983 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
2984 << SourceRange(TheCall->getArg(Arg: 0)->getBeginLoc(),
2985 TheCall->getArg(Arg: N - 1)->getEndLoc());
2986 return true;
2987 }
2988 }
2989 return false;
2990}
2991
2992static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
2993 QualType ArgType = Arg->getType();
2994 if (!S->getASTContext().hasSameUnqualifiedType(T1: ArgType, T2: ExpectedType)) {
2995 S->Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_typecheck_convert_incompatible)
2996 << ArgType << ExpectedType << 1 << 0 << 0;
2997 return true;
2998 }
2999 return false;
3000}
3001
3002static bool CheckAllArgTypesAreCorrect(
3003 Sema *S, CallExpr *TheCall,
3004 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
3005 clang::QualType PassedType)>
3006 Check) {
3007 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
3008 Expr *Arg = TheCall->getArg(Arg: I);
3009 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
3010 return true;
3011 }
3012 return false;
3013}
3014
3015static bool CheckFloatRepresentation(Sema *S, SourceLocation Loc,
3016 int ArgOrdinal,
3017 clang::QualType PassedType) {
3018 clang::QualType BaseType =
3019 PassedType->isVectorType()
3020 ? PassedType->castAs<clang::VectorType>()->getElementType()
3021 : PassedType;
3022 if (!BaseType->isFloat32Type())
3023 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3024 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3025 << /* float */ 1 << PassedType;
3026 return false;
3027}
3028
3029static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
3030 int ArgOrdinal,
3031 clang::QualType PassedType) {
3032 clang::QualType BaseType =
3033 PassedType->isVectorType()
3034 ? PassedType->castAs<clang::VectorType>()->getElementType()
3035 : PassedType;
3036 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
3037 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3038 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3039 << /* half or float */ 2 << PassedType;
3040 return false;
3041}
3042
3043static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
3044 unsigned ArgIndex) {
3045 auto *Arg = TheCall->getArg(Arg: ArgIndex);
3046 SourceLocation OrigLoc = Arg->getExprLoc();
3047 if (Arg->IgnoreCasts()->isModifiableLvalue(Ctx&: S->Context, Loc: &OrigLoc) ==
3048 Expr::MLV_Valid)
3049 return false;
3050 S->Diag(Loc: OrigLoc, DiagID: diag::error_hlsl_inout_lvalue) << Arg << 0;
3051 return true;
3052}
3053
3054static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
3055 clang::QualType PassedType) {
3056 const auto *VecTy = PassedType->getAs<VectorType>();
3057 if (!VecTy)
3058 return false;
3059
3060 if (VecTy->getElementType()->isDoubleType())
3061 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3062 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
3063 << PassedType;
3064 return false;
3065}
3066
3067static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
3068 int ArgOrdinal,
3069 clang::QualType PassedType) {
3070 if (!PassedType->hasIntegerRepresentation() &&
3071 !PassedType->hasFloatingRepresentation())
3072 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3073 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
3074 << /* fp */ 1 << PassedType;
3075 return false;
3076}
3077
3078static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
3079 int ArgOrdinal,
3080 clang::QualType PassedType) {
3081 if (auto *VecTy = PassedType->getAs<VectorType>())
3082 if (VecTy->getElementType()->isUnsignedIntegerType())
3083 return false;
3084
3085 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3086 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
3087 << PassedType;
3088}
3089
3090// checks for unsigned ints of all sizes
3091static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
3092 int ArgOrdinal,
3093 clang::QualType PassedType) {
3094 if (!PassedType->hasUnsignedIntegerRepresentation())
3095 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3096 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
3097 << /* no fp */ 0 << PassedType;
3098 return false;
3099}
3100
3101static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall,
3102 unsigned ArgOrdinal, unsigned Width) {
3103 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
3104 if (auto *VTy = ArgTy->getAs<VectorType>())
3105 ArgTy = VTy->getElementType();
3106 // ensure arg type has expected bit width
3107 uint64_t ElementBitCount =
3108 S->getASTContext().getTypeSizeInChars(T: ArgTy).getQuantity() * 8;
3109 if (ElementBitCount != Width) {
3110 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3111 DiagID: diag::err_integer_incorrect_bit_count)
3112 << Width << ElementBitCount;
3113 return true;
3114 }
3115 return false;
3116}
3117
3118static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
3119 QualType ReturnType) {
3120 auto *VecTyA = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
3121 if (VecTyA)
3122 ReturnType =
3123 S->Context.getExtVectorType(VectorType: ReturnType, NumElts: VecTyA->getNumElements());
3124
3125 TheCall->setType(ReturnType);
3126}
3127
3128static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
3129 unsigned ArgIndex) {
3130 assert(TheCall->getNumArgs() >= ArgIndex);
3131 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3132 auto *VTy = ArgType->getAs<VectorType>();
3133 // not the scalar or vector<scalar>
3134 if (!(S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar) ||
3135 (VTy &&
3136 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar)))) {
3137 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3138 DiagID: diag::err_typecheck_expect_scalar_or_vector)
3139 << ArgType << Scalar;
3140 return true;
3141 }
3142 return false;
3143}
3144
3145static bool CheckScalarOrVectorOrMatrix(Sema *S, CallExpr *TheCall,
3146 QualType Scalar, unsigned ArgIndex) {
3147 assert(TheCall->getNumArgs() > ArgIndex);
3148
3149 Expr *Arg = TheCall->getArg(Arg: ArgIndex);
3150 QualType ArgType = Arg->getType();
3151
3152 // Scalar: T
3153 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar))
3154 return false;
3155
3156 // Vector: vector<T>
3157 if (const auto *VTy = ArgType->getAs<VectorType>()) {
3158 if (S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar))
3159 return false;
3160 }
3161
3162 // Matrix: ConstantMatrixType with element type T
3163 if (const auto *MTy = ArgType->getAs<ConstantMatrixType>()) {
3164 if (S->Context.hasSameUnqualifiedType(T1: MTy->getElementType(), T2: Scalar))
3165 return false;
3166 }
3167
3168 // Not a scalar/vector/matrix-of-scalar
3169 S->Diag(Loc: Arg->getBeginLoc(),
3170 DiagID: diag::err_typecheck_expect_scalar_or_vector_or_matrix)
3171 << ArgType << Scalar;
3172 return true;
3173}
3174
3175static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
3176 unsigned ArgIndex) {
3177 assert(TheCall->getNumArgs() >= ArgIndex);
3178 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3179 auto *VTy = ArgType->getAs<VectorType>();
3180 // not the scalar or vector<scalar>
3181 if (!(ArgType->isScalarType() ||
3182 (VTy && VTy->getElementType()->isScalarType()))) {
3183 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3184 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3185 << ArgType << 1;
3186 return true;
3187 }
3188 return false;
3189}
3190
3191// Check that the argument is not a bool or vector<bool>
3192// Returns true on error
3193static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall,
3194 unsigned ArgIndex) {
3195 QualType BoolType = S->getASTContext().BoolTy;
3196 assert(ArgIndex < TheCall->getNumArgs());
3197 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3198 auto *VTy = ArgType->getAs<VectorType>();
3199 // is the bool or vector<bool>
3200 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: BoolType) ||
3201 (VTy &&
3202 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: BoolType))) {
3203 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3204 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3205 << ArgType << 0;
3206 return true;
3207 }
3208 return false;
3209}
3210
3211static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
3212 if (CheckNotBoolScalarOrVector(S, TheCall, ArgIndex: 0))
3213 return true;
3214 return false;
3215}
3216
3217static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
3218 if (CheckNotBoolScalarOrVector(S, TheCall, ArgIndex: 0))
3219 return true;
3220 return false;
3221}
3222
3223static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
3224 assert(TheCall->getNumArgs() == 3);
3225 Expr *Arg1 = TheCall->getArg(Arg: 1);
3226 Expr *Arg2 = TheCall->getArg(Arg: 2);
3227 if (!S->Context.hasSameUnqualifiedType(T1: Arg1->getType(), T2: Arg2->getType())) {
3228 S->Diag(Loc: TheCall->getBeginLoc(),
3229 DiagID: diag::err_typecheck_call_different_arg_types)
3230 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
3231 << Arg2->getSourceRange();
3232 return true;
3233 }
3234
3235 TheCall->setType(Arg1->getType());
3236 return false;
3237}
3238
3239static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
3240 assert(TheCall->getNumArgs() == 3);
3241 Expr *Arg1 = TheCall->getArg(Arg: 1);
3242 QualType Arg1Ty = Arg1->getType();
3243 Expr *Arg2 = TheCall->getArg(Arg: 2);
3244 QualType Arg2Ty = Arg2->getType();
3245
3246 QualType Arg1ScalarTy = Arg1Ty;
3247 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
3248 Arg1ScalarTy = VTy->getElementType();
3249
3250 QualType Arg2ScalarTy = Arg2Ty;
3251 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
3252 Arg2ScalarTy = VTy->getElementType();
3253
3254 if (!S->Context.hasSameUnqualifiedType(T1: Arg1ScalarTy, T2: Arg2ScalarTy))
3255 S->Diag(Loc: Arg1->getBeginLoc(), DiagID: diag::err_hlsl_builtin_scalar_vector_mismatch)
3256 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
3257
3258 QualType Arg0Ty = TheCall->getArg(Arg: 0)->getType();
3259 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
3260 unsigned Arg1Length = Arg1Ty->isVectorType()
3261 ? Arg1Ty->getAs<VectorType>()->getNumElements()
3262 : 0;
3263 unsigned Arg2Length = Arg2Ty->isVectorType()
3264 ? Arg2Ty->getAs<VectorType>()->getNumElements()
3265 : 0;
3266 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
3267 S->Diag(Loc: TheCall->getBeginLoc(),
3268 DiagID: diag::err_typecheck_vector_lengths_not_equal)
3269 << Arg0Ty << Arg1Ty << TheCall->getArg(Arg: 0)->getSourceRange()
3270 << Arg1->getSourceRange();
3271 return true;
3272 }
3273
3274 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
3275 S->Diag(Loc: TheCall->getBeginLoc(),
3276 DiagID: diag::err_typecheck_vector_lengths_not_equal)
3277 << Arg0Ty << Arg2Ty << TheCall->getArg(Arg: 0)->getSourceRange()
3278 << Arg2->getSourceRange();
3279 return true;
3280 }
3281
3282 TheCall->setType(
3283 S->getASTContext().getExtVectorType(VectorType: Arg1ScalarTy, NumElts: Arg0Length));
3284 return false;
3285}
3286
3287static bool CheckResourceHandle(
3288 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
3289 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
3290 nullptr) {
3291 assert(TheCall->getNumArgs() >= ArgIndex);
3292 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3293 const HLSLAttributedResourceType *ResTy =
3294 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
3295 if (!ResTy) {
3296 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getBeginLoc(),
3297 DiagID: diag::err_typecheck_expect_hlsl_resource)
3298 << ArgType;
3299 return true;
3300 }
3301 if (Check && Check(ResTy)) {
3302 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getExprLoc(),
3303 DiagID: diag::err_invalid_hlsl_resource_type)
3304 << ArgType;
3305 return true;
3306 }
3307 return false;
3308}
3309
3310static bool CheckVectorElementCount(Sema *S, QualType PassedType,
3311 QualType BaseType, unsigned ExpectedCount,
3312 SourceLocation Loc) {
3313 unsigned PassedCount = 1;
3314 if (const auto *VecTy = PassedType->getAs<VectorType>())
3315 PassedCount = VecTy->getNumElements();
3316
3317 if (PassedCount != ExpectedCount) {
3318 QualType ExpectedType =
3319 S->Context.getExtVectorType(VectorType: BaseType, NumElts: ExpectedCount);
3320 S->Diag(Loc, DiagID: diag::err_typecheck_convert_incompatible)
3321 << PassedType << ExpectedType << 1 << 0 << 0;
3322 return true;
3323 }
3324 return false;
3325}
3326
3327enum class SampleKind { Sample, Bias, Grad, Level, Cmp, CmpLevelZero };
3328
3329static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind) {
3330 unsigned MinArgs, MaxArgs;
3331 if (Kind == SampleKind::Sample) {
3332 MinArgs = 3;
3333 MaxArgs = 5;
3334 } else if (Kind == SampleKind::Bias) {
3335 MinArgs = 4;
3336 MaxArgs = 6;
3337 } else if (Kind == SampleKind::Grad) {
3338 MinArgs = 5;
3339 MaxArgs = 7;
3340 } else if (Kind == SampleKind::Level) {
3341 MinArgs = 4;
3342 MaxArgs = 5;
3343 } else if (Kind == SampleKind::Cmp) {
3344 MinArgs = 4;
3345 MaxArgs = 6;
3346 } else {
3347 assert(Kind == SampleKind::CmpLevelZero);
3348 MinArgs = 4;
3349 MaxArgs = 5;
3350 }
3351
3352 if (S.checkArgCountRange(Call: TheCall, MinArgCount: MinArgs, MaxArgCount: MaxArgs))
3353 return true;
3354
3355 // Check the texture handle.
3356 if (CheckResourceHandle(S: &S, TheCall, ArgIndex: 0,
3357 Check: [](const HLSLAttributedResourceType *ResType) {
3358 return ResType->getAttrs().ResourceDimension ==
3359 llvm::dxil::ResourceDimension::Unknown;
3360 }))
3361 return true;
3362
3363 // Check the sampler handle.
3364 if (CheckResourceHandle(S: &S, TheCall, ArgIndex: 1,
3365 Check: [](const HLSLAttributedResourceType *ResType) {
3366 return ResType->getAttrs().ResourceClass !=
3367 llvm::hlsl::ResourceClass::Sampler;
3368 }))
3369 return true;
3370
3371 auto *ResourceTy =
3372 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3373
3374 // Check the location.
3375 unsigned ExpectedDim =
3376 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3377 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: 2)->getType(),
3378 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3379 Loc: TheCall->getBeginLoc()))
3380 return true;
3381
3382 unsigned NextIdx = 3;
3383 if (Kind == SampleKind::Bias || Kind == SampleKind::Level ||
3384 Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3385 // Check the bias, lod level, or compare value, depending on the kind.
3386 // All of them must be a scalar float value.
3387 QualType BiasOrLODOrCmpTy = TheCall->getArg(Arg: NextIdx)->getType();
3388 if (!BiasOrLODOrCmpTy->isFloatingType() ||
3389 BiasOrLODOrCmpTy->isVectorType()) {
3390 S.Diag(Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc(),
3391 DiagID: diag::err_typecheck_convert_incompatible)
3392 << BiasOrLODOrCmpTy << S.Context.FloatTy << 1 << 0 << 0;
3393 return true;
3394 }
3395 NextIdx++;
3396 } else if (Kind == SampleKind::Grad) {
3397 // Check the DDX operand.
3398 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx)->getType(),
3399 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3400 Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc()))
3401 return true;
3402
3403 // Check the DDY operand.
3404 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx + 1)->getType(),
3405 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3406 Loc: TheCall->getArg(Arg: NextIdx + 1)->getBeginLoc()))
3407 return true;
3408 NextIdx += 2;
3409 }
3410
3411 // Check the offset operand.
3412 if (TheCall->getNumArgs() > NextIdx) {
3413 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx)->getType(),
3414 BaseType: S.Context.IntTy, ExpectedCount: ExpectedDim,
3415 Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc()))
3416 return true;
3417 NextIdx++;
3418 }
3419
3420 // Check the clamp operand.
3421 if (Kind != SampleKind::Level && Kind != SampleKind::CmpLevelZero &&
3422 TheCall->getNumArgs() > NextIdx) {
3423 QualType ClampTy = TheCall->getArg(Arg: NextIdx)->getType();
3424 if (!ClampTy->isFloatingType() || ClampTy->isVectorType()) {
3425 S.Diag(Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc(),
3426 DiagID: diag::err_typecheck_convert_incompatible)
3427 << ClampTy << S.Context.FloatTy << 1 << 0 << 0;
3428 return true;
3429 }
3430 }
3431
3432 assert(ResourceTy->hasContainedType() &&
3433 "Expecting a contained type for resource with a dimension "
3434 "attribute.");
3435 QualType ReturnType = ResourceTy->getContainedType();
3436 if (Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3437 if (!ReturnType->hasFloatingRepresentation()) {
3438 S.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_hlsl_samplecmp_requires_float);
3439 return true;
3440 }
3441 ReturnType = S.Context.FloatTy;
3442 }
3443 TheCall->setType(ReturnType);
3444
3445 return false;
3446}
3447
3448// Note: returning true in this case results in CheckBuiltinFunctionCall
3449// returning an ExprError
3450bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
3451 switch (BuiltinID) {
3452 case Builtin::BI__builtin_hlsl_adduint64: {
3453 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3454 return true;
3455
3456 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3457 Check: CheckUnsignedIntVecRepresentation))
3458 return true;
3459
3460 // ensure arg integers are 32-bits
3461 if (CheckExpectedBitWidth(S: &SemaRef, TheCall, ArgOrdinal: 0, Width: 32))
3462 return true;
3463
3464 // ensure both args are vectors of total bit size of a multiple of 64
3465 auto *VTy = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
3466 int NumElementsArg = VTy->getNumElements();
3467 if (NumElementsArg != 2 && NumElementsArg != 4) {
3468 SemaRef.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vector_incorrect_bit_count)
3469 << 1 /*a multiple of*/ << 64 << NumElementsArg * 32;
3470 return true;
3471 }
3472
3473 // ensure first arg and second arg have the same type
3474 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3475 return true;
3476
3477 ExprResult A = TheCall->getArg(Arg: 0);
3478 QualType ArgTyA = A.get()->getType();
3479 // return type is the same as the input type
3480 TheCall->setType(ArgTyA);
3481 break;
3482 }
3483 case Builtin::BI__builtin_hlsl_resource_getpointer: {
3484 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2) ||
3485 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3486 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3487 ExpectedType: SemaRef.getASTContext().UnsignedIntTy))
3488 return true;
3489
3490 auto *ResourceTy =
3491 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3492 QualType ContainedTy = ResourceTy->getContainedType();
3493 auto ReturnType =
3494 SemaRef.Context.getAddrSpaceQualType(T: ContainedTy, AddressSpace: LangAS::hlsl_device);
3495 ReturnType = SemaRef.Context.getPointerType(T: ReturnType);
3496 TheCall->setType(ReturnType);
3497 TheCall->setValueKind(VK_LValue);
3498
3499 break;
3500 }
3501 case Builtin::BI__builtin_hlsl_resource_getpointer_typed: {
3502 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3) ||
3503 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3504 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3505 ExpectedType: SemaRef.getASTContext().UnsignedIntTy))
3506 return true;
3507
3508 QualType ElementTy = TheCall->getArg(Arg: 2)->getType();
3509 assert(ElementTy->isPointerType() &&
3510 "expected pointer type for second argument");
3511 ElementTy = ElementTy->getPointeeType();
3512
3513 // Reject array types
3514 if (ElementTy->isArrayType())
3515 return SemaRef.Diag(
3516 Loc: cast<FunctionDecl>(Val: SemaRef.CurContext)->getPointOfInstantiation(),
3517 DiagID: diag::err_invalid_use_of_array_type);
3518
3519 auto ReturnType =
3520 SemaRef.Context.getAddrSpaceQualType(T: ElementTy, AddressSpace: LangAS::hlsl_device);
3521 ReturnType = SemaRef.Context.getPointerType(T: ReturnType);
3522 TheCall->setType(ReturnType);
3523
3524 break;
3525 }
3526 case Builtin::BI__builtin_hlsl_resource_load_with_status: {
3527 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3) ||
3528 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3529 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3530 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3531 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2),
3532 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3533 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
3534 return true;
3535
3536 auto *ResourceTy =
3537 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3538 QualType ReturnType = ResourceTy->getContainedType();
3539 TheCall->setType(ReturnType);
3540
3541 break;
3542 }
3543 case Builtin::BI__builtin_hlsl_resource_load_with_status_typed: {
3544 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 4) ||
3545 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3546 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3547 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3548 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2),
3549 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3550 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
3551 return true;
3552
3553 QualType ReturnType = TheCall->getArg(Arg: 3)->getType();
3554 assert(ReturnType->isPointerType() &&
3555 "expected pointer type for second argument");
3556 ReturnType = ReturnType->getPointeeType();
3557
3558 // Reject array types
3559 if (ReturnType->isArrayType())
3560 return SemaRef.Diag(
3561 Loc: cast<FunctionDecl>(Val: SemaRef.CurContext)->getPointOfInstantiation(),
3562 DiagID: diag::err_invalid_use_of_array_type);
3563
3564 TheCall->setType(ReturnType);
3565
3566 break;
3567 }
3568 case Builtin::BI__builtin_hlsl_resource_sample:
3569 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Sample);
3570 case Builtin::BI__builtin_hlsl_resource_sample_bias:
3571 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Bias);
3572 case Builtin::BI__builtin_hlsl_resource_sample_grad:
3573 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Grad);
3574 case Builtin::BI__builtin_hlsl_resource_sample_level:
3575 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Level);
3576 case Builtin::BI__builtin_hlsl_resource_sample_cmp:
3577 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Cmp);
3578 case Builtin::BI__builtin_hlsl_resource_sample_cmp_level_zero:
3579 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::CmpLevelZero);
3580 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
3581 assert(TheCall->getNumArgs() == 1 && "expected 1 arg");
3582 // Update return type to be the attributed resource type from arg0.
3583 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3584 TheCall->setType(ResourceTy);
3585 break;
3586 }
3587 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
3588 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3589 // Update return type to be the attributed resource type from arg0.
3590 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3591 TheCall->setType(ResourceTy);
3592 break;
3593 }
3594 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
3595 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3596 // Update return type to be the attributed resource type from arg0.
3597 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3598 TheCall->setType(ResourceTy);
3599 break;
3600 }
3601 case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: {
3602 assert(TheCall->getNumArgs() == 3 && "expected 3 args");
3603 ASTContext &AST = SemaRef.getASTContext();
3604 QualType MainHandleTy = TheCall->getArg(Arg: 0)->getType();
3605 auto *MainResType = MainHandleTy->getAs<HLSLAttributedResourceType>();
3606 auto MainAttrs = MainResType->getAttrs();
3607 assert(!MainAttrs.IsCounter && "cannot create a counter from a counter");
3608 MainAttrs.IsCounter = true;
3609 QualType CounterHandleTy = AST.getHLSLAttributedResourceType(
3610 Wrapped: MainResType->getWrappedType(), Contained: MainResType->getContainedType(),
3611 Attrs: MainAttrs);
3612 // Update return type to be the attributed resource type from arg0
3613 // with added IsCounter flag.
3614 TheCall->setType(CounterHandleTy);
3615 break;
3616 }
3617 case Builtin::BI__builtin_hlsl_and:
3618 case Builtin::BI__builtin_hlsl_or: {
3619 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3620 return true;
3621 if (CheckScalarOrVectorOrMatrix(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy,
3622 ArgIndex: 0))
3623 return true;
3624 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3625 return true;
3626
3627 ExprResult A = TheCall->getArg(Arg: 0);
3628 QualType ArgTyA = A.get()->getType();
3629 // return type is the same as the input type
3630 TheCall->setType(ArgTyA);
3631 break;
3632 }
3633 case Builtin::BI__builtin_hlsl_all:
3634 case Builtin::BI__builtin_hlsl_any: {
3635 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3636 return true;
3637 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3638 return true;
3639 break;
3640 }
3641 case Builtin::BI__builtin_hlsl_asdouble: {
3642 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3643 return true;
3644 if (CheckScalarOrVector(
3645 S: &SemaRef, TheCall,
3646 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
3647 /* arg index */ ArgIndex: 0))
3648 return true;
3649 if (CheckScalarOrVector(
3650 S: &SemaRef, TheCall,
3651 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
3652 /* arg index */ ArgIndex: 1))
3653 return true;
3654 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3655 return true;
3656
3657 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().DoubleTy);
3658 break;
3659 }
3660 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
3661 if (SemaRef.BuiltinElementwiseTernaryMath(
3662 TheCall, /*ArgTyRestr=*/
3663 Sema::EltwiseBuiltinArgTyRestriction::None))
3664 return true;
3665 break;
3666 }
3667 case Builtin::BI__builtin_hlsl_dot: {
3668 // arg count is checked by BuiltinVectorToScalarMath
3669 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
3670 return true;
3671 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckNoDoubleVectors))
3672 return true;
3673 break;
3674 }
3675 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
3676 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
3677 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3678 return true;
3679
3680 const Expr *Arg = TheCall->getArg(Arg: 0);
3681 QualType ArgTy = Arg->getType();
3682 QualType EltTy = ArgTy;
3683
3684 QualType ResTy = SemaRef.Context.UnsignedIntTy;
3685
3686 if (auto *VecTy = EltTy->getAs<VectorType>()) {
3687 EltTy = VecTy->getElementType();
3688 ResTy = SemaRef.Context.getExtVectorType(VectorType: ResTy, NumElts: VecTy->getNumElements());
3689 }
3690
3691 if (!EltTy->isIntegerType()) {
3692 Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
3693 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
3694 << /* no fp */ 0 << ArgTy;
3695 return true;
3696 }
3697
3698 TheCall->setType(ResTy);
3699 break;
3700 }
3701 case Builtin::BI__builtin_hlsl_select: {
3702 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
3703 return true;
3704 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy, ArgIndex: 0))
3705 return true;
3706 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
3707 if (ArgTy->isBooleanType() && CheckBoolSelect(S: &SemaRef, TheCall))
3708 return true;
3709 auto *VTy = ArgTy->getAs<VectorType>();
3710 if (VTy && VTy->getElementType()->isBooleanType() &&
3711 CheckVectorSelect(S: &SemaRef, TheCall))
3712 return true;
3713 break;
3714 }
3715 case Builtin::BI__builtin_hlsl_elementwise_saturate:
3716 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
3717 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3718 return true;
3719 if (!TheCall->getArg(Arg: 0)
3720 ->getType()
3721 ->hasFloatingRepresentation()) // half or float or double
3722 return SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3723 DiagID: diag::err_builtin_invalid_arg_type)
3724 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
3725 << /* fp */ 1 << TheCall->getArg(Arg: 0)->getType();
3726 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3727 return true;
3728 break;
3729 }
3730 case Builtin::BI__builtin_hlsl_elementwise_degrees:
3731 case Builtin::BI__builtin_hlsl_elementwise_radians:
3732 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
3733 case Builtin::BI__builtin_hlsl_elementwise_frac:
3734 case Builtin::BI__builtin_hlsl_elementwise_ddx_coarse:
3735 case Builtin::BI__builtin_hlsl_elementwise_ddy_coarse:
3736 case Builtin::BI__builtin_hlsl_elementwise_ddx_fine:
3737 case Builtin::BI__builtin_hlsl_elementwise_ddy_fine: {
3738 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3739 return true;
3740 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3741 Check: CheckFloatOrHalfRepresentation))
3742 return true;
3743 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3744 return true;
3745 break;
3746 }
3747 case Builtin::BI__builtin_hlsl_elementwise_isinf:
3748 case Builtin::BI__builtin_hlsl_elementwise_isnan: {
3749 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3750 return true;
3751 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3752 Check: CheckFloatOrHalfRepresentation))
3753 return true;
3754 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3755 return true;
3756 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().BoolTy);
3757 break;
3758 }
3759 case Builtin::BI__builtin_hlsl_lerp: {
3760 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
3761 return true;
3762 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3763 Check: CheckFloatOrHalfRepresentation))
3764 return true;
3765 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3766 return true;
3767 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
3768 return true;
3769 break;
3770 }
3771 case Builtin::BI__builtin_hlsl_mad: {
3772 if (SemaRef.BuiltinElementwiseTernaryMath(
3773 TheCall, /*ArgTyRestr=*/
3774 Sema::EltwiseBuiltinArgTyRestriction::None))
3775 return true;
3776 break;
3777 }
3778 case Builtin::BI__builtin_hlsl_normalize: {
3779 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3780 return true;
3781 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3782 Check: CheckFloatOrHalfRepresentation))
3783 return true;
3784 ExprResult A = TheCall->getArg(Arg: 0);
3785 QualType ArgTyA = A.get()->getType();
3786 // return type is the same as the input type
3787 TheCall->setType(ArgTyA);
3788 break;
3789 }
3790 case Builtin::BI__builtin_hlsl_elementwise_sign: {
3791 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3792 return true;
3793 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3794 Check: CheckFloatingOrIntRepresentation))
3795 return true;
3796 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().IntTy);
3797 break;
3798 }
3799 case Builtin::BI__builtin_hlsl_step: {
3800 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3801 return true;
3802 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3803 Check: CheckFloatOrHalfRepresentation))
3804 return true;
3805
3806 ExprResult A = TheCall->getArg(Arg: 0);
3807 QualType ArgTyA = A.get()->getType();
3808 // return type is the same as the input type
3809 TheCall->setType(ArgTyA);
3810 break;
3811 }
3812 case Builtin::BI__builtin_hlsl_wave_active_max:
3813 case Builtin::BI__builtin_hlsl_wave_active_min:
3814 case Builtin::BI__builtin_hlsl_wave_active_sum: {
3815 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3816 return true;
3817
3818 // Ensure input expr type is a scalar/vector and the same as the return type
3819 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3820 return true;
3821 if (CheckWaveActive(S: &SemaRef, TheCall))
3822 return true;
3823 ExprResult Expr = TheCall->getArg(Arg: 0);
3824 QualType ArgTyExpr = Expr.get()->getType();
3825 TheCall->setType(ArgTyExpr);
3826 break;
3827 }
3828 // Note these are llvm builtins that we want to catch invalid intrinsic
3829 // generation. Normal handling of these builtins will occur elsewhere.
3830 case Builtin::BI__builtin_elementwise_bitreverse: {
3831 // does not include a check for number of arguments
3832 // because that is done previously
3833 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3834 Check: CheckUnsignedIntRepresentation))
3835 return true;
3836 break;
3837 }
3838 case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: {
3839 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3840 return true;
3841
3842 QualType ArgType = TheCall->getArg(Arg: 0)->getType();
3843
3844 if (!(ArgType->isScalarType())) {
3845 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3846 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3847 << ArgType << 0;
3848 return true;
3849 }
3850
3851 if (!(ArgType->isBooleanType())) {
3852 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3853 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3854 << ArgType << 0;
3855 return true;
3856 }
3857
3858 break;
3859 }
3860 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
3861 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3862 return true;
3863
3864 // Ensure index parameter type can be interpreted as a uint
3865 ExprResult Index = TheCall->getArg(Arg: 1);
3866 QualType ArgTyIndex = Index.get()->getType();
3867 if (!ArgTyIndex->isIntegerType()) {
3868 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
3869 DiagID: diag::err_typecheck_convert_incompatible)
3870 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
3871 return true;
3872 }
3873
3874 // Ensure input expr type is a scalar/vector and the same as the return type
3875 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3876 return true;
3877
3878 ExprResult Expr = TheCall->getArg(Arg: 0);
3879 QualType ArgTyExpr = Expr.get()->getType();
3880 TheCall->setType(ArgTyExpr);
3881 break;
3882 }
3883 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
3884 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 0))
3885 return true;
3886 break;
3887 }
3888 case Builtin::BI__builtin_hlsl_wave_prefix_sum:
3889 case Builtin::BI__builtin_hlsl_wave_prefix_product: {
3890 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3891 return true;
3892
3893 // Ensure input expr type is a scalar/vector and the same as the return type
3894 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3895 return true;
3896 if (CheckWavePrefix(S: &SemaRef, TheCall))
3897 return true;
3898 ExprResult Expr = TheCall->getArg(Arg: 0);
3899 QualType ArgTyExpr = Expr.get()->getType();
3900 TheCall->setType(ArgTyExpr);
3901 break;
3902 }
3903 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
3904 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
3905 return true;
3906
3907 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.DoubleTy, ArgIndex: 0) ||
3908 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
3909 ArgIndex: 1) ||
3910 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
3911 ArgIndex: 2))
3912 return true;
3913
3914 if (CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 1) ||
3915 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
3916 return true;
3917 break;
3918 }
3919 case Builtin::BI__builtin_hlsl_elementwise_clip: {
3920 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3921 return true;
3922
3923 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.FloatTy, ArgIndex: 0))
3924 return true;
3925 break;
3926 }
3927 case Builtin::BI__builtin_elementwise_acos:
3928 case Builtin::BI__builtin_elementwise_asin:
3929 case Builtin::BI__builtin_elementwise_atan:
3930 case Builtin::BI__builtin_elementwise_atan2:
3931 case Builtin::BI__builtin_elementwise_ceil:
3932 case Builtin::BI__builtin_elementwise_cos:
3933 case Builtin::BI__builtin_elementwise_cosh:
3934 case Builtin::BI__builtin_elementwise_exp:
3935 case Builtin::BI__builtin_elementwise_exp2:
3936 case Builtin::BI__builtin_elementwise_exp10:
3937 case Builtin::BI__builtin_elementwise_floor:
3938 case Builtin::BI__builtin_elementwise_fmod:
3939 case Builtin::BI__builtin_elementwise_log:
3940 case Builtin::BI__builtin_elementwise_log2:
3941 case Builtin::BI__builtin_elementwise_log10:
3942 case Builtin::BI__builtin_elementwise_pow:
3943 case Builtin::BI__builtin_elementwise_roundeven:
3944 case Builtin::BI__builtin_elementwise_sin:
3945 case Builtin::BI__builtin_elementwise_sinh:
3946 case Builtin::BI__builtin_elementwise_sqrt:
3947 case Builtin::BI__builtin_elementwise_tan:
3948 case Builtin::BI__builtin_elementwise_tanh:
3949 case Builtin::BI__builtin_elementwise_trunc: {
3950 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3951 Check: CheckFloatOrHalfRepresentation))
3952 return true;
3953 break;
3954 }
3955 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
3956 assert(TheCall->getNumArgs() == 2 && "expected 2 args");
3957 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
3958 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
3959 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
3960 };
3961 if (CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0, Check: checkResTy))
3962 return true;
3963 Expr *OffsetExpr = TheCall->getArg(Arg: 1);
3964 std::optional<llvm::APSInt> Offset =
3965 OffsetExpr->getIntegerConstantExpr(Ctx: SemaRef.getASTContext());
3966 if (!Offset.has_value() || std::abs(i: Offset->getExtValue()) != 1) {
3967 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
3968 DiagID: diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
3969 << 1;
3970 return true;
3971 }
3972 break;
3973 }
3974 case Builtin::BI__builtin_hlsl_elementwise_f16tof32: {
3975 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3976 return true;
3977 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3978 Check: CheckUnsignedIntRepresentation))
3979 return true;
3980 // ensure arg integers are 32 bits
3981 if (CheckExpectedBitWidth(S: &SemaRef, TheCall, ArgOrdinal: 0, Width: 32))
3982 return true;
3983 // check it wasn't a bool type
3984 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
3985 if (auto *VTy = ArgTy->getAs<VectorType>())
3986 ArgTy = VTy->getElementType();
3987 if (ArgTy->isBooleanType()) {
3988 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3989 DiagID: diag::err_builtin_invalid_arg_type)
3990 << 1 << /* scalar or vector of */ 5 << /* unsigned int */ 3
3991 << /* no fp */ 0 << TheCall->getArg(Arg: 0)->getType();
3992 return true;
3993 }
3994
3995 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().FloatTy);
3996 break;
3997 }
3998 case Builtin::BI__builtin_hlsl_elementwise_f32tof16: {
3999 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4000 return true;
4001 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckFloatRepresentation))
4002 return true;
4003 SetElementTypeAsReturnType(S: &SemaRef, TheCall,
4004 ReturnType: getASTContext().UnsignedIntTy);
4005 break;
4006 }
4007 }
4008 return false;
4009}
4010
4011static void BuildFlattenedTypeList(QualType BaseTy,
4012 llvm::SmallVectorImpl<QualType> &List) {
4013 llvm::SmallVector<QualType, 16> WorkList;
4014 WorkList.push_back(Elt: BaseTy);
4015 while (!WorkList.empty()) {
4016 QualType T = WorkList.pop_back_val();
4017 T = T.getCanonicalType().getUnqualifiedType();
4018 if (const auto *AT = dyn_cast<ConstantArrayType>(Val&: T)) {
4019 llvm::SmallVector<QualType, 16> ElementFields;
4020 // Generally I've avoided recursion in this algorithm, but arrays of
4021 // structs could be time-consuming to flatten and churn through on the
4022 // work list. Hopefully nesting arrays of structs containing arrays
4023 // of structs too many levels deep is unlikely.
4024 BuildFlattenedTypeList(BaseTy: AT->getElementType(), List&: ElementFields);
4025 // Repeat the element's field list n times.
4026 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
4027 llvm::append_range(C&: List, R&: ElementFields);
4028 continue;
4029 }
4030 // Vectors can only have element types that are builtin types, so this can
4031 // add directly to the list instead of to the WorkList.
4032 if (const auto *VT = dyn_cast<VectorType>(Val&: T)) {
4033 List.insert(I: List.end(), NumToInsert: VT->getNumElements(), Elt: VT->getElementType());
4034 continue;
4035 }
4036 if (const auto *MT = dyn_cast<ConstantMatrixType>(Val&: T)) {
4037 List.insert(I: List.end(), NumToInsert: MT->getNumElementsFlattened(),
4038 Elt: MT->getElementType());
4039 continue;
4040 }
4041 if (const auto *RD = T->getAsCXXRecordDecl()) {
4042 if (RD->isStandardLayout())
4043 RD = RD->getStandardLayoutBaseWithFields();
4044
4045 // For types that we shouldn't decompose (unions and non-aggregates), just
4046 // add the type itself to the list.
4047 if (RD->isUnion() || !RD->isAggregate()) {
4048 List.push_back(Elt: T);
4049 continue;
4050 }
4051
4052 llvm::SmallVector<QualType, 16> FieldTypes;
4053 for (const auto *FD : RD->fields())
4054 if (!FD->isUnnamedBitField())
4055 FieldTypes.push_back(Elt: FD->getType());
4056 // Reverse the newly added sub-range.
4057 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
4058 llvm::append_range(C&: WorkList, R&: FieldTypes);
4059
4060 // If this wasn't a standard layout type we may also have some base
4061 // classes to deal with.
4062 if (!RD->isStandardLayout()) {
4063 FieldTypes.clear();
4064 for (const auto &Base : RD->bases())
4065 FieldTypes.push_back(Elt: Base.getType());
4066 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
4067 llvm::append_range(C&: WorkList, R&: FieldTypes);
4068 }
4069 continue;
4070 }
4071 List.push_back(Elt: T);
4072 }
4073}
4074
4075bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) {
4076 // null and array types are not allowed.
4077 if (QT.isNull() || QT->isArrayType())
4078 return false;
4079
4080 // UDT types are not allowed
4081 if (QT->isRecordType())
4082 return false;
4083
4084 if (QT->isBooleanType() || QT->isEnumeralType())
4085 return false;
4086
4087 // the only other valid builtin types are scalars or vectors
4088 if (QT->isArithmeticType()) {
4089 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
4090 return false;
4091 return true;
4092 }
4093
4094 if (const VectorType *VT = QT->getAs<VectorType>()) {
4095 int ArraySize = VT->getNumElements();
4096
4097 if (ArraySize > 4)
4098 return false;
4099
4100 QualType ElTy = VT->getElementType();
4101 if (ElTy->isBooleanType())
4102 return false;
4103
4104 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
4105 return false;
4106 return true;
4107 }
4108
4109 return false;
4110}
4111
4112bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const {
4113 if (T1.isNull() || T2.isNull())
4114 return false;
4115
4116 T1 = T1.getCanonicalType().getUnqualifiedType();
4117 T2 = T2.getCanonicalType().getUnqualifiedType();
4118
4119 // If both types are the same canonical type, they're obviously compatible.
4120 if (SemaRef.getASTContext().hasSameType(T1, T2))
4121 return true;
4122
4123 llvm::SmallVector<QualType, 16> T1Types;
4124 BuildFlattenedTypeList(BaseTy: T1, List&: T1Types);
4125 llvm::SmallVector<QualType, 16> T2Types;
4126 BuildFlattenedTypeList(BaseTy: T2, List&: T2Types);
4127
4128 // Check the flattened type list
4129 return llvm::equal(LRange&: T1Types, RRange&: T2Types,
4130 P: [this](QualType LHS, QualType RHS) -> bool {
4131 return SemaRef.IsLayoutCompatible(T1: LHS, T2: RHS);
4132 });
4133}
4134
4135bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
4136 FunctionDecl *Old) {
4137 if (New->getNumParams() != Old->getNumParams())
4138 return true;
4139
4140 bool HadError = false;
4141
4142 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
4143 ParmVarDecl *NewParam = New->getParamDecl(i);
4144 ParmVarDecl *OldParam = Old->getParamDecl(i);
4145
4146 // HLSL parameter declarations for inout and out must match between
4147 // declarations. In HLSL inout and out are ambiguous at the call site,
4148 // but have different calling behavior, so you cannot overload a
4149 // method based on a difference between inout and out annotations.
4150 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
4151 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
4152 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
4153 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
4154
4155 if (NSpellingIdx != OSpellingIdx) {
4156 SemaRef.Diag(Loc: NewParam->getLocation(),
4157 DiagID: diag::err_hlsl_param_qualifier_mismatch)
4158 << NDAttr << NewParam;
4159 SemaRef.Diag(Loc: OldParam->getLocation(), DiagID: diag::note_previous_declaration_as)
4160 << ODAttr;
4161 HadError = true;
4162 }
4163 }
4164 return HadError;
4165}
4166
4167// Generally follows PerformScalarCast, with cases reordered for
4168// clarity of what types are supported
4169bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
4170
4171 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
4172 return false;
4173
4174 if (SemaRef.getASTContext().hasSameUnqualifiedType(T1: SrcTy, T2: DestTy))
4175 return true;
4176
4177 switch (SrcTy->getScalarTypeKind()) {
4178 case Type::STK_Bool: // casting from bool is like casting from an integer
4179 case Type::STK_Integral:
4180 switch (DestTy->getScalarTypeKind()) {
4181 case Type::STK_Bool:
4182 case Type::STK_Integral:
4183 case Type::STK_Floating:
4184 return true;
4185 case Type::STK_CPointer:
4186 case Type::STK_ObjCObjectPointer:
4187 case Type::STK_BlockPointer:
4188 case Type::STK_MemberPointer:
4189 llvm_unreachable("HLSL doesn't support pointers.");
4190 case Type::STK_IntegralComplex:
4191 case Type::STK_FloatingComplex:
4192 llvm_unreachable("HLSL doesn't support complex types.");
4193 case Type::STK_FixedPoint:
4194 llvm_unreachable("HLSL doesn't support fixed point types.");
4195 }
4196 llvm_unreachable("Should have returned before this");
4197
4198 case Type::STK_Floating:
4199 switch (DestTy->getScalarTypeKind()) {
4200 case Type::STK_Floating:
4201 case Type::STK_Bool:
4202 case Type::STK_Integral:
4203 return true;
4204 case Type::STK_FloatingComplex:
4205 case Type::STK_IntegralComplex:
4206 llvm_unreachable("HLSL doesn't support complex types.");
4207 case Type::STK_FixedPoint:
4208 llvm_unreachable("HLSL doesn't support fixed point types.");
4209 case Type::STK_CPointer:
4210 case Type::STK_ObjCObjectPointer:
4211 case Type::STK_BlockPointer:
4212 case Type::STK_MemberPointer:
4213 llvm_unreachable("HLSL doesn't support pointers.");
4214 }
4215 llvm_unreachable("Should have returned before this");
4216
4217 case Type::STK_MemberPointer:
4218 case Type::STK_CPointer:
4219 case Type::STK_BlockPointer:
4220 case Type::STK_ObjCObjectPointer:
4221 llvm_unreachable("HLSL doesn't support pointers.");
4222
4223 case Type::STK_FixedPoint:
4224 llvm_unreachable("HLSL doesn't support fixed point types.");
4225
4226 case Type::STK_FloatingComplex:
4227 case Type::STK_IntegralComplex:
4228 llvm_unreachable("HLSL doesn't support complex types.");
4229 }
4230
4231 llvm_unreachable("Unhandled scalar cast");
4232}
4233
4234// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
4235// Src is a scalar or a vector of length 1
4236// Or if Dest is a vector and Src is a vector of length 1
4237bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
4238
4239 QualType SrcTy = Src->getType();
4240 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
4241 // going to be a vector splat from a scalar.
4242 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
4243 DestTy->isScalarType())
4244 return false;
4245
4246 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
4247
4248 // Src isn't a scalar or a vector of length 1
4249 if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
4250 return false;
4251
4252 if (SrcVecTy)
4253 SrcTy = SrcVecTy->getElementType();
4254
4255 llvm::SmallVector<QualType> DestTypes;
4256 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
4257
4258 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
4259 if (DestTypes[I]->isUnionType())
4260 return false;
4261 if (!CanPerformScalarCast(SrcTy, DestTy: DestTypes[I]))
4262 return false;
4263 }
4264 return true;
4265}
4266
4267// Can we perform an HLSL Elementwise cast?
4268bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
4269
4270 // Don't handle casts where LHS and RHS are any combination of scalar/vector
4271 // There must be an aggregate somewhere
4272 QualType SrcTy = Src->getType();
4273 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
4274 return false;
4275
4276 if (SrcTy->isVectorType() &&
4277 (DestTy->isScalarType() || DestTy->isVectorType()))
4278 return false;
4279
4280 if (SrcTy->isConstantMatrixType() &&
4281 (DestTy->isScalarType() || DestTy->isConstantMatrixType()))
4282 return false;
4283
4284 llvm::SmallVector<QualType> DestTypes;
4285 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
4286 llvm::SmallVector<QualType> SrcTypes;
4287 BuildFlattenedTypeList(BaseTy: SrcTy, List&: SrcTypes);
4288
4289 // Usually the size of SrcTypes must be greater than or equal to the size of
4290 // DestTypes.
4291 if (SrcTypes.size() < DestTypes.size())
4292 return false;
4293
4294 unsigned SrcSize = SrcTypes.size();
4295 unsigned DstSize = DestTypes.size();
4296 unsigned I;
4297 for (I = 0; I < DstSize && I < SrcSize; I++) {
4298 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
4299 return false;
4300 if (!CanPerformScalarCast(SrcTy: SrcTypes[I], DestTy: DestTypes[I])) {
4301 return false;
4302 }
4303 }
4304
4305 // check the rest of the source type for unions.
4306 for (; I < SrcSize; I++) {
4307 if (SrcTypes[I]->isUnionType())
4308 return false;
4309 }
4310 return true;
4311}
4312
4313ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) {
4314 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
4315 "We should not get here without a parameter modifier expression");
4316 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
4317 if (Attr->getABI() == ParameterABI::Ordinary)
4318 return ExprResult(Arg);
4319
4320 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
4321 if (!Arg->isLValue()) {
4322 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_lvalue)
4323 << Arg << (IsInOut ? 1 : 0);
4324 return ExprError();
4325 }
4326
4327 ASTContext &Ctx = SemaRef.getASTContext();
4328
4329 QualType Ty = Param->getType().getNonLValueExprType(Context: Ctx);
4330
4331 // HLSL allows implicit conversions from scalars to vectors, but not the
4332 // inverse, so we need to disallow `inout` with scalar->vector or
4333 // scalar->matrix conversions.
4334 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
4335 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_scalar_extension)
4336 << Arg << (IsInOut ? 1 : 0);
4337 return ExprError();
4338 }
4339
4340 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
4341 VK_LValue, OK_Ordinary, Arg);
4342
4343 // Parameters are initialized via copy initialization. This allows for
4344 // overload resolution of argument constructors.
4345 InitializedEntity Entity =
4346 InitializedEntity::InitializeParameter(Context&: Ctx, Type: Ty, Consumed: false);
4347 ExprResult Res =
4348 SemaRef.PerformCopyInitialization(Entity, EqualLoc: Param->getBeginLoc(), Init: ArgOpV);
4349 if (Res.isInvalid())
4350 return ExprError();
4351 Expr *Base = Res.get();
4352 // After the cast, drop the reference type when creating the exprs.
4353 Ty = Ty.getNonLValueExprType(Context: Ctx);
4354 auto *OpV = new (Ctx)
4355 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
4356
4357 // Writebacks are performed with `=` binary operator, which allows for
4358 // overload resolution on writeback result expressions.
4359 Res = SemaRef.ActOnBinOp(S: SemaRef.getCurScope(), TokLoc: Param->getBeginLoc(),
4360 Kind: tok::equal, LHSExpr: ArgOpV, RHSExpr: OpV);
4361
4362 if (Res.isInvalid())
4363 return ExprError();
4364 Expr *Writeback = Res.get();
4365 auto *OutExpr =
4366 HLSLOutArgExpr::Create(C: Ctx, Ty, Base: ArgOpV, OpV, WB: Writeback, IsInOut);
4367
4368 return ExprResult(OutExpr);
4369}
4370
4371QualType SemaHLSL::getInoutParameterType(QualType Ty) {
4372 // If HLSL gains support for references, all the cites that use this will need
4373 // to be updated with semantic checking to produce errors for
4374 // pointers/references.
4375 assert(!Ty->isReferenceType() &&
4376 "Pointer and reference types cannot be inout or out parameters");
4377 Ty = SemaRef.getASTContext().getLValueReferenceType(T: Ty);
4378 Ty.addRestrict();
4379 return Ty;
4380}
4381
4382static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD) {
4383 bool IsVulkan =
4384 Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::Vulkan;
4385 bool IsVKPushConstant = IsVulkan && VD->hasAttr<HLSLVkPushConstantAttr>();
4386 QualType QT = VD->getType();
4387 return VD->getDeclContext()->isTranslationUnit() &&
4388 QT.getAddressSpace() == LangAS::Default &&
4389 VD->getStorageClass() != SC_Static &&
4390 !VD->hasAttr<HLSLVkConstantIdAttr>() && !IsVKPushConstant &&
4391 !isInvalidConstantBufferLeafElementType(Ty: QT.getTypePtr());
4392}
4393
4394void SemaHLSL::deduceAddressSpace(VarDecl *Decl) {
4395 // The variable already has an address space (groupshared for ex).
4396 if (Decl->getType().hasAddressSpace())
4397 return;
4398
4399 if (Decl->getType()->isDependentType())
4400 return;
4401
4402 QualType Type = Decl->getType();
4403
4404 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
4405 LangAS ImplAS = LangAS::hlsl_input;
4406 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4407 Decl->setType(Type);
4408 return;
4409 }
4410
4411 bool IsVulkan = getASTContext().getTargetInfo().getTriple().getOS() ==
4412 llvm::Triple::Vulkan;
4413 if (IsVulkan && Decl->hasAttr<HLSLVkPushConstantAttr>()) {
4414 if (HasDeclaredAPushConstant)
4415 SemaRef.Diag(Loc: Decl->getLocation(), DiagID: diag::err_hlsl_push_constant_unique);
4416
4417 LangAS ImplAS = LangAS::hlsl_push_constant;
4418 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4419 Decl->setType(Type);
4420 HasDeclaredAPushConstant = true;
4421 return;
4422 }
4423
4424 if (Type->isSamplerT() || Type->isVoidType())
4425 return;
4426
4427 // Resource handles.
4428 if (Type->isHLSLResourceRecord() || Type->isHLSLResourceRecordArray())
4429 return;
4430
4431 // Only static globals belong to the Private address space.
4432 // Non-static globals belongs to the cbuffer.
4433 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
4434 return;
4435
4436 LangAS ImplAS = LangAS::hlsl_private;
4437 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4438 Decl->setType(Type);
4439}
4440
4441void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
4442 if (VD->hasGlobalStorage()) {
4443 // make sure the declaration has a complete type
4444 if (SemaRef.RequireCompleteType(
4445 Loc: VD->getLocation(),
4446 T: SemaRef.getASTContext().getBaseElementType(QT: VD->getType()),
4447 DiagID: diag::err_typecheck_decl_incomplete_type)) {
4448 VD->setInvalidDecl();
4449 deduceAddressSpace(Decl: VD);
4450 return;
4451 }
4452
4453 // Global variables outside a cbuffer block that are not a resource, static,
4454 // groupshared, or an empty array or struct belong to the default constant
4455 // buffer $Globals (to be created at the end of the translation unit).
4456 if (IsDefaultBufferConstantDecl(Ctx: getASTContext(), VD)) {
4457 // update address space to hlsl_constant
4458 QualType NewTy = getASTContext().getAddrSpaceQualType(
4459 T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
4460 VD->setType(NewTy);
4461 DefaultCBufferDecls.push_back(Elt: VD);
4462 }
4463
4464 // find all resources bindings on decl
4465 if (VD->getType()->isHLSLIntangibleType())
4466 collectResourceBindingsOnVarDecl(D: VD);
4467
4468 if (VD->hasAttr<HLSLVkConstantIdAttr>())
4469 VD->setStorageClass(StorageClass::SC_Static);
4470
4471 if (isResourceRecordTypeOrArrayOf(VD) &&
4472 VD->getStorageClass() != SC_Static) {
4473 // Add internal linkage attribute to non-static resource variables. The
4474 // global externally visible storage is accessed through the handle, which
4475 // is a member. The variable itself is not externally visible.
4476 VD->addAttr(A: InternalLinkageAttr::CreateImplicit(Ctx&: getASTContext()));
4477 }
4478
4479 // process explicit bindings
4480 processExplicitBindingsOnDecl(D: VD);
4481
4482 // Add implicit binding attribute to non-static resource arrays.
4483 if (VD->getType()->isHLSLResourceRecordArray() &&
4484 VD->getStorageClass() != SC_Static) {
4485 // If the resource array does not have an explicit binding attribute,
4486 // create an implicit one. It will be used to transfer implicit binding
4487 // order_ID to codegen.
4488 ResourceBindingAttrs Binding(VD);
4489 if (!Binding.isExplicit()) {
4490 uint32_t OrderID = getNextImplicitBindingOrderID();
4491 if (Binding.hasBinding())
4492 Binding.setImplicitOrderID(OrderID);
4493 else {
4494 addImplicitBindingAttrToDecl(
4495 S&: SemaRef, D: VD, RT: getRegisterType(ResTy: getResourceArrayHandleType(VD)),
4496 ImplicitBindingOrderID: OrderID);
4497 // Re-create the binding object to pick up the new attribute.
4498 Binding = ResourceBindingAttrs(VD);
4499 }
4500 }
4501
4502 // Get to the base type of a potentially multi-dimensional array.
4503 QualType Ty = getASTContext().getBaseElementType(QT: VD->getType());
4504
4505 const CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
4506 if (hasCounterHandle(RD)) {
4507 if (!Binding.hasCounterImplicitOrderID()) {
4508 uint32_t OrderID = getNextImplicitBindingOrderID();
4509 Binding.setCounterImplicitOrderID(OrderID);
4510 }
4511 }
4512 }
4513 }
4514
4515 deduceAddressSpace(Decl: VD);
4516}
4517
4518bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
4519 assert(VD->getType()->isHLSLResourceRecord() &&
4520 "expected resource record type");
4521
4522 ASTContext &AST = SemaRef.getASTContext();
4523 uint64_t UIntTySize = AST.getTypeSize(T: AST.UnsignedIntTy);
4524 uint64_t IntTySize = AST.getTypeSize(T: AST.IntTy);
4525
4526 // Gather resource binding attributes.
4527 ResourceBindingAttrs Binding(VD);
4528
4529 // Find correct initialization method and create its arguments.
4530 QualType ResourceTy = VD->getType();
4531 CXXRecordDecl *ResourceDecl = ResourceTy->getAsCXXRecordDecl();
4532 CXXMethodDecl *CreateMethod = nullptr;
4533 llvm::SmallVector<Expr *> Args;
4534
4535 bool HasCounter = hasCounterHandle(RD: ResourceDecl);
4536 const char *CreateMethodName;
4537 if (Binding.isExplicit())
4538 CreateMethodName = HasCounter ? "__createFromBindingWithImplicitCounter"
4539 : "__createFromBinding";
4540 else
4541 CreateMethodName = HasCounter
4542 ? "__createFromImplicitBindingWithImplicitCounter"
4543 : "__createFromImplicitBinding";
4544
4545 CreateMethod =
4546 lookupMethod(S&: SemaRef, RecordDecl: ResourceDecl, Name: CreateMethodName, Loc: VD->getLocation());
4547
4548 if (!CreateMethod)
4549 // This can happen if someone creates a struct that looks like an HLSL
4550 // resource record but does not have the required static create method.
4551 // No binding will be generated for it.
4552 return false;
4553
4554 if (Binding.isExplicit()) {
4555 IntegerLiteral *RegSlot =
4556 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, Binding.getSlot()),
4557 type: AST.UnsignedIntTy, l: SourceLocation());
4558 Args.push_back(Elt: RegSlot);
4559 } else {
4560 uint32_t OrderID = (Binding.hasImplicitOrderID())
4561 ? Binding.getImplicitOrderID()
4562 : getNextImplicitBindingOrderID();
4563 IntegerLiteral *OrderId =
4564 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, OrderID),
4565 type: AST.UnsignedIntTy, l: SourceLocation());
4566 Args.push_back(Elt: OrderId);
4567 }
4568
4569 IntegerLiteral *Space =
4570 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, Binding.getSpace()),
4571 type: AST.UnsignedIntTy, l: SourceLocation());
4572 Args.push_back(Elt: Space);
4573
4574 IntegerLiteral *RangeSize = IntegerLiteral::Create(
4575 C: AST, V: llvm::APInt(IntTySize, 1), type: AST.IntTy, l: SourceLocation());
4576 Args.push_back(Elt: RangeSize);
4577
4578 IntegerLiteral *Index = IntegerLiteral::Create(
4579 C: AST, V: llvm::APInt(UIntTySize, 0), type: AST.UnsignedIntTy, l: SourceLocation());
4580 Args.push_back(Elt: Index);
4581
4582 StringRef VarName = VD->getName();
4583 StringLiteral *Name = StringLiteral::Create(
4584 Ctx: AST, Str: VarName, Kind: StringLiteralKind::Ordinary, Pascal: false,
4585 Ty: AST.getStringLiteralArrayType(EltTy: AST.CharTy.withConst(), Length: VarName.size()),
4586 Locs: SourceLocation());
4587 ImplicitCastExpr *NameCast = ImplicitCastExpr::Create(
4588 Context: AST, T: AST.getPointerType(T: AST.CharTy.withConst()), Kind: CK_ArrayToPointerDecay,
4589 Operand: Name, BasePath: nullptr, Cat: VK_PRValue, FPO: FPOptionsOverride());
4590 Args.push_back(Elt: NameCast);
4591
4592 if (HasCounter) {
4593 // Will this be in the correct order?
4594 uint32_t CounterOrderID = getNextImplicitBindingOrderID();
4595 IntegerLiteral *CounterId =
4596 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, CounterOrderID),
4597 type: AST.UnsignedIntTy, l: SourceLocation());
4598 Args.push_back(Elt: CounterId);
4599 }
4600
4601 // Make sure the create method template is instantiated and emitted.
4602 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
4603 SemaRef.InstantiateFunctionDefinition(PointOfInstantiation: VD->getLocation(), Function: CreateMethod,
4604 Recursive: true);
4605
4606 // Create CallExpr with a call to the static method and set it as the decl
4607 // initialization.
4608 DeclRefExpr *DRE = DeclRefExpr::Create(
4609 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: CreateMethod, RefersToEnclosingVariableOrCapture: false,
4610 NameInfo: CreateMethod->getNameInfo(), T: CreateMethod->getType(), VK: VK_PRValue);
4611
4612 auto *ImpCast = ImplicitCastExpr::Create(
4613 Context: AST, T: AST.getPointerType(T: CreateMethod->getType()),
4614 Kind: CK_FunctionToPointerDecay, Operand: DRE, BasePath: nullptr, Cat: VK_PRValue, FPO: FPOptionsOverride());
4615
4616 CallExpr *InitExpr =
4617 CallExpr::Create(Ctx: AST, Fn: ImpCast, Args, Ty: ResourceTy, VK: VK_PRValue,
4618 RParenLoc: SourceLocation(), FPFeatures: FPOptionsOverride());
4619 VD->setInit(InitExpr);
4620 VD->setInitStyle(VarDecl::CallInit);
4621 SemaRef.CheckCompleteVariableDeclaration(VD);
4622 return true;
4623}
4624
4625bool SemaHLSL::initGlobalResourceArrayDecl(VarDecl *VD) {
4626 assert(VD->getType()->isHLSLResourceRecordArray() &&
4627 "expected array of resource records");
4628
4629 // Individual resources in a resource array are not initialized here. They
4630 // are initialized later on during codegen when the individual resources are
4631 // accessed. Codegen will emit a call to the resource initialization method
4632 // with the specified array index. We need to make sure though that the method
4633 // for the specific resource type is instantiated, so codegen can emit a call
4634 // to it when the array element is accessed.
4635
4636 // Find correct initialization method based on the resource binding
4637 // information.
4638 ASTContext &AST = SemaRef.getASTContext();
4639 QualType ResElementTy = AST.getBaseElementType(QT: VD->getType());
4640 CXXRecordDecl *ResourceDecl = ResElementTy->getAsCXXRecordDecl();
4641 CXXMethodDecl *CreateMethod = nullptr;
4642
4643 bool HasCounter = hasCounterHandle(RD: ResourceDecl);
4644 ResourceBindingAttrs ResourceAttrs(VD);
4645 if (ResourceAttrs.isExplicit())
4646 // Resource has explicit binding.
4647 CreateMethod =
4648 lookupMethod(S&: SemaRef, RecordDecl: ResourceDecl,
4649 Name: HasCounter ? "__createFromBindingWithImplicitCounter"
4650 : "__createFromBinding",
4651 Loc: VD->getLocation());
4652 else
4653 // Resource has implicit binding.
4654 CreateMethod = lookupMethod(
4655 S&: SemaRef, RecordDecl: ResourceDecl,
4656 Name: HasCounter ? "__createFromImplicitBindingWithImplicitCounter"
4657 : "__createFromImplicitBinding",
4658 Loc: VD->getLocation());
4659
4660 if (!CreateMethod)
4661 return false;
4662
4663 // Make sure the create method template is instantiated and emitted.
4664 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
4665 SemaRef.InstantiateFunctionDefinition(PointOfInstantiation: VD->getLocation(), Function: CreateMethod,
4666 Recursive: true);
4667 return true;
4668}
4669
4670// Returns true if the initialization has been handled.
4671// Returns false to use default initialization.
4672bool SemaHLSL::ActOnUninitializedVarDecl(VarDecl *VD) {
4673 // Objects in the hlsl_constant address space are initialized
4674 // externally, so don't synthesize an implicit initializer.
4675 if (VD->getType().getAddressSpace() == LangAS::hlsl_constant)
4676 return true;
4677
4678 // Initialize non-static resources at the global scope.
4679 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
4680 const Type *Ty = VD->getType().getTypePtr();
4681 if (Ty->isHLSLResourceRecord())
4682 return initGlobalResourceDecl(VD);
4683 if (Ty->isHLSLResourceRecordArray())
4684 return initGlobalResourceArrayDecl(VD);
4685 }
4686 return false;
4687}
4688
4689// Return true if everything is ok; returns false if there was an error.
4690bool SemaHLSL::CheckResourceBinOp(BinaryOperatorKind Opc, Expr *LHSExpr,
4691 Expr *RHSExpr, SourceLocation Loc) {
4692 assert((LHSExpr->getType()->isHLSLResourceRecord() ||
4693 LHSExpr->getType()->isHLSLResourceRecordArray()) &&
4694 "expected LHS to be a resource record or array of resource records");
4695 if (Opc != BO_Assign)
4696 return true;
4697
4698 // If LHS is an array subscript, get the underlying declaration.
4699 Expr *E = LHSExpr;
4700 while (auto *ASE = dyn_cast<ArraySubscriptExpr>(Val: E))
4701 E = ASE->getBase()->IgnoreParenImpCasts();
4702
4703 // Report error if LHS is a non-static resource declared at a global scope.
4704 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Val: E->IgnoreParens())) {
4705 if (VarDecl *VD = dyn_cast<VarDecl>(Val: DRE->getDecl())) {
4706 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
4707 // assignment to global resource is not allowed
4708 SemaRef.Diag(Loc, DiagID: diag::err_hlsl_assign_to_global_resource) << VD;
4709 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::note_var_declared_here) << VD;
4710 return false;
4711 }
4712 }
4713 }
4714 return true;
4715}
4716
4717// Walks though the global variable declaration, collects all resource binding
4718// requirements and adds them to Bindings
4719void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
4720 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
4721 "expected global variable that contains HLSL resource");
4722
4723 // Cbuffers and Tbuffers are HLSLBufferDecl types
4724 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: VD)) {
4725 Bindings.addDeclBindingInfo(VD, ResClass: CBufferOrTBuffer->isCBuffer()
4726 ? ResourceClass::CBuffer
4727 : ResourceClass::SRV);
4728 return;
4729 }
4730
4731 // Unwrap arrays
4732 // FIXME: Calculate array size while unwrapping
4733 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
4734 while (Ty->isArrayType()) {
4735 const ArrayType *AT = cast<ArrayType>(Val: Ty);
4736 Ty = AT->getElementType()->getUnqualifiedDesugaredType();
4737 }
4738
4739 // Resource (or array of resources)
4740 if (const HLSLAttributedResourceType *AttrResType =
4741 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
4742 Bindings.addDeclBindingInfo(VD, ResClass: AttrResType->getAttrs().ResourceClass);
4743 return;
4744 }
4745
4746 // User defined record type
4747 if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty))
4748 collectResourceBindingsOnUserRecordDecl(VD, RT);
4749}
4750
4751// Walks though the explicit resource binding attributes on the declaration,
4752// and makes sure there is a resource that matched the binding and updates
4753// DeclBindingInfoLists
4754void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
4755 assert(VD->hasGlobalStorage() && "expected global variable");
4756
4757 bool HasBinding = false;
4758 for (Attr *A : VD->attrs()) {
4759 if (isa<HLSLVkBindingAttr>(Val: A)) {
4760 HasBinding = true;
4761 if (auto PA = VD->getAttr<HLSLVkPushConstantAttr>())
4762 Diag(Loc: PA->getLoc(), DiagID: diag::err_hlsl_attr_incompatible) << A << PA;
4763 }
4764
4765 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: A);
4766 if (!RBA || !RBA->hasRegisterSlot())
4767 continue;
4768 HasBinding = true;
4769
4770 RegisterType RT = RBA->getRegisterType();
4771 assert(RT != RegisterType::I && "invalid or obsolete register type should "
4772 "never have an attribute created");
4773
4774 if (RT == RegisterType::C) {
4775 if (Bindings.hasBindingInfoForDecl(VD))
4776 SemaRef.Diag(Loc: VD->getLocation(),
4777 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
4778 << static_cast<int>(RT);
4779 continue;
4780 }
4781
4782 // Find DeclBindingInfo for this binding and update it, or report error
4783 // if it does not exist (user type does to contain resources with the
4784 // expected resource class).
4785 ResourceClass RC = getResourceClass(RT);
4786 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, ResClass: RC)) {
4787 // update binding info
4788 BI->setBindingAttribute(A: RBA, BT: BindingType::Explicit);
4789 } else {
4790 SemaRef.Diag(Loc: VD->getLocation(),
4791 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
4792 << static_cast<int>(RT);
4793 }
4794 }
4795
4796 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
4797 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
4798}
4799namespace {
4800class InitListTransformer {
4801 Sema &S;
4802 ASTContext &Ctx;
4803 QualType InitTy;
4804 QualType *DstIt = nullptr;
4805 Expr **ArgIt = nullptr;
4806 // Is wrapping the destination type iterator required? This is only used for
4807 // incomplete array types where we loop over the destination type since we
4808 // don't know the full number of elements from the declaration.
4809 bool Wrap;
4810
4811 bool castInitializer(Expr *E) {
4812 assert(DstIt && "This should always be something!");
4813 if (DstIt == DestTypes.end()) {
4814 if (!Wrap) {
4815 ArgExprs.push_back(Elt: E);
4816 // This is odd, but it isn't technically a failure due to conversion, we
4817 // handle mismatched counts of arguments differently.
4818 return true;
4819 }
4820 DstIt = DestTypes.begin();
4821 }
4822 InitializedEntity Entity = InitializedEntity::InitializeParameter(
4823 Context&: Ctx, Type: *DstIt, /* Consumed (ObjC) */ Consumed: false);
4824 ExprResult Res = S.PerformCopyInitialization(Entity, EqualLoc: E->getBeginLoc(), Init: E);
4825 if (Res.isInvalid())
4826 return false;
4827 Expr *Init = Res.get();
4828 ArgExprs.push_back(Elt: Init);
4829 DstIt++;
4830 return true;
4831 }
4832
4833 bool buildInitializerListImpl(Expr *E) {
4834 // If this is an initialization list, traverse the sub initializers.
4835 if (auto *Init = dyn_cast<InitListExpr>(Val: E)) {
4836 for (auto *SubInit : Init->inits())
4837 if (!buildInitializerListImpl(E: SubInit))
4838 return false;
4839 return true;
4840 }
4841
4842 // If this is a scalar type, just enqueue the expression.
4843 QualType Ty = E->getType();
4844
4845 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
4846 return castInitializer(E);
4847
4848 if (auto *VecTy = Ty->getAs<VectorType>()) {
4849 uint64_t Size = VecTy->getNumElements();
4850
4851 QualType SizeTy = Ctx.getSizeType();
4852 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
4853 for (uint64_t I = 0; I < Size; ++I) {
4854 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
4855 type: SizeTy, l: SourceLocation());
4856
4857 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
4858 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
4859 if (ElExpr.isInvalid())
4860 return false;
4861 if (!castInitializer(E: ElExpr.get()))
4862 return false;
4863 }
4864 return true;
4865 }
4866 if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
4867 unsigned Rows = MTy->getNumRows();
4868 unsigned Cols = MTy->getNumColumns();
4869 QualType ElemTy = MTy->getElementType();
4870
4871 for (unsigned R = 0; R < Rows; ++R) {
4872 for (unsigned C = 0; C < Cols; ++C) {
4873 // row index literal
4874 Expr *RowIdx = IntegerLiteral::Create(
4875 C: Ctx, V: llvm::APInt(Ctx.getIntWidth(T: Ctx.IntTy), R), type: Ctx.IntTy,
4876 l: E->getBeginLoc());
4877 // column index literal
4878 Expr *ColIdx = IntegerLiteral::Create(
4879 C: Ctx, V: llvm::APInt(Ctx.getIntWidth(T: Ctx.IntTy), C), type: Ctx.IntTy,
4880 l: E->getBeginLoc());
4881 ExprResult ElExpr = S.CreateBuiltinMatrixSubscriptExpr(
4882 Base: E, RowIdx, ColumnIdx: ColIdx, RBLoc: E->getEndLoc());
4883 if (ElExpr.isInvalid())
4884 return false;
4885 if (!castInitializer(E: ElExpr.get()))
4886 return false;
4887 ElExpr.get()->setType(ElemTy);
4888 }
4889 }
4890 return true;
4891 }
4892
4893 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Val: Ty.getTypePtr())) {
4894 uint64_t Size = ArrTy->getZExtSize();
4895 QualType SizeTy = Ctx.getSizeType();
4896 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
4897 for (uint64_t I = 0; I < Size; ++I) {
4898 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
4899 type: SizeTy, l: SourceLocation());
4900 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
4901 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
4902 if (ElExpr.isInvalid())
4903 return false;
4904 if (!buildInitializerListImpl(E: ElExpr.get()))
4905 return false;
4906 }
4907 return true;
4908 }
4909
4910 if (auto *RD = Ty->getAsCXXRecordDecl()) {
4911 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
4912 RecordDecls.push_back(Elt: RD);
4913 while (RecordDecls.back()->getNumBases()) {
4914 CXXRecordDecl *D = RecordDecls.back();
4915 assert(D->getNumBases() == 1 &&
4916 "HLSL doesn't support multiple inheritance");
4917 RecordDecls.push_back(
4918 Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
4919 }
4920 while (!RecordDecls.empty()) {
4921 CXXRecordDecl *RD = RecordDecls.pop_back_val();
4922 for (auto *FD : RD->fields()) {
4923 if (FD->isUnnamedBitField())
4924 continue;
4925 DeclAccessPair Found = DeclAccessPair::make(D: FD, AS: FD->getAccess());
4926 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
4927 ExprResult Res = S.BuildFieldReferenceExpr(
4928 BaseExpr: E, IsArrow: false, OpLoc: E->getBeginLoc(), SS: CXXScopeSpec(), Field: FD, FoundDecl: Found, MemberNameInfo: NameInfo);
4929 if (Res.isInvalid())
4930 return false;
4931 if (!buildInitializerListImpl(E: Res.get()))
4932 return false;
4933 }
4934 }
4935 }
4936 return true;
4937 }
4938
4939 Expr *generateInitListsImpl(QualType Ty) {
4940 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
4941 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
4942 return *(ArgIt++);
4943
4944 llvm::SmallVector<Expr *> Inits;
4945 Ty = Ty.getDesugaredType(Context: Ctx);
4946 if (Ty->isVectorType() || Ty->isConstantArrayType() ||
4947 Ty->isConstantMatrixType()) {
4948 QualType ElTy;
4949 uint64_t Size = 0;
4950 if (auto *ATy = Ty->getAs<VectorType>()) {
4951 ElTy = ATy->getElementType();
4952 Size = ATy->getNumElements();
4953 } else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
4954 ElTy = CMTy->getElementType();
4955 Size = CMTy->getNumElementsFlattened();
4956 } else {
4957 auto *VTy = cast<ConstantArrayType>(Val: Ty.getTypePtr());
4958 ElTy = VTy->getElementType();
4959 Size = VTy->getZExtSize();
4960 }
4961 for (uint64_t I = 0; I < Size; ++I)
4962 Inits.push_back(Elt: generateInitListsImpl(Ty: ElTy));
4963 }
4964 if (auto *RD = Ty->getAsCXXRecordDecl()) {
4965 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
4966 RecordDecls.push_back(Elt: RD);
4967 while (RecordDecls.back()->getNumBases()) {
4968 CXXRecordDecl *D = RecordDecls.back();
4969 assert(D->getNumBases() == 1 &&
4970 "HLSL doesn't support multiple inheritance");
4971 RecordDecls.push_back(
4972 Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
4973 }
4974 while (!RecordDecls.empty()) {
4975 CXXRecordDecl *RD = RecordDecls.pop_back_val();
4976 for (auto *FD : RD->fields())
4977 if (!FD->isUnnamedBitField())
4978 Inits.push_back(Elt: generateInitListsImpl(Ty: FD->getType()));
4979 }
4980 }
4981 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
4982 Inits, Inits.back()->getEndLoc());
4983 NewInit->setType(Ty);
4984 return NewInit;
4985 }
4986
4987public:
4988 llvm::SmallVector<QualType, 16> DestTypes;
4989 llvm::SmallVector<Expr *, 16> ArgExprs;
4990 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
4991 : S(SemaRef), Ctx(SemaRef.getASTContext()),
4992 Wrap(Entity.getType()->isIncompleteArrayType()) {
4993 InitTy = Entity.getType().getNonReferenceType();
4994 // When we're generating initializer lists for incomplete array types we
4995 // need to wrap around both when building the initializers and when
4996 // generating the final initializer lists.
4997 if (Wrap) {
4998 assert(InitTy->isIncompleteArrayType());
4999 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(T: InitTy);
5000 InitTy = IAT->getElementType();
5001 }
5002 BuildFlattenedTypeList(BaseTy: InitTy, List&: DestTypes);
5003 DstIt = DestTypes.begin();
5004 }
5005
5006 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
5007
5008 Expr *generateInitLists() {
5009 assert(!ArgExprs.empty() &&
5010 "Call buildInitializerList to generate argument expressions.");
5011 ArgIt = ArgExprs.begin();
5012 if (!Wrap)
5013 return generateInitListsImpl(Ty: InitTy);
5014 llvm::SmallVector<Expr *> Inits;
5015 while (ArgIt != ArgExprs.end())
5016 Inits.push_back(Elt: generateInitListsImpl(Ty: InitTy));
5017
5018 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
5019 Inits, Inits.back()->getEndLoc());
5020 llvm::APInt ArySize(64, Inits.size());
5021 NewInit->setType(Ctx.getConstantArrayType(EltTy: InitTy, ArySize, SizeExpr: nullptr,
5022 ASM: ArraySizeModifier::Normal, IndexTypeQuals: 0));
5023 return NewInit;
5024 }
5025};
5026} // namespace
5027
5028// Recursively detect any incomplete array anywhere in the type graph,
5029// including arrays, struct fields, and base classes.
5030static bool containsIncompleteArrayType(QualType Ty) {
5031 Ty = Ty.getCanonicalType();
5032
5033 // Array types
5034 if (const ArrayType *AT = dyn_cast<ArrayType>(Val&: Ty)) {
5035 if (isa<IncompleteArrayType>(Val: AT))
5036 return true;
5037 return containsIncompleteArrayType(Ty: AT->getElementType());
5038 }
5039
5040 // Record (struct/class) types
5041 if (const auto *RT = Ty->getAs<RecordType>()) {
5042 const RecordDecl *RD = RT->getDecl();
5043
5044 // Walk base classes (for C++ / HLSL structs with inheritance)
5045 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
5046 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
5047 if (containsIncompleteArrayType(Ty: Base.getType()))
5048 return true;
5049 }
5050 }
5051
5052 // Walk fields
5053 for (const FieldDecl *F : RD->fields()) {
5054 if (containsIncompleteArrayType(Ty: F->getType()))
5055 return true;
5056 }
5057 }
5058
5059 return false;
5060}
5061
5062bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
5063 InitListExpr *Init) {
5064 // If the initializer is a scalar, just return it.
5065 if (Init->getType()->isScalarType())
5066 return true;
5067 ASTContext &Ctx = SemaRef.getASTContext();
5068 InitListTransformer ILT(SemaRef, Entity);
5069
5070 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
5071 Expr *E = Init->getInit(Init: I);
5072 if (E->HasSideEffects(Ctx)) {
5073 QualType Ty = E->getType();
5074 if (Ty->isRecordType())
5075 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
5076 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
5077 E->getObjectKind(), E);
5078 Init->setInit(Init: I, expr: E);
5079 }
5080 if (!ILT.buildInitializerList(E))
5081 return false;
5082 }
5083 size_t ExpectedSize = ILT.DestTypes.size();
5084 size_t ActualSize = ILT.ArgExprs.size();
5085 if (ExpectedSize == 0 && ActualSize == 0)
5086 return true;
5087
5088 // Reject empty initializer if *any* incomplete array exists structurally
5089 if (ActualSize == 0 && containsIncompleteArrayType(Ty: Entity.getType())) {
5090 QualType InitTy = Entity.getType().getNonReferenceType();
5091 if (InitTy.hasAddressSpace())
5092 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
5093
5094 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
5095 << /*TooManyOrFew=*/(int)(ExpectedSize < ActualSize) << InitTy
5096 << /*ExpectedSize=*/ExpectedSize << /*ActualSize=*/ActualSize;
5097 return false;
5098 }
5099
5100 // We infer size after validating legality.
5101 // For incomplete arrays it is completely arbitrary to choose whether we think
5102 // the user intended fewer or more elements. This implementation assumes that
5103 // the user intended more, and errors that there are too few initializers to
5104 // complete the final element.
5105 if (Entity.getType()->isIncompleteArrayType()) {
5106 assert(ExpectedSize > 0 &&
5107 "The expected size of an incomplete array type must be at least 1.");
5108 ExpectedSize =
5109 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
5110 }
5111
5112 // An initializer list might be attempting to initialize a reference or
5113 // rvalue-reference. When checking the initializer we should look through
5114 // the reference.
5115 QualType InitTy = Entity.getType().getNonReferenceType();
5116 if (InitTy.hasAddressSpace())
5117 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
5118 if (ExpectedSize != ActualSize) {
5119 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
5120 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
5121 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
5122 return false;
5123 }
5124
5125 // generateInitListsImpl will always return an InitListExpr here, because the
5126 // scalar case is handled above.
5127 auto *NewInit = cast<InitListExpr>(Val: ILT.generateInitLists());
5128 Init->resizeInits(Context: Ctx, NumInits: NewInit->getNumInits());
5129 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
5130 Init->updateInit(C: Ctx, Init: I, expr: NewInit->getInit(Init: I));
5131 return true;
5132}
5133
5134static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name,
5135 StringRef Expected,
5136 SourceLocation OpLoc,
5137 SourceLocation CompLoc) {
5138 S.Diag(Loc: OpLoc, DiagID: diag::err_builtin_matrix_invalid_member)
5139 << Name << Expected << SourceRange(CompLoc);
5140 return QualType();
5141}
5142
5143QualType SemaHLSL::checkMatrixComponent(Sema &S, QualType baseType,
5144 ExprValueKind &VK, SourceLocation OpLoc,
5145 const IdentifierInfo *CompName,
5146 SourceLocation CompLoc) {
5147 const auto *MT = baseType->castAs<ConstantMatrixType>();
5148 StringRef AccessorName = CompName->getName();
5149 assert(!AccessorName.empty() && "Matrix Accessor must have a name");
5150
5151 unsigned Rows = MT->getNumRows();
5152 unsigned Cols = MT->getNumColumns();
5153 bool IsZeroBasedAccessor = false;
5154 unsigned ChunkLen = 0;
5155 if (AccessorName.size() < 2)
5156 return ReportMatrixInvalidMember(S, Name: AccessorName,
5157 Expected: "length 4 for zero based: \'_mRC\' or "
5158 "length 3 for one-based: \'_RC\' accessor",
5159 OpLoc, CompLoc);
5160
5161 if (AccessorName[0] == '_') {
5162 if (AccessorName[1] == 'm') {
5163 IsZeroBasedAccessor = true;
5164 ChunkLen = 4; // zero-based: "_mRC"
5165 } else {
5166 ChunkLen = 3; // one-based: "_RC"
5167 }
5168 } else
5169 return ReportMatrixInvalidMember(
5170 S, Name: AccessorName, Expected: "zero based: \'_mRC\' or one-based: \'_RC\' accessor",
5171 OpLoc, CompLoc);
5172
5173 if (AccessorName.size() % ChunkLen != 0) {
5174 const llvm::StringRef Expected = IsZeroBasedAccessor
5175 ? "zero based: '_mRC' accessor"
5176 : "one-based: '_RC' accessor";
5177
5178 return ReportMatrixInvalidMember(S, Name: AccessorName, Expected, OpLoc, CompLoc);
5179 }
5180
5181 auto isDigit = [](char c) { return c >= '0' && c <= '9'; };
5182 auto isZeroBasedIndex = [](unsigned i) { return i <= 3; };
5183 auto isOneBasedIndex = [](unsigned i) { return i >= 1 && i <= 4; };
5184
5185 bool HasRepeated = false;
5186 SmallVector<bool, 16> Seen(Rows * Cols, false);
5187 unsigned NumComponents = 0;
5188 const char *Begin = AccessorName.data();
5189
5190 for (unsigned I = 0, E = AccessorName.size(); I < E; I += ChunkLen) {
5191 const char *Chunk = Begin + I;
5192 char RowChar = 0, ColChar = 0;
5193 if (IsZeroBasedAccessor) {
5194 // Zero-based: "_mRC"
5195 if (Chunk[0] != '_' || Chunk[1] != 'm') {
5196 char Bad = (Chunk[0] != '_') ? Chunk[0] : Chunk[1];
5197 return ReportMatrixInvalidMember(
5198 S, Name: StringRef(&Bad, 1), Expected: "\'_m\' prefix",
5199 OpLoc: OpLoc.getLocWithOffset(Offset: I + (Bad == Chunk[0] ? 1 : 2)), CompLoc);
5200 }
5201 RowChar = Chunk[2];
5202 ColChar = Chunk[3];
5203 } else {
5204 // One-based: "_RC"
5205 if (Chunk[0] != '_')
5206 return ReportMatrixInvalidMember(
5207 S, Name: StringRef(&Chunk[0], 1), Expected: "\'_\' prefix",
5208 OpLoc: OpLoc.getLocWithOffset(Offset: I + 1), CompLoc);
5209 RowChar = Chunk[1];
5210 ColChar = Chunk[2];
5211 }
5212
5213 // Must be digits.
5214 bool IsDigitsError = false;
5215 if (!isDigit(RowChar)) {
5216 unsigned BadPos = IsZeroBasedAccessor ? 2 : 1;
5217 ReportMatrixInvalidMember(S, Name: StringRef(&RowChar, 1), Expected: "row as integer",
5218 OpLoc: OpLoc.getLocWithOffset(Offset: I + BadPos + 1),
5219 CompLoc);
5220 IsDigitsError = true;
5221 }
5222
5223 if (!isDigit(ColChar)) {
5224 unsigned BadPos = IsZeroBasedAccessor ? 3 : 2;
5225 ReportMatrixInvalidMember(S, Name: StringRef(&ColChar, 1), Expected: "column as integer",
5226 OpLoc: OpLoc.getLocWithOffset(Offset: I + BadPos + 1),
5227 CompLoc);
5228 IsDigitsError = true;
5229 }
5230 if (IsDigitsError)
5231 return QualType();
5232
5233 unsigned Row = RowChar - '0';
5234 unsigned Col = ColChar - '0';
5235
5236 bool HasIndexingError = false;
5237 if (IsZeroBasedAccessor) {
5238 // 0-based [0..3]
5239 if (!isZeroBasedIndex(Row)) {
5240 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
5241 << /*row*/ 0 << /*zero-based*/ 0 << SourceRange(CompLoc);
5242 HasIndexingError = true;
5243 }
5244 if (!isZeroBasedIndex(Col)) {
5245 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
5246 << /*col*/ 1 << /*zero-based*/ 0 << SourceRange(CompLoc);
5247 HasIndexingError = true;
5248 }
5249 } else {
5250 // 1-based [1..4]
5251 if (!isOneBasedIndex(Row)) {
5252 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
5253 << /*row*/ 0 << /*one-based*/ 1 << SourceRange(CompLoc);
5254 HasIndexingError = true;
5255 }
5256 if (!isOneBasedIndex(Col)) {
5257 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
5258 << /*col*/ 1 << /*one-based*/ 1 << SourceRange(CompLoc);
5259 HasIndexingError = true;
5260 }
5261 // Convert to 0-based after range checking.
5262 --Row;
5263 --Col;
5264 }
5265
5266 if (HasIndexingError)
5267 return QualType();
5268
5269 // Note: matrix swizzle index is hard coded. That means Row and Col can
5270 // potentially be larger than Rows and Cols if matrix size is less than
5271 // the max index size.
5272 bool HasBoundsError = false;
5273 if (Row >= Rows) {
5274 Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_index_out_of_bounds)
5275 << /*Row*/ 0 << Row << Rows << SourceRange(CompLoc);
5276 HasBoundsError = true;
5277 }
5278 if (Col >= Cols) {
5279 Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_index_out_of_bounds)
5280 << /*Col*/ 1 << Col << Cols << SourceRange(CompLoc);
5281 HasBoundsError = true;
5282 }
5283 if (HasBoundsError)
5284 return QualType();
5285
5286 unsigned FlatIndex = Row * Cols + Col;
5287 if (Seen[FlatIndex])
5288 HasRepeated = true;
5289 Seen[FlatIndex] = true;
5290 ++NumComponents;
5291 }
5292 if (NumComponents == 0 || NumComponents > 4) {
5293 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_swizzle_invalid_length)
5294 << NumComponents << SourceRange(CompLoc);
5295 return QualType();
5296 }
5297
5298 QualType ElemTy = MT->getElementType();
5299 if (NumComponents == 1)
5300 return ElemTy;
5301 QualType VT = S.Context.getExtVectorType(VectorType: ElemTy, NumElts: NumComponents);
5302 if (HasRepeated)
5303 VK = VK_PRValue;
5304
5305 for (Sema::ExtVectorDeclsType::iterator
5306 I = S.ExtVectorDecls.begin(source: S.getExternalSource()),
5307 E = S.ExtVectorDecls.end();
5308 I != E; ++I) {
5309 if ((*I)->getUnderlyingType() == VT)
5310 return S.Context.getTypedefType(Keyword: ElaboratedTypeKeyword::None,
5311 /*Qualifier=*/std::nullopt, Decl: *I);
5312 }
5313
5314 return VT;
5315}
5316
5317bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
5318 const HLSLVkConstantIdAttr *ConstIdAttr =
5319 VDecl->getAttr<HLSLVkConstantIdAttr>();
5320 if (!ConstIdAttr)
5321 return true;
5322
5323 ASTContext &Context = SemaRef.getASTContext();
5324
5325 APValue InitValue;
5326 if (!Init->isCXX11ConstantExpr(Ctx: Context, Result: &InitValue)) {
5327 Diag(Loc: VDecl->getLocation(), DiagID: diag::err_specialization_const);
5328 VDecl->setInvalidDecl();
5329 return false;
5330 }
5331
5332 Builtin::ID BID =
5333 getSpecConstBuiltinId(Type: VDecl->getType()->getUnqualifiedDesugaredType());
5334
5335 // Argument 1: The ID from the attribute
5336 int ConstantID = ConstIdAttr->getId();
5337 llvm::APInt IDVal(Context.getIntWidth(T: Context.IntTy), ConstantID);
5338 Expr *IdExpr = IntegerLiteral::Create(C: Context, V: IDVal, type: Context.IntTy,
5339 l: ConstIdAttr->getLocation());
5340
5341 SmallVector<Expr *, 2> Args = {IdExpr, Init};
5342 Expr *C = SemaRef.BuildBuiltinCallExpr(Loc: Init->getExprLoc(), Id: BID, CallArgs: Args);
5343 if (C->getType()->getCanonicalTypeUnqualified() !=
5344 VDecl->getType()->getCanonicalTypeUnqualified()) {
5345 C = SemaRef
5346 .BuildCStyleCastExpr(LParenLoc: SourceLocation(),
5347 Ty: Context.getTrivialTypeSourceInfo(
5348 T: Init->getType(), Loc: Init->getExprLoc()),
5349 RParenLoc: SourceLocation(), Op: C)
5350 .get();
5351 }
5352 Init = C;
5353 return true;
5354}
5355