1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
12#include "clang/AST/ASTConsumer.h"
13#include "clang/AST/ASTContext.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Attrs.inc"
16#include "clang/AST/Decl.h"
17#include "clang/AST/DeclBase.h"
18#include "clang/AST/DeclCXX.h"
19#include "clang/AST/DeclarationName.h"
20#include "clang/AST/DynamicRecursiveASTVisitor.h"
21#include "clang/AST/Expr.h"
22#include "clang/AST/HLSLResource.h"
23#include "clang/AST/Type.h"
24#include "clang/AST/TypeBase.h"
25#include "clang/AST/TypeLoc.h"
26#include "clang/Basic/Builtins.h"
27#include "clang/Basic/DiagnosticSema.h"
28#include "clang/Basic/IdentifierTable.h"
29#include "clang/Basic/LLVM.h"
30#include "clang/Basic/SourceLocation.h"
31#include "clang/Basic/Specifiers.h"
32#include "clang/Basic/TargetInfo.h"
33#include "clang/Sema/Initialization.h"
34#include "clang/Sema/Lookup.h"
35#include "clang/Sema/ParsedAttr.h"
36#include "clang/Sema/Sema.h"
37#include "clang/Sema/Template.h"
38#include "llvm/ADT/ArrayRef.h"
39#include "llvm/ADT/STLExtras.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/StringExtras.h"
42#include "llvm/ADT/StringRef.h"
43#include "llvm/ADT/Twine.h"
44#include "llvm/Frontend/HLSL/HLSLBinding.h"
45#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
46#include "llvm/Support/Casting.h"
47#include "llvm/Support/DXILABI.h"
48#include "llvm/Support/ErrorHandling.h"
49#include "llvm/Support/FormatVariadic.h"
50#include "llvm/TargetParser/Triple.h"
51#include <cmath>
52#include <cstddef>
53#include <iterator>
54#include <utility>
55
56using namespace clang;
57using namespace clang::hlsl;
58using RegisterType = HLSLResourceBindingAttr::RegisterType;
59
60static CXXRecordDecl *createHostLayoutStruct(Sema &S,
61 CXXRecordDecl *StructDecl);
62
63static RegisterType getRegisterType(ResourceClass RC) {
64 switch (RC) {
65 case ResourceClass::SRV:
66 return RegisterType::SRV;
67 case ResourceClass::UAV:
68 return RegisterType::UAV;
69 case ResourceClass::CBuffer:
70 return RegisterType::CBuffer;
71 case ResourceClass::Sampler:
72 return RegisterType::Sampler;
73 }
74 llvm_unreachable("unexpected ResourceClass value");
75}
76
77static RegisterType getRegisterType(const HLSLAttributedResourceType *ResTy) {
78 return getRegisterType(RC: ResTy->getAttrs().ResourceClass);
79}
80
81// Converts the first letter of string Slot to RegisterType.
82// Returns false if the letter does not correspond to a valid register type.
83static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
84 assert(RT != nullptr);
85 switch (Slot[0]) {
86 case 't':
87 case 'T':
88 *RT = RegisterType::SRV;
89 return true;
90 case 'u':
91 case 'U':
92 *RT = RegisterType::UAV;
93 return true;
94 case 'b':
95 case 'B':
96 *RT = RegisterType::CBuffer;
97 return true;
98 case 's':
99 case 'S':
100 *RT = RegisterType::Sampler;
101 return true;
102 case 'c':
103 case 'C':
104 *RT = RegisterType::C;
105 return true;
106 case 'i':
107 case 'I':
108 *RT = RegisterType::I;
109 return true;
110 default:
111 return false;
112 }
113}
114
115static ResourceClass getResourceClass(RegisterType RT) {
116 switch (RT) {
117 case RegisterType::SRV:
118 return ResourceClass::SRV;
119 case RegisterType::UAV:
120 return ResourceClass::UAV;
121 case RegisterType::CBuffer:
122 return ResourceClass::CBuffer;
123 case RegisterType::Sampler:
124 return ResourceClass::Sampler;
125 case RegisterType::C:
126 case RegisterType::I:
127 // Deliberately falling through to the unreachable below.
128 break;
129 }
130 llvm_unreachable("unexpected RegisterType value");
131}
132
133static Builtin::ID getSpecConstBuiltinId(const Type *Type) {
134 const auto *BT = dyn_cast<BuiltinType>(Val: Type);
135 if (!BT) {
136 if (!Type->isEnumeralType())
137 return Builtin::NotBuiltin;
138 return Builtin::BI__builtin_get_spirv_spec_constant_int;
139 }
140
141 switch (BT->getKind()) {
142 case BuiltinType::Bool:
143 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
144 case BuiltinType::Short:
145 return Builtin::BI__builtin_get_spirv_spec_constant_short;
146 case BuiltinType::Int:
147 return Builtin::BI__builtin_get_spirv_spec_constant_int;
148 case BuiltinType::LongLong:
149 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
150 case BuiltinType::UShort:
151 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
152 case BuiltinType::UInt:
153 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
154 case BuiltinType::ULongLong:
155 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
156 case BuiltinType::Half:
157 return Builtin::BI__builtin_get_spirv_spec_constant_half;
158 case BuiltinType::Float:
159 return Builtin::BI__builtin_get_spirv_spec_constant_float;
160 case BuiltinType::Double:
161 return Builtin::BI__builtin_get_spirv_spec_constant_double;
162 default:
163 return Builtin::NotBuiltin;
164 }
165}
166
167DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
168 ResourceClass ResClass) {
169 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
170 "DeclBindingInfo already added");
171 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
172 // VarDecl may have multiple entries for different resource classes.
173 // DeclToBindingListIndex stores the index of the first binding we saw
174 // for this decl. If there are any additional ones then that index
175 // shouldn't be updated.
176 DeclToBindingListIndex.try_emplace(Key: VD, Args: BindingsList.size());
177 return &BindingsList.emplace_back(Args&: VD, Args&: ResClass);
178}
179
180DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
181 ResourceClass ResClass) {
182 auto Entry = DeclToBindingListIndex.find(Val: VD);
183 if (Entry != DeclToBindingListIndex.end()) {
184 for (unsigned Index = Entry->getSecond();
185 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
186 ++Index) {
187 if (BindingsList[Index].ResClass == ResClass)
188 return &BindingsList[Index];
189 }
190 }
191 return nullptr;
192}
193
194bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
195 return DeclToBindingListIndex.contains(Val: VD);
196}
197
198SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
199
200Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
201 SourceLocation KwLoc, IdentifierInfo *Ident,
202 SourceLocation IdentLoc,
203 SourceLocation LBrace) {
204 // For anonymous namespace, take the location of the left brace.
205 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
206 HLSLBufferDecl *Result = HLSLBufferDecl::Create(
207 C&: getASTContext(), LexicalParent, CBuffer, KwLoc, ID: Ident, IDLoc: IdentLoc, LBrace);
208
209 // if CBuffer is false, then it's a TBuffer
210 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
211 : llvm::hlsl::ResourceClass::SRV;
212 Result->addAttr(A: HLSLResourceClassAttr::CreateImplicit(Ctx&: getASTContext(), ResourceClass: RC));
213
214 SemaRef.PushOnScopeChains(D: Result, S: BufferScope);
215 SemaRef.PushDeclContext(S: BufferScope, DC: Result);
216
217 return Result;
218}
219
220static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
221 QualType T) {
222 // Arrays and Structs are always aligned to new buffer rows
223 if (T->isArrayType() || T->isStructureType())
224 return 16;
225
226 // Vectors are aligned to the type they contain
227 if (const VectorType *VT = T->getAs<VectorType>())
228 return calculateLegacyCbufferFieldAlign(Context, T: VT->getElementType());
229
230 assert(Context.getTypeSize(T) <= 64 &&
231 "Scalar bit widths larger than 64 not supported");
232
233 // Scalar types are aligned to their byte width
234 return Context.getTypeSize(T) / 8;
235}
236
237// Calculate the size of a legacy cbuffer type in bytes based on
238// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
239static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
240 QualType T) {
241 constexpr unsigned CBufferAlign = 16;
242 if (const auto *RD = T->getAsRecordDecl()) {
243 unsigned Size = 0;
244 for (const FieldDecl *Field : RD->fields()) {
245 QualType Ty = Field->getType();
246 unsigned FieldSize = calculateLegacyCbufferSize(Context, T: Ty);
247 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, T: Ty);
248
249 // If the field crosses the row boundary after alignment it drops to the
250 // next row
251 unsigned AlignSize = llvm::alignTo(Value: Size, Align: FieldAlign);
252 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
253 FieldAlign = CBufferAlign;
254 }
255
256 Size = llvm::alignTo(Value: Size, Align: FieldAlign);
257 Size += FieldSize;
258 }
259 return Size;
260 }
261
262 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
263 unsigned ElementCount = AT->getSize().getZExtValue();
264 if (ElementCount == 0)
265 return 0;
266
267 unsigned ElementSize =
268 calculateLegacyCbufferSize(Context, T: AT->getElementType());
269 unsigned AlignedElementSize = llvm::alignTo(Value: ElementSize, Align: CBufferAlign);
270 return AlignedElementSize * (ElementCount - 1) + ElementSize;
271 }
272
273 if (const VectorType *VT = T->getAs<VectorType>()) {
274 unsigned ElementCount = VT->getNumElements();
275 unsigned ElementSize =
276 calculateLegacyCbufferSize(Context, T: VT->getElementType());
277 return ElementSize * ElementCount;
278 }
279
280 return Context.getTypeSize(T) / 8;
281}
282
283// Validate packoffset:
284// - if packoffset it used it must be set on all declarations inside the buffer
285// - packoffset ranges must not overlap
286static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
287 llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
288
289 // Make sure the packoffset annotations are either on all declarations
290 // or on none.
291 bool HasPackOffset = false;
292 bool HasNonPackOffset = false;
293 for (auto *Field : BufDecl->buffer_decls()) {
294 VarDecl *Var = dyn_cast<VarDecl>(Val: Field);
295 if (!Var)
296 continue;
297 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
298 PackOffsetVec.emplace_back(Args&: Var, Args: Field->getAttr<HLSLPackOffsetAttr>());
299 HasPackOffset = true;
300 } else {
301 HasNonPackOffset = true;
302 }
303 }
304
305 if (!HasPackOffset)
306 return;
307
308 if (HasNonPackOffset)
309 S.Diag(Loc: BufDecl->getLocation(), DiagID: diag::warn_hlsl_packoffset_mix);
310
311 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
312 // and compare adjacent values.
313 bool IsValid = true;
314 ASTContext &Context = S.getASTContext();
315 std::sort(first: PackOffsetVec.begin(), last: PackOffsetVec.end(),
316 comp: [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
317 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
318 return LHS.second->getOffsetInBytes() <
319 RHS.second->getOffsetInBytes();
320 });
321 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
322 VarDecl *Var = PackOffsetVec[i].first;
323 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
324 unsigned Size = calculateLegacyCbufferSize(Context, T: Var->getType());
325 unsigned Begin = Attr->getOffsetInBytes();
326 unsigned End = Begin + Size;
327 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
328 if (End > NextBegin) {
329 VarDecl *NextVar = PackOffsetVec[i + 1].first;
330 S.Diag(Loc: NextVar->getLocation(), DiagID: diag::err_hlsl_packoffset_overlap)
331 << NextVar << Var;
332 IsValid = false;
333 }
334 }
335 BufDecl->setHasValidPackoffset(IsValid);
336}
337
338// Returns true if the array has a zero size = if any of the dimensions is 0
339static bool isZeroSizedArray(const ConstantArrayType *CAT) {
340 while (CAT && !CAT->isZeroSize())
341 CAT = dyn_cast<ConstantArrayType>(
342 Val: CAT->getElementType()->getUnqualifiedDesugaredType());
343 return CAT != nullptr;
344}
345
346static bool isResourceRecordTypeOrArrayOf(VarDecl *VD) {
347 const Type *Ty = VD->getType().getTypePtr();
348 return Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray();
349}
350
351static const HLSLAttributedResourceType *
352getResourceArrayHandleType(VarDecl *VD) {
353 assert(VD->getType()->isHLSLResourceRecordArray() &&
354 "expected array of resource records");
355 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
356 while (const ArrayType *AT = dyn_cast<ArrayType>(Val: Ty))
357 Ty = AT->getArrayElementTypeNoTypeQual()->getUnqualifiedDesugaredType();
358 return HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty);
359}
360
361// Returns true if the type is a leaf element type that is not valid to be
362// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
363// array, or a builtin intangible type. Returns false it is a valid leaf element
364// type or if it is a record type that needs to be inspected further.
365static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
366 Ty = Ty->getUnqualifiedDesugaredType();
367 if (Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray())
368 return true;
369 if (const auto *RD = Ty->getAsCXXRecordDecl())
370 return RD->isEmpty();
371 if (Ty->isConstantArrayType() &&
372 isZeroSizedArray(CAT: cast<ConstantArrayType>(Val: Ty)))
373 return true;
374 if (Ty->isHLSLBuiltinIntangibleType() || Ty->isHLSLAttributedResourceType())
375 return true;
376 return false;
377}
378
379// Returns true if the struct contains at least one element that prevents it
380// from being included inside HLSL Buffer as is, such as an intangible type,
381// empty struct, or zero-sized array. If it does, a new implicit layout struct
382// needs to be created for HLSL Buffer use that will exclude these unwanted
383// declarations (see createHostLayoutStruct function).
384static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
385 if (RD->isHLSLIntangible() || RD->isEmpty())
386 return true;
387 // check fields
388 for (const FieldDecl *Field : RD->fields()) {
389 QualType Ty = Field->getType();
390 if (isInvalidConstantBufferLeafElementType(Ty: Ty.getTypePtr()))
391 return true;
392 if (const auto *RD = Ty->getAsCXXRecordDecl();
393 RD && requiresImplicitBufferLayoutStructure(RD))
394 return true;
395 }
396 // check bases
397 for (const CXXBaseSpecifier &Base : RD->bases())
398 if (requiresImplicitBufferLayoutStructure(
399 RD: Base.getType()->castAsCXXRecordDecl()))
400 return true;
401 return false;
402}
403
404static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II,
405 DeclContext *DC) {
406 CXXRecordDecl *RD = nullptr;
407 for (NamedDecl *Decl :
408 DC->getNonTransparentContext()->lookup(Name: DeclarationName(II))) {
409 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Val: Decl)) {
410 assert(RD == nullptr &&
411 "there should be at most 1 record by a given name in a scope");
412 RD = FoundRD;
413 }
414 }
415 return RD;
416}
417
418// Creates a name for buffer layout struct using the provide name base.
419// If the name must be unique (not previously defined), a suffix is added
420// until a unique name is found.
421static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl,
422 bool MustBeUnique) {
423 ASTContext &AST = S.getASTContext();
424
425 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
426 llvm::SmallString<64> Name("__cblayout_");
427 if (NameBaseII) {
428 Name.append(RHS: NameBaseII->getName());
429 } else {
430 // anonymous struct
431 Name.append(RHS: "anon");
432 MustBeUnique = true;
433 }
434
435 size_t NameLength = Name.size();
436 IdentifierInfo *II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
437 if (!MustBeUnique)
438 return II;
439
440 unsigned suffix = 0;
441 while (true) {
442 if (suffix != 0) {
443 Name.append(RHS: "_");
444 Name.append(RHS: llvm::Twine(suffix).str());
445 II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
446 }
447 if (!findRecordDeclInContext(II, DC: BaseDecl->getDeclContext()))
448 return II;
449 // declaration with that name already exists - increment suffix and try
450 // again until unique name is found
451 suffix++;
452 Name.truncate(N: NameLength);
453 };
454}
455
456// Creates a field declaration of given name and type for HLSL buffer layout
457// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
458static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
459 IdentifierInfo *II,
460 CXXRecordDecl *LayoutStruct) {
461 if (isInvalidConstantBufferLeafElementType(Ty))
462 return nullptr;
463
464 if (auto *RD = Ty->getAsCXXRecordDecl()) {
465 if (requiresImplicitBufferLayoutStructure(RD)) {
466 RD = createHostLayoutStruct(S, StructDecl: RD);
467 if (!RD)
468 return nullptr;
469 Ty = S.Context.getCanonicalTagType(TD: RD)->getTypePtr();
470 }
471 }
472
473 QualType QT = QualType(Ty, 0);
474 ASTContext &AST = S.getASTContext();
475 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(T: QT, Loc: SourceLocation());
476 auto *Field = FieldDecl::Create(C: AST, DC: LayoutStruct, StartLoc: SourceLocation(),
477 IdLoc: SourceLocation(), Id: II, T: QT, TInfo: TSI, BW: nullptr, Mutable: false,
478 InitStyle: InClassInitStyle::ICIS_NoInit);
479 Field->setAccess(AccessSpecifier::AS_public);
480 return Field;
481}
482
483// Creates host layout struct for a struct included in HLSL Buffer.
484// The layout struct will include only fields that are allowed in HLSL buffer.
485// These fields will be filtered out:
486// - resource classes
487// - empty structs
488// - zero-sized arrays
489// Returns nullptr if the resulting layout struct would be empty.
490static CXXRecordDecl *createHostLayoutStruct(Sema &S,
491 CXXRecordDecl *StructDecl) {
492 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
493 "struct is already HLSL buffer compatible");
494
495 ASTContext &AST = S.getASTContext();
496 DeclContext *DC = StructDecl->getDeclContext();
497 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: StructDecl, MustBeUnique: false);
498
499 // reuse existing if the layout struct if it already exists
500 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
501 return RD;
502
503 CXXRecordDecl *LS =
504 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC, StartLoc: SourceLocation(),
505 IdLoc: SourceLocation(), Id: II);
506 LS->setImplicit(true);
507 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
508 LS->startDefinition();
509
510 // copy base struct, create HLSL Buffer compatible version if needed
511 if (unsigned NumBases = StructDecl->getNumBases()) {
512 assert(NumBases == 1 && "HLSL supports only one base type");
513 (void)NumBases;
514 CXXBaseSpecifier Base = *StructDecl->bases_begin();
515 CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
516 if (requiresImplicitBufferLayoutStructure(RD: BaseDecl)) {
517 BaseDecl = createHostLayoutStruct(S, StructDecl: BaseDecl);
518 if (BaseDecl) {
519 TypeSourceInfo *TSI =
520 AST.getTrivialTypeSourceInfo(T: AST.getCanonicalTagType(TD: BaseDecl));
521 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
522 AS_none, TSI, SourceLocation());
523 }
524 }
525 if (BaseDecl) {
526 const CXXBaseSpecifier *BasesArray[1] = {&Base};
527 LS->setBases(Bases: BasesArray, NumBases: 1);
528 }
529 }
530
531 // filter struct fields
532 for (const FieldDecl *FD : StructDecl->fields()) {
533 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
534 if (FieldDecl *NewFD =
535 createFieldForHostLayoutStruct(S, Ty, II: FD->getIdentifier(), LayoutStruct: LS))
536 LS->addDecl(D: NewFD);
537 }
538 LS->completeDefinition();
539
540 if (LS->field_empty() && LS->getNumBases() == 0)
541 return nullptr;
542
543 DC->addDecl(D: LS);
544 return LS;
545}
546
547// Creates host layout struct for HLSL Buffer. The struct will include only
548// fields of types that are allowed in HLSL buffer and it will filter out:
549// - static or groupshared variable declarations
550// - resource classes
551// - empty structs
552// - zero-sized arrays
553// - non-variable declarations
554// The layout struct will be added to the HLSLBufferDecl declarations.
555void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
556 ASTContext &AST = S.getASTContext();
557 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: BufDecl, MustBeUnique: true);
558
559 CXXRecordDecl *LS =
560 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC: BufDecl,
561 StartLoc: SourceLocation(), IdLoc: SourceLocation(), Id: II);
562 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
563 LS->setImplicit(true);
564 LS->startDefinition();
565
566 for (Decl *D : BufDecl->buffer_decls()) {
567 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
568 if (!VD || VD->getStorageClass() == SC_Static ||
569 VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
570 continue;
571 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
572 if (FieldDecl *FD =
573 createFieldForHostLayoutStruct(S, Ty, II: VD->getIdentifier(), LayoutStruct: LS)) {
574 // add the field decl to the layout struct
575 LS->addDecl(D: FD);
576 // update address space of the original decl to hlsl_constant
577 QualType NewTy =
578 AST.getAddrSpaceQualType(T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
579 VD->setType(NewTy);
580 }
581 }
582 LS->completeDefinition();
583 BufDecl->addLayoutStruct(LS);
584}
585
586static void addImplicitBindingAttrToDecl(Sema &S, Decl *D, RegisterType RT,
587 uint32_t ImplicitBindingOrderID) {
588 auto *Attr =
589 HLSLResourceBindingAttr::CreateImplicit(Ctx&: S.getASTContext(), Slot: "", Space: "0", Range: {});
590 Attr->setBinding(RT, SlotNum: std::nullopt, SpaceNum: 0);
591 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
592 D->addAttr(A: Attr);
593}
594
595// Handle end of cbuffer/tbuffer declaration
596void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
597 auto *BufDecl = cast<HLSLBufferDecl>(Val: Dcl);
598 BufDecl->setRBraceLoc(RBrace);
599
600 validatePackoffset(S&: SemaRef, BufDecl);
601
602 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl);
603
604 // Handle implicit binding if needed.
605 ResourceBindingAttrs ResourceAttrs(Dcl);
606 if (!ResourceAttrs.isExplicit()) {
607 SemaRef.Diag(Loc: Dcl->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
608 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
609 // to codegen. If it does not exist, create an implicit attribute.
610 uint32_t OrderID = getNextImplicitBindingOrderID();
611 if (ResourceAttrs.hasBinding())
612 ResourceAttrs.setImplicitOrderID(OrderID);
613 else
614 addImplicitBindingAttrToDecl(S&: SemaRef, D: BufDecl,
615 RT: BufDecl->isCBuffer() ? RegisterType::CBuffer
616 : RegisterType::SRV,
617 ImplicitBindingOrderID: OrderID);
618 }
619
620 SemaRef.PopDeclContext();
621}
622
623HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
624 const AttributeCommonInfo &AL,
625 int X, int Y, int Z) {
626 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
627 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
628 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
629 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
630 }
631 return nullptr;
632 }
633 return ::new (getASTContext())
634 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
635}
636
637HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
638 const AttributeCommonInfo &AL,
639 int Min, int Max, int Preferred,
640 int SpelledArgsCount) {
641 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
642 if (WS->getMin() != Min || WS->getMax() != Max ||
643 WS->getPreferred() != Preferred ||
644 WS->getSpelledArgsCount() != SpelledArgsCount) {
645 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
646 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
647 }
648 return nullptr;
649 }
650 HLSLWaveSizeAttr *Result = ::new (getASTContext())
651 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
652 Result->setSpelledArgsCount(SpelledArgsCount);
653 return Result;
654}
655
656HLSLVkConstantIdAttr *
657SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
658 int Id) {
659
660 auto &TargetInfo = getASTContext().getTargetInfo();
661 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
662 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_ignored) << AL;
663 return nullptr;
664 }
665
666 auto *VD = cast<VarDecl>(Val: D);
667
668 if (getSpecConstBuiltinId(Type: VD->getType()->getUnqualifiedDesugaredType()) ==
669 Builtin::NotBuiltin) {
670 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
671 return nullptr;
672 }
673
674 if (!VD->getType().isConstQualified()) {
675 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
676 return nullptr;
677 }
678
679 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
680 if (CI->getId() != Id) {
681 Diag(Loc: CI->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
682 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
683 }
684 return nullptr;
685 }
686
687 HLSLVkConstantIdAttr *Result =
688 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
689 return Result;
690}
691
692HLSLShaderAttr *
693SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
694 llvm::Triple::EnvironmentType ShaderType) {
695 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
696 if (NT->getType() != ShaderType) {
697 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
698 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
699 }
700 return nullptr;
701 }
702 return HLSLShaderAttr::Create(Ctx&: getASTContext(), Type: ShaderType, CommonInfo: AL);
703}
704
705HLSLParamModifierAttr *
706SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
707 HLSLParamModifierAttr::Spelling Spelling) {
708 // We can only merge an `in` attribute with an `out` attribute. All other
709 // combinations of duplicated attributes are ill-formed.
710 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
711 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
712 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
713 D->dropAttr<HLSLParamModifierAttr>();
714 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
715 return HLSLParamModifierAttr::Create(
716 Ctx&: getASTContext(), /*MergedSpelling=*/true, Range: AdjustedRange,
717 S: HLSLParamModifierAttr::Keyword_inout);
718 }
719 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_duplicate_parameter_modifier) << AL;
720 Diag(Loc: PA->getLocation(), DiagID: diag::note_conflicting_attribute);
721 return nullptr;
722 }
723 return HLSLParamModifierAttr::Create(Ctx&: getASTContext(), CommonInfo: AL);
724}
725
726void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
727 auto &TargetInfo = getASTContext().getTargetInfo();
728
729 if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
730 return;
731
732 // If we have specified a root signature to override the entry function then
733 // attach it now
734 HLSLRootSignatureDecl *SignatureDecl =
735 lookupRootSignatureOverrideDecl(DC: FD->getDeclContext());
736 if (SignatureDecl) {
737 FD->dropAttr<RootSignatureAttr>();
738 // We could look up the SourceRange of the macro here as well
739 AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(),
740 SourceRange(), ParsedAttr::Form::Microsoft());
741 FD->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
742 getASTContext(), AL, RootSigOverrideIdent, SignatureDecl));
743 }
744
745 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
746 if (HLSLShaderAttr::isValidShaderType(ShaderType: Env) && Env != llvm::Triple::Library) {
747 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
748 // The entry point is already annotated - check that it matches the
749 // triple.
750 if (Shader->getType() != Env) {
751 Diag(Loc: Shader->getLocation(), DiagID: diag::err_hlsl_entry_shader_attr_mismatch)
752 << Shader;
753 FD->setInvalidDecl();
754 }
755 } else {
756 // Implicitly add the shader attribute if the entry function isn't
757 // explicitly annotated.
758 FD->addAttr(A: HLSLShaderAttr::CreateImplicit(Ctx&: getASTContext(), Type: Env,
759 Range: FD->getBeginLoc()));
760 }
761 } else {
762 switch (Env) {
763 case llvm::Triple::UnknownEnvironment:
764 case llvm::Triple::Library:
765 break;
766 case llvm::Triple::RootSignature:
767 llvm_unreachable("rootsig environment has no functions");
768 default:
769 llvm_unreachable("Unhandled environment in triple");
770 }
771 }
772}
773
774static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD,
775 HLSLAppliedSemanticAttr *Semantic,
776 bool IsInput) {
777 if (AstContext.getTargetInfo().getTriple().getOS() != llvm::Triple::Vulkan)
778 return false;
779
780 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
781 assert(ShaderAttr && "Entry point has no shader attribute");
782 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
783 auto SemanticName = Semantic->getSemanticName().upper();
784
785 // The SV_Position semantic is lowered to:
786 // - Position built-in for vertex output.
787 // - FragCoord built-in for fragment input.
788 if (SemanticName == "SV_POSITION") {
789 return (ST == llvm::Triple::Vertex && !IsInput) ||
790 (ST == llvm::Triple::Pixel && IsInput);
791 }
792
793 return false;
794}
795
796bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
797 DeclaratorDecl *OutputDecl,
798 DeclaratorDecl *D,
799 SemanticInfo &ActiveSemantic,
800 SemaHLSL::SemanticContext &SC) {
801 if (ActiveSemantic.Semantic == nullptr) {
802 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
803 if (ActiveSemantic.Semantic)
804 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
805 }
806
807 if (!ActiveSemantic.Semantic) {
808 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_missing_semantic_annotation);
809 return false;
810 }
811
812 auto *A = ::new (getASTContext())
813 HLSLAppliedSemanticAttr(getASTContext(), *ActiveSemantic.Semantic,
814 ActiveSemantic.Semantic->getAttrName()->getName(),
815 ActiveSemantic.Index.value_or(u: 0));
816 if (!A)
817 return false;
818
819 checkSemanticAnnotation(EntryPoint: FD, Param: D, SemanticAttr: A, SC);
820 OutputDecl->addAttr(A);
821
822 unsigned Location = ActiveSemantic.Index.value_or(u: 0);
823
824 if (!isVkPipelineBuiltin(AstContext: getASTContext(), FD, Semantic: A,
825 IsInput: SC.CurrentIOType & IOType::In)) {
826 bool HasVkLocation = false;
827 if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
828 HasVkLocation = true;
829 Location = A->getLocation();
830 }
831
832 if (SC.UsesExplicitVkLocations.value_or(u&: HasVkLocation) != HasVkLocation) {
833 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_semantic_partial_explicit_indexing);
834 return false;
835 }
836 SC.UsesExplicitVkLocations = HasVkLocation;
837 }
838
839 const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(Val: D->getType());
840 unsigned ElementCount = AT ? AT->getZExtSize() : 1;
841 ActiveSemantic.Index = Location + ElementCount;
842
843 Twine BaseName = Twine(ActiveSemantic.Semantic->getAttrName()->getName());
844 for (unsigned I = 0; I < ElementCount; ++I) {
845 Twine VariableName = BaseName.concat(Suffix: Twine(Location + I));
846
847 auto [_, Inserted] = SC.ActiveSemantics.insert(key: VariableName.str());
848 if (!Inserted) {
849 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_semantic_index_overlap)
850 << VariableName.str();
851 return false;
852 }
853 }
854
855 return true;
856}
857
858bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
859 DeclaratorDecl *OutputDecl,
860 DeclaratorDecl *D,
861 SemanticInfo &ActiveSemantic,
862 SemaHLSL::SemanticContext &SC) {
863 if (ActiveSemantic.Semantic == nullptr) {
864 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
865 if (ActiveSemantic.Semantic)
866 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
867 }
868
869 const Type *T = D == FD ? &*FD->getReturnType() : &*D->getType();
870 T = T->getUnqualifiedDesugaredType();
871
872 const RecordType *RT = dyn_cast<RecordType>(Val: T);
873 if (!RT)
874 return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
875 SC);
876
877 const RecordDecl *RD = RT->getDecl();
878 for (FieldDecl *Field : RD->fields()) {
879 SemanticInfo Info = ActiveSemantic;
880 if (!determineActiveSemantic(FD, OutputDecl, D: Field, ActiveSemantic&: Info, SC)) {
881 Diag(Loc: Field->getLocation(), DiagID: diag::note_hlsl_semantic_used_here) << Field;
882 return false;
883 }
884 if (ActiveSemantic.Semantic)
885 ActiveSemantic = Info;
886 }
887
888 return true;
889}
890
891void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
892 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
893 assert(ShaderAttr && "Entry point has no shader attribute");
894 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
895 auto &TargetInfo = getASTContext().getTargetInfo();
896 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
897 switch (ST) {
898 case llvm::Triple::Pixel:
899 case llvm::Triple::Vertex:
900 case llvm::Triple::Geometry:
901 case llvm::Triple::Hull:
902 case llvm::Triple::Domain:
903 case llvm::Triple::RayGeneration:
904 case llvm::Triple::Intersection:
905 case llvm::Triple::AnyHit:
906 case llvm::Triple::ClosestHit:
907 case llvm::Triple::Miss:
908 case llvm::Triple::Callable:
909 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
910 diagnoseAttrStageMismatch(A: NT, Stage: ST,
911 AllowedStages: {llvm::Triple::Compute,
912 llvm::Triple::Amplification,
913 llvm::Triple::Mesh});
914 FD->setInvalidDecl();
915 }
916 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
917 diagnoseAttrStageMismatch(A: WS, Stage: ST,
918 AllowedStages: {llvm::Triple::Compute,
919 llvm::Triple::Amplification,
920 llvm::Triple::Mesh});
921 FD->setInvalidDecl();
922 }
923 break;
924
925 case llvm::Triple::Compute:
926 case llvm::Triple::Amplification:
927 case llvm::Triple::Mesh:
928 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
929 Diag(Loc: FD->getLocation(), DiagID: diag::err_hlsl_missing_numthreads)
930 << llvm::Triple::getEnvironmentTypeName(Kind: ST);
931 FD->setInvalidDecl();
932 }
933 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
934 if (Ver < VersionTuple(6, 6)) {
935 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_in_wrong_shader_model)
936 << WS << "6.6";
937 FD->setInvalidDecl();
938 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
939 Diag(
940 Loc: WS->getLocation(),
941 DiagID: diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
942 << WS << WS->getSpelledArgsCount() << "6.8";
943 FD->setInvalidDecl();
944 }
945 }
946 break;
947 case llvm::Triple::RootSignature:
948 llvm_unreachable("rootsig environment has no function entry point");
949 default:
950 llvm_unreachable("Unhandled environment in triple");
951 }
952
953 SemaHLSL::SemanticContext InputSC = {};
954 InputSC.CurrentIOType = IOType::In;
955
956 for (ParmVarDecl *Param : FD->parameters()) {
957 SemanticInfo ActiveSemantic;
958 ActiveSemantic.Semantic = Param->getAttr<HLSLParsedSemanticAttr>();
959 if (ActiveSemantic.Semantic)
960 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
961
962 // FIXME: Verify output semantics in parameters.
963 if (!determineActiveSemantic(FD, OutputDecl: Param, D: Param, ActiveSemantic, SC&: InputSC)) {
964 Diag(Loc: Param->getLocation(), DiagID: diag::note_previous_decl) << Param;
965 FD->setInvalidDecl();
966 }
967 }
968
969 SemanticInfo ActiveSemantic;
970 SemaHLSL::SemanticContext OutputSC = {};
971 OutputSC.CurrentIOType = IOType::Out;
972 ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
973 if (ActiveSemantic.Semantic)
974 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
975 if (!FD->getReturnType()->isVoidType())
976 determineActiveSemantic(FD, OutputDecl: FD, D: FD, ActiveSemantic, SC&: OutputSC);
977}
978
979void SemaHLSL::checkSemanticAnnotation(
980 FunctionDecl *EntryPoint, const Decl *Param,
981 const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC) {
982 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
983 assert(ShaderAttr && "Entry point has no shader attribute");
984 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
985
986 auto SemanticName = SemanticAttr->getSemanticName().upper();
987 if (SemanticName == "SV_DISPATCHTHREADID" ||
988 SemanticName == "SV_GROUPINDEX" || SemanticName == "SV_GROUPTHREADID" ||
989 SemanticName == "SV_GROUPID") {
990
991 if (ST != llvm::Triple::Compute)
992 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
993 AllowedStages: {{.Stage: llvm::Triple::Compute, .AllowedIOTypesMask: IOType::In}});
994
995 if (SemanticAttr->getSemanticIndex() != 0) {
996 std::string PrettyName =
997 "'" + SemanticAttr->getSemanticName().str() + "'";
998 Diag(Loc: SemanticAttr->getLoc(),
999 DiagID: diag::err_hlsl_semantic_indexing_not_supported)
1000 << PrettyName;
1001 }
1002 return;
1003 }
1004
1005 if (SemanticName == "SV_POSITION") {
1006 // SV_Position can be an input or output in vertex shaders,
1007 // but only an input in pixel shaders.
1008 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1009 AllowedStages: {{.Stage: llvm::Triple::Vertex, .AllowedIOTypesMask: IOType::InOut},
1010 {.Stage: llvm::Triple::Pixel, .AllowedIOTypesMask: IOType::In}});
1011 return;
1012 }
1013
1014 if (SemanticName == "SV_TARGET") {
1015 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1016 AllowedStages: {{.Stage: llvm::Triple::Pixel, .AllowedIOTypesMask: IOType::Out}});
1017 return;
1018 }
1019
1020 // FIXME: catch-all for non-implemented system semantics reaching this
1021 // location.
1022 if (SemanticAttr->getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
1023 llvm_unreachable("Unknown SemanticAttr");
1024}
1025
1026void SemaHLSL::diagnoseAttrStageMismatch(
1027 const Attr *A, llvm::Triple::EnvironmentType Stage,
1028 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
1029 SmallVector<StringRef, 8> StageStrings;
1030 llvm::transform(Range&: AllowedStages, d_first: std::back_inserter(x&: StageStrings),
1031 F: [](llvm::Triple::EnvironmentType ST) {
1032 return StringRef(
1033 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: ST));
1034 });
1035 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
1036 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
1037 << (AllowedStages.size() != 1) << join(R&: StageStrings, Separator: ", ");
1038}
1039
1040void SemaHLSL::diagnoseSemanticStageMismatch(
1041 const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
1042 std::initializer_list<SemanticStageInfo> Allowed) {
1043
1044 for (auto &Case : Allowed) {
1045 if (Case.Stage != Stage)
1046 continue;
1047
1048 if (CurrentIOType & Case.AllowedIOTypesMask)
1049 return;
1050
1051 SmallVector<std::string, 8> ValidCases;
1052 llvm::transform(
1053 Range&: Allowed, d_first: std::back_inserter(x&: ValidCases), F: [](SemanticStageInfo Case) {
1054 SmallVector<std::string, 2> ValidType;
1055 if (Case.AllowedIOTypesMask & IOType::In)
1056 ValidType.push_back(Elt: "input");
1057 if (Case.AllowedIOTypesMask & IOType::Out)
1058 ValidType.push_back(Elt: "output");
1059 return std::string(
1060 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: Case.Stage)) +
1061 " " + join(R&: ValidType, Separator: "/");
1062 });
1063 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1064 << A->getAttrName() << (CurrentIOType & IOType::In ? "input" : "output")
1065 << llvm::Triple::getEnvironmentTypeName(Kind: Case.Stage)
1066 << join(R&: ValidCases, Separator: ", ");
1067 return;
1068 }
1069
1070 SmallVector<StringRef, 8> StageStrings;
1071 llvm::transform(
1072 Range&: Allowed, d_first: std::back_inserter(x&: StageStrings), F: [](SemanticStageInfo Case) {
1073 return StringRef(
1074 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: Case.Stage));
1075 });
1076
1077 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
1078 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
1079 << (Allowed.size() != 1) << join(R&: StageStrings, Separator: ", ");
1080}
1081
1082template <CastKind Kind>
1083static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
1084 if (const auto *VTy = Ty->getAs<VectorType>())
1085 Ty = VTy->getElementType();
1086 Ty = S.getASTContext().getExtVectorType(VectorType: Ty, NumElts: Sz);
1087 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
1088}
1089
1090template <CastKind Kind>
1091static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
1092 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
1093 return Ty;
1094}
1095
1096static QualType handleFloatVectorBinOpConversion(
1097 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1098 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1099 bool LHSFloat = LElTy->isRealFloatingType();
1100 bool RHSFloat = RElTy->isRealFloatingType();
1101
1102 if (LHSFloat && RHSFloat) {
1103 if (IsCompAssign ||
1104 SemaRef.getASTContext().getFloatingTypeOrder(LHS: LElTy, RHS: RElTy) > 0)
1105 return castElement<CK_FloatingCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1106
1107 return castElement<CK_FloatingCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1108 }
1109
1110 if (LHSFloat)
1111 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: RHS, Ty: LHSType);
1112
1113 assert(RHSFloat);
1114 if (IsCompAssign)
1115 return castElement<clang::CK_FloatingToIntegral>(S&: SemaRef, E&: RHS, Ty: LHSType);
1116
1117 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: LHS, Ty: RHSType);
1118}
1119
1120static QualType handleIntegerVectorBinOpConversion(
1121 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1122 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1123
1124 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LHS: LElTy, RHS: RElTy);
1125 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
1126 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
1127 auto &Ctx = SemaRef.getASTContext();
1128
1129 // If both types have the same signedness, use the higher ranked type.
1130 if (LHSSigned == RHSSigned) {
1131 if (IsCompAssign || IntOrder >= 0)
1132 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1133
1134 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1135 }
1136
1137 // If the unsigned type has greater than or equal rank of the signed type, use
1138 // the unsigned type.
1139 if (IntOrder != (LHSSigned ? 1 : -1)) {
1140 if (IsCompAssign || RHSSigned)
1141 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1142 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1143 }
1144
1145 // At this point the signed type has higher rank than the unsigned type, which
1146 // means it will be the same size or bigger. If the signed type is bigger, it
1147 // can represent all the values of the unsigned type, so select it.
1148 if (Ctx.getIntWidth(T: LElTy) != Ctx.getIntWidth(T: RElTy)) {
1149 if (IsCompAssign || LHSSigned)
1150 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1151 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1152 }
1153
1154 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
1155 // to C/C++ leaking through. The place this happens today is long vs long
1156 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
1157 // the long long has higher rank than long even though they are the same size.
1158
1159 // If this is a compound assignment cast the right hand side to the left hand
1160 // side's type.
1161 if (IsCompAssign)
1162 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1163
1164 // If this isn't a compound assignment we convert to unsigned long long.
1165 QualType ElTy = Ctx.getCorrespondingUnsignedType(T: LHSSigned ? LElTy : RElTy);
1166 QualType NewTy = Ctx.getExtVectorType(
1167 VectorType: ElTy, NumElts: RHSType->castAs<VectorType>()->getNumElements());
1168 (void)castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: NewTy);
1169
1170 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: NewTy);
1171}
1172
1173static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
1174 QualType SrcTy) {
1175 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
1176 return CK_FloatingCast;
1177 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
1178 return CK_IntegralCast;
1179 if (DestTy->isRealFloatingType())
1180 return CK_IntegralToFloating;
1181 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
1182 return CK_FloatingToIntegral;
1183}
1184
1185QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
1186 QualType LHSType,
1187 QualType RHSType,
1188 bool IsCompAssign) {
1189 const auto *LVecTy = LHSType->getAs<VectorType>();
1190 const auto *RVecTy = RHSType->getAs<VectorType>();
1191 auto &Ctx = getASTContext();
1192
1193 // If the LHS is not a vector and this is a compound assignment, we truncate
1194 // the argument to a scalar then convert it to the LHS's type.
1195 if (!LVecTy && IsCompAssign) {
1196 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1197 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: RElTy, CK: CK_HLSLVectorTruncation);
1198 RHSType = RHS.get()->getType();
1199 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1200 return LHSType;
1201 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: LHSType,
1202 CK: getScalarCastKind(Ctx, DestTy: LHSType, SrcTy: RHSType));
1203 return LHSType;
1204 }
1205
1206 unsigned EndSz = std::numeric_limits<unsigned>::max();
1207 unsigned LSz = 0;
1208 if (LVecTy)
1209 LSz = EndSz = LVecTy->getNumElements();
1210 if (RVecTy)
1211 EndSz = std::min(a: RVecTy->getNumElements(), b: EndSz);
1212 assert(EndSz != std::numeric_limits<unsigned>::max() &&
1213 "one of the above should have had a value");
1214
1215 // In a compound assignment, the left operand does not change type, the right
1216 // operand is converted to the type of the left operand.
1217 if (IsCompAssign && LSz != EndSz) {
1218 Diag(Loc: LHS.get()->getBeginLoc(),
1219 DiagID: diag::err_hlsl_vector_compound_assignment_truncation)
1220 << LHSType << RHSType;
1221 return QualType();
1222 }
1223
1224 if (RVecTy && RVecTy->getNumElements() > EndSz)
1225 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1226 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1227 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1228
1229 if (!RVecTy)
1230 castVector<CK_VectorSplat>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1231 if (!IsCompAssign && !LVecTy)
1232 castVector<CK_VectorSplat>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1233
1234 // If we're at the same type after resizing we can stop here.
1235 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1236 return Ctx.getCommonSugaredType(X: LHSType, Y: RHSType);
1237
1238 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1239 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1240
1241 // Handle conversion for floating point vectors.
1242 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1243 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1244 LElTy, RElTy, IsCompAssign);
1245
1246 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1247 "HLSL Vectors can only contain integer or floating point types");
1248 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1249 LElTy, RElTy, IsCompAssign);
1250}
1251
1252void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
1253 BinaryOperatorKind Opc) {
1254 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1255 "Called with non-logical operator");
1256 llvm::SmallVector<char, 256> Buff;
1257 llvm::raw_svector_ostream OS(Buff);
1258 PrintingPolicy PP(SemaRef.getLangOpts());
1259 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1260 OS << NewFnName << "(";
1261 LHS->printPretty(OS, Helper: nullptr, Policy: PP);
1262 OS << ", ";
1263 RHS->printPretty(OS, Helper: nullptr, Policy: PP);
1264 OS << ")";
1265 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1266 SemaRef.Diag(Loc: LHS->getBeginLoc(), DiagID: diag::note_function_suggestion)
1267 << NewFnName << FixItHint::CreateReplacement(RemoveRange: FullRange, Code: OS.str());
1268}
1269
1270std::pair<IdentifierInfo *, bool>
1271SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) {
1272 llvm::hash_code Hash = llvm::hash_value(S: Signature);
1273 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(val: Hash);
1274 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(Name: IdStr));
1275
1276 // Check if we have already found a decl of the same name.
1277 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1278 Sema::LookupOrdinaryName);
1279 bool Found = SemaRef.LookupQualifiedName(R, LookupCtx: SemaRef.CurContext);
1280 return {DeclIdent, Found};
1281}
1282
1283void SemaHLSL::ActOnFinishRootSignatureDecl(
1284 SourceLocation Loc, IdentifierInfo *DeclIdent,
1285 ArrayRef<hlsl::RootSignatureElement> RootElements) {
1286
1287 if (handleRootSignatureElements(Elements: RootElements))
1288 return;
1289
1290 SmallVector<llvm::hlsl::rootsig::RootElement> Elements;
1291 for (auto &RootSigElement : RootElements)
1292 Elements.push_back(Elt: RootSigElement.getElement());
1293
1294 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1295 C&: SemaRef.getASTContext(), /*DeclContext=*/DC: SemaRef.CurContext, Loc,
1296 ID: DeclIdent, Version: SemaRef.getLangOpts().HLSLRootSigVer, RootElements: Elements);
1297
1298 SignatureDecl->setImplicit();
1299 SemaRef.PushOnScopeChains(D: SignatureDecl, S: SemaRef.getCurScope());
1300}
1301
1302HLSLRootSignatureDecl *
1303SemaHLSL::lookupRootSignatureOverrideDecl(DeclContext *DC) const {
1304 if (RootSigOverrideIdent) {
1305 LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(),
1306 Sema::LookupOrdinaryName);
1307 if (SemaRef.LookupQualifiedName(R, LookupCtx: DC))
1308 return dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl());
1309 }
1310
1311 return nullptr;
1312}
1313
1314namespace {
1315
1316struct PerVisibilityBindingChecker {
1317 SemaHLSL *S;
1318 // We need one builder per `llvm::dxbc::ShaderVisibility` value.
1319 std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
1320
1321 struct ElemInfo {
1322 const hlsl::RootSignatureElement *Elem;
1323 llvm::dxbc::ShaderVisibility Vis;
1324 bool Diagnosed;
1325 };
1326 llvm::SmallVector<ElemInfo> ElemInfoMap;
1327
1328 PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
1329
1330 void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
1331 llvm::dxil::ResourceClass RC, uint32_t Space,
1332 uint32_t LowerBound, uint32_t UpperBound,
1333 const hlsl::RootSignatureElement *Elem) {
1334 uint32_t BuilderIndex = llvm::to_underlying(E: Visibility);
1335 assert(BuilderIndex < Builders.size() &&
1336 "Not enough builders for visibility type");
1337 Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
1338 Cookie: static_cast<const void *>(Elem));
1339
1340 static_assert(llvm::to_underlying(E: llvm::dxbc::ShaderVisibility::All) == 0,
1341 "'All' visibility must come first");
1342 if (Visibility == llvm::dxbc::ShaderVisibility::All)
1343 for (size_t I = 1, E = Builders.size(); I < E; ++I)
1344 Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
1345 Cookie: static_cast<const void *>(Elem));
1346
1347 ElemInfoMap.push_back(Elt: {.Elem: Elem, .Vis: Visibility, .Diagnosed: false});
1348 }
1349
1350 ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
1351 auto It = llvm::lower_bound(
1352 Range&: ElemInfoMap, Value&: Elem,
1353 C: [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
1354 assert(It->Elem == Elem && "Element not in map");
1355 return *It;
1356 }
1357
1358 bool checkOverlap() {
1359 llvm::sort(C&: ElemInfoMap, Comp: [](const auto &LHS, const auto &RHS) {
1360 return LHS.Elem < RHS.Elem;
1361 });
1362
1363 bool HadOverlap = false;
1364
1365 using llvm::hlsl::BindingInfoBuilder;
1366 auto ReportOverlap = [this,
1367 &HadOverlap](const BindingInfoBuilder &Builder,
1368 const llvm::hlsl::Binding &Reported) {
1369 HadOverlap = true;
1370
1371 const auto *Elem =
1372 static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
1373 const llvm::hlsl::Binding &Previous = Builder.findOverlapping(ReportedBinding: Reported);
1374 const auto *PrevElem =
1375 static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
1376
1377 ElemInfo &Info = getInfo(Elem);
1378 // We will have already diagnosed this binding if there's overlap in the
1379 // "All" visibility as well as any particular visibility.
1380 if (Info.Diagnosed)
1381 return;
1382 Info.Diagnosed = true;
1383
1384 ElemInfo &PrevInfo = getInfo(Elem: PrevElem);
1385 llvm::dxbc::ShaderVisibility CommonVis =
1386 Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
1387 : Info.Vis;
1388
1389 this->S->Diag(Loc: Elem->getLocation(), DiagID: diag::err_hlsl_resource_range_overlap)
1390 << llvm::to_underlying(E: Reported.RC) << Reported.LowerBound
1391 << Reported.isUnbounded() << Reported.UpperBound
1392 << llvm::to_underlying(E: Previous.RC) << Previous.LowerBound
1393 << Previous.isUnbounded() << Previous.UpperBound << Reported.Space
1394 << CommonVis;
1395
1396 this->S->Diag(Loc: PrevElem->getLocation(),
1397 DiagID: diag::note_hlsl_resource_range_here);
1398 };
1399
1400 for (BindingInfoBuilder &Builder : Builders)
1401 Builder.calculateBindingInfo(ReportOverlap);
1402
1403 return HadOverlap;
1404 }
1405};
1406
1407static CXXMethodDecl *lookupMethod(Sema &S, CXXRecordDecl *RecordDecl,
1408 StringRef Name, SourceLocation Loc) {
1409 DeclarationName DeclName(&S.getASTContext().Idents.get(Name));
1410 LookupResult Result(S, DeclName, Loc, Sema::LookupMemberName);
1411 if (!S.LookupQualifiedName(R&: Result, LookupCtx: static_cast<DeclContext *>(RecordDecl)))
1412 return nullptr;
1413 return cast<CXXMethodDecl>(Val: Result.getFoundDecl());
1414}
1415
1416} // end anonymous namespace
1417
1418static bool hasCounterHandle(const CXXRecordDecl *RD) {
1419 if (RD->field_empty())
1420 return false;
1421 auto It = std::next(x: RD->field_begin());
1422 if (It == RD->field_end())
1423 return false;
1424 const FieldDecl *SecondField = *It;
1425 if (const auto *ResTy =
1426 SecondField->getType()->getAs<HLSLAttributedResourceType>()) {
1427 return ResTy->getAttrs().IsCounter;
1428 }
1429 return false;
1430}
1431
1432bool SemaHLSL::handleRootSignatureElements(
1433 ArrayRef<hlsl::RootSignatureElement> Elements) {
1434 // Define some common error handling functions
1435 bool HadError = false;
1436 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1437 uint32_t UpperBound) {
1438 HadError = true;
1439 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1440 << LowerBound << UpperBound;
1441 };
1442
1443 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1444 float LowerBound,
1445 float UpperBound) {
1446 HadError = true;
1447 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1448 << llvm::formatv(Fmt: "{0:f}", Vals&: LowerBound).sstr<6>()
1449 << llvm::formatv(Fmt: "{0:f}", Vals&: UpperBound).sstr<6>();
1450 };
1451
1452 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1453 if (!llvm::hlsl::rootsig::verifyRegisterValue(RegisterValue: Register))
1454 ReportError(Loc, 0, 0xfffffffe);
1455 };
1456
1457 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1458 if (!llvm::hlsl::rootsig::verifyRegisterSpace(RegisterSpace: Space))
1459 ReportError(Loc, 0, 0xffffffef);
1460 };
1461
1462 const uint32_t Version =
1463 llvm::to_underlying(E: SemaRef.getLangOpts().HLSLRootSigVer);
1464 const uint32_t VersionEnum = Version - 1;
1465 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1466 HadError = true;
1467 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_flag)
1468 << /*version minor*/ VersionEnum;
1469 };
1470
1471 // Iterate through the elements and do basic validations
1472 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1473 SourceLocation Loc = RootSigElem.getLocation();
1474 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1475 if (const auto *Descriptor =
1476 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1477 VerifyRegister(Loc, Descriptor->Reg.Number);
1478 VerifySpace(Loc, Descriptor->Space);
1479
1480 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
1481 Flags: Descriptor->Flags))
1482 ReportFlagError(Loc);
1483 } else if (const auto *Constants =
1484 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1485 VerifyRegister(Loc, Constants->Reg.Number);
1486 VerifySpace(Loc, Constants->Space);
1487 } else if (const auto *Sampler =
1488 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1489 VerifyRegister(Loc, Sampler->Reg.Number);
1490 VerifySpace(Loc, Sampler->Space);
1491
1492 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1493 "By construction, parseFloatParam can't produce a NaN from a "
1494 "float_literal token");
1495
1496 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(MaxAnisotropy: Sampler->MaxAnisotropy))
1497 ReportError(Loc, 0, 16);
1498 if (!llvm::hlsl::rootsig::verifyMipLODBias(MipLODBias: Sampler->MipLODBias))
1499 ReportFloatError(Loc, -16.f, 15.99f);
1500 } else if (const auto *Clause =
1501 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1502 ptr: &Elem)) {
1503 VerifyRegister(Loc, Clause->Reg.Number);
1504 VerifySpace(Loc, Clause->Space);
1505
1506 if (!llvm::hlsl::rootsig::verifyNumDescriptors(NumDescriptors: Clause->NumDescriptors)) {
1507 // NumDescriptor could techincally be ~0u but that is reserved for
1508 // unbounded, so the diagnostic will not report that as a valid int
1509 // value
1510 ReportError(Loc, 1, 0xfffffffe);
1511 }
1512
1513 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Type: Clause->Type,
1514 Flags: Clause->Flags))
1515 ReportFlagError(Loc);
1516 }
1517 }
1518
1519 PerVisibilityBindingChecker BindingChecker(this);
1520 SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
1521 const hlsl::RootSignatureElement *>>
1522 UnboundClauses;
1523
1524 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1525 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1526 if (const auto *Descriptor =
1527 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1528 uint32_t LowerBound(Descriptor->Reg.Number);
1529 uint32_t UpperBound(LowerBound); // inclusive range
1530
1531 BindingChecker.trackBinding(
1532 Visibility: Descriptor->Visibility,
1533 RC: static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
1534 Space: Descriptor->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1535 } else if (const auto *Constants =
1536 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1537 uint32_t LowerBound(Constants->Reg.Number);
1538 uint32_t UpperBound(LowerBound); // inclusive range
1539
1540 BindingChecker.trackBinding(
1541 Visibility: Constants->Visibility, RC: llvm::dxil::ResourceClass::CBuffer,
1542 Space: Constants->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1543 } else if (const auto *Sampler =
1544 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1545 uint32_t LowerBound(Sampler->Reg.Number);
1546 uint32_t UpperBound(LowerBound); // inclusive range
1547
1548 BindingChecker.trackBinding(
1549 Visibility: Sampler->Visibility, RC: llvm::dxil::ResourceClass::Sampler,
1550 Space: Sampler->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1551 } else if (const auto *Clause =
1552 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1553 ptr: &Elem)) {
1554 // We'll process these once we see the table element.
1555 UnboundClauses.emplace_back(Args&: Clause, Args: &RootSigElem);
1556 } else if (const auto *Table =
1557 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(ptr: &Elem)) {
1558 assert(UnboundClauses.size() == Table->NumClauses &&
1559 "Number of unbound elements must match the number of clauses");
1560 bool HasAnySampler = false;
1561 bool HasAnyNonSampler = false;
1562 uint64_t Offset = 0;
1563 bool IsPrevUnbound = false;
1564 for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1565 SourceLocation Loc = ClauseElem->getLocation();
1566 if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
1567 HasAnySampler = true;
1568 else
1569 HasAnyNonSampler = true;
1570
1571 if (HasAnySampler && HasAnyNonSampler)
1572 Diag(Loc, DiagID: diag::err_hlsl_invalid_mixed_resources);
1573
1574 // Relevant error will have already been reported above and needs to be
1575 // fixed before we can conduct further analysis, so shortcut error
1576 // return
1577 if (Clause->NumDescriptors == 0)
1578 return true;
1579
1580 bool IsAppending =
1581 Clause->Offset == llvm::hlsl::rootsig::DescriptorTableOffsetAppend;
1582 if (!IsAppending)
1583 Offset = Clause->Offset;
1584
1585 uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
1586 Offset, Size: Clause->NumDescriptors);
1587
1588 if (IsPrevUnbound && IsAppending)
1589 Diag(Loc, DiagID: diag::err_hlsl_appending_onto_unbound);
1590 else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(Offset: RangeBound))
1591 Diag(Loc, DiagID: diag::err_hlsl_offset_overflow) << Offset << RangeBound;
1592
1593 // Update offset to be 1 past this range's bound
1594 Offset = RangeBound + 1;
1595 IsPrevUnbound = Clause->NumDescriptors ==
1596 llvm::hlsl::rootsig::NumDescriptorsUnbounded;
1597
1598 // Compute the register bounds and track resource binding
1599 uint32_t LowerBound(Clause->Reg.Number);
1600 uint32_t UpperBound = llvm::hlsl::rootsig::computeRangeBound(
1601 Offset: LowerBound, Size: Clause->NumDescriptors);
1602
1603 BindingChecker.trackBinding(
1604 Visibility: Table->Visibility,
1605 RC: static_cast<llvm::dxil::ResourceClass>(Clause->Type), Space: Clause->Space,
1606 LowerBound, UpperBound, Elem: ClauseElem);
1607 }
1608 UnboundClauses.clear();
1609 }
1610 }
1611
1612 return BindingChecker.checkOverlap();
1613}
1614
1615void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
1616 if (AL.getNumArgs() != 1) {
1617 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
1618 return;
1619 }
1620
1621 IdentifierInfo *Ident = AL.getArgAsIdent(Arg: 0)->getIdentifierInfo();
1622 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1623 if (RS->getSignatureIdent() != Ident) {
1624 Diag(Loc: AL.getLoc(), DiagID: diag::err_disallowed_duplicate_attribute) << RS;
1625 return;
1626 }
1627
1628 Diag(Loc: AL.getLoc(), DiagID: diag::warn_duplicate_attribute_exact) << RS;
1629 return;
1630 }
1631
1632 LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName);
1633 if (SemaRef.LookupQualifiedName(R, LookupCtx: D->getDeclContext()))
1634 if (auto *SignatureDecl =
1635 dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl())) {
1636 D->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
1637 getASTContext(), AL, Ident, SignatureDecl));
1638 }
1639}
1640
1641void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
1642 llvm::VersionTuple SMVersion =
1643 getASTContext().getTargetInfo().getTriple().getOSVersion();
1644 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1645 llvm::Triple::dxil;
1646
1647 uint32_t ZMax = 1024;
1648 uint32_t ThreadMax = 1024;
1649 if (IsDXIL && SMVersion.getMajor() <= 4) {
1650 ZMax = 1;
1651 ThreadMax = 768;
1652 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1653 ZMax = 64;
1654 ThreadMax = 1024;
1655 }
1656
1657 uint32_t X;
1658 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: X))
1659 return;
1660 if (X > 1024) {
1661 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1662 DiagID: diag::err_hlsl_numthreads_argument_oor)
1663 << 0 << 1024;
1664 return;
1665 }
1666 uint32_t Y;
1667 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Y))
1668 return;
1669 if (Y > 1024) {
1670 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1671 DiagID: diag::err_hlsl_numthreads_argument_oor)
1672 << 1 << 1024;
1673 return;
1674 }
1675 uint32_t Z;
1676 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Z))
1677 return;
1678 if (Z > ZMax) {
1679 SemaRef.Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1680 DiagID: diag::err_hlsl_numthreads_argument_oor)
1681 << 2 << ZMax;
1682 return;
1683 }
1684
1685 if (X * Y * Z > ThreadMax) {
1686 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_numthreads_invalid) << ThreadMax;
1687 return;
1688 }
1689
1690 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1691 if (NewAttr)
1692 D->addAttr(A: NewAttr);
1693}
1694
1695static bool isValidWaveSizeValue(unsigned Value) {
1696 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1697}
1698
1699void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
1700 // validate that the wavesize argument is a power of 2 between 4 and 128
1701 // inclusive
1702 unsigned SpelledArgsCount = AL.getNumArgs();
1703 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1704 return;
1705
1706 uint32_t Min;
1707 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Min))
1708 return;
1709
1710 uint32_t Max = 0;
1711 if (SpelledArgsCount > 1 &&
1712 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Max))
1713 return;
1714
1715 uint32_t Preferred = 0;
1716 if (SpelledArgsCount > 2 &&
1717 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Preferred))
1718 return;
1719
1720 if (SpelledArgsCount > 2) {
1721 if (!isValidWaveSizeValue(Value: Preferred)) {
1722 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1723 DiagID: diag::err_attribute_power_of_two_in_range)
1724 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1725 << Preferred;
1726 return;
1727 }
1728 // Preferred not in range.
1729 if (Preferred < Min || Preferred > Max) {
1730 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1731 DiagID: diag::err_attribute_power_of_two_in_range)
1732 << AL << Min << Max << Preferred;
1733 return;
1734 }
1735 } else if (SpelledArgsCount > 1) {
1736 if (!isValidWaveSizeValue(Value: Max)) {
1737 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1738 DiagID: diag::err_attribute_power_of_two_in_range)
1739 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1740 return;
1741 }
1742 if (Max < Min) {
1743 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_invalid) << AL << 1;
1744 return;
1745 } else if (Max == Min) {
1746 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attr_min_eq_max) << AL;
1747 }
1748 } else {
1749 if (!isValidWaveSizeValue(Value: Min)) {
1750 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1751 DiagID: diag::err_attribute_power_of_two_in_range)
1752 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1753 return;
1754 }
1755 }
1756
1757 HLSLWaveSizeAttr *NewAttr =
1758 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1759 if (NewAttr)
1760 D->addAttr(A: NewAttr);
1761}
1762
1763void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
1764 uint32_t ID;
1765 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: ID))
1766 return;
1767 D->addAttr(A: ::new (getASTContext())
1768 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1769}
1770
1771void SemaHLSL::handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL) {
1772 D->addAttr(A: ::new (getASTContext())
1773 HLSLVkPushConstantAttr(getASTContext(), AL));
1774}
1775
1776void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1777 uint32_t Id;
1778 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Id))
1779 return;
1780 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1781 if (NewAttr)
1782 D->addAttr(A: NewAttr);
1783}
1784
1785void SemaHLSL::handleVkBindingAttr(Decl *D, const ParsedAttr &AL) {
1786 uint32_t Binding = 0;
1787 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Binding))
1788 return;
1789 uint32_t Set = 0;
1790 if (AL.getNumArgs() > 1 &&
1791 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Set))
1792 return;
1793
1794 D->addAttr(A: ::new (getASTContext())
1795 HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
1796}
1797
1798void SemaHLSL::handleVkLocationAttr(Decl *D, const ParsedAttr &AL) {
1799 uint32_t Location;
1800 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Location))
1801 return;
1802
1803 D->addAttr(A: ::new (getASTContext())
1804 HLSLVkLocationAttr(getASTContext(), AL, Location));
1805}
1806
1807bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
1808 const auto *VT = T->getAs<VectorType>();
1809
1810 if (!T->hasUnsignedIntegerRepresentation() ||
1811 (VT && VT->getNumElements() > 3)) {
1812 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1813 << AL << "uint/uint2/uint3";
1814 return false;
1815 }
1816
1817 return true;
1818}
1819
1820bool SemaHLSL::diagnosePositionType(QualType T, const ParsedAttr &AL) {
1821 const auto *VT = T->getAs<VectorType>();
1822 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1823 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1824 << AL << "float/float1/float2/float3/float4";
1825 return false;
1826 }
1827
1828 return true;
1829}
1830
1831void SemaHLSL::diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL,
1832 std::optional<unsigned> Index) {
1833 std::string SemanticName = AL.getAttrName()->getName().upper();
1834
1835 auto *VD = cast<ValueDecl>(Val: D);
1836 QualType ValueType = VD->getType();
1837 if (auto *FD = dyn_cast<FunctionDecl>(Val: D))
1838 ValueType = FD->getReturnType();
1839
1840 bool IsOutput = false;
1841 if (HLSLParamModifierAttr *MA = D->getAttr<HLSLParamModifierAttr>()) {
1842 if (MA->isOut()) {
1843 IsOutput = true;
1844 ValueType = cast<ReferenceType>(Val&: ValueType)->getPointeeType();
1845 }
1846 }
1847
1848 if (SemanticName == "SV_DISPATCHTHREADID") {
1849 diagnoseInputIDType(T: ValueType, AL);
1850 if (IsOutput)
1851 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1852 if (Index.has_value())
1853 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1854 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1855 return;
1856 }
1857
1858 if (SemanticName == "SV_GROUPINDEX") {
1859 if (IsOutput)
1860 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1861 if (Index.has_value())
1862 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1863 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1864 return;
1865 }
1866
1867 if (SemanticName == "SV_GROUPTHREADID") {
1868 diagnoseInputIDType(T: ValueType, AL);
1869 if (IsOutput)
1870 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1871 if (Index.has_value())
1872 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1873 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1874 return;
1875 }
1876
1877 if (SemanticName == "SV_GROUPID") {
1878 diagnoseInputIDType(T: ValueType, AL);
1879 if (IsOutput)
1880 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1881 if (Index.has_value())
1882 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1883 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1884 return;
1885 }
1886
1887 if (SemanticName == "SV_POSITION") {
1888 const auto *VT = ValueType->getAs<VectorType>();
1889 if (!ValueType->hasFloatingRepresentation() ||
1890 (VT && VT->getNumElements() > 4))
1891 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1892 << AL << "float/float1/float2/float3/float4";
1893 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1894 return;
1895 }
1896
1897 if (SemanticName == "SV_TARGET") {
1898 const auto *VT = ValueType->getAs<VectorType>();
1899 if (!ValueType->hasFloatingRepresentation() ||
1900 (VT && VT->getNumElements() > 4))
1901 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1902 << AL << "float/float1/float2/float3/float4";
1903 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1904 return;
1905 }
1906
1907 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_unknown_semantic) << AL;
1908}
1909
1910void SemaHLSL::handleSemanticAttr(Decl *D, const ParsedAttr &AL) {
1911 uint32_t IndexValue(0), ExplicitIndex(0);
1912 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: IndexValue) ||
1913 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: ExplicitIndex)) {
1914 assert(0 && "HLSLUnparsedSemantic is expected to have 2 int arguments.");
1915 }
1916 assert(IndexValue > 0 ? ExplicitIndex : true);
1917 std::optional<unsigned> Index =
1918 ExplicitIndex ? std::optional<unsigned>(IndexValue) : std::nullopt;
1919
1920 if (AL.getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
1921 diagnoseSystemSemanticAttr(D, AL, Index);
1922 else
1923 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1924}
1925
1926void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
1927 if (!isa<VarDecl>(Val: D) || !isa<HLSLBufferDecl>(Val: D->getDeclContext())) {
1928 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_ast_node)
1929 << AL << "shader constant in a constant buffer";
1930 return;
1931 }
1932
1933 uint32_t SubComponent;
1934 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: SubComponent))
1935 return;
1936 uint32_t Component;
1937 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Component))
1938 return;
1939
1940 QualType T = cast<VarDecl>(Val: D)->getType().getCanonicalType();
1941 // Check if T is an array or struct type.
1942 // TODO: mark matrix type as aggregate type.
1943 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
1944
1945 // Check Component is valid for T.
1946 if (Component) {
1947 unsigned Size = getASTContext().getTypeSize(T);
1948 if (IsAggregateTy) {
1949 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_invalid_register_or_packoffset);
1950 return;
1951 } else {
1952 // Make sure Component + sizeof(T) <= 4.
1953 if ((Component * 32 + Size) > 128) {
1954 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_cross_reg_boundary);
1955 return;
1956 }
1957 QualType EltTy = T;
1958 if (const auto *VT = T->getAs<VectorType>())
1959 EltTy = VT->getElementType();
1960 unsigned Align = getASTContext().getTypeAlign(T: EltTy);
1961 if (Align > 32 && Component == 1) {
1962 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
1963 // So we only need to check Component 1 here.
1964 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_alignment_mismatch)
1965 << Align << EltTy;
1966 return;
1967 }
1968 }
1969 }
1970
1971 D->addAttr(A: ::new (getASTContext()) HLSLPackOffsetAttr(
1972 getASTContext(), AL, SubComponent, Component));
1973}
1974
1975void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
1976 StringRef Str;
1977 SourceLocation ArgLoc;
1978 if (!SemaRef.checkStringLiteralArgumentAttr(Attr: AL, ArgNum: 0, Str, ArgLocation: &ArgLoc))
1979 return;
1980
1981 llvm::Triple::EnvironmentType ShaderType;
1982 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Val: Str, Out&: ShaderType)) {
1983 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_type_not_supported)
1984 << AL << Str << ArgLoc;
1985 return;
1986 }
1987
1988 // FIXME: check function match the shader stage.
1989
1990 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
1991 if (NewAttr)
1992 D->addAttr(A: NewAttr);
1993}
1994
1995bool clang::CreateHLSLAttributedResourceType(
1996 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
1997 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
1998 assert(AttrList.size() && "expected list of resource attributes");
1999
2000 QualType ContainedTy = QualType();
2001 TypeSourceInfo *ContainedTyInfo = nullptr;
2002 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
2003 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
2004
2005 HLSLAttributedResourceType::Attributes ResAttrs;
2006
2007 bool HasResourceClass = false;
2008 bool HasResourceDimension = false;
2009 for (const Attr *A : AttrList) {
2010 if (!A)
2011 continue;
2012 LocEnd = A->getRange().getEnd();
2013 switch (A->getKind()) {
2014 case attr::HLSLResourceClass: {
2015 ResourceClass RC = cast<HLSLResourceClassAttr>(Val: A)->getResourceClass();
2016 if (HasResourceClass) {
2017 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceClass == RC
2018 ? diag::warn_duplicate_attribute_exact
2019 : diag::warn_duplicate_attribute)
2020 << A;
2021 return false;
2022 }
2023 ResAttrs.ResourceClass = RC;
2024 HasResourceClass = true;
2025 break;
2026 }
2027 case attr::HLSLResourceDimension: {
2028 llvm::dxil::ResourceDimension RD =
2029 cast<HLSLResourceDimensionAttr>(Val: A)->getDimension();
2030 if (HasResourceDimension) {
2031 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceDimension == RD
2032 ? diag::warn_duplicate_attribute_exact
2033 : diag::warn_duplicate_attribute)
2034 << A;
2035 return false;
2036 }
2037 ResAttrs.ResourceDimension = RD;
2038 HasResourceDimension = true;
2039 break;
2040 }
2041 case attr::HLSLROV:
2042 if (ResAttrs.IsROV) {
2043 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2044 return false;
2045 }
2046 ResAttrs.IsROV = true;
2047 break;
2048 case attr::HLSLRawBuffer:
2049 if (ResAttrs.RawBuffer) {
2050 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2051 return false;
2052 }
2053 ResAttrs.RawBuffer = true;
2054 break;
2055 case attr::HLSLIsCounter:
2056 if (ResAttrs.IsCounter) {
2057 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2058 return false;
2059 }
2060 ResAttrs.IsCounter = true;
2061 break;
2062 case attr::HLSLContainedType: {
2063 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(Val: A);
2064 QualType Ty = CTAttr->getType();
2065 if (!ContainedTy.isNull()) {
2066 S.Diag(Loc: A->getLocation(), DiagID: ContainedTy == Ty
2067 ? diag::warn_duplicate_attribute_exact
2068 : diag::warn_duplicate_attribute)
2069 << A;
2070 return false;
2071 }
2072 ContainedTy = Ty;
2073 ContainedTyInfo = CTAttr->getTypeLoc();
2074 break;
2075 }
2076 default:
2077 llvm_unreachable("unhandled resource attribute type");
2078 }
2079 }
2080
2081 if (!HasResourceClass) {
2082 S.Diag(Loc: AttrList.back()->getRange().getEnd(),
2083 DiagID: diag::err_hlsl_missing_resource_class);
2084 return false;
2085 }
2086
2087 ResType = S.getASTContext().getHLSLAttributedResourceType(
2088 Wrapped, Contained: ContainedTy, Attrs: ResAttrs);
2089
2090 if (LocInfo && ContainedTyInfo) {
2091 LocInfo->Range = SourceRange(LocBegin, LocEnd);
2092 LocInfo->ContainedTyInfo = ContainedTyInfo;
2093 }
2094 return true;
2095}
2096
2097// Validates and creates an HLSL attribute that is applied as type attribute on
2098// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
2099// the end of the declaration they are applied to the declaration type by
2100// wrapping it in HLSLAttributedResourceType.
2101bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) {
2102 // only allow resource type attributes on intangible types
2103 if (!T->isHLSLResourceType()) {
2104 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attribute_needs_intangible_type)
2105 << AL << getASTContext().HLSLResourceTy;
2106 return false;
2107 }
2108
2109 // validate number of arguments
2110 if (!AL.checkExactlyNumArgs(S&: SemaRef, Num: AL.getMinArgs()))
2111 return false;
2112
2113 Attr *A = nullptr;
2114
2115 AttributeCommonInfo ACI(
2116 AL.getLoc(), AttributeScopeInfo(AL.getScopeName(), AL.getScopeLoc()),
2117 AttributeCommonInfo::NoSemaHandlerAttribute,
2118 {
2119 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
2120 false /*IsRegularKeywordAttribute*/
2121 });
2122
2123 switch (AL.getKind()) {
2124 case ParsedAttr::AT_HLSLResourceClass: {
2125 if (!AL.isArgIdent(Arg: 0)) {
2126 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2127 << AL << AANT_ArgumentIdentifier;
2128 return false;
2129 }
2130
2131 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
2132 StringRef Identifier = Loc->getIdentifierInfo()->getName();
2133 SourceLocation ArgLoc = Loc->getLoc();
2134
2135 // Validate resource class value
2136 ResourceClass RC;
2137 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Val: Identifier, Out&: RC)) {
2138 Diag(Loc: ArgLoc, DiagID: diag::warn_attribute_type_not_supported)
2139 << "ResourceClass" << Identifier;
2140 return false;
2141 }
2142 A = HLSLResourceClassAttr::Create(Ctx&: getASTContext(), ResourceClass: RC, CommonInfo: ACI);
2143 break;
2144 }
2145
2146 case ParsedAttr::AT_HLSLROV:
2147 A = HLSLROVAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2148 break;
2149
2150 case ParsedAttr::AT_HLSLRawBuffer:
2151 A = HLSLRawBufferAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2152 break;
2153
2154 case ParsedAttr::AT_HLSLIsCounter:
2155 A = HLSLIsCounterAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2156 break;
2157
2158 case ParsedAttr::AT_HLSLContainedType: {
2159 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
2160 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
2161 return false;
2162 }
2163
2164 TypeSourceInfo *TSI = nullptr;
2165 QualType QT = SemaRef.GetTypeFromParser(Ty: AL.getTypeArg(), TInfo: &TSI);
2166 assert(TSI && "no type source info for attribute argument");
2167 if (SemaRef.RequireCompleteType(Loc: TSI->getTypeLoc().getBeginLoc(), T: QT,
2168 DiagID: diag::err_incomplete_type))
2169 return false;
2170 A = HLSLContainedTypeAttr::Create(Ctx&: getASTContext(), Type: TSI, CommonInfo: ACI);
2171 break;
2172 }
2173
2174 default:
2175 llvm_unreachable("unhandled HLSL attribute");
2176 }
2177
2178 HLSLResourcesTypeAttrs.emplace_back(Args&: A);
2179 return true;
2180}
2181
2182// Combines all resource type attributes and creates HLSLAttributedResourceType.
2183QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) {
2184 if (!HLSLResourcesTypeAttrs.size())
2185 return CurrentType;
2186
2187 QualType QT = CurrentType;
2188 HLSLAttributedResourceLocInfo LocInfo;
2189 if (CreateHLSLAttributedResourceType(S&: SemaRef, Wrapped: CurrentType,
2190 AttrList: HLSLResourcesTypeAttrs, ResType&: QT, LocInfo: &LocInfo)) {
2191 const HLSLAttributedResourceType *RT =
2192 cast<HLSLAttributedResourceType>(Val: QT.getTypePtr());
2193
2194 // Temporarily store TypeLoc information for the new type.
2195 // It will be transferred to HLSLAttributesResourceTypeLoc
2196 // shortly after the type is created by TypeSpecLocFiller which
2197 // will call the TakeLocForHLSLAttribute method below.
2198 LocsForHLSLAttributedResources.insert(KV: std::pair(RT, LocInfo));
2199 }
2200 HLSLResourcesTypeAttrs.clear();
2201 return QT;
2202}
2203
2204// Returns source location for the HLSLAttributedResourceType
2205HLSLAttributedResourceLocInfo
2206SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
2207 HLSLAttributedResourceLocInfo LocInfo = {};
2208 auto I = LocsForHLSLAttributedResources.find(Val: RT);
2209 if (I != LocsForHLSLAttributedResources.end()) {
2210 LocInfo = I->second;
2211 LocsForHLSLAttributedResources.erase(I);
2212 return LocInfo;
2213 }
2214 LocInfo.Range = SourceRange();
2215 return LocInfo;
2216}
2217
2218// Walks though the global variable declaration, collects all resource binding
2219// requirements and adds them to Bindings
2220void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
2221 const RecordType *RT) {
2222 const RecordDecl *RD = RT->getDecl()->getDefinitionOrSelf();
2223 for (FieldDecl *FD : RD->fields()) {
2224 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
2225
2226 // Unwrap arrays
2227 // FIXME: Calculate array size while unwrapping
2228 assert(!Ty->isIncompleteArrayType() &&
2229 "incomplete arrays inside user defined types are not supported");
2230 while (Ty->isConstantArrayType()) {
2231 const ConstantArrayType *CAT = cast<ConstantArrayType>(Val: Ty);
2232 Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
2233 }
2234
2235 if (!Ty->isRecordType())
2236 continue;
2237
2238 if (const HLSLAttributedResourceType *AttrResType =
2239 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
2240 // Add a new DeclBindingInfo to Bindings if it does not already exist
2241 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
2242 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, ResClass: RC);
2243 if (!DBI)
2244 Bindings.addDeclBindingInfo(VD, ResClass: RC);
2245 } else if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty)) {
2246 // Recursively scan embedded struct or class; it would be nice to do this
2247 // without recursion, but tricky to correctly calculate the size of the
2248 // binding, which is something we are probably going to need to do later
2249 // on. Hopefully nesting of structs in structs too many levels is
2250 // unlikely.
2251 collectResourceBindingsOnUserRecordDecl(VD, RT);
2252 }
2253 }
2254}
2255
2256// Diagnose localized register binding errors for a single binding; does not
2257// diagnose resource binding on user record types, that will be done later
2258// in processResourceBindingOnDecl based on the information collected in
2259// collectResourceBindingsOnVarDecl.
2260// Returns false if the register binding is not valid.
2261static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
2262 Decl *D, RegisterType RegType,
2263 bool SpecifiedSpace) {
2264 int RegTypeNum = static_cast<int>(RegType);
2265
2266 // check if the decl type is groupshared
2267 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
2268 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2269 return false;
2270 }
2271
2272 // Cbuffers and Tbuffers are HLSLBufferDecl types
2273 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: D)) {
2274 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
2275 : ResourceClass::SRV;
2276 if (RegType == getRegisterType(RC))
2277 return true;
2278
2279 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
2280 << RegTypeNum;
2281 return false;
2282 }
2283
2284 // Samplers, UAVs, and SRVs are VarDecl types
2285 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
2286 VarDecl *VD = cast<VarDecl>(Val: D);
2287
2288 // Resource
2289 if (const HLSLAttributedResourceType *AttrResType =
2290 HLSLAttributedResourceType::findHandleTypeOnResource(
2291 RT: VD->getType().getTypePtr())) {
2292 if (RegType == getRegisterType(ResTy: AttrResType))
2293 return true;
2294
2295 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
2296 << RegTypeNum;
2297 return false;
2298 }
2299
2300 const clang::Type *Ty = VD->getType().getTypePtr();
2301 while (Ty->isArrayType())
2302 Ty = Ty->getArrayElementTypeNoTypeQual();
2303
2304 // Basic types
2305 if (Ty->isArithmeticType() || Ty->isVectorType()) {
2306 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(Val: D->getDeclContext());
2307 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
2308 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_space_on_global_constant);
2309
2310 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(Ctx: S.getASTContext()) ||
2311 Ty->isFloatingType() || Ty->isVectorType())) {
2312 // Register annotation on default constant buffer declaration ($Globals)
2313 if (RegType == RegisterType::CBuffer)
2314 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_deprecated_register_type_b);
2315 else if (RegType != RegisterType::C)
2316 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2317 else
2318 return true;
2319 } else {
2320 if (RegType == RegisterType::C)
2321 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_register_type_c_packoffset);
2322 else
2323 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2324 }
2325 return false;
2326 }
2327 if (Ty->isRecordType())
2328 // RecordTypes will be diagnosed in processResourceBindingOnDecl
2329 // that is called from ActOnVariableDeclarator
2330 return true;
2331
2332 // Anything else is an error
2333 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2334 return false;
2335}
2336
2337static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
2338 RegisterType regType) {
2339 // make sure that there are no two register annotations
2340 // applied to the decl with the same register type
2341 bool RegisterTypesDetected[5] = {false};
2342 RegisterTypesDetected[static_cast<int>(regType)] = true;
2343
2344 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
2345 if (HLSLResourceBindingAttr *attr =
2346 dyn_cast<HLSLResourceBindingAttr>(Val: *it)) {
2347
2348 RegisterType otherRegType = attr->getRegisterType();
2349 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
2350 int otherRegTypeNum = static_cast<int>(otherRegType);
2351 S.Diag(Loc: TheDecl->getLocation(),
2352 DiagID: diag::err_hlsl_duplicate_register_annotation)
2353 << otherRegTypeNum;
2354 return false;
2355 }
2356 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
2357 }
2358 }
2359 return true;
2360}
2361
2362static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
2363 Decl *D, RegisterType RegType,
2364 bool SpecifiedSpace) {
2365
2366 // exactly one of these two types should be set
2367 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
2368 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
2369 "expecting VarDecl or HLSLBufferDecl");
2370
2371 // check if the declaration contains resource matching the register type
2372 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
2373 return false;
2374
2375 // next, if multiple register annotations exist, check that none conflict.
2376 return ValidateMultipleRegisterAnnotations(S, TheDecl: D, regType: RegType);
2377}
2378
2379// return false if the slot count exceeds the limit, true otherwise
2380static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot,
2381 const uint64_t &Limit,
2382 const ResourceClass ResClass,
2383 ASTContext &Ctx,
2384 uint64_t ArrayCount = 1) {
2385 Ty = Ty.getCanonicalType();
2386 const Type *T = Ty.getTypePtr();
2387
2388 // Early exit if already overflowed
2389 if (StartSlot > Limit)
2390 return false;
2391
2392 // Case 1: array type
2393 if (const auto *AT = dyn_cast<ArrayType>(Val: T)) {
2394 uint64_t Count = 1;
2395
2396 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: AT))
2397 Count = CAT->getSize().getZExtValue();
2398
2399 QualType ElemTy = AT->getElementType();
2400 return AccumulateHLSLResourceSlots(Ty: ElemTy, StartSlot, Limit, ResClass, Ctx,
2401 ArrayCount: ArrayCount * Count);
2402 }
2403
2404 // Case 2: resource leaf
2405 if (auto ResTy = dyn_cast<HLSLAttributedResourceType>(Val: T)) {
2406 // First ensure this resource counts towards the corresponding
2407 // register type limit.
2408 if (ResTy->getAttrs().ResourceClass != ResClass)
2409 return true;
2410
2411 // Validate highest slot used
2412 uint64_t EndSlot = StartSlot + ArrayCount - 1;
2413 if (EndSlot > Limit)
2414 return false;
2415
2416 // Advance SlotCount past the consumed range
2417 StartSlot = EndSlot + 1;
2418 return true;
2419 }
2420
2421 // Case 3: struct / record
2422 if (const auto *RT = dyn_cast<RecordType>(Val: T)) {
2423 const RecordDecl *RD = RT->getDecl();
2424
2425 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
2426 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
2427 if (!AccumulateHLSLResourceSlots(Ty: Base.getType(), StartSlot, Limit,
2428 ResClass, Ctx, ArrayCount))
2429 return false;
2430 }
2431 }
2432
2433 for (const FieldDecl *Field : RD->fields()) {
2434 if (!AccumulateHLSLResourceSlots(Ty: Field->getType(), StartSlot, Limit,
2435 ResClass, Ctx, ArrayCount))
2436 return false;
2437 }
2438
2439 return true;
2440 }
2441
2442 // Case 4: everything else
2443 return true;
2444}
2445
2446// return true if there is something invalid, false otherwise
2447static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl,
2448 ASTContext &Ctx, RegisterType RegTy) {
2449 const uint64_t Limit = UINT32_MAX;
2450 if (SlotNum > Limit)
2451 return true;
2452
2453 // after verifying the number doesn't exceed uint32max, we don't need
2454 // to look further into c or i register types
2455 if (RegTy == RegisterType::C || RegTy == RegisterType::I)
2456 return false;
2457
2458 if (VarDecl *VD = dyn_cast<VarDecl>(Val: TheDecl)) {
2459 uint64_t BaseSlot = SlotNum;
2460
2461 if (!AccumulateHLSLResourceSlots(Ty: VD->getType(), StartSlot&: SlotNum, Limit,
2462 ResClass: getResourceClass(RT: RegTy), Ctx))
2463 return true;
2464
2465 // After AccumulateHLSLResourceSlots runs, SlotNum is now
2466 // the first free slot; last used was SlotNum - 1
2467 return (BaseSlot > Limit);
2468 }
2469 // handle the cbuffer/tbuffer case
2470 if (isa<HLSLBufferDecl>(Val: TheDecl))
2471 // resources cannot be put within a cbuffer, so no need
2472 // to analyze the structure since the register number
2473 // won't be pushed any higher.
2474 return (SlotNum > Limit);
2475
2476 // we don't expect any other decl type, so fail
2477 llvm_unreachable("unexpected decl type");
2478}
2479
2480void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
2481 if (VarDecl *VD = dyn_cast<VarDecl>(Val: TheDecl)) {
2482 QualType Ty = VD->getType();
2483 if (const auto *IAT = dyn_cast<IncompleteArrayType>(Val&: Ty))
2484 Ty = IAT->getElementType();
2485 if (SemaRef.RequireCompleteType(Loc: TheDecl->getBeginLoc(), T: Ty,
2486 DiagID: diag::err_incomplete_type))
2487 return;
2488 }
2489
2490 StringRef Slot = "";
2491 StringRef Space = "";
2492 SourceLocation SlotLoc, SpaceLoc;
2493
2494 if (!AL.isArgIdent(Arg: 0)) {
2495 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2496 << AL << AANT_ArgumentIdentifier;
2497 return;
2498 }
2499 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
2500
2501 if (AL.getNumArgs() == 2) {
2502 Slot = Loc->getIdentifierInfo()->getName();
2503 SlotLoc = Loc->getLoc();
2504 if (!AL.isArgIdent(Arg: 1)) {
2505 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2506 << AL << AANT_ArgumentIdentifier;
2507 return;
2508 }
2509 Loc = AL.getArgAsIdent(Arg: 1);
2510 Space = Loc->getIdentifierInfo()->getName();
2511 SpaceLoc = Loc->getLoc();
2512 } else {
2513 StringRef Str = Loc->getIdentifierInfo()->getName();
2514 if (Str.starts_with(Prefix: "space")) {
2515 Space = Str;
2516 SpaceLoc = Loc->getLoc();
2517 } else {
2518 Slot = Str;
2519 SlotLoc = Loc->getLoc();
2520 Space = "space0";
2521 }
2522 }
2523
2524 RegisterType RegType = RegisterType::SRV;
2525 std::optional<unsigned> SlotNum;
2526 unsigned SpaceNum = 0;
2527
2528 // Validate slot
2529 if (!Slot.empty()) {
2530 if (!convertToRegisterType(Slot, RT: &RegType)) {
2531 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_binding_type_invalid) << Slot.substr(Start: 0, N: 1);
2532 return;
2533 }
2534 if (RegType == RegisterType::I) {
2535 Diag(Loc: SlotLoc, DiagID: diag::warn_hlsl_deprecated_register_type_i);
2536 return;
2537 }
2538 const StringRef SlotNumStr = Slot.substr(Start: 1);
2539
2540 uint64_t N;
2541
2542 // validate that the slot number is a non-empty number
2543 if (SlotNumStr.getAsInteger(Radix: 10, Result&: N)) {
2544 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_unsupported_register_number);
2545 return;
2546 }
2547
2548 // Validate register number. It should not exceed UINT32_MAX,
2549 // including if the resource type is an array that starts
2550 // before UINT32_MAX, but ends afterwards.
2551 if (ValidateRegisterNumber(SlotNum: N, TheDecl, Ctx&: getASTContext(), RegTy: RegType)) {
2552 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_register_number_too_large);
2553 return;
2554 }
2555
2556 // the slot number has been validated and does not exceed UINT32_MAX
2557 SlotNum = (unsigned)N;
2558 }
2559
2560 // Validate space
2561 if (!Space.starts_with(Prefix: "space")) {
2562 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2563 return;
2564 }
2565 StringRef SpaceNumStr = Space.substr(Start: 5);
2566 if (SpaceNumStr.getAsInteger(Radix: 10, Result&: SpaceNum)) {
2567 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2568 return;
2569 }
2570
2571 // If we have slot, diagnose it is the right register type for the decl
2572 if (SlotNum.has_value())
2573 if (!DiagnoseHLSLRegisterAttribute(S&: SemaRef, ArgLoc&: SlotLoc, D: TheDecl, RegType,
2574 SpecifiedSpace: !SpaceLoc.isInvalid()))
2575 return;
2576
2577 HLSLResourceBindingAttr *NewAttr =
2578 HLSLResourceBindingAttr::Create(Ctx&: getASTContext(), Slot, Space, CommonInfo: AL);
2579 if (NewAttr) {
2580 NewAttr->setBinding(RT: RegType, SlotNum, SpaceNum);
2581 TheDecl->addAttr(A: NewAttr);
2582 }
2583}
2584
2585void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
2586 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2587 D, AL,
2588 Spelling: static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2589 if (NewAttr)
2590 D->addAttr(A: NewAttr);
2591}
2592
2593namespace {
2594
2595/// This class implements HLSL availability diagnostics for default
2596/// and relaxed mode
2597///
2598/// The goal of this diagnostic is to emit an error or warning when an
2599/// unavailable API is found in code that is reachable from the shader
2600/// entry function or from an exported function (when compiling a shader
2601/// library).
2602///
2603/// This is done by traversing the AST of all shader entry point functions
2604/// and of all exported functions, and any functions that are referenced
2605/// from this AST. In other words, any functions that are reachable from
2606/// the entry points.
2607class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2608 Sema &SemaRef;
2609
2610 // Stack of functions to be scaned
2611 llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;
2612
2613 // Tracks which environments functions have been scanned in.
2614 //
2615 // Maps FunctionDecl to an unsigned number that represents the set of shader
2616 // environments the function has been scanned for.
2617 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2618 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2619 // (verified by static_asserts in Triple.cpp), we can use it to index
2620 // individual bits in the set, as long as we shift the values to start with 0
2621 // by subtracting the value of llvm::Triple::Pixel first.
2622 //
2623 // The N'th bit in the set will be set if the function has been scanned
2624 // in shader environment whose llvm::Triple::EnvironmentType integer value
2625 // equals (llvm::Triple::Pixel + N).
2626 //
2627 // For example, if a function has been scanned in compute and pixel stage
2628 // environment, the value will be 0x21 (100001 binary) because:
2629 //
2630 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2631 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2632 //
2633 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2634 // been scanned in any environment.
2635 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2636
2637 // Do not access these directly, use the get/set methods below to make
2638 // sure the values are in sync
2639 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2640 unsigned CurrentShaderStageBit;
2641
2642 // True if scanning a function that was already scanned in a different
2643 // shader stage context, and therefore we should not report issues that
2644 // depend only on shader model version because they would be duplicate.
2645 bool ReportOnlyShaderStageIssues;
2646
2647 // Helper methods for dealing with current stage context / environment
2648 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2649 static_assert(sizeof(unsigned) >= 4);
2650 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2651 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2652 "ShaderType is too big for this bitmap"); // 31 is reserved for
2653 // "unknown"
2654
2655 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2656 CurrentShaderEnvironment = ShaderType;
2657 CurrentShaderStageBit = (1 << bitmapIndex);
2658 }
2659
2660 void SetUnknownShaderStageContext() {
2661 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2662 CurrentShaderStageBit = (1 << 31);
2663 }
2664
2665 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2666 return CurrentShaderEnvironment;
2667 }
2668
2669 bool InUnknownShaderStageContext() const {
2670 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2671 }
2672
2673 // Helper methods for dealing with shader stage bitmap
2674 void AddToScannedFunctions(const FunctionDecl *FD) {
2675 unsigned &ScannedStages = ScannedDecls[FD];
2676 ScannedStages |= CurrentShaderStageBit;
2677 }
2678
2679 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2680
2681 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2682 return WasAlreadyScannedInCurrentStage(ScannerStages: GetScannedStages(FD));
2683 }
2684
2685 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2686 return ScannerStages & CurrentShaderStageBit;
2687 }
2688
2689 static bool NeverBeenScanned(unsigned ScannedStages) {
2690 return ScannedStages == 0;
2691 }
2692
2693 // Scanning methods
2694 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2695 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2696 SourceRange Range);
2697 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2698 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2699
2700public:
2701 DiagnoseHLSLAvailability(Sema &SemaRef)
2702 : SemaRef(SemaRef),
2703 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2704 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2705
2706 // AST traversal methods
2707 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2708 void RunOnFunction(const FunctionDecl *FD);
2709
2710 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2711 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: DRE->getDecl());
2712 if (FD)
2713 HandleFunctionOrMethodRef(FD, RefExpr: DRE);
2714 return true;
2715 }
2716
2717 bool VisitMemberExpr(MemberExpr *ME) override {
2718 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: ME->getMemberDecl());
2719 if (FD)
2720 HandleFunctionOrMethodRef(FD, RefExpr: ME);
2721 return true;
2722 }
2723};
2724
2725void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2726 Expr *RefExpr) {
2727 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2728 "expected DeclRefExpr or MemberExpr");
2729
2730 // has a definition -> add to stack to be scanned
2731 const FunctionDecl *FDWithBody = nullptr;
2732 if (FD->hasBody(Definition&: FDWithBody)) {
2733 if (!WasAlreadyScannedInCurrentStage(FD: FDWithBody))
2734 DeclsToScan.push_back(Elt: FDWithBody);
2735 return;
2736 }
2737
2738 // no body -> diagnose availability
2739 const AvailabilityAttr *AA = FindAvailabilityAttr(D: FD);
2740 if (AA)
2741 CheckDeclAvailability(
2742 D: FD, AA, Range: SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2743}
2744
2745void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2746 const TranslationUnitDecl *TU) {
2747
2748 // Iterate over all shader entry functions and library exports, and for those
2749 // that have a body (definiton), run diag scan on each, setting appropriate
2750 // shader environment context based on whether it is a shader entry function
2751 // or an exported function. Exported functions can be in namespaces and in
2752 // export declarations so we need to scan those declaration contexts as well.
2753 llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;
2754 DeclContextsToScan.push_back(Elt: TU);
2755
2756 while (!DeclContextsToScan.empty()) {
2757 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2758 for (auto &D : DC->decls()) {
2759 // do not scan implicit declaration generated by the implementation
2760 if (D->isImplicit())
2761 continue;
2762
2763 // for namespace or export declaration add the context to the list to be
2764 // scanned later
2765 if (llvm::dyn_cast<NamespaceDecl>(Val: D) || llvm::dyn_cast<ExportDecl>(Val: D)) {
2766 DeclContextsToScan.push_back(Elt: llvm::dyn_cast<DeclContext>(Val: D));
2767 continue;
2768 }
2769
2770 // skip over other decls or function decls without body
2771 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: D);
2772 if (!FD || !FD->isThisDeclarationADefinition())
2773 continue;
2774
2775 // shader entry point
2776 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2777 SetShaderStageContext(ShaderAttr->getType());
2778 RunOnFunction(FD);
2779 continue;
2780 }
2781 // exported library function
2782 // FIXME: replace this loop with external linkage check once issue #92071
2783 // is resolved
2784 bool isExport = FD->isInExportDeclContext();
2785 if (!isExport) {
2786 for (const auto *Redecl : FD->redecls()) {
2787 if (Redecl->isInExportDeclContext()) {
2788 isExport = true;
2789 break;
2790 }
2791 }
2792 }
2793 if (isExport) {
2794 SetUnknownShaderStageContext();
2795 RunOnFunction(FD);
2796 continue;
2797 }
2798 }
2799 }
2800}
2801
2802void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2803 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2804 DeclsToScan.push_back(Elt: FD);
2805
2806 while (!DeclsToScan.empty()) {
2807 // Take one decl from the stack and check it by traversing its AST.
2808 // For any CallExpr found during the traversal add it's callee to the top of
2809 // the stack to be processed next. Functions already processed are stored in
2810 // ScannedDecls.
2811 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2812
2813 // Decl was already scanned
2814 const unsigned ScannedStages = GetScannedStages(FD);
2815 if (WasAlreadyScannedInCurrentStage(ScannerStages: ScannedStages))
2816 continue;
2817
2818 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2819
2820 AddToScannedFunctions(FD);
2821 TraverseStmt(S: FD->getBody());
2822 }
2823}
2824
2825bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2826 const AvailabilityAttr *AA) {
2827 const IdentifierInfo *IIEnvironment = AA->getEnvironment();
2828 if (!IIEnvironment)
2829 return true;
2830
2831 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2832 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2833 return false;
2834
2835 llvm::Triple::EnvironmentType AttrEnv =
2836 AvailabilityAttr::getEnvironmentType(Environment: IIEnvironment->getName());
2837
2838 return CurrentEnv == AttrEnv;
2839}
2840
2841const AvailabilityAttr *
2842DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2843 AvailabilityAttr const *PartialMatch = nullptr;
2844 // Check each AvailabilityAttr to find the one for this platform.
2845 // For multiple attributes with the same platform try to find one for this
2846 // environment.
2847 for (const auto *A : D->attrs()) {
2848 if (const auto *Avail = dyn_cast<AvailabilityAttr>(Val: A)) {
2849 StringRef AttrPlatform = Avail->getPlatform()->getName();
2850 StringRef TargetPlatform =
2851 SemaRef.getASTContext().getTargetInfo().getPlatformName();
2852
2853 // Match the platform name.
2854 if (AttrPlatform == TargetPlatform) {
2855 // Find the best matching attribute for this environment
2856 if (HasMatchingEnvironmentOrNone(AA: Avail))
2857 return Avail;
2858 PartialMatch = Avail;
2859 }
2860 }
2861 }
2862 return PartialMatch;
2863}
2864
2865// Check availability against target shader model version and current shader
2866// stage and emit diagnostic
2867void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2868 const AvailabilityAttr *AA,
2869 SourceRange Range) {
2870
2871 const IdentifierInfo *IIEnv = AA->getEnvironment();
2872
2873 if (!IIEnv) {
2874 // The availability attribute does not have environment -> it depends only
2875 // on shader model version and not on specific the shader stage.
2876
2877 // Skip emitting the diagnostics if the diagnostic mode is set to
2878 // strict (-fhlsl-strict-availability) because all relevant diagnostics
2879 // were already emitted in the DiagnoseUnguardedAvailability scan
2880 // (SemaAvailability.cpp).
2881 if (SemaRef.getLangOpts().HLSLStrictAvailability)
2882 return;
2883
2884 // Do not report shader-stage-independent issues if scanning a function
2885 // that was already scanned in a different shader stage context (they would
2886 // be duplicate)
2887 if (ReportOnlyShaderStageIssues)
2888 return;
2889
2890 } else {
2891 // The availability attribute has environment -> we need to know
2892 // the current stage context to property diagnose it.
2893 if (InUnknownShaderStageContext())
2894 return;
2895 }
2896
2897 // Check introduced version and if environment matches
2898 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
2899 VersionTuple Introduced = AA->getIntroduced();
2900 VersionTuple TargetVersion =
2901 SemaRef.Context.getTargetInfo().getPlatformMinVersion();
2902
2903 if (TargetVersion >= Introduced && EnvironmentMatches)
2904 return;
2905
2906 // Emit diagnostic message
2907 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2908 llvm::StringRef PlatformName(
2909 AvailabilityAttr::getPrettyPlatformName(Platform: TI.getPlatformName()));
2910
2911 llvm::StringRef CurrentEnvStr =
2912 llvm::Triple::getEnvironmentTypeName(Kind: GetCurrentShaderEnvironment());
2913
2914 llvm::StringRef AttrEnvStr =
2915 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
2916 bool UseEnvironment = !AttrEnvStr.empty();
2917
2918 if (EnvironmentMatches) {
2919 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability)
2920 << Range << D << PlatformName << Introduced.getAsString()
2921 << UseEnvironment << CurrentEnvStr;
2922 } else {
2923 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability_unavailable)
2924 << Range << D;
2925 }
2926
2927 SemaRef.Diag(Loc: D->getLocation(), DiagID: diag::note_partial_availability_specified_here)
2928 << D << PlatformName << Introduced.getAsString()
2929 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
2930 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
2931}
2932
2933} // namespace
2934
2935void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) {
2936 // process default CBuffer - create buffer layout struct and invoke codegenCGH
2937 if (!DefaultCBufferDecls.empty()) {
2938 HLSLBufferDecl *DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer(
2939 C&: SemaRef.getASTContext(), LexicalParent: SemaRef.getCurLexicalContext(),
2940 DefaultCBufferDecls);
2941 addImplicitBindingAttrToDecl(S&: SemaRef, D: DefaultCBuffer, RT: RegisterType::CBuffer,
2942 ImplicitBindingOrderID: getNextImplicitBindingOrderID());
2943 SemaRef.getCurLexicalContext()->addDecl(D: DefaultCBuffer);
2944 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl: DefaultCBuffer);
2945
2946 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
2947 for (const Decl *VD : DefaultCBufferDecls) {
2948 const HLSLResourceBindingAttr *RBA =
2949 VD->getAttr<HLSLResourceBindingAttr>();
2950 if (RBA && RBA->hasRegisterSlot() &&
2951 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
2952 DefaultCBuffer->setHasValidPackoffset(true);
2953 break;
2954 }
2955 }
2956
2957 DeclGroupRef DG(DefaultCBuffer);
2958 SemaRef.Consumer.HandleTopLevelDecl(D: DG);
2959 }
2960 diagnoseAvailabilityViolations(TU);
2961}
2962
2963void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
2964 // Skip running the diagnostics scan if the diagnostic mode is
2965 // strict (-fhlsl-strict-availability) and the target shader stage is known
2966 // because all relevant diagnostics were already emitted in the
2967 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
2968 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2969 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
2970 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
2971 return;
2972
2973 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
2974}
2975
2976static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
2977 assert(TheCall->getNumArgs() > 1);
2978 QualType ArgTy0 = TheCall->getArg(Arg: 0)->getType();
2979
2980 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
2981 if (!S->getASTContext().hasSameUnqualifiedType(
2982 T1: ArgTy0, T2: TheCall->getArg(Arg: I)->getType())) {
2983 S->Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vec_builtin_incompatible_vector)
2984 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
2985 << SourceRange(TheCall->getArg(Arg: 0)->getBeginLoc(),
2986 TheCall->getArg(Arg: N - 1)->getEndLoc());
2987 return true;
2988 }
2989 }
2990 return false;
2991}
2992
2993static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
2994 QualType ArgType = Arg->getType();
2995 if (!S->getASTContext().hasSameUnqualifiedType(T1: ArgType, T2: ExpectedType)) {
2996 S->Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_typecheck_convert_incompatible)
2997 << ArgType << ExpectedType << 1 << 0 << 0;
2998 return true;
2999 }
3000 return false;
3001}
3002
3003static bool CheckAllArgTypesAreCorrect(
3004 Sema *S, CallExpr *TheCall,
3005 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
3006 clang::QualType PassedType)>
3007 Check) {
3008 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
3009 Expr *Arg = TheCall->getArg(Arg: I);
3010 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
3011 return true;
3012 }
3013 return false;
3014}
3015
3016static bool CheckFloatRepresentation(Sema *S, SourceLocation Loc,
3017 int ArgOrdinal,
3018 clang::QualType PassedType) {
3019 clang::QualType BaseType =
3020 PassedType->isVectorType()
3021 ? PassedType->castAs<clang::VectorType>()->getElementType()
3022 : PassedType;
3023 if (!BaseType->isFloat32Type())
3024 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3025 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3026 << /* float */ 1 << PassedType;
3027 return false;
3028}
3029
3030static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
3031 int ArgOrdinal,
3032 clang::QualType PassedType) {
3033 clang::QualType BaseType =
3034 PassedType->isVectorType()
3035 ? PassedType->castAs<clang::VectorType>()->getElementType()
3036 : PassedType;
3037 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
3038 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3039 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3040 << /* half or float */ 2 << PassedType;
3041 return false;
3042}
3043
3044static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
3045 unsigned ArgIndex) {
3046 auto *Arg = TheCall->getArg(Arg: ArgIndex);
3047 SourceLocation OrigLoc = Arg->getExprLoc();
3048 if (Arg->IgnoreCasts()->isModifiableLvalue(Ctx&: S->Context, Loc: &OrigLoc) ==
3049 Expr::MLV_Valid)
3050 return false;
3051 S->Diag(Loc: OrigLoc, DiagID: diag::error_hlsl_inout_lvalue) << Arg << 0;
3052 return true;
3053}
3054
3055static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
3056 clang::QualType PassedType) {
3057 const auto *VecTy = PassedType->getAs<VectorType>();
3058 if (!VecTy)
3059 return false;
3060
3061 if (VecTy->getElementType()->isDoubleType())
3062 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3063 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
3064 << PassedType;
3065 return false;
3066}
3067
3068static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
3069 int ArgOrdinal,
3070 clang::QualType PassedType) {
3071 if (!PassedType->hasIntegerRepresentation() &&
3072 !PassedType->hasFloatingRepresentation())
3073 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3074 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
3075 << /* fp */ 1 << PassedType;
3076 return false;
3077}
3078
3079static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
3080 int ArgOrdinal,
3081 clang::QualType PassedType) {
3082 if (auto *VecTy = PassedType->getAs<VectorType>())
3083 if (VecTy->getElementType()->isUnsignedIntegerType())
3084 return false;
3085
3086 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3087 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
3088 << PassedType;
3089}
3090
3091// checks for unsigned ints of all sizes
3092static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
3093 int ArgOrdinal,
3094 clang::QualType PassedType) {
3095 if (!PassedType->hasUnsignedIntegerRepresentation())
3096 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3097 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
3098 << /* no fp */ 0 << PassedType;
3099 return false;
3100}
3101
3102static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall,
3103 unsigned ArgOrdinal, unsigned Width) {
3104 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
3105 if (auto *VTy = ArgTy->getAs<VectorType>())
3106 ArgTy = VTy->getElementType();
3107 // ensure arg type has expected bit width
3108 uint64_t ElementBitCount =
3109 S->getASTContext().getTypeSizeInChars(T: ArgTy).getQuantity() * 8;
3110 if (ElementBitCount != Width) {
3111 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3112 DiagID: diag::err_integer_incorrect_bit_count)
3113 << Width << ElementBitCount;
3114 return true;
3115 }
3116 return false;
3117}
3118
3119static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
3120 QualType ReturnType) {
3121 auto *VecTyA = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
3122 if (VecTyA)
3123 ReturnType =
3124 S->Context.getExtVectorType(VectorType: ReturnType, NumElts: VecTyA->getNumElements());
3125
3126 TheCall->setType(ReturnType);
3127}
3128
3129static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
3130 unsigned ArgIndex) {
3131 assert(TheCall->getNumArgs() >= ArgIndex);
3132 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3133 auto *VTy = ArgType->getAs<VectorType>();
3134 // not the scalar or vector<scalar>
3135 if (!(S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar) ||
3136 (VTy &&
3137 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar)))) {
3138 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3139 DiagID: diag::err_typecheck_expect_scalar_or_vector)
3140 << ArgType << Scalar;
3141 return true;
3142 }
3143 return false;
3144}
3145
3146static bool CheckScalarOrVectorOrMatrix(Sema *S, CallExpr *TheCall,
3147 QualType Scalar, unsigned ArgIndex) {
3148 assert(TheCall->getNumArgs() > ArgIndex);
3149
3150 Expr *Arg = TheCall->getArg(Arg: ArgIndex);
3151 QualType ArgType = Arg->getType();
3152
3153 // Scalar: T
3154 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar))
3155 return false;
3156
3157 // Vector: vector<T>
3158 if (const auto *VTy = ArgType->getAs<VectorType>()) {
3159 if (S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar))
3160 return false;
3161 }
3162
3163 // Matrix: ConstantMatrixType with element type T
3164 if (const auto *MTy = ArgType->getAs<ConstantMatrixType>()) {
3165 if (S->Context.hasSameUnqualifiedType(T1: MTy->getElementType(), T2: Scalar))
3166 return false;
3167 }
3168
3169 // Not a scalar/vector/matrix-of-scalar
3170 S->Diag(Loc: Arg->getBeginLoc(),
3171 DiagID: diag::err_typecheck_expect_scalar_or_vector_or_matrix)
3172 << ArgType << Scalar;
3173 return true;
3174}
3175
3176static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
3177 unsigned ArgIndex) {
3178 assert(TheCall->getNumArgs() >= ArgIndex);
3179 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3180 auto *VTy = ArgType->getAs<VectorType>();
3181 // not the scalar or vector<scalar>
3182 if (!(ArgType->isScalarType() ||
3183 (VTy && VTy->getElementType()->isScalarType()))) {
3184 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3185 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3186 << ArgType << 1;
3187 return true;
3188 }
3189 return false;
3190}
3191
3192// Check that the argument is not a bool or vector<bool>
3193// Returns true on error
3194static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall,
3195 unsigned ArgIndex) {
3196 QualType BoolType = S->getASTContext().BoolTy;
3197 assert(ArgIndex < TheCall->getNumArgs());
3198 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3199 auto *VTy = ArgType->getAs<VectorType>();
3200 // is the bool or vector<bool>
3201 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: BoolType) ||
3202 (VTy &&
3203 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: BoolType))) {
3204 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3205 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3206 << ArgType << 0;
3207 return true;
3208 }
3209 return false;
3210}
3211
3212static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
3213 if (CheckNotBoolScalarOrVector(S, TheCall, ArgIndex: 0))
3214 return true;
3215 return false;
3216}
3217
3218static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
3219 if (CheckNotBoolScalarOrVector(S, TheCall, ArgIndex: 0))
3220 return true;
3221 return false;
3222}
3223
3224static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
3225 assert(TheCall->getNumArgs() == 3);
3226 Expr *Arg1 = TheCall->getArg(Arg: 1);
3227 Expr *Arg2 = TheCall->getArg(Arg: 2);
3228 if (!S->Context.hasSameUnqualifiedType(T1: Arg1->getType(), T2: Arg2->getType())) {
3229 S->Diag(Loc: TheCall->getBeginLoc(),
3230 DiagID: diag::err_typecheck_call_different_arg_types)
3231 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
3232 << Arg2->getSourceRange();
3233 return true;
3234 }
3235
3236 TheCall->setType(Arg1->getType());
3237 return false;
3238}
3239
3240static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
3241 assert(TheCall->getNumArgs() == 3);
3242 Expr *Arg1 = TheCall->getArg(Arg: 1);
3243 QualType Arg1Ty = Arg1->getType();
3244 Expr *Arg2 = TheCall->getArg(Arg: 2);
3245 QualType Arg2Ty = Arg2->getType();
3246
3247 QualType Arg1ScalarTy = Arg1Ty;
3248 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
3249 Arg1ScalarTy = VTy->getElementType();
3250
3251 QualType Arg2ScalarTy = Arg2Ty;
3252 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
3253 Arg2ScalarTy = VTy->getElementType();
3254
3255 if (!S->Context.hasSameUnqualifiedType(T1: Arg1ScalarTy, T2: Arg2ScalarTy))
3256 S->Diag(Loc: Arg1->getBeginLoc(), DiagID: diag::err_hlsl_builtin_scalar_vector_mismatch)
3257 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
3258
3259 QualType Arg0Ty = TheCall->getArg(Arg: 0)->getType();
3260 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
3261 unsigned Arg1Length = Arg1Ty->isVectorType()
3262 ? Arg1Ty->getAs<VectorType>()->getNumElements()
3263 : 0;
3264 unsigned Arg2Length = Arg2Ty->isVectorType()
3265 ? Arg2Ty->getAs<VectorType>()->getNumElements()
3266 : 0;
3267 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
3268 S->Diag(Loc: TheCall->getBeginLoc(),
3269 DiagID: diag::err_typecheck_vector_lengths_not_equal)
3270 << Arg0Ty << Arg1Ty << TheCall->getArg(Arg: 0)->getSourceRange()
3271 << Arg1->getSourceRange();
3272 return true;
3273 }
3274
3275 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
3276 S->Diag(Loc: TheCall->getBeginLoc(),
3277 DiagID: diag::err_typecheck_vector_lengths_not_equal)
3278 << Arg0Ty << Arg2Ty << TheCall->getArg(Arg: 0)->getSourceRange()
3279 << Arg2->getSourceRange();
3280 return true;
3281 }
3282
3283 TheCall->setType(
3284 S->getASTContext().getExtVectorType(VectorType: Arg1ScalarTy, NumElts: Arg0Length));
3285 return false;
3286}
3287
3288static bool CheckResourceHandle(
3289 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
3290 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
3291 nullptr) {
3292 assert(TheCall->getNumArgs() >= ArgIndex);
3293 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3294 const HLSLAttributedResourceType *ResTy =
3295 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
3296 if (!ResTy) {
3297 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getBeginLoc(),
3298 DiagID: diag::err_typecheck_expect_hlsl_resource)
3299 << ArgType;
3300 return true;
3301 }
3302 if (Check && Check(ResTy)) {
3303 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getExprLoc(),
3304 DiagID: diag::err_invalid_hlsl_resource_type)
3305 << ArgType;
3306 return true;
3307 }
3308 return false;
3309}
3310
3311static bool CheckVectorElementCount(Sema *S, QualType PassedType,
3312 QualType BaseType, unsigned ExpectedCount,
3313 SourceLocation Loc) {
3314 unsigned PassedCount = 1;
3315 if (const auto *VecTy = PassedType->getAs<VectorType>())
3316 PassedCount = VecTy->getNumElements();
3317
3318 if (PassedCount != ExpectedCount) {
3319 QualType ExpectedType =
3320 S->Context.getExtVectorType(VectorType: BaseType, NumElts: ExpectedCount);
3321 S->Diag(Loc, DiagID: diag::err_typecheck_convert_incompatible)
3322 << PassedType << ExpectedType << 1 << 0 << 0;
3323 return true;
3324 }
3325 return false;
3326}
3327
3328// Note: returning true in this case results in CheckBuiltinFunctionCall
3329// returning an ExprError
3330bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
3331 switch (BuiltinID) {
3332 case Builtin::BI__builtin_hlsl_adduint64: {
3333 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3334 return true;
3335
3336 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3337 Check: CheckUnsignedIntVecRepresentation))
3338 return true;
3339
3340 // ensure arg integers are 32-bits
3341 if (CheckExpectedBitWidth(S: &SemaRef, TheCall, ArgOrdinal: 0, Width: 32))
3342 return true;
3343
3344 // ensure both args are vectors of total bit size of a multiple of 64
3345 auto *VTy = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
3346 int NumElementsArg = VTy->getNumElements();
3347 if (NumElementsArg != 2 && NumElementsArg != 4) {
3348 SemaRef.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vector_incorrect_bit_count)
3349 << 1 /*a multiple of*/ << 64 << NumElementsArg * 32;
3350 return true;
3351 }
3352
3353 // ensure first arg and second arg have the same type
3354 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3355 return true;
3356
3357 ExprResult A = TheCall->getArg(Arg: 0);
3358 QualType ArgTyA = A.get()->getType();
3359 // return type is the same as the input type
3360 TheCall->setType(ArgTyA);
3361 break;
3362 }
3363 case Builtin::BI__builtin_hlsl_resource_getpointer: {
3364 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2) ||
3365 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3366 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3367 ExpectedType: SemaRef.getASTContext().UnsignedIntTy))
3368 return true;
3369
3370 auto *ResourceTy =
3371 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3372 QualType ContainedTy = ResourceTy->getContainedType();
3373 auto ReturnType =
3374 SemaRef.Context.getAddrSpaceQualType(T: ContainedTy, AddressSpace: LangAS::hlsl_device);
3375 ReturnType = SemaRef.Context.getPointerType(T: ReturnType);
3376 TheCall->setType(ReturnType);
3377 TheCall->setValueKind(VK_LValue);
3378
3379 break;
3380 }
3381 case Builtin::BI__builtin_hlsl_resource_load_with_status: {
3382 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3) ||
3383 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3384 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3385 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3386 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2),
3387 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3388 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
3389 return true;
3390
3391 auto *ResourceTy =
3392 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3393 QualType ReturnType = ResourceTy->getContainedType();
3394 TheCall->setType(ReturnType);
3395
3396 break;
3397 }
3398 case Builtin::BI__builtin_hlsl_resource_sample: {
3399 if (SemaRef.checkArgCountRange(Call: TheCall, MinArgCount: 3, MaxArgCount: 5))
3400 return true;
3401
3402 if (CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0,
3403 Check: [](const HLSLAttributedResourceType *ResType) {
3404 return ResType->getAttrs().ResourceDimension ==
3405 llvm::dxil::ResourceDimension::Unknown;
3406 }))
3407 return true;
3408
3409 if (CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 1,
3410 Check: [](const HLSLAttributedResourceType *ResType) {
3411 return ResType->getAttrs().ResourceClass !=
3412 llvm::hlsl::ResourceClass::Sampler;
3413 }))
3414 return true;
3415
3416 auto *ResourceTy =
3417 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3418
3419 unsigned ExpectedDim =
3420 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3421 if (CheckVectorElementCount(S: &SemaRef, PassedType: TheCall->getArg(Arg: 2)->getType(),
3422 BaseType: SemaRef.Context.FloatTy, ExpectedCount: ExpectedDim,
3423 Loc: TheCall->getArg(Arg: 2)->getBeginLoc()))
3424 return true;
3425
3426 if (TheCall->getNumArgs() > 3) {
3427 if (CheckVectorElementCount(S: &SemaRef, PassedType: TheCall->getArg(Arg: 3)->getType(),
3428 BaseType: SemaRef.Context.IntTy, ExpectedCount: ExpectedDim,
3429 Loc: TheCall->getArg(Arg: 3)->getBeginLoc()))
3430 return true;
3431 }
3432
3433 if (TheCall->getNumArgs() > 4) {
3434 QualType ClampTy = TheCall->getArg(Arg: 4)->getType();
3435 if (!ClampTy->isFloatingType() || ClampTy->isVectorType()) {
3436 SemaRef.Diag(Loc: TheCall->getArg(Arg: 4)->getBeginLoc(),
3437 DiagID: diag::err_typecheck_convert_incompatible)
3438 << ClampTy << SemaRef.Context.FloatTy << 1 << 0 << 0;
3439 return true;
3440 }
3441 }
3442
3443 assert(ResourceTy->hasContainedType() &&
3444 "Expecting a contained type for resource with a dimension "
3445 "attribute.");
3446 QualType ReturnType = ResourceTy->getContainedType();
3447 TheCall->setType(ReturnType);
3448
3449 break;
3450 }
3451 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
3452 assert(TheCall->getNumArgs() == 1 && "expected 1 arg");
3453 // Update return type to be the attributed resource type from arg0.
3454 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3455 TheCall->setType(ResourceTy);
3456 break;
3457 }
3458 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
3459 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3460 // Update return type to be the attributed resource type from arg0.
3461 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3462 TheCall->setType(ResourceTy);
3463 break;
3464 }
3465 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
3466 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3467 // Update return type to be the attributed resource type from arg0.
3468 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3469 TheCall->setType(ResourceTy);
3470 break;
3471 }
3472 case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: {
3473 assert(TheCall->getNumArgs() == 3 && "expected 3 args");
3474 ASTContext &AST = SemaRef.getASTContext();
3475 QualType MainHandleTy = TheCall->getArg(Arg: 0)->getType();
3476 auto *MainResType = MainHandleTy->getAs<HLSLAttributedResourceType>();
3477 auto MainAttrs = MainResType->getAttrs();
3478 assert(!MainAttrs.IsCounter && "cannot create a counter from a counter");
3479 MainAttrs.IsCounter = true;
3480 QualType CounterHandleTy = AST.getHLSLAttributedResourceType(
3481 Wrapped: MainResType->getWrappedType(), Contained: MainResType->getContainedType(),
3482 Attrs: MainAttrs);
3483 // Update return type to be the attributed resource type from arg0
3484 // with added IsCounter flag.
3485 TheCall->setType(CounterHandleTy);
3486 break;
3487 }
3488 case Builtin::BI__builtin_hlsl_and:
3489 case Builtin::BI__builtin_hlsl_or: {
3490 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3491 return true;
3492 if (CheckScalarOrVectorOrMatrix(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy,
3493 ArgIndex: 0))
3494 return true;
3495 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3496 return true;
3497
3498 ExprResult A = TheCall->getArg(Arg: 0);
3499 QualType ArgTyA = A.get()->getType();
3500 // return type is the same as the input type
3501 TheCall->setType(ArgTyA);
3502 break;
3503 }
3504 case Builtin::BI__builtin_hlsl_all:
3505 case Builtin::BI__builtin_hlsl_any: {
3506 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3507 return true;
3508 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3509 return true;
3510 break;
3511 }
3512 case Builtin::BI__builtin_hlsl_asdouble: {
3513 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3514 return true;
3515 if (CheckScalarOrVector(
3516 S: &SemaRef, TheCall,
3517 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
3518 /* arg index */ ArgIndex: 0))
3519 return true;
3520 if (CheckScalarOrVector(
3521 S: &SemaRef, TheCall,
3522 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
3523 /* arg index */ ArgIndex: 1))
3524 return true;
3525 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3526 return true;
3527
3528 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().DoubleTy);
3529 break;
3530 }
3531 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
3532 if (SemaRef.BuiltinElementwiseTernaryMath(
3533 TheCall, /*ArgTyRestr=*/
3534 Sema::EltwiseBuiltinArgTyRestriction::None))
3535 return true;
3536 break;
3537 }
3538 case Builtin::BI__builtin_hlsl_dot: {
3539 // arg count is checked by BuiltinVectorToScalarMath
3540 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
3541 return true;
3542 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckNoDoubleVectors))
3543 return true;
3544 break;
3545 }
3546 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
3547 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
3548 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3549 return true;
3550
3551 const Expr *Arg = TheCall->getArg(Arg: 0);
3552 QualType ArgTy = Arg->getType();
3553 QualType EltTy = ArgTy;
3554
3555 QualType ResTy = SemaRef.Context.UnsignedIntTy;
3556
3557 if (auto *VecTy = EltTy->getAs<VectorType>()) {
3558 EltTy = VecTy->getElementType();
3559 ResTy = SemaRef.Context.getExtVectorType(VectorType: ResTy, NumElts: VecTy->getNumElements());
3560 }
3561
3562 if (!EltTy->isIntegerType()) {
3563 Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
3564 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
3565 << /* no fp */ 0 << ArgTy;
3566 return true;
3567 }
3568
3569 TheCall->setType(ResTy);
3570 break;
3571 }
3572 case Builtin::BI__builtin_hlsl_select: {
3573 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
3574 return true;
3575 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy, ArgIndex: 0))
3576 return true;
3577 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
3578 if (ArgTy->isBooleanType() && CheckBoolSelect(S: &SemaRef, TheCall))
3579 return true;
3580 auto *VTy = ArgTy->getAs<VectorType>();
3581 if (VTy && VTy->getElementType()->isBooleanType() &&
3582 CheckVectorSelect(S: &SemaRef, TheCall))
3583 return true;
3584 break;
3585 }
3586 case Builtin::BI__builtin_hlsl_elementwise_saturate:
3587 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
3588 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3589 return true;
3590 if (!TheCall->getArg(Arg: 0)
3591 ->getType()
3592 ->hasFloatingRepresentation()) // half or float or double
3593 return SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3594 DiagID: diag::err_builtin_invalid_arg_type)
3595 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
3596 << /* fp */ 1 << TheCall->getArg(Arg: 0)->getType();
3597 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3598 return true;
3599 break;
3600 }
3601 case Builtin::BI__builtin_hlsl_elementwise_degrees:
3602 case Builtin::BI__builtin_hlsl_elementwise_radians:
3603 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
3604 case Builtin::BI__builtin_hlsl_elementwise_frac:
3605 case Builtin::BI__builtin_hlsl_elementwise_ddx_coarse:
3606 case Builtin::BI__builtin_hlsl_elementwise_ddy_coarse:
3607 case Builtin::BI__builtin_hlsl_elementwise_ddx_fine:
3608 case Builtin::BI__builtin_hlsl_elementwise_ddy_fine: {
3609 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3610 return true;
3611 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3612 Check: CheckFloatOrHalfRepresentation))
3613 return true;
3614 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3615 return true;
3616 break;
3617 }
3618 case Builtin::BI__builtin_hlsl_elementwise_isinf:
3619 case Builtin::BI__builtin_hlsl_elementwise_isnan: {
3620 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3621 return true;
3622 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3623 Check: CheckFloatOrHalfRepresentation))
3624 return true;
3625 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3626 return true;
3627 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().BoolTy);
3628 break;
3629 }
3630 case Builtin::BI__builtin_hlsl_lerp: {
3631 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
3632 return true;
3633 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3634 Check: CheckFloatOrHalfRepresentation))
3635 return true;
3636 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3637 return true;
3638 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
3639 return true;
3640 break;
3641 }
3642 case Builtin::BI__builtin_hlsl_mad: {
3643 if (SemaRef.BuiltinElementwiseTernaryMath(
3644 TheCall, /*ArgTyRestr=*/
3645 Sema::EltwiseBuiltinArgTyRestriction::None))
3646 return true;
3647 break;
3648 }
3649 case Builtin::BI__builtin_hlsl_normalize: {
3650 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3651 return true;
3652 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3653 Check: CheckFloatOrHalfRepresentation))
3654 return true;
3655 ExprResult A = TheCall->getArg(Arg: 0);
3656 QualType ArgTyA = A.get()->getType();
3657 // return type is the same as the input type
3658 TheCall->setType(ArgTyA);
3659 break;
3660 }
3661 case Builtin::BI__builtin_hlsl_elementwise_sign: {
3662 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3663 return true;
3664 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3665 Check: CheckFloatingOrIntRepresentation))
3666 return true;
3667 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().IntTy);
3668 break;
3669 }
3670 case Builtin::BI__builtin_hlsl_step: {
3671 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3672 return true;
3673 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3674 Check: CheckFloatOrHalfRepresentation))
3675 return true;
3676
3677 ExprResult A = TheCall->getArg(Arg: 0);
3678 QualType ArgTyA = A.get()->getType();
3679 // return type is the same as the input type
3680 TheCall->setType(ArgTyA);
3681 break;
3682 }
3683 case Builtin::BI__builtin_hlsl_wave_active_max:
3684 case Builtin::BI__builtin_hlsl_wave_active_min:
3685 case Builtin::BI__builtin_hlsl_wave_active_sum: {
3686 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3687 return true;
3688
3689 // Ensure input expr type is a scalar/vector and the same as the return type
3690 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3691 return true;
3692 if (CheckWaveActive(S: &SemaRef, TheCall))
3693 return true;
3694 ExprResult Expr = TheCall->getArg(Arg: 0);
3695 QualType ArgTyExpr = Expr.get()->getType();
3696 TheCall->setType(ArgTyExpr);
3697 break;
3698 }
3699 // Note these are llvm builtins that we want to catch invalid intrinsic
3700 // generation. Normal handling of these builtins will occur elsewhere.
3701 case Builtin::BI__builtin_elementwise_bitreverse: {
3702 // does not include a check for number of arguments
3703 // because that is done previously
3704 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3705 Check: CheckUnsignedIntRepresentation))
3706 return true;
3707 break;
3708 }
3709 case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: {
3710 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3711 return true;
3712
3713 QualType ArgType = TheCall->getArg(Arg: 0)->getType();
3714
3715 if (!(ArgType->isScalarType())) {
3716 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3717 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3718 << ArgType << 0;
3719 return true;
3720 }
3721
3722 if (!(ArgType->isBooleanType())) {
3723 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3724 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3725 << ArgType << 0;
3726 return true;
3727 }
3728
3729 break;
3730 }
3731 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
3732 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3733 return true;
3734
3735 // Ensure index parameter type can be interpreted as a uint
3736 ExprResult Index = TheCall->getArg(Arg: 1);
3737 QualType ArgTyIndex = Index.get()->getType();
3738 if (!ArgTyIndex->isIntegerType()) {
3739 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
3740 DiagID: diag::err_typecheck_convert_incompatible)
3741 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
3742 return true;
3743 }
3744
3745 // Ensure input expr type is a scalar/vector and the same as the return type
3746 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3747 return true;
3748
3749 ExprResult Expr = TheCall->getArg(Arg: 0);
3750 QualType ArgTyExpr = Expr.get()->getType();
3751 TheCall->setType(ArgTyExpr);
3752 break;
3753 }
3754 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
3755 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 0))
3756 return true;
3757 break;
3758 }
3759 case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
3760 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3761 return true;
3762
3763 // Ensure input expr type is a scalar/vector and the same as the return type
3764 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3765 return true;
3766 if (CheckWavePrefix(S: &SemaRef, TheCall))
3767 return true;
3768 ExprResult Expr = TheCall->getArg(Arg: 0);
3769 QualType ArgTyExpr = Expr.get()->getType();
3770 TheCall->setType(ArgTyExpr);
3771 break;
3772 }
3773 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
3774 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
3775 return true;
3776
3777 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.DoubleTy, ArgIndex: 0) ||
3778 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
3779 ArgIndex: 1) ||
3780 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
3781 ArgIndex: 2))
3782 return true;
3783
3784 if (CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 1) ||
3785 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
3786 return true;
3787 break;
3788 }
3789 case Builtin::BI__builtin_hlsl_elementwise_clip: {
3790 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3791 return true;
3792
3793 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.FloatTy, ArgIndex: 0))
3794 return true;
3795 break;
3796 }
3797 case Builtin::BI__builtin_elementwise_acos:
3798 case Builtin::BI__builtin_elementwise_asin:
3799 case Builtin::BI__builtin_elementwise_atan:
3800 case Builtin::BI__builtin_elementwise_atan2:
3801 case Builtin::BI__builtin_elementwise_ceil:
3802 case Builtin::BI__builtin_elementwise_cos:
3803 case Builtin::BI__builtin_elementwise_cosh:
3804 case Builtin::BI__builtin_elementwise_exp:
3805 case Builtin::BI__builtin_elementwise_exp2:
3806 case Builtin::BI__builtin_elementwise_exp10:
3807 case Builtin::BI__builtin_elementwise_floor:
3808 case Builtin::BI__builtin_elementwise_fmod:
3809 case Builtin::BI__builtin_elementwise_log:
3810 case Builtin::BI__builtin_elementwise_log2:
3811 case Builtin::BI__builtin_elementwise_log10:
3812 case Builtin::BI__builtin_elementwise_pow:
3813 case Builtin::BI__builtin_elementwise_roundeven:
3814 case Builtin::BI__builtin_elementwise_sin:
3815 case Builtin::BI__builtin_elementwise_sinh:
3816 case Builtin::BI__builtin_elementwise_sqrt:
3817 case Builtin::BI__builtin_elementwise_tan:
3818 case Builtin::BI__builtin_elementwise_tanh:
3819 case Builtin::BI__builtin_elementwise_trunc: {
3820 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3821 Check: CheckFloatOrHalfRepresentation))
3822 return true;
3823 break;
3824 }
3825 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
3826 assert(TheCall->getNumArgs() == 2 && "expected 2 args");
3827 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
3828 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
3829 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
3830 };
3831 if (CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0, Check: checkResTy))
3832 return true;
3833 Expr *OffsetExpr = TheCall->getArg(Arg: 1);
3834 std::optional<llvm::APSInt> Offset =
3835 OffsetExpr->getIntegerConstantExpr(Ctx: SemaRef.getASTContext());
3836 if (!Offset.has_value() || std::abs(i: Offset->getExtValue()) != 1) {
3837 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
3838 DiagID: diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
3839 << 1;
3840 return true;
3841 }
3842 break;
3843 }
3844 case Builtin::BI__builtin_hlsl_elementwise_f16tof32: {
3845 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3846 return true;
3847 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3848 Check: CheckUnsignedIntRepresentation))
3849 return true;
3850 // ensure arg integers are 32 bits
3851 if (CheckExpectedBitWidth(S: &SemaRef, TheCall, ArgOrdinal: 0, Width: 32))
3852 return true;
3853 // check it wasn't a bool type
3854 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
3855 if (auto *VTy = ArgTy->getAs<VectorType>())
3856 ArgTy = VTy->getElementType();
3857 if (ArgTy->isBooleanType()) {
3858 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3859 DiagID: diag::err_builtin_invalid_arg_type)
3860 << 1 << /* scalar or vector of */ 5 << /* unsigned int */ 3
3861 << /* no fp */ 0 << TheCall->getArg(Arg: 0)->getType();
3862 return true;
3863 }
3864
3865 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().FloatTy);
3866 break;
3867 }
3868 case Builtin::BI__builtin_hlsl_elementwise_f32tof16: {
3869 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3870 return true;
3871 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckFloatRepresentation))
3872 return true;
3873 SetElementTypeAsReturnType(S: &SemaRef, TheCall,
3874 ReturnType: getASTContext().UnsignedIntTy);
3875 break;
3876 }
3877 }
3878 return false;
3879}
3880
3881static void BuildFlattenedTypeList(QualType BaseTy,
3882 llvm::SmallVectorImpl<QualType> &List) {
3883 llvm::SmallVector<QualType, 16> WorkList;
3884 WorkList.push_back(Elt: BaseTy);
3885 while (!WorkList.empty()) {
3886 QualType T = WorkList.pop_back_val();
3887 T = T.getCanonicalType().getUnqualifiedType();
3888 if (const auto *AT = dyn_cast<ConstantArrayType>(Val&: T)) {
3889 llvm::SmallVector<QualType, 16> ElementFields;
3890 // Generally I've avoided recursion in this algorithm, but arrays of
3891 // structs could be time-consuming to flatten and churn through on the
3892 // work list. Hopefully nesting arrays of structs containing arrays
3893 // of structs too many levels deep is unlikely.
3894 BuildFlattenedTypeList(BaseTy: AT->getElementType(), List&: ElementFields);
3895 // Repeat the element's field list n times.
3896 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
3897 llvm::append_range(C&: List, R&: ElementFields);
3898 continue;
3899 }
3900 // Vectors can only have element types that are builtin types, so this can
3901 // add directly to the list instead of to the WorkList.
3902 if (const auto *VT = dyn_cast<VectorType>(Val&: T)) {
3903 List.insert(I: List.end(), NumToInsert: VT->getNumElements(), Elt: VT->getElementType());
3904 continue;
3905 }
3906 if (const auto *MT = dyn_cast<ConstantMatrixType>(Val&: T)) {
3907 List.insert(I: List.end(), NumToInsert: MT->getNumElementsFlattened(),
3908 Elt: MT->getElementType());
3909 continue;
3910 }
3911 if (const auto *RD = T->getAsCXXRecordDecl()) {
3912 if (RD->isStandardLayout())
3913 RD = RD->getStandardLayoutBaseWithFields();
3914
3915 // For types that we shouldn't decompose (unions and non-aggregates), just
3916 // add the type itself to the list.
3917 if (RD->isUnion() || !RD->isAggregate()) {
3918 List.push_back(Elt: T);
3919 continue;
3920 }
3921
3922 llvm::SmallVector<QualType, 16> FieldTypes;
3923 for (const auto *FD : RD->fields())
3924 if (!FD->isUnnamedBitField())
3925 FieldTypes.push_back(Elt: FD->getType());
3926 // Reverse the newly added sub-range.
3927 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
3928 llvm::append_range(C&: WorkList, R&: FieldTypes);
3929
3930 // If this wasn't a standard layout type we may also have some base
3931 // classes to deal with.
3932 if (!RD->isStandardLayout()) {
3933 FieldTypes.clear();
3934 for (const auto &Base : RD->bases())
3935 FieldTypes.push_back(Elt: Base.getType());
3936 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
3937 llvm::append_range(C&: WorkList, R&: FieldTypes);
3938 }
3939 continue;
3940 }
3941 List.push_back(Elt: T);
3942 }
3943}
3944
3945bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) {
3946 // null and array types are not allowed.
3947 if (QT.isNull() || QT->isArrayType())
3948 return false;
3949
3950 // UDT types are not allowed
3951 if (QT->isRecordType())
3952 return false;
3953
3954 if (QT->isBooleanType() || QT->isEnumeralType())
3955 return false;
3956
3957 // the only other valid builtin types are scalars or vectors
3958 if (QT->isArithmeticType()) {
3959 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
3960 return false;
3961 return true;
3962 }
3963
3964 if (const VectorType *VT = QT->getAs<VectorType>()) {
3965 int ArraySize = VT->getNumElements();
3966
3967 if (ArraySize > 4)
3968 return false;
3969
3970 QualType ElTy = VT->getElementType();
3971 if (ElTy->isBooleanType())
3972 return false;
3973
3974 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
3975 return false;
3976 return true;
3977 }
3978
3979 return false;
3980}
3981
3982bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const {
3983 if (T1.isNull() || T2.isNull())
3984 return false;
3985
3986 T1 = T1.getCanonicalType().getUnqualifiedType();
3987 T2 = T2.getCanonicalType().getUnqualifiedType();
3988
3989 // If both types are the same canonical type, they're obviously compatible.
3990 if (SemaRef.getASTContext().hasSameType(T1, T2))
3991 return true;
3992
3993 llvm::SmallVector<QualType, 16> T1Types;
3994 BuildFlattenedTypeList(BaseTy: T1, List&: T1Types);
3995 llvm::SmallVector<QualType, 16> T2Types;
3996 BuildFlattenedTypeList(BaseTy: T2, List&: T2Types);
3997
3998 // Check the flattened type list
3999 return llvm::equal(LRange&: T1Types, RRange&: T2Types,
4000 P: [this](QualType LHS, QualType RHS) -> bool {
4001 return SemaRef.IsLayoutCompatible(T1: LHS, T2: RHS);
4002 });
4003}
4004
4005bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
4006 FunctionDecl *Old) {
4007 if (New->getNumParams() != Old->getNumParams())
4008 return true;
4009
4010 bool HadError = false;
4011
4012 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
4013 ParmVarDecl *NewParam = New->getParamDecl(i);
4014 ParmVarDecl *OldParam = Old->getParamDecl(i);
4015
4016 // HLSL parameter declarations for inout and out must match between
4017 // declarations. In HLSL inout and out are ambiguous at the call site,
4018 // but have different calling behavior, so you cannot overload a
4019 // method based on a difference between inout and out annotations.
4020 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
4021 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
4022 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
4023 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
4024
4025 if (NSpellingIdx != OSpellingIdx) {
4026 SemaRef.Diag(Loc: NewParam->getLocation(),
4027 DiagID: diag::err_hlsl_param_qualifier_mismatch)
4028 << NDAttr << NewParam;
4029 SemaRef.Diag(Loc: OldParam->getLocation(), DiagID: diag::note_previous_declaration_as)
4030 << ODAttr;
4031 HadError = true;
4032 }
4033 }
4034 return HadError;
4035}
4036
4037// Generally follows PerformScalarCast, with cases reordered for
4038// clarity of what types are supported
4039bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
4040
4041 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
4042 return false;
4043
4044 if (SemaRef.getASTContext().hasSameUnqualifiedType(T1: SrcTy, T2: DestTy))
4045 return true;
4046
4047 switch (SrcTy->getScalarTypeKind()) {
4048 case Type::STK_Bool: // casting from bool is like casting from an integer
4049 case Type::STK_Integral:
4050 switch (DestTy->getScalarTypeKind()) {
4051 case Type::STK_Bool:
4052 case Type::STK_Integral:
4053 case Type::STK_Floating:
4054 return true;
4055 case Type::STK_CPointer:
4056 case Type::STK_ObjCObjectPointer:
4057 case Type::STK_BlockPointer:
4058 case Type::STK_MemberPointer:
4059 llvm_unreachable("HLSL doesn't support pointers.");
4060 case Type::STK_IntegralComplex:
4061 case Type::STK_FloatingComplex:
4062 llvm_unreachable("HLSL doesn't support complex types.");
4063 case Type::STK_FixedPoint:
4064 llvm_unreachable("HLSL doesn't support fixed point types.");
4065 }
4066 llvm_unreachable("Should have returned before this");
4067
4068 case Type::STK_Floating:
4069 switch (DestTy->getScalarTypeKind()) {
4070 case Type::STK_Floating:
4071 case Type::STK_Bool:
4072 case Type::STK_Integral:
4073 return true;
4074 case Type::STK_FloatingComplex:
4075 case Type::STK_IntegralComplex:
4076 llvm_unreachable("HLSL doesn't support complex types.");
4077 case Type::STK_FixedPoint:
4078 llvm_unreachable("HLSL doesn't support fixed point types.");
4079 case Type::STK_CPointer:
4080 case Type::STK_ObjCObjectPointer:
4081 case Type::STK_BlockPointer:
4082 case Type::STK_MemberPointer:
4083 llvm_unreachable("HLSL doesn't support pointers.");
4084 }
4085 llvm_unreachable("Should have returned before this");
4086
4087 case Type::STK_MemberPointer:
4088 case Type::STK_CPointer:
4089 case Type::STK_BlockPointer:
4090 case Type::STK_ObjCObjectPointer:
4091 llvm_unreachable("HLSL doesn't support pointers.");
4092
4093 case Type::STK_FixedPoint:
4094 llvm_unreachable("HLSL doesn't support fixed point types.");
4095
4096 case Type::STK_FloatingComplex:
4097 case Type::STK_IntegralComplex:
4098 llvm_unreachable("HLSL doesn't support complex types.");
4099 }
4100
4101 llvm_unreachable("Unhandled scalar cast");
4102}
4103
4104// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
4105// Src is a scalar or a vector of length 1
4106// Or if Dest is a vector and Src is a vector of length 1
4107bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
4108
4109 QualType SrcTy = Src->getType();
4110 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
4111 // going to be a vector splat from a scalar.
4112 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
4113 DestTy->isScalarType())
4114 return false;
4115
4116 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
4117
4118 // Src isn't a scalar or a vector of length 1
4119 if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
4120 return false;
4121
4122 if (SrcVecTy)
4123 SrcTy = SrcVecTy->getElementType();
4124
4125 llvm::SmallVector<QualType> DestTypes;
4126 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
4127
4128 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
4129 if (DestTypes[I]->isUnionType())
4130 return false;
4131 if (!CanPerformScalarCast(SrcTy, DestTy: DestTypes[I]))
4132 return false;
4133 }
4134 return true;
4135}
4136
4137// Can we perform an HLSL Elementwise cast?
4138bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
4139
4140 // Don't handle casts where LHS and RHS are any combination of scalar/vector
4141 // There must be an aggregate somewhere
4142 QualType SrcTy = Src->getType();
4143 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
4144 return false;
4145
4146 if (SrcTy->isVectorType() &&
4147 (DestTy->isScalarType() || DestTy->isVectorType()))
4148 return false;
4149
4150 if (SrcTy->isConstantMatrixType() &&
4151 (DestTy->isScalarType() || DestTy->isConstantMatrixType()))
4152 return false;
4153
4154 llvm::SmallVector<QualType> DestTypes;
4155 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
4156 llvm::SmallVector<QualType> SrcTypes;
4157 BuildFlattenedTypeList(BaseTy: SrcTy, List&: SrcTypes);
4158
4159 // Usually the size of SrcTypes must be greater than or equal to the size of
4160 // DestTypes.
4161 if (SrcTypes.size() < DestTypes.size())
4162 return false;
4163
4164 unsigned SrcSize = SrcTypes.size();
4165 unsigned DstSize = DestTypes.size();
4166 unsigned I;
4167 for (I = 0; I < DstSize && I < SrcSize; I++) {
4168 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
4169 return false;
4170 if (!CanPerformScalarCast(SrcTy: SrcTypes[I], DestTy: DestTypes[I])) {
4171 return false;
4172 }
4173 }
4174
4175 // check the rest of the source type for unions.
4176 for (; I < SrcSize; I++) {
4177 if (SrcTypes[I]->isUnionType())
4178 return false;
4179 }
4180 return true;
4181}
4182
4183ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) {
4184 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
4185 "We should not get here without a parameter modifier expression");
4186 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
4187 if (Attr->getABI() == ParameterABI::Ordinary)
4188 return ExprResult(Arg);
4189
4190 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
4191 if (!Arg->isLValue()) {
4192 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_lvalue)
4193 << Arg << (IsInOut ? 1 : 0);
4194 return ExprError();
4195 }
4196
4197 ASTContext &Ctx = SemaRef.getASTContext();
4198
4199 QualType Ty = Param->getType().getNonLValueExprType(Context: Ctx);
4200
4201 // HLSL allows implicit conversions from scalars to vectors, but not the
4202 // inverse, so we need to disallow `inout` with scalar->vector or
4203 // scalar->matrix conversions.
4204 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
4205 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_scalar_extension)
4206 << Arg << (IsInOut ? 1 : 0);
4207 return ExprError();
4208 }
4209
4210 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
4211 VK_LValue, OK_Ordinary, Arg);
4212
4213 // Parameters are initialized via copy initialization. This allows for
4214 // overload resolution of argument constructors.
4215 InitializedEntity Entity =
4216 InitializedEntity::InitializeParameter(Context&: Ctx, Type: Ty, Consumed: false);
4217 ExprResult Res =
4218 SemaRef.PerformCopyInitialization(Entity, EqualLoc: Param->getBeginLoc(), Init: ArgOpV);
4219 if (Res.isInvalid())
4220 return ExprError();
4221 Expr *Base = Res.get();
4222 // After the cast, drop the reference type when creating the exprs.
4223 Ty = Ty.getNonLValueExprType(Context: Ctx);
4224 auto *OpV = new (Ctx)
4225 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
4226
4227 // Writebacks are performed with `=` binary operator, which allows for
4228 // overload resolution on writeback result expressions.
4229 Res = SemaRef.ActOnBinOp(S: SemaRef.getCurScope(), TokLoc: Param->getBeginLoc(),
4230 Kind: tok::equal, LHSExpr: ArgOpV, RHSExpr: OpV);
4231
4232 if (Res.isInvalid())
4233 return ExprError();
4234 Expr *Writeback = Res.get();
4235 auto *OutExpr =
4236 HLSLOutArgExpr::Create(C: Ctx, Ty, Base: ArgOpV, OpV, WB: Writeback, IsInOut);
4237
4238 return ExprResult(OutExpr);
4239}
4240
4241QualType SemaHLSL::getInoutParameterType(QualType Ty) {
4242 // If HLSL gains support for references, all the cites that use this will need
4243 // to be updated with semantic checking to produce errors for
4244 // pointers/references.
4245 assert(!Ty->isReferenceType() &&
4246 "Pointer and reference types cannot be inout or out parameters");
4247 Ty = SemaRef.getASTContext().getLValueReferenceType(T: Ty);
4248 Ty.addRestrict();
4249 return Ty;
4250}
4251
4252static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD) {
4253 bool IsVulkan =
4254 Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::Vulkan;
4255 bool IsVKPushConstant = IsVulkan && VD->hasAttr<HLSLVkPushConstantAttr>();
4256 QualType QT = VD->getType();
4257 return VD->getDeclContext()->isTranslationUnit() &&
4258 QT.getAddressSpace() == LangAS::Default &&
4259 VD->getStorageClass() != SC_Static &&
4260 !VD->hasAttr<HLSLVkConstantIdAttr>() && !IsVKPushConstant &&
4261 !isInvalidConstantBufferLeafElementType(Ty: QT.getTypePtr());
4262}
4263
4264void SemaHLSL::deduceAddressSpace(VarDecl *Decl) {
4265 // The variable already has an address space (groupshared for ex).
4266 if (Decl->getType().hasAddressSpace())
4267 return;
4268
4269 if (Decl->getType()->isDependentType())
4270 return;
4271
4272 QualType Type = Decl->getType();
4273
4274 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
4275 LangAS ImplAS = LangAS::hlsl_input;
4276 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4277 Decl->setType(Type);
4278 return;
4279 }
4280
4281 bool IsVulkan = getASTContext().getTargetInfo().getTriple().getOS() ==
4282 llvm::Triple::Vulkan;
4283 if (IsVulkan && Decl->hasAttr<HLSLVkPushConstantAttr>()) {
4284 if (HasDeclaredAPushConstant)
4285 SemaRef.Diag(Loc: Decl->getLocation(), DiagID: diag::err_hlsl_push_constant_unique);
4286
4287 LangAS ImplAS = LangAS::hlsl_push_constant;
4288 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4289 Decl->setType(Type);
4290 HasDeclaredAPushConstant = true;
4291 return;
4292 }
4293
4294 if (Type->isSamplerT() || Type->isVoidType())
4295 return;
4296
4297 // Resource handles.
4298 if (Type->isHLSLResourceRecord() || Type->isHLSLResourceRecordArray())
4299 return;
4300
4301 // Only static globals belong to the Private address space.
4302 // Non-static globals belongs to the cbuffer.
4303 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
4304 return;
4305
4306 LangAS ImplAS = LangAS::hlsl_private;
4307 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4308 Decl->setType(Type);
4309}
4310
4311void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
4312 if (VD->hasGlobalStorage()) {
4313 // make sure the declaration has a complete type
4314 if (SemaRef.RequireCompleteType(
4315 Loc: VD->getLocation(),
4316 T: SemaRef.getASTContext().getBaseElementType(QT: VD->getType()),
4317 DiagID: diag::err_typecheck_decl_incomplete_type)) {
4318 VD->setInvalidDecl();
4319 deduceAddressSpace(Decl: VD);
4320 return;
4321 }
4322
4323 // Global variables outside a cbuffer block that are not a resource, static,
4324 // groupshared, or an empty array or struct belong to the default constant
4325 // buffer $Globals (to be created at the end of the translation unit).
4326 if (IsDefaultBufferConstantDecl(Ctx: getASTContext(), VD)) {
4327 // update address space to hlsl_constant
4328 QualType NewTy = getASTContext().getAddrSpaceQualType(
4329 T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
4330 VD->setType(NewTy);
4331 DefaultCBufferDecls.push_back(Elt: VD);
4332 }
4333
4334 // find all resources bindings on decl
4335 if (VD->getType()->isHLSLIntangibleType())
4336 collectResourceBindingsOnVarDecl(D: VD);
4337
4338 if (VD->hasAttr<HLSLVkConstantIdAttr>())
4339 VD->setStorageClass(StorageClass::SC_Static);
4340
4341 if (isResourceRecordTypeOrArrayOf(VD) &&
4342 VD->getStorageClass() != SC_Static) {
4343 // Add internal linkage attribute to non-static resource variables. The
4344 // global externally visible storage is accessed through the handle, which
4345 // is a member. The variable itself is not externally visible.
4346 VD->addAttr(A: InternalLinkageAttr::CreateImplicit(Ctx&: getASTContext()));
4347 }
4348
4349 // process explicit bindings
4350 processExplicitBindingsOnDecl(D: VD);
4351
4352 // Add implicit binding attribute to non-static resource arrays.
4353 if (VD->getType()->isHLSLResourceRecordArray() &&
4354 VD->getStorageClass() != SC_Static) {
4355 // If the resource array does not have an explicit binding attribute,
4356 // create an implicit one. It will be used to transfer implicit binding
4357 // order_ID to codegen.
4358 ResourceBindingAttrs Binding(VD);
4359 if (!Binding.isExplicit()) {
4360 uint32_t OrderID = getNextImplicitBindingOrderID();
4361 if (Binding.hasBinding())
4362 Binding.setImplicitOrderID(OrderID);
4363 else {
4364 addImplicitBindingAttrToDecl(
4365 S&: SemaRef, D: VD, RT: getRegisterType(ResTy: getResourceArrayHandleType(VD)),
4366 ImplicitBindingOrderID: OrderID);
4367 // Re-create the binding object to pick up the new attribute.
4368 Binding = ResourceBindingAttrs(VD);
4369 }
4370 }
4371
4372 // Get to the base type of a potentially multi-dimensional array.
4373 QualType Ty = getASTContext().getBaseElementType(QT: VD->getType());
4374
4375 const CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
4376 if (hasCounterHandle(RD)) {
4377 if (!Binding.hasCounterImplicitOrderID()) {
4378 uint32_t OrderID = getNextImplicitBindingOrderID();
4379 Binding.setCounterImplicitOrderID(OrderID);
4380 }
4381 }
4382 }
4383 }
4384
4385 deduceAddressSpace(Decl: VD);
4386}
4387
4388bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
4389 assert(VD->getType()->isHLSLResourceRecord() &&
4390 "expected resource record type");
4391
4392 ASTContext &AST = SemaRef.getASTContext();
4393 uint64_t UIntTySize = AST.getTypeSize(T: AST.UnsignedIntTy);
4394 uint64_t IntTySize = AST.getTypeSize(T: AST.IntTy);
4395
4396 // Gather resource binding attributes.
4397 ResourceBindingAttrs Binding(VD);
4398
4399 // Find correct initialization method and create its arguments.
4400 QualType ResourceTy = VD->getType();
4401 CXXRecordDecl *ResourceDecl = ResourceTy->getAsCXXRecordDecl();
4402 CXXMethodDecl *CreateMethod = nullptr;
4403 llvm::SmallVector<Expr *> Args;
4404
4405 bool HasCounter = hasCounterHandle(RD: ResourceDecl);
4406 const char *CreateMethodName;
4407 if (Binding.isExplicit())
4408 CreateMethodName = HasCounter ? "__createFromBindingWithImplicitCounter"
4409 : "__createFromBinding";
4410 else
4411 CreateMethodName = HasCounter
4412 ? "__createFromImplicitBindingWithImplicitCounter"
4413 : "__createFromImplicitBinding";
4414
4415 CreateMethod =
4416 lookupMethod(S&: SemaRef, RecordDecl: ResourceDecl, Name: CreateMethodName, Loc: VD->getLocation());
4417
4418 if (!CreateMethod)
4419 // This can happen if someone creates a struct that looks like an HLSL
4420 // resource record but does not have the required static create method.
4421 // No binding will be generated for it.
4422 return false;
4423
4424 if (Binding.isExplicit()) {
4425 IntegerLiteral *RegSlot =
4426 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, Binding.getSlot()),
4427 type: AST.UnsignedIntTy, l: SourceLocation());
4428 Args.push_back(Elt: RegSlot);
4429 } else {
4430 uint32_t OrderID = (Binding.hasImplicitOrderID())
4431 ? Binding.getImplicitOrderID()
4432 : getNextImplicitBindingOrderID();
4433 IntegerLiteral *OrderId =
4434 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, OrderID),
4435 type: AST.UnsignedIntTy, l: SourceLocation());
4436 Args.push_back(Elt: OrderId);
4437 }
4438
4439 IntegerLiteral *Space =
4440 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, Binding.getSpace()),
4441 type: AST.UnsignedIntTy, l: SourceLocation());
4442 Args.push_back(Elt: Space);
4443
4444 IntegerLiteral *RangeSize = IntegerLiteral::Create(
4445 C: AST, V: llvm::APInt(IntTySize, 1), type: AST.IntTy, l: SourceLocation());
4446 Args.push_back(Elt: RangeSize);
4447
4448 IntegerLiteral *Index = IntegerLiteral::Create(
4449 C: AST, V: llvm::APInt(UIntTySize, 0), type: AST.UnsignedIntTy, l: SourceLocation());
4450 Args.push_back(Elt: Index);
4451
4452 StringRef VarName = VD->getName();
4453 StringLiteral *Name = StringLiteral::Create(
4454 Ctx: AST, Str: VarName, Kind: StringLiteralKind::Ordinary, Pascal: false,
4455 Ty: AST.getStringLiteralArrayType(EltTy: AST.CharTy.withConst(), Length: VarName.size()),
4456 Locs: SourceLocation());
4457 ImplicitCastExpr *NameCast = ImplicitCastExpr::Create(
4458 Context: AST, T: AST.getPointerType(T: AST.CharTy.withConst()), Kind: CK_ArrayToPointerDecay,
4459 Operand: Name, BasePath: nullptr, Cat: VK_PRValue, FPO: FPOptionsOverride());
4460 Args.push_back(Elt: NameCast);
4461
4462 if (HasCounter) {
4463 // Will this be in the correct order?
4464 uint32_t CounterOrderID = getNextImplicitBindingOrderID();
4465 IntegerLiteral *CounterId =
4466 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, CounterOrderID),
4467 type: AST.UnsignedIntTy, l: SourceLocation());
4468 Args.push_back(Elt: CounterId);
4469 }
4470
4471 // Make sure the create method template is instantiated and emitted.
4472 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
4473 SemaRef.InstantiateFunctionDefinition(PointOfInstantiation: VD->getLocation(), Function: CreateMethod,
4474 Recursive: true);
4475
4476 // Create CallExpr with a call to the static method and set it as the decl
4477 // initialization.
4478 DeclRefExpr *DRE = DeclRefExpr::Create(
4479 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: CreateMethod, RefersToEnclosingVariableOrCapture: false,
4480 NameInfo: CreateMethod->getNameInfo(), T: CreateMethod->getType(), VK: VK_PRValue);
4481
4482 auto *ImpCast = ImplicitCastExpr::Create(
4483 Context: AST, T: AST.getPointerType(T: CreateMethod->getType()),
4484 Kind: CK_FunctionToPointerDecay, Operand: DRE, BasePath: nullptr, Cat: VK_PRValue, FPO: FPOptionsOverride());
4485
4486 CallExpr *InitExpr =
4487 CallExpr::Create(Ctx: AST, Fn: ImpCast, Args, Ty: ResourceTy, VK: VK_PRValue,
4488 RParenLoc: SourceLocation(), FPFeatures: FPOptionsOverride());
4489 VD->setInit(InitExpr);
4490 VD->setInitStyle(VarDecl::CallInit);
4491 SemaRef.CheckCompleteVariableDeclaration(VD);
4492 return true;
4493}
4494
4495bool SemaHLSL::initGlobalResourceArrayDecl(VarDecl *VD) {
4496 assert(VD->getType()->isHLSLResourceRecordArray() &&
4497 "expected array of resource records");
4498
4499 // Individual resources in a resource array are not initialized here. They
4500 // are initialized later on during codegen when the individual resources are
4501 // accessed. Codegen will emit a call to the resource initialization method
4502 // with the specified array index. We need to make sure though that the method
4503 // for the specific resource type is instantiated, so codegen can emit a call
4504 // to it when the array element is accessed.
4505
4506 // Find correct initialization method based on the resource binding
4507 // information.
4508 ASTContext &AST = SemaRef.getASTContext();
4509 QualType ResElementTy = AST.getBaseElementType(QT: VD->getType());
4510 CXXRecordDecl *ResourceDecl = ResElementTy->getAsCXXRecordDecl();
4511 CXXMethodDecl *CreateMethod = nullptr;
4512
4513 bool HasCounter = hasCounterHandle(RD: ResourceDecl);
4514 ResourceBindingAttrs ResourceAttrs(VD);
4515 if (ResourceAttrs.isExplicit())
4516 // Resource has explicit binding.
4517 CreateMethod =
4518 lookupMethod(S&: SemaRef, RecordDecl: ResourceDecl,
4519 Name: HasCounter ? "__createFromBindingWithImplicitCounter"
4520 : "__createFromBinding",
4521 Loc: VD->getLocation());
4522 else
4523 // Resource has implicit binding.
4524 CreateMethod = lookupMethod(
4525 S&: SemaRef, RecordDecl: ResourceDecl,
4526 Name: HasCounter ? "__createFromImplicitBindingWithImplicitCounter"
4527 : "__createFromImplicitBinding",
4528 Loc: VD->getLocation());
4529
4530 if (!CreateMethod)
4531 return false;
4532
4533 // Make sure the create method template is instantiated and emitted.
4534 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
4535 SemaRef.InstantiateFunctionDefinition(PointOfInstantiation: VD->getLocation(), Function: CreateMethod,
4536 Recursive: true);
4537 return true;
4538}
4539
4540// Returns true if the initialization has been handled.
4541// Returns false to use default initialization.
4542bool SemaHLSL::ActOnUninitializedVarDecl(VarDecl *VD) {
4543 // Objects in the hlsl_constant address space are initialized
4544 // externally, so don't synthesize an implicit initializer.
4545 if (VD->getType().getAddressSpace() == LangAS::hlsl_constant)
4546 return true;
4547
4548 // Initialize non-static resources at the global scope.
4549 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
4550 const Type *Ty = VD->getType().getTypePtr();
4551 if (Ty->isHLSLResourceRecord())
4552 return initGlobalResourceDecl(VD);
4553 if (Ty->isHLSLResourceRecordArray())
4554 return initGlobalResourceArrayDecl(VD);
4555 }
4556 return false;
4557}
4558
4559// Return true if everything is ok; returns false if there was an error.
4560bool SemaHLSL::CheckResourceBinOp(BinaryOperatorKind Opc, Expr *LHSExpr,
4561 Expr *RHSExpr, SourceLocation Loc) {
4562 assert((LHSExpr->getType()->isHLSLResourceRecord() ||
4563 LHSExpr->getType()->isHLSLResourceRecordArray()) &&
4564 "expected LHS to be a resource record or array of resource records");
4565 if (Opc != BO_Assign)
4566 return true;
4567
4568 // If LHS is an array subscript, get the underlying declaration.
4569 Expr *E = LHSExpr;
4570 while (auto *ASE = dyn_cast<ArraySubscriptExpr>(Val: E))
4571 E = ASE->getBase()->IgnoreParenImpCasts();
4572
4573 // Report error if LHS is a non-static resource declared at a global scope.
4574 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Val: E->IgnoreParens())) {
4575 if (VarDecl *VD = dyn_cast<VarDecl>(Val: DRE->getDecl())) {
4576 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
4577 // assignment to global resource is not allowed
4578 SemaRef.Diag(Loc, DiagID: diag::err_hlsl_assign_to_global_resource) << VD;
4579 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::note_var_declared_here) << VD;
4580 return false;
4581 }
4582 }
4583 }
4584 return true;
4585}
4586
4587// Walks though the global variable declaration, collects all resource binding
4588// requirements and adds them to Bindings
4589void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
4590 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
4591 "expected global variable that contains HLSL resource");
4592
4593 // Cbuffers and Tbuffers are HLSLBufferDecl types
4594 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: VD)) {
4595 Bindings.addDeclBindingInfo(VD, ResClass: CBufferOrTBuffer->isCBuffer()
4596 ? ResourceClass::CBuffer
4597 : ResourceClass::SRV);
4598 return;
4599 }
4600
4601 // Unwrap arrays
4602 // FIXME: Calculate array size while unwrapping
4603 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
4604 while (Ty->isArrayType()) {
4605 const ArrayType *AT = cast<ArrayType>(Val: Ty);
4606 Ty = AT->getElementType()->getUnqualifiedDesugaredType();
4607 }
4608
4609 // Resource (or array of resources)
4610 if (const HLSLAttributedResourceType *AttrResType =
4611 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
4612 Bindings.addDeclBindingInfo(VD, ResClass: AttrResType->getAttrs().ResourceClass);
4613 return;
4614 }
4615
4616 // User defined record type
4617 if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty))
4618 collectResourceBindingsOnUserRecordDecl(VD, RT);
4619}
4620
4621// Walks though the explicit resource binding attributes on the declaration,
4622// and makes sure there is a resource that matched the binding and updates
4623// DeclBindingInfoLists
4624void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
4625 assert(VD->hasGlobalStorage() && "expected global variable");
4626
4627 bool HasBinding = false;
4628 for (Attr *A : VD->attrs()) {
4629 if (isa<HLSLVkBindingAttr>(Val: A)) {
4630 HasBinding = true;
4631 if (auto PA = VD->getAttr<HLSLVkPushConstantAttr>())
4632 Diag(Loc: PA->getLoc(), DiagID: diag::err_hlsl_attr_incompatible) << A << PA;
4633 }
4634
4635 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: A);
4636 if (!RBA || !RBA->hasRegisterSlot())
4637 continue;
4638 HasBinding = true;
4639
4640 RegisterType RT = RBA->getRegisterType();
4641 assert(RT != RegisterType::I && "invalid or obsolete register type should "
4642 "never have an attribute created");
4643
4644 if (RT == RegisterType::C) {
4645 if (Bindings.hasBindingInfoForDecl(VD))
4646 SemaRef.Diag(Loc: VD->getLocation(),
4647 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
4648 << static_cast<int>(RT);
4649 continue;
4650 }
4651
4652 // Find DeclBindingInfo for this binding and update it, or report error
4653 // if it does not exist (user type does to contain resources with the
4654 // expected resource class).
4655 ResourceClass RC = getResourceClass(RT);
4656 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, ResClass: RC)) {
4657 // update binding info
4658 BI->setBindingAttribute(A: RBA, BT: BindingType::Explicit);
4659 } else {
4660 SemaRef.Diag(Loc: VD->getLocation(),
4661 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
4662 << static_cast<int>(RT);
4663 }
4664 }
4665
4666 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
4667 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
4668}
4669namespace {
4670class InitListTransformer {
4671 Sema &S;
4672 ASTContext &Ctx;
4673 QualType InitTy;
4674 QualType *DstIt = nullptr;
4675 Expr **ArgIt = nullptr;
4676 // Is wrapping the destination type iterator required? This is only used for
4677 // incomplete array types where we loop over the destination type since we
4678 // don't know the full number of elements from the declaration.
4679 bool Wrap;
4680
4681 bool castInitializer(Expr *E) {
4682 assert(DstIt && "This should always be something!");
4683 if (DstIt == DestTypes.end()) {
4684 if (!Wrap) {
4685 ArgExprs.push_back(Elt: E);
4686 // This is odd, but it isn't technically a failure due to conversion, we
4687 // handle mismatched counts of arguments differently.
4688 return true;
4689 }
4690 DstIt = DestTypes.begin();
4691 }
4692 InitializedEntity Entity = InitializedEntity::InitializeParameter(
4693 Context&: Ctx, Type: *DstIt, /* Consumed (ObjC) */ Consumed: false);
4694 ExprResult Res = S.PerformCopyInitialization(Entity, EqualLoc: E->getBeginLoc(), Init: E);
4695 if (Res.isInvalid())
4696 return false;
4697 Expr *Init = Res.get();
4698 ArgExprs.push_back(Elt: Init);
4699 DstIt++;
4700 return true;
4701 }
4702
4703 bool buildInitializerListImpl(Expr *E) {
4704 // If this is an initialization list, traverse the sub initializers.
4705 if (auto *Init = dyn_cast<InitListExpr>(Val: E)) {
4706 for (auto *SubInit : Init->inits())
4707 if (!buildInitializerListImpl(E: SubInit))
4708 return false;
4709 return true;
4710 }
4711
4712 // If this is a scalar type, just enqueue the expression.
4713 QualType Ty = E->getType();
4714
4715 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
4716 return castInitializer(E);
4717
4718 if (auto *VecTy = Ty->getAs<VectorType>()) {
4719 uint64_t Size = VecTy->getNumElements();
4720
4721 QualType SizeTy = Ctx.getSizeType();
4722 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
4723 for (uint64_t I = 0; I < Size; ++I) {
4724 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
4725 type: SizeTy, l: SourceLocation());
4726
4727 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
4728 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
4729 if (ElExpr.isInvalid())
4730 return false;
4731 if (!castInitializer(E: ElExpr.get()))
4732 return false;
4733 }
4734 return true;
4735 }
4736 if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
4737 unsigned Rows = MTy->getNumRows();
4738 unsigned Cols = MTy->getNumColumns();
4739 QualType ElemTy = MTy->getElementType();
4740
4741 for (unsigned R = 0; R < Rows; ++R) {
4742 for (unsigned C = 0; C < Cols; ++C) {
4743 // row index literal
4744 Expr *RowIdx = IntegerLiteral::Create(
4745 C: Ctx, V: llvm::APInt(Ctx.getIntWidth(T: Ctx.IntTy), R), type: Ctx.IntTy,
4746 l: E->getBeginLoc());
4747 // column index literal
4748 Expr *ColIdx = IntegerLiteral::Create(
4749 C: Ctx, V: llvm::APInt(Ctx.getIntWidth(T: Ctx.IntTy), C), type: Ctx.IntTy,
4750 l: E->getBeginLoc());
4751 ExprResult ElExpr = S.CreateBuiltinMatrixSubscriptExpr(
4752 Base: E, RowIdx, ColumnIdx: ColIdx, RBLoc: E->getEndLoc());
4753 if (ElExpr.isInvalid())
4754 return false;
4755 if (!castInitializer(E: ElExpr.get()))
4756 return false;
4757 ElExpr.get()->setType(ElemTy);
4758 }
4759 }
4760 return true;
4761 }
4762
4763 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Val: Ty.getTypePtr())) {
4764 uint64_t Size = ArrTy->getZExtSize();
4765 QualType SizeTy = Ctx.getSizeType();
4766 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
4767 for (uint64_t I = 0; I < Size; ++I) {
4768 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
4769 type: SizeTy, l: SourceLocation());
4770 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
4771 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
4772 if (ElExpr.isInvalid())
4773 return false;
4774 if (!buildInitializerListImpl(E: ElExpr.get()))
4775 return false;
4776 }
4777 return true;
4778 }
4779
4780 if (auto *RD = Ty->getAsCXXRecordDecl()) {
4781 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
4782 RecordDecls.push_back(Elt: RD);
4783 while (RecordDecls.back()->getNumBases()) {
4784 CXXRecordDecl *D = RecordDecls.back();
4785 assert(D->getNumBases() == 1 &&
4786 "HLSL doesn't support multiple inheritance");
4787 RecordDecls.push_back(
4788 Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
4789 }
4790 while (!RecordDecls.empty()) {
4791 CXXRecordDecl *RD = RecordDecls.pop_back_val();
4792 for (auto *FD : RD->fields()) {
4793 if (FD->isUnnamedBitField())
4794 continue;
4795 DeclAccessPair Found = DeclAccessPair::make(D: FD, AS: FD->getAccess());
4796 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
4797 ExprResult Res = S.BuildFieldReferenceExpr(
4798 BaseExpr: E, IsArrow: false, OpLoc: E->getBeginLoc(), SS: CXXScopeSpec(), Field: FD, FoundDecl: Found, MemberNameInfo: NameInfo);
4799 if (Res.isInvalid())
4800 return false;
4801 if (!buildInitializerListImpl(E: Res.get()))
4802 return false;
4803 }
4804 }
4805 }
4806 return true;
4807 }
4808
4809 Expr *generateInitListsImpl(QualType Ty) {
4810 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
4811 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
4812 return *(ArgIt++);
4813
4814 llvm::SmallVector<Expr *> Inits;
4815 Ty = Ty.getDesugaredType(Context: Ctx);
4816 if (Ty->isVectorType() || Ty->isConstantArrayType() ||
4817 Ty->isConstantMatrixType()) {
4818 QualType ElTy;
4819 uint64_t Size = 0;
4820 if (auto *ATy = Ty->getAs<VectorType>()) {
4821 ElTy = ATy->getElementType();
4822 Size = ATy->getNumElements();
4823 } else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
4824 ElTy = CMTy->getElementType();
4825 Size = CMTy->getNumElementsFlattened();
4826 } else {
4827 auto *VTy = cast<ConstantArrayType>(Val: Ty.getTypePtr());
4828 ElTy = VTy->getElementType();
4829 Size = VTy->getZExtSize();
4830 }
4831 for (uint64_t I = 0; I < Size; ++I)
4832 Inits.push_back(Elt: generateInitListsImpl(Ty: ElTy));
4833 }
4834 if (auto *RD = Ty->getAsCXXRecordDecl()) {
4835 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
4836 RecordDecls.push_back(Elt: RD);
4837 while (RecordDecls.back()->getNumBases()) {
4838 CXXRecordDecl *D = RecordDecls.back();
4839 assert(D->getNumBases() == 1 &&
4840 "HLSL doesn't support multiple inheritance");
4841 RecordDecls.push_back(
4842 Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
4843 }
4844 while (!RecordDecls.empty()) {
4845 CXXRecordDecl *RD = RecordDecls.pop_back_val();
4846 for (auto *FD : RD->fields())
4847 if (!FD->isUnnamedBitField())
4848 Inits.push_back(Elt: generateInitListsImpl(Ty: FD->getType()));
4849 }
4850 }
4851 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
4852 Inits, Inits.back()->getEndLoc());
4853 NewInit->setType(Ty);
4854 return NewInit;
4855 }
4856
4857public:
4858 llvm::SmallVector<QualType, 16> DestTypes;
4859 llvm::SmallVector<Expr *, 16> ArgExprs;
4860 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
4861 : S(SemaRef), Ctx(SemaRef.getASTContext()),
4862 Wrap(Entity.getType()->isIncompleteArrayType()) {
4863 InitTy = Entity.getType().getNonReferenceType();
4864 // When we're generating initializer lists for incomplete array types we
4865 // need to wrap around both when building the initializers and when
4866 // generating the final initializer lists.
4867 if (Wrap) {
4868 assert(InitTy->isIncompleteArrayType());
4869 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(T: InitTy);
4870 InitTy = IAT->getElementType();
4871 }
4872 BuildFlattenedTypeList(BaseTy: InitTy, List&: DestTypes);
4873 DstIt = DestTypes.begin();
4874 }
4875
4876 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
4877
4878 Expr *generateInitLists() {
4879 assert(!ArgExprs.empty() &&
4880 "Call buildInitializerList to generate argument expressions.");
4881 ArgIt = ArgExprs.begin();
4882 if (!Wrap)
4883 return generateInitListsImpl(Ty: InitTy);
4884 llvm::SmallVector<Expr *> Inits;
4885 while (ArgIt != ArgExprs.end())
4886 Inits.push_back(Elt: generateInitListsImpl(Ty: InitTy));
4887
4888 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
4889 Inits, Inits.back()->getEndLoc());
4890 llvm::APInt ArySize(64, Inits.size());
4891 NewInit->setType(Ctx.getConstantArrayType(EltTy: InitTy, ArySize, SizeExpr: nullptr,
4892 ASM: ArraySizeModifier::Normal, IndexTypeQuals: 0));
4893 return NewInit;
4894 }
4895};
4896} // namespace
4897
4898// Recursively detect any incomplete array anywhere in the type graph,
4899// including arrays, struct fields, and base classes.
4900static bool containsIncompleteArrayType(QualType Ty) {
4901 Ty = Ty.getCanonicalType();
4902
4903 // Array types
4904 if (const ArrayType *AT = dyn_cast<ArrayType>(Val&: Ty)) {
4905 if (isa<IncompleteArrayType>(Val: AT))
4906 return true;
4907 return containsIncompleteArrayType(Ty: AT->getElementType());
4908 }
4909
4910 // Record (struct/class) types
4911 if (const auto *RT = Ty->getAs<RecordType>()) {
4912 const RecordDecl *RD = RT->getDecl();
4913
4914 // Walk base classes (for C++ / HLSL structs with inheritance)
4915 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
4916 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
4917 if (containsIncompleteArrayType(Ty: Base.getType()))
4918 return true;
4919 }
4920 }
4921
4922 // Walk fields
4923 for (const FieldDecl *F : RD->fields()) {
4924 if (containsIncompleteArrayType(Ty: F->getType()))
4925 return true;
4926 }
4927 }
4928
4929 return false;
4930}
4931
4932bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
4933 InitListExpr *Init) {
4934 // If the initializer is a scalar, just return it.
4935 if (Init->getType()->isScalarType())
4936 return true;
4937 ASTContext &Ctx = SemaRef.getASTContext();
4938 InitListTransformer ILT(SemaRef, Entity);
4939
4940 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
4941 Expr *E = Init->getInit(Init: I);
4942 if (E->HasSideEffects(Ctx)) {
4943 QualType Ty = E->getType();
4944 if (Ty->isRecordType())
4945 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
4946 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
4947 E->getObjectKind(), E);
4948 Init->setInit(Init: I, expr: E);
4949 }
4950 if (!ILT.buildInitializerList(E))
4951 return false;
4952 }
4953 size_t ExpectedSize = ILT.DestTypes.size();
4954 size_t ActualSize = ILT.ArgExprs.size();
4955 if (ExpectedSize == 0 && ActualSize == 0)
4956 return true;
4957
4958 // Reject empty initializer if *any* incomplete array exists structurally
4959 if (ActualSize == 0 && containsIncompleteArrayType(Ty: Entity.getType())) {
4960 QualType InitTy = Entity.getType().getNonReferenceType();
4961 if (InitTy.hasAddressSpace())
4962 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
4963
4964 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
4965 << /*TooManyOrFew=*/(int)(ExpectedSize < ActualSize) << InitTy
4966 << /*ExpectedSize=*/ExpectedSize << /*ActualSize=*/ActualSize;
4967 return false;
4968 }
4969
4970 // We infer size after validating legality.
4971 // For incomplete arrays it is completely arbitrary to choose whether we think
4972 // the user intended fewer or more elements. This implementation assumes that
4973 // the user intended more, and errors that there are too few initializers to
4974 // complete the final element.
4975 if (Entity.getType()->isIncompleteArrayType()) {
4976 assert(ExpectedSize > 0 &&
4977 "The expected size of an incomplete array type must be at least 1.");
4978 ExpectedSize =
4979 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
4980 }
4981
4982 // An initializer list might be attempting to initialize a reference or
4983 // rvalue-reference. When checking the initializer we should look through
4984 // the reference.
4985 QualType InitTy = Entity.getType().getNonReferenceType();
4986 if (InitTy.hasAddressSpace())
4987 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
4988 if (ExpectedSize != ActualSize) {
4989 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
4990 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
4991 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
4992 return false;
4993 }
4994
4995 // generateInitListsImpl will always return an InitListExpr here, because the
4996 // scalar case is handled above.
4997 auto *NewInit = cast<InitListExpr>(Val: ILT.generateInitLists());
4998 Init->resizeInits(Context: Ctx, NumInits: NewInit->getNumInits());
4999 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
5000 Init->updateInit(C: Ctx, Init: I, expr: NewInit->getInit(Init: I));
5001 return true;
5002}
5003
5004bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
5005 const HLSLVkConstantIdAttr *ConstIdAttr =
5006 VDecl->getAttr<HLSLVkConstantIdAttr>();
5007 if (!ConstIdAttr)
5008 return true;
5009
5010 ASTContext &Context = SemaRef.getASTContext();
5011
5012 APValue InitValue;
5013 if (!Init->isCXX11ConstantExpr(Ctx: Context, Result: &InitValue)) {
5014 Diag(Loc: VDecl->getLocation(), DiagID: diag::err_specialization_const);
5015 VDecl->setInvalidDecl();
5016 return false;
5017 }
5018
5019 Builtin::ID BID =
5020 getSpecConstBuiltinId(Type: VDecl->getType()->getUnqualifiedDesugaredType());
5021
5022 // Argument 1: The ID from the attribute
5023 int ConstantID = ConstIdAttr->getId();
5024 llvm::APInt IDVal(Context.getIntWidth(T: Context.IntTy), ConstantID);
5025 Expr *IdExpr = IntegerLiteral::Create(C: Context, V: IDVal, type: Context.IntTy,
5026 l: ConstIdAttr->getLocation());
5027
5028 SmallVector<Expr *, 2> Args = {IdExpr, Init};
5029 Expr *C = SemaRef.BuildBuiltinCallExpr(Loc: Init->getExprLoc(), Id: BID, CallArgs: Args);
5030 if (C->getType()->getCanonicalTypeUnqualified() !=
5031 VDecl->getType()->getCanonicalTypeUnqualified()) {
5032 C = SemaRef
5033 .BuildCStyleCastExpr(LParenLoc: SourceLocation(),
5034 Ty: Context.getTrivialTypeSourceInfo(
5035 T: Init->getType(), Loc: Init->getExprLoc()),
5036 RParenLoc: SourceLocation(), Op: C)
5037 .get();
5038 }
5039 Init = C;
5040 return true;
5041}
5042