1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
12#include "clang/AST/ASTConsumer.h"
13#include "clang/AST/ASTContext.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Decl.h"
16#include "clang/AST/DeclBase.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/DeclarationName.h"
19#include "clang/AST/DynamicRecursiveASTVisitor.h"
20#include "clang/AST/Expr.h"
21#include "clang/AST/HLSLResource.h"
22#include "clang/AST/Type.h"
23#include "clang/AST/TypeBase.h"
24#include "clang/AST/TypeLoc.h"
25#include "clang/Basic/Builtins.h"
26#include "clang/Basic/DiagnosticSema.h"
27#include "clang/Basic/IdentifierTable.h"
28#include "clang/Basic/LLVM.h"
29#include "clang/Basic/SourceLocation.h"
30#include "clang/Basic/Specifiers.h"
31#include "clang/Basic/TargetInfo.h"
32#include "clang/Sema/Initialization.h"
33#include "clang/Sema/Lookup.h"
34#include "clang/Sema/ParsedAttr.h"
35#include "clang/Sema/Sema.h"
36#include "clang/Sema/Template.h"
37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/StringExtras.h"
41#include "llvm/ADT/StringRef.h"
42#include "llvm/ADT/Twine.h"
43#include "llvm/Frontend/HLSL/HLSLBinding.h"
44#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
45#include "llvm/Support/Casting.h"
46#include "llvm/Support/DXILABI.h"
47#include "llvm/Support/ErrorHandling.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/TargetParser/Triple.h"
50#include <cmath>
51#include <cstddef>
52#include <iterator>
53#include <utility>
54
55using namespace clang;
56using namespace clang::hlsl;
57using RegisterType = HLSLResourceBindingAttr::RegisterType;
58
59static CXXRecordDecl *createHostLayoutStruct(Sema &S,
60 CXXRecordDecl *StructDecl);
61
62static RegisterType getRegisterType(ResourceClass RC) {
63 switch (RC) {
64 case ResourceClass::SRV:
65 return RegisterType::SRV;
66 case ResourceClass::UAV:
67 return RegisterType::UAV;
68 case ResourceClass::CBuffer:
69 return RegisterType::CBuffer;
70 case ResourceClass::Sampler:
71 return RegisterType::Sampler;
72 }
73 llvm_unreachable("unexpected ResourceClass value");
74}
75
76static RegisterType getRegisterType(const HLSLAttributedResourceType *ResTy) {
77 return getRegisterType(RC: ResTy->getAttrs().ResourceClass);
78}
79
80// Converts the first letter of string Slot to RegisterType.
81// Returns false if the letter does not correspond to a valid register type.
82static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
83 assert(RT != nullptr);
84 switch (Slot[0]) {
85 case 't':
86 case 'T':
87 *RT = RegisterType::SRV;
88 return true;
89 case 'u':
90 case 'U':
91 *RT = RegisterType::UAV;
92 return true;
93 case 'b':
94 case 'B':
95 *RT = RegisterType::CBuffer;
96 return true;
97 case 's':
98 case 'S':
99 *RT = RegisterType::Sampler;
100 return true;
101 case 'c':
102 case 'C':
103 *RT = RegisterType::C;
104 return true;
105 case 'i':
106 case 'I':
107 *RT = RegisterType::I;
108 return true;
109 default:
110 return false;
111 }
112}
113
114static ResourceClass getResourceClass(RegisterType RT) {
115 switch (RT) {
116 case RegisterType::SRV:
117 return ResourceClass::SRV;
118 case RegisterType::UAV:
119 return ResourceClass::UAV;
120 case RegisterType::CBuffer:
121 return ResourceClass::CBuffer;
122 case RegisterType::Sampler:
123 return ResourceClass::Sampler;
124 case RegisterType::C:
125 case RegisterType::I:
126 // Deliberately falling through to the unreachable below.
127 break;
128 }
129 llvm_unreachable("unexpected RegisterType value");
130}
131
132static Builtin::ID getSpecConstBuiltinId(const Type *Type) {
133 const auto *BT = dyn_cast<BuiltinType>(Val: Type);
134 if (!BT) {
135 if (!Type->isEnumeralType())
136 return Builtin::NotBuiltin;
137 return Builtin::BI__builtin_get_spirv_spec_constant_int;
138 }
139
140 switch (BT->getKind()) {
141 case BuiltinType::Bool:
142 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
143 case BuiltinType::Short:
144 return Builtin::BI__builtin_get_spirv_spec_constant_short;
145 case BuiltinType::Int:
146 return Builtin::BI__builtin_get_spirv_spec_constant_int;
147 case BuiltinType::LongLong:
148 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
149 case BuiltinType::UShort:
150 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
151 case BuiltinType::UInt:
152 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
153 case BuiltinType::ULongLong:
154 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
155 case BuiltinType::Half:
156 return Builtin::BI__builtin_get_spirv_spec_constant_half;
157 case BuiltinType::Float:
158 return Builtin::BI__builtin_get_spirv_spec_constant_float;
159 case BuiltinType::Double:
160 return Builtin::BI__builtin_get_spirv_spec_constant_double;
161 default:
162 return Builtin::NotBuiltin;
163 }
164}
165
166DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
167 ResourceClass ResClass) {
168 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
169 "DeclBindingInfo already added");
170 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
171 // VarDecl may have multiple entries for different resource classes.
172 // DeclToBindingListIndex stores the index of the first binding we saw
173 // for this decl. If there are any additional ones then that index
174 // shouldn't be updated.
175 DeclToBindingListIndex.try_emplace(Key: VD, Args: BindingsList.size());
176 return &BindingsList.emplace_back(Args&: VD, Args&: ResClass);
177}
178
179DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
180 ResourceClass ResClass) {
181 auto Entry = DeclToBindingListIndex.find(Val: VD);
182 if (Entry != DeclToBindingListIndex.end()) {
183 for (unsigned Index = Entry->getSecond();
184 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
185 ++Index) {
186 if (BindingsList[Index].ResClass == ResClass)
187 return &BindingsList[Index];
188 }
189 }
190 return nullptr;
191}
192
193bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
194 return DeclToBindingListIndex.contains(Val: VD);
195}
196
197SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
198
199Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
200 SourceLocation KwLoc, IdentifierInfo *Ident,
201 SourceLocation IdentLoc,
202 SourceLocation LBrace) {
203 // For anonymous namespace, take the location of the left brace.
204 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
205 HLSLBufferDecl *Result = HLSLBufferDecl::Create(
206 C&: getASTContext(), LexicalParent, CBuffer, KwLoc, ID: Ident, IDLoc: IdentLoc, LBrace);
207
208 // if CBuffer is false, then it's a TBuffer
209 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
210 : llvm::hlsl::ResourceClass::SRV;
211 Result->addAttr(A: HLSLResourceClassAttr::CreateImplicit(Ctx&: getASTContext(), ResourceClass: RC));
212
213 SemaRef.PushOnScopeChains(D: Result, S: BufferScope);
214 SemaRef.PushDeclContext(S: BufferScope, DC: Result);
215
216 return Result;
217}
218
219static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
220 QualType T) {
221 // Arrays, Matrices, and Structs are always aligned to new buffer rows
222 if (T->isArrayType() || T->isStructureType() || T->isConstantMatrixType())
223 return 16;
224
225 // Vectors are aligned to the type they contain
226 if (const VectorType *VT = T->getAs<VectorType>())
227 return calculateLegacyCbufferFieldAlign(Context, T: VT->getElementType());
228
229 assert(Context.getTypeSize(T) <= 64 &&
230 "Scalar bit widths larger than 64 not supported");
231
232 // Scalar types are aligned to their byte width
233 return Context.getTypeSize(T) / 8;
234}
235
236// Calculate the size of a legacy cbuffer type in bytes based on
237// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
238static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
239 QualType T) {
240 constexpr unsigned CBufferAlign = 16;
241 if (const auto *RD = T->getAsRecordDecl()) {
242 unsigned Size = 0;
243 for (const FieldDecl *Field : RD->fields()) {
244 QualType Ty = Field->getType();
245 unsigned FieldSize = calculateLegacyCbufferSize(Context, T: Ty);
246 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, T: Ty);
247
248 // If the field crosses the row boundary after alignment it drops to the
249 // next row
250 unsigned AlignSize = llvm::alignTo(Value: Size, Align: FieldAlign);
251 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
252 FieldAlign = CBufferAlign;
253 }
254
255 Size = llvm::alignTo(Value: Size, Align: FieldAlign);
256 Size += FieldSize;
257 }
258 return Size;
259 }
260
261 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
262 unsigned ElementCount = AT->getSize().getZExtValue();
263 if (ElementCount == 0)
264 return 0;
265
266 unsigned ElementSize =
267 calculateLegacyCbufferSize(Context, T: AT->getElementType());
268 unsigned AlignedElementSize = llvm::alignTo(Value: ElementSize, Align: CBufferAlign);
269 return AlignedElementSize * (ElementCount - 1) + ElementSize;
270 }
271
272 if (const VectorType *VT = T->getAs<VectorType>()) {
273 unsigned ElementCount = VT->getNumElements();
274 unsigned ElementSize =
275 calculateLegacyCbufferSize(Context, T: VT->getElementType());
276 return ElementSize * ElementCount;
277 }
278
279 return Context.getTypeSize(T) / 8;
280}
281
282// Validate packoffset:
283// - if packoffset it used it must be set on all declarations inside the buffer
284// - packoffset ranges must not overlap
285static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
286 llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
287
288 // Make sure the packoffset annotations are either on all declarations
289 // or on none.
290 bool HasPackOffset = false;
291 bool HasNonPackOffset = false;
292 for (auto *Field : BufDecl->buffer_decls()) {
293 VarDecl *Var = dyn_cast<VarDecl>(Val: Field);
294 if (!Var)
295 continue;
296 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
297 PackOffsetVec.emplace_back(Args&: Var, Args: Field->getAttr<HLSLPackOffsetAttr>());
298 HasPackOffset = true;
299 } else {
300 HasNonPackOffset = true;
301 }
302 }
303
304 if (!HasPackOffset)
305 return;
306
307 if (HasNonPackOffset)
308 S.Diag(Loc: BufDecl->getLocation(), DiagID: diag::warn_hlsl_packoffset_mix);
309
310 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
311 // and compare adjacent values.
312 bool IsValid = true;
313 ASTContext &Context = S.getASTContext();
314 std::sort(first: PackOffsetVec.begin(), last: PackOffsetVec.end(),
315 comp: [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
316 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
317 return LHS.second->getOffsetInBytes() <
318 RHS.second->getOffsetInBytes();
319 });
320 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
321 VarDecl *Var = PackOffsetVec[i].first;
322 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
323 unsigned Size = calculateLegacyCbufferSize(Context, T: Var->getType());
324 unsigned Begin = Attr->getOffsetInBytes();
325 unsigned End = Begin + Size;
326 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
327 if (End > NextBegin) {
328 VarDecl *NextVar = PackOffsetVec[i + 1].first;
329 S.Diag(Loc: NextVar->getLocation(), DiagID: diag::err_hlsl_packoffset_overlap)
330 << NextVar << Var;
331 IsValid = false;
332 }
333 }
334 BufDecl->setHasValidPackoffset(IsValid);
335}
336
337// Returns true if the array has a zero size = if any of the dimensions is 0
338static bool isZeroSizedArray(const ConstantArrayType *CAT) {
339 while (CAT && !CAT->isZeroSize())
340 CAT = dyn_cast<ConstantArrayType>(
341 Val: CAT->getElementType()->getUnqualifiedDesugaredType());
342 return CAT != nullptr;
343}
344
345static bool isResourceRecordTypeOrArrayOf(VarDecl *VD) {
346 const Type *Ty = VD->getType().getTypePtr();
347 return Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray();
348}
349
350static const HLSLAttributedResourceType *
351getResourceArrayHandleType(VarDecl *VD) {
352 assert(VD->getType()->isHLSLResourceRecordArray() &&
353 "expected array of resource records");
354 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
355 while (const ArrayType *AT = dyn_cast<ArrayType>(Val: Ty))
356 Ty = AT->getArrayElementTypeNoTypeQual()->getUnqualifiedDesugaredType();
357 return HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty);
358}
359
360// Returns true if the type is a leaf element type that is not valid to be
361// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
362// array, or a builtin intangible type. Returns false it is a valid leaf element
363// type or if it is a record type that needs to be inspected further.
364static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
365 Ty = Ty->getUnqualifiedDesugaredType();
366 if (Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray())
367 return true;
368 if (const auto *RD = Ty->getAsCXXRecordDecl())
369 return RD->isEmpty();
370 if (Ty->isConstantArrayType() &&
371 isZeroSizedArray(CAT: cast<ConstantArrayType>(Val: Ty)))
372 return true;
373 if (Ty->isHLSLBuiltinIntangibleType() || Ty->isHLSLAttributedResourceType())
374 return true;
375 return false;
376}
377
378// Returns true if the struct contains at least one element that prevents it
379// from being included inside HLSL Buffer as is, such as an intangible type,
380// empty struct, or zero-sized array. If it does, a new implicit layout struct
381// needs to be created for HLSL Buffer use that will exclude these unwanted
382// declarations (see createHostLayoutStruct function).
383static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
384 if (RD->isHLSLIntangible() || RD->isEmpty())
385 return true;
386 // check fields
387 for (const FieldDecl *Field : RD->fields()) {
388 QualType Ty = Field->getType();
389 if (isInvalidConstantBufferLeafElementType(Ty: Ty.getTypePtr()))
390 return true;
391 if (const auto *RD = Ty->getAsCXXRecordDecl();
392 RD && requiresImplicitBufferLayoutStructure(RD))
393 return true;
394 }
395 // check bases
396 for (const CXXBaseSpecifier &Base : RD->bases())
397 if (requiresImplicitBufferLayoutStructure(
398 RD: Base.getType()->castAsCXXRecordDecl()))
399 return true;
400 return false;
401}
402
403static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II,
404 DeclContext *DC) {
405 CXXRecordDecl *RD = nullptr;
406 for (NamedDecl *Decl :
407 DC->getNonTransparentContext()->lookup(Name: DeclarationName(II))) {
408 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Val: Decl)) {
409 assert(RD == nullptr &&
410 "there should be at most 1 record by a given name in a scope");
411 RD = FoundRD;
412 }
413 }
414 return RD;
415}
416
417// Creates a name for buffer layout struct using the provide name base.
418// If the name must be unique (not previously defined), a suffix is added
419// until a unique name is found.
420static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl,
421 bool MustBeUnique) {
422 ASTContext &AST = S.getASTContext();
423
424 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
425 llvm::SmallString<64> Name("__cblayout_");
426 if (NameBaseII) {
427 Name.append(RHS: NameBaseII->getName());
428 } else {
429 // anonymous struct
430 Name.append(RHS: "anon");
431 MustBeUnique = true;
432 }
433
434 size_t NameLength = Name.size();
435 IdentifierInfo *II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
436 if (!MustBeUnique)
437 return II;
438
439 unsigned suffix = 0;
440 while (true) {
441 if (suffix != 0) {
442 Name.append(RHS: "_");
443 Name.append(RHS: llvm::Twine(suffix).str());
444 II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
445 }
446 if (!findRecordDeclInContext(II, DC: BaseDecl->getDeclContext()))
447 return II;
448 // declaration with that name already exists - increment suffix and try
449 // again until unique name is found
450 suffix++;
451 Name.truncate(N: NameLength);
452 };
453}
454
455static const Type *createHostLayoutType(Sema &S, const Type *Ty) {
456 ASTContext &AST = S.getASTContext();
457 if (auto *RD = Ty->getAsCXXRecordDecl()) {
458 if (!requiresImplicitBufferLayoutStructure(RD))
459 return Ty;
460 RD = createHostLayoutStruct(S, StructDecl: RD);
461 if (!RD)
462 return nullptr;
463 return AST.getCanonicalTagType(TD: RD)->getTypePtr();
464 }
465
466 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: Ty)) {
467 const Type *ElementTy = createHostLayoutType(
468 S, Ty: CAT->getElementType()->getUnqualifiedDesugaredType());
469 if (!ElementTy)
470 return nullptr;
471 return AST
472 .getConstantArrayType(EltTy: QualType(ElementTy, 0), ArySize: CAT->getSize(), SizeExpr: nullptr,
473 ASM: CAT->getSizeModifier(),
474 IndexTypeQuals: CAT->getIndexTypeCVRQualifiers())
475 .getTypePtr();
476 }
477 return Ty;
478}
479
480// Creates a field declaration of given name and type for HLSL buffer layout
481// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
482static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
483 IdentifierInfo *II,
484 CXXRecordDecl *LayoutStruct) {
485 if (isInvalidConstantBufferLeafElementType(Ty))
486 return nullptr;
487
488 Ty = createHostLayoutType(S, Ty);
489 if (!Ty)
490 return nullptr;
491
492 QualType QT = QualType(Ty, 0);
493 ASTContext &AST = S.getASTContext();
494 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(T: QT, Loc: SourceLocation());
495 auto *Field = FieldDecl::Create(C: AST, DC: LayoutStruct, StartLoc: SourceLocation(),
496 IdLoc: SourceLocation(), Id: II, T: QT, TInfo: TSI, BW: nullptr, Mutable: false,
497 InitStyle: InClassInitStyle::ICIS_NoInit);
498 Field->setAccess(AccessSpecifier::AS_public);
499 return Field;
500}
501
502// Creates host layout struct for a struct included in HLSL Buffer.
503// The layout struct will include only fields that are allowed in HLSL buffer.
504// These fields will be filtered out:
505// - resource classes
506// - empty structs
507// - zero-sized arrays
508// Returns nullptr if the resulting layout struct would be empty.
509static CXXRecordDecl *createHostLayoutStruct(Sema &S,
510 CXXRecordDecl *StructDecl) {
511 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
512 "struct is already HLSL buffer compatible");
513
514 ASTContext &AST = S.getASTContext();
515 DeclContext *DC = StructDecl->getDeclContext();
516 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: StructDecl, MustBeUnique: false);
517
518 // reuse existing if the layout struct if it already exists
519 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
520 return RD;
521
522 CXXRecordDecl *LS =
523 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC, StartLoc: SourceLocation(),
524 IdLoc: SourceLocation(), Id: II);
525 LS->setImplicit(true);
526 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
527 LS->startDefinition();
528
529 // copy base struct, create HLSL Buffer compatible version if needed
530 if (unsigned NumBases = StructDecl->getNumBases()) {
531 assert(NumBases == 1 && "HLSL supports only one base type");
532 (void)NumBases;
533 CXXBaseSpecifier Base = *StructDecl->bases_begin();
534 CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
535 if (requiresImplicitBufferLayoutStructure(RD: BaseDecl)) {
536 BaseDecl = createHostLayoutStruct(S, StructDecl: BaseDecl);
537 if (BaseDecl) {
538 TypeSourceInfo *TSI =
539 AST.getTrivialTypeSourceInfo(T: AST.getCanonicalTagType(TD: BaseDecl));
540 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
541 AS_none, TSI, SourceLocation());
542 }
543 }
544 if (BaseDecl) {
545 const CXXBaseSpecifier *BasesArray[1] = {&Base};
546 LS->setBases(Bases: BasesArray, NumBases: 1);
547 }
548 }
549
550 // filter struct fields
551 for (const FieldDecl *FD : StructDecl->fields()) {
552 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
553 if (FieldDecl *NewFD =
554 createFieldForHostLayoutStruct(S, Ty, II: FD->getIdentifier(), LayoutStruct: LS))
555 LS->addDecl(D: NewFD);
556 }
557 LS->completeDefinition();
558
559 if (LS->field_empty() && LS->getNumBases() == 0)
560 return nullptr;
561
562 DC->addDecl(D: LS);
563 return LS;
564}
565
566// Creates host layout struct for HLSL Buffer. The struct will include only
567// fields of types that are allowed in HLSL buffer and it will filter out:
568// - static or groupshared variable declarations
569// - resource classes
570// - empty structs
571// - zero-sized arrays
572// - non-variable declarations
573// The layout struct will be added to the HLSLBufferDecl declarations.
574static void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
575 ASTContext &AST = S.getASTContext();
576 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: BufDecl, MustBeUnique: true);
577
578 CXXRecordDecl *LS =
579 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC: BufDecl,
580 StartLoc: SourceLocation(), IdLoc: SourceLocation(), Id: II);
581 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
582 LS->setImplicit(true);
583 LS->startDefinition();
584
585 for (Decl *D : BufDecl->buffer_decls()) {
586 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
587 if (!VD || VD->getStorageClass() == SC_Static ||
588 VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
589 continue;
590 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
591
592 FieldDecl *FD =
593 createFieldForHostLayoutStruct(S, Ty, II: VD->getIdentifier(), LayoutStruct: LS);
594 // Declarations collected for the default $Globals constant buffer have
595 // already been checked to have non-empty cbuffer layout, so
596 // createFieldForHostLayoutStruct should always succeed. These declarations
597 // already have their address space set to hlsl_constant.
598 // For declarations in a named cbuffer block
599 // createFieldForHostLayoutStruct can still return nullptr if the type
600 // is empty (does not have a cbuffer layout).
601 assert((FD || VD->getType().getAddressSpace() != LangAS::hlsl_constant) &&
602 "host layout field for $Globals decl failed to be created");
603 if (FD) {
604 // Add the field decl to the layout struct.
605 LS->addDecl(D: FD);
606 if (VD->getType().getAddressSpace() != LangAS::hlsl_constant) {
607 // Update address space of the original decl to hlsl_constant.
608 QualType NewTy =
609 AST.getAddrSpaceQualType(T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
610 VD->setType(NewTy);
611 }
612 }
613 }
614 LS->completeDefinition();
615 BufDecl->addLayoutStruct(LS);
616}
617
618static void addImplicitBindingAttrToDecl(Sema &S, Decl *D, RegisterType RT,
619 uint32_t ImplicitBindingOrderID) {
620 auto *Attr =
621 HLSLResourceBindingAttr::CreateImplicit(Ctx&: S.getASTContext(), Slot: "", Space: "0", Range: {});
622 Attr->setBinding(RT, SlotNum: std::nullopt, SpaceNum: 0);
623 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
624 D->addAttr(A: Attr);
625}
626
627// Handle end of cbuffer/tbuffer declaration
628void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
629 auto *BufDecl = cast<HLSLBufferDecl>(Val: Dcl);
630 BufDecl->setRBraceLoc(RBrace);
631
632 validatePackoffset(S&: SemaRef, BufDecl);
633
634 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl);
635
636 // Handle implicit binding if needed.
637 ResourceBindingAttrs ResourceAttrs(Dcl);
638 if (!ResourceAttrs.isExplicit()) {
639 SemaRef.Diag(Loc: Dcl->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
640 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
641 // to codegen. If it does not exist, create an implicit attribute.
642 uint32_t OrderID = getNextImplicitBindingOrderID();
643 if (ResourceAttrs.hasBinding())
644 ResourceAttrs.setImplicitOrderID(OrderID);
645 else
646 addImplicitBindingAttrToDecl(S&: SemaRef, D: BufDecl,
647 RT: BufDecl->isCBuffer() ? RegisterType::CBuffer
648 : RegisterType::SRV,
649 ImplicitBindingOrderID: OrderID);
650 }
651
652 SemaRef.PopDeclContext();
653}
654
655HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
656 const AttributeCommonInfo &AL,
657 int X, int Y, int Z) {
658 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
659 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
660 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
661 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
662 }
663 return nullptr;
664 }
665 return ::new (getASTContext())
666 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
667}
668
669HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
670 const AttributeCommonInfo &AL,
671 int Min, int Max, int Preferred,
672 int SpelledArgsCount) {
673 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
674 if (WS->getMin() != Min || WS->getMax() != Max ||
675 WS->getPreferred() != Preferred ||
676 WS->getSpelledArgsCount() != SpelledArgsCount) {
677 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
678 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
679 }
680 return nullptr;
681 }
682 HLSLWaveSizeAttr *Result = ::new (getASTContext())
683 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
684 Result->setSpelledArgsCount(SpelledArgsCount);
685 return Result;
686}
687
688HLSLVkConstantIdAttr *
689SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
690 int Id) {
691
692 auto &TargetInfo = getASTContext().getTargetInfo();
693 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
694 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_ignored) << AL;
695 return nullptr;
696 }
697
698 auto *VD = cast<VarDecl>(Val: D);
699
700 if (getSpecConstBuiltinId(Type: VD->getType()->getUnqualifiedDesugaredType()) ==
701 Builtin::NotBuiltin) {
702 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
703 return nullptr;
704 }
705
706 if (!VD->getType().isConstQualified()) {
707 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
708 return nullptr;
709 }
710
711 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
712 if (CI->getId() != Id) {
713 Diag(Loc: CI->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
714 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
715 }
716 return nullptr;
717 }
718
719 HLSLVkConstantIdAttr *Result =
720 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
721 return Result;
722}
723
724HLSLShaderAttr *
725SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
726 llvm::Triple::EnvironmentType ShaderType) {
727 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
728 if (NT->getType() != ShaderType) {
729 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
730 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
731 }
732 return nullptr;
733 }
734 return HLSLShaderAttr::Create(Ctx&: getASTContext(), Type: ShaderType, CommonInfo: AL);
735}
736
737HLSLParamModifierAttr *
738SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
739 HLSLParamModifierAttr::Spelling Spelling) {
740 // We can only merge an `in` attribute with an `out` attribute. All other
741 // combinations of duplicated attributes are ill-formed.
742 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
743 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
744 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
745 D->dropAttr<HLSLParamModifierAttr>();
746 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
747 return HLSLParamModifierAttr::Create(
748 Ctx&: getASTContext(), /*MergedSpelling=*/true, Range: AdjustedRange,
749 S: HLSLParamModifierAttr::Keyword_inout);
750 }
751 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_duplicate_parameter_modifier) << AL;
752 Diag(Loc: PA->getLocation(), DiagID: diag::note_conflicting_attribute);
753 return nullptr;
754 }
755 return HLSLParamModifierAttr::Create(Ctx&: getASTContext(), CommonInfo: AL);
756}
757
758void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
759 auto &TargetInfo = getASTContext().getTargetInfo();
760
761 if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
762 return;
763
764 // If we have specified a root signature to override the entry function then
765 // attach it now
766 HLSLRootSignatureDecl *SignatureDecl =
767 lookupRootSignatureOverrideDecl(DC: FD->getDeclContext());
768 if (SignatureDecl) {
769 FD->dropAttr<RootSignatureAttr>();
770 // We could look up the SourceRange of the macro here as well
771 AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(),
772 SourceRange(), ParsedAttr::Form::Microsoft());
773 FD->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
774 getASTContext(), AL, RootSigOverrideIdent, SignatureDecl));
775 }
776
777 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
778 if (HLSLShaderAttr::isValidShaderType(ShaderType: Env) && Env != llvm::Triple::Library) {
779 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
780 // The entry point is already annotated - check that it matches the
781 // triple.
782 if (Shader->getType() != Env) {
783 Diag(Loc: Shader->getLocation(), DiagID: diag::err_hlsl_entry_shader_attr_mismatch)
784 << Shader;
785 FD->setInvalidDecl();
786 }
787 } else {
788 // Implicitly add the shader attribute if the entry function isn't
789 // explicitly annotated.
790 FD->addAttr(A: HLSLShaderAttr::CreateImplicit(Ctx&: getASTContext(), Type: Env,
791 Range: FD->getBeginLoc()));
792 }
793 } else {
794 switch (Env) {
795 case llvm::Triple::UnknownEnvironment:
796 case llvm::Triple::Library:
797 break;
798 case llvm::Triple::RootSignature:
799 llvm_unreachable("rootsig environment has no functions");
800 default:
801 llvm_unreachable("Unhandled environment in triple");
802 }
803 }
804}
805
806static bool isVkPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD,
807 HLSLAppliedSemanticAttr *Semantic,
808 bool IsInput) {
809 if (AstContext.getTargetInfo().getTriple().getOS() != llvm::Triple::Vulkan)
810 return false;
811
812 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
813 assert(ShaderAttr && "Entry point has no shader attribute");
814 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
815 auto SemanticName = Semantic->getSemanticName().upper();
816
817 // The SV_Position semantic is lowered to:
818 // - Position built-in for vertex output.
819 // - FragCoord built-in for fragment input.
820 if (SemanticName == "SV_POSITION") {
821 return (ST == llvm::Triple::Vertex && !IsInput) ||
822 (ST == llvm::Triple::Pixel && IsInput);
823 }
824 if (SemanticName == "SV_VERTEXID")
825 return true;
826
827 return false;
828}
829
830bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
831 DeclaratorDecl *OutputDecl,
832 DeclaratorDecl *D,
833 SemanticInfo &ActiveSemantic,
834 SemaHLSL::SemanticContext &SC) {
835 if (ActiveSemantic.Semantic == nullptr) {
836 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
837 if (ActiveSemantic.Semantic)
838 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
839 }
840
841 if (!ActiveSemantic.Semantic) {
842 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_missing_semantic_annotation);
843 return false;
844 }
845
846 auto *A = ::new (getASTContext())
847 HLSLAppliedSemanticAttr(getASTContext(), *ActiveSemantic.Semantic,
848 ActiveSemantic.Semantic->getAttrName()->getName(),
849 ActiveSemantic.Index.value_or(u: 0));
850 if (!A)
851 return false;
852
853 checkSemanticAnnotation(EntryPoint: FD, Param: D, SemanticAttr: A, SC);
854 OutputDecl->addAttr(A);
855
856 unsigned Location = ActiveSemantic.Index.value_or(u: 0);
857
858 if (!isVkPipelineBuiltin(AstContext: getASTContext(), FD, Semantic: A,
859 IsInput: SC.CurrentIOType & IOType::In)) {
860 bool HasVkLocation = false;
861 if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
862 HasVkLocation = true;
863 Location = A->getLocation();
864 }
865
866 if (SC.UsesExplicitVkLocations.value_or(u&: HasVkLocation) != HasVkLocation) {
867 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_semantic_partial_explicit_indexing);
868 return false;
869 }
870 SC.UsesExplicitVkLocations = HasVkLocation;
871 }
872
873 const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(Val: D->getType());
874 unsigned ElementCount = AT ? AT->getZExtSize() : 1;
875 ActiveSemantic.Index = Location + ElementCount;
876
877 Twine BaseName = Twine(ActiveSemantic.Semantic->getAttrName()->getName());
878 for (unsigned I = 0; I < ElementCount; ++I) {
879 Twine VariableName = BaseName.concat(Suffix: Twine(Location + I));
880
881 auto [_, Inserted] = SC.ActiveSemantics.insert(key: VariableName.str());
882 if (!Inserted) {
883 Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_semantic_index_overlap)
884 << VariableName.str();
885 return false;
886 }
887 }
888
889 return true;
890}
891
892bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
893 DeclaratorDecl *OutputDecl,
894 DeclaratorDecl *D,
895 SemanticInfo &ActiveSemantic,
896 SemaHLSL::SemanticContext &SC) {
897 if (ActiveSemantic.Semantic == nullptr) {
898 ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
899 if (ActiveSemantic.Semantic)
900 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
901 }
902
903 const Type *T = D == FD ? &*FD->getReturnType() : &*D->getType();
904 T = T->getUnqualifiedDesugaredType();
905
906 const RecordType *RT = dyn_cast<RecordType>(Val: T);
907 if (!RT)
908 return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
909 SC);
910
911 const RecordDecl *RD = RT->getDecl();
912 for (FieldDecl *Field : RD->fields()) {
913 SemanticInfo Info = ActiveSemantic;
914 if (!determineActiveSemantic(FD, OutputDecl, D: Field, ActiveSemantic&: Info, SC)) {
915 Diag(Loc: Field->getLocation(), DiagID: diag::note_hlsl_semantic_used_here) << Field;
916 return false;
917 }
918 if (ActiveSemantic.Semantic)
919 ActiveSemantic = Info;
920 }
921
922 return true;
923}
924
925void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
926 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
927 assert(ShaderAttr && "Entry point has no shader attribute");
928 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
929 auto &TargetInfo = getASTContext().getTargetInfo();
930 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
931 switch (ST) {
932 case llvm::Triple::Pixel:
933 case llvm::Triple::Vertex:
934 case llvm::Triple::Geometry:
935 case llvm::Triple::Hull:
936 case llvm::Triple::Domain:
937 case llvm::Triple::RayGeneration:
938 case llvm::Triple::Intersection:
939 case llvm::Triple::AnyHit:
940 case llvm::Triple::ClosestHit:
941 case llvm::Triple::Miss:
942 case llvm::Triple::Callable:
943 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
944 diagnoseAttrStageMismatch(A: NT, Stage: ST,
945 AllowedStages: {llvm::Triple::Compute,
946 llvm::Triple::Amplification,
947 llvm::Triple::Mesh});
948 FD->setInvalidDecl();
949 }
950 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
951 diagnoseAttrStageMismatch(A: WS, Stage: ST,
952 AllowedStages: {llvm::Triple::Compute,
953 llvm::Triple::Amplification,
954 llvm::Triple::Mesh});
955 FD->setInvalidDecl();
956 }
957 break;
958
959 case llvm::Triple::Compute:
960 case llvm::Triple::Amplification:
961 case llvm::Triple::Mesh:
962 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
963 Diag(Loc: FD->getLocation(), DiagID: diag::err_hlsl_missing_numthreads)
964 << llvm::Triple::getEnvironmentTypeName(Kind: ST);
965 FD->setInvalidDecl();
966 }
967 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
968 if (Ver < VersionTuple(6, 6)) {
969 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_in_wrong_shader_model)
970 << WS << "6.6";
971 FD->setInvalidDecl();
972 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
973 Diag(
974 Loc: WS->getLocation(),
975 DiagID: diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
976 << WS << WS->getSpelledArgsCount() << "6.8";
977 FD->setInvalidDecl();
978 }
979 }
980 break;
981 case llvm::Triple::RootSignature:
982 llvm_unreachable("rootsig environment has no function entry point");
983 default:
984 llvm_unreachable("Unhandled environment in triple");
985 }
986
987 SemaHLSL::SemanticContext InputSC = {};
988 InputSC.CurrentIOType = IOType::In;
989
990 for (ParmVarDecl *Param : FD->parameters()) {
991 SemanticInfo ActiveSemantic;
992 ActiveSemantic.Semantic = Param->getAttr<HLSLParsedSemanticAttr>();
993 if (ActiveSemantic.Semantic)
994 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
995
996 // FIXME: Verify output semantics in parameters.
997 if (!determineActiveSemantic(FD, OutputDecl: Param, D: Param, ActiveSemantic, SC&: InputSC)) {
998 Diag(Loc: Param->getLocation(), DiagID: diag::note_previous_decl) << Param;
999 FD->setInvalidDecl();
1000 }
1001 }
1002
1003 SemanticInfo ActiveSemantic;
1004 SemaHLSL::SemanticContext OutputSC = {};
1005 OutputSC.CurrentIOType = IOType::Out;
1006 ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
1007 if (ActiveSemantic.Semantic)
1008 ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
1009 if (!FD->getReturnType()->isVoidType())
1010 determineActiveSemantic(FD, OutputDecl: FD, D: FD, ActiveSemantic, SC&: OutputSC);
1011}
1012
1013void SemaHLSL::checkSemanticAnnotation(
1014 FunctionDecl *EntryPoint, const Decl *Param,
1015 const HLSLAppliedSemanticAttr *SemanticAttr, const SemanticContext &SC) {
1016 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
1017 assert(ShaderAttr && "Entry point has no shader attribute");
1018 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
1019
1020 auto SemanticName = SemanticAttr->getSemanticName().upper();
1021 if (SemanticName == "SV_DISPATCHTHREADID" ||
1022 SemanticName == "SV_GROUPINDEX" || SemanticName == "SV_GROUPTHREADID" ||
1023 SemanticName == "SV_GROUPID") {
1024
1025 if (ST != llvm::Triple::Compute)
1026 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1027 AllowedStages: {{.Stage: llvm::Triple::Compute, .AllowedIOTypesMask: IOType::In}});
1028
1029 if (SemanticAttr->getSemanticIndex() != 0) {
1030 std::string PrettyName =
1031 "'" + SemanticAttr->getSemanticName().str() + "'";
1032 Diag(Loc: SemanticAttr->getLoc(),
1033 DiagID: diag::err_hlsl_semantic_indexing_not_supported)
1034 << PrettyName;
1035 }
1036 return;
1037 }
1038
1039 if (SemanticName == "SV_POSITION") {
1040 // SV_Position can be an input or output in vertex shaders,
1041 // but only an input in pixel shaders.
1042 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1043 AllowedStages: {{.Stage: llvm::Triple::Vertex, .AllowedIOTypesMask: IOType::InOut},
1044 {.Stage: llvm::Triple::Pixel, .AllowedIOTypesMask: IOType::In}});
1045 return;
1046 }
1047 if (SemanticName == "SV_VERTEXID") {
1048 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1049 AllowedStages: {{.Stage: llvm::Triple::Vertex, .AllowedIOTypesMask: IOType::In}});
1050 return;
1051 }
1052
1053 if (SemanticName == "SV_TARGET") {
1054 diagnoseSemanticStageMismatch(A: SemanticAttr, Stage: ST, CurrentIOType: SC.CurrentIOType,
1055 AllowedStages: {{.Stage: llvm::Triple::Pixel, .AllowedIOTypesMask: IOType::Out}});
1056 return;
1057 }
1058
1059 // FIXME: catch-all for non-implemented system semantics reaching this
1060 // location.
1061 if (SemanticAttr->getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
1062 llvm_unreachable("Unknown SemanticAttr");
1063}
1064
1065void SemaHLSL::diagnoseAttrStageMismatch(
1066 const Attr *A, llvm::Triple::EnvironmentType Stage,
1067 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
1068 SmallVector<StringRef, 8> StageStrings;
1069 llvm::transform(Range&: AllowedStages, d_first: std::back_inserter(x&: StageStrings),
1070 F: [](llvm::Triple::EnvironmentType ST) {
1071 return StringRef(
1072 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: ST));
1073 });
1074 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
1075 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
1076 << (AllowedStages.size() != 1) << join(R&: StageStrings, Separator: ", ");
1077}
1078
1079void SemaHLSL::diagnoseSemanticStageMismatch(
1080 const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
1081 std::initializer_list<SemanticStageInfo> Allowed) {
1082
1083 for (auto &Case : Allowed) {
1084 if (Case.Stage != Stage)
1085 continue;
1086
1087 if (CurrentIOType & Case.AllowedIOTypesMask)
1088 return;
1089
1090 SmallVector<std::string, 8> ValidCases;
1091 llvm::transform(
1092 Range&: Allowed, d_first: std::back_inserter(x&: ValidCases), F: [](SemanticStageInfo Case) {
1093 SmallVector<std::string, 2> ValidType;
1094 if (Case.AllowedIOTypesMask & IOType::In)
1095 ValidType.push_back(Elt: "input");
1096 if (Case.AllowedIOTypesMask & IOType::Out)
1097 ValidType.push_back(Elt: "output");
1098 return std::string(
1099 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: Case.Stage)) +
1100 " " + join(R&: ValidType, Separator: "/");
1101 });
1102 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_semantic_unsupported_iotype_for_stage)
1103 << A->getAttrName() << (CurrentIOType & IOType::In ? "input" : "output")
1104 << llvm::Triple::getEnvironmentTypeName(Kind: Case.Stage)
1105 << join(R&: ValidCases, Separator: ", ");
1106 return;
1107 }
1108
1109 SmallVector<StringRef, 8> StageStrings;
1110 llvm::transform(
1111 Range&: Allowed, d_first: std::back_inserter(x&: StageStrings), F: [](SemanticStageInfo Case) {
1112 return StringRef(
1113 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: Case.Stage));
1114 });
1115
1116 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
1117 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
1118 << (Allowed.size() != 1) << join(R&: StageStrings, Separator: ", ");
1119}
1120
1121template <CastKind Kind>
1122static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
1123 if (const auto *VTy = Ty->getAs<VectorType>())
1124 Ty = VTy->getElementType();
1125 Ty = S.getASTContext().getExtVectorType(VectorType: Ty, NumElts: Sz);
1126 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
1127}
1128
1129template <CastKind Kind>
1130static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
1131 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
1132 return Ty;
1133}
1134
1135static QualType handleFloatVectorBinOpConversion(
1136 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1137 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1138 bool LHSFloat = LElTy->isRealFloatingType();
1139 bool RHSFloat = RElTy->isRealFloatingType();
1140
1141 if (LHSFloat && RHSFloat) {
1142 if (IsCompAssign ||
1143 SemaRef.getASTContext().getFloatingTypeOrder(LHS: LElTy, RHS: RElTy) > 0)
1144 return castElement<CK_FloatingCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1145
1146 return castElement<CK_FloatingCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1147 }
1148
1149 if (LHSFloat)
1150 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: RHS, Ty: LHSType);
1151
1152 assert(RHSFloat);
1153 if (IsCompAssign)
1154 return castElement<clang::CK_FloatingToIntegral>(S&: SemaRef, E&: RHS, Ty: LHSType);
1155
1156 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: LHS, Ty: RHSType);
1157}
1158
1159static QualType handleIntegerVectorBinOpConversion(
1160 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
1161 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
1162
1163 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LHS: LElTy, RHS: RElTy);
1164 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
1165 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
1166 auto &Ctx = SemaRef.getASTContext();
1167
1168 // If both types have the same signedness, use the higher ranked type.
1169 if (LHSSigned == RHSSigned) {
1170 if (IsCompAssign || IntOrder >= 0)
1171 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1172
1173 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1174 }
1175
1176 // If the unsigned type has greater than or equal rank of the signed type, use
1177 // the unsigned type.
1178 if (IntOrder != (LHSSigned ? 1 : -1)) {
1179 if (IsCompAssign || RHSSigned)
1180 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1181 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1182 }
1183
1184 // At this point the signed type has higher rank than the unsigned type, which
1185 // means it will be the same size or bigger. If the signed type is bigger, it
1186 // can represent all the values of the unsigned type, so select it.
1187 if (Ctx.getIntWidth(T: LElTy) != Ctx.getIntWidth(T: RElTy)) {
1188 if (IsCompAssign || LHSSigned)
1189 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1190 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
1191 }
1192
1193 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
1194 // to C/C++ leaking through. The place this happens today is long vs long
1195 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
1196 // the long long has higher rank than long even though they are the same size.
1197
1198 // If this is a compound assignment cast the right hand side to the left hand
1199 // side's type.
1200 if (IsCompAssign)
1201 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
1202
1203 // If this isn't a compound assignment we convert to unsigned long long.
1204 QualType ElTy = Ctx.getCorrespondingUnsignedType(T: LHSSigned ? LElTy : RElTy);
1205 QualType NewTy = Ctx.getExtVectorType(
1206 VectorType: ElTy, NumElts: RHSType->castAs<VectorType>()->getNumElements());
1207 (void)castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: NewTy);
1208
1209 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: NewTy);
1210}
1211
1212static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
1213 QualType SrcTy) {
1214 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
1215 return CK_FloatingCast;
1216 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
1217 return CK_IntegralCast;
1218 if (DestTy->isRealFloatingType())
1219 return CK_IntegralToFloating;
1220 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
1221 return CK_FloatingToIntegral;
1222}
1223
1224QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
1225 QualType LHSType,
1226 QualType RHSType,
1227 bool IsCompAssign) {
1228 const auto *LVecTy = LHSType->getAs<VectorType>();
1229 const auto *RVecTy = RHSType->getAs<VectorType>();
1230 auto &Ctx = getASTContext();
1231
1232 // If the LHS is not a vector and this is a compound assignment, we truncate
1233 // the argument to a scalar then convert it to the LHS's type.
1234 if (!LVecTy && IsCompAssign) {
1235 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1236 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: RElTy, CK: CK_HLSLVectorTruncation);
1237 RHSType = RHS.get()->getType();
1238 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1239 return LHSType;
1240 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: LHSType,
1241 CK: getScalarCastKind(Ctx, DestTy: LHSType, SrcTy: RHSType));
1242 return LHSType;
1243 }
1244
1245 unsigned EndSz = std::numeric_limits<unsigned>::max();
1246 unsigned LSz = 0;
1247 if (LVecTy)
1248 LSz = EndSz = LVecTy->getNumElements();
1249 if (RVecTy)
1250 EndSz = std::min(a: RVecTy->getNumElements(), b: EndSz);
1251 assert(EndSz != std::numeric_limits<unsigned>::max() &&
1252 "one of the above should have had a value");
1253
1254 // In a compound assignment, the left operand does not change type, the right
1255 // operand is converted to the type of the left operand.
1256 if (IsCompAssign && LSz != EndSz) {
1257 Diag(Loc: LHS.get()->getBeginLoc(),
1258 DiagID: diag::err_hlsl_vector_compound_assignment_truncation)
1259 << LHSType << RHSType;
1260 return QualType();
1261 }
1262
1263 if (RVecTy && RVecTy->getNumElements() > EndSz)
1264 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1265 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1266 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1267
1268 if (!RVecTy)
1269 castVector<CK_VectorSplat>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1270 if (!IsCompAssign && !LVecTy)
1271 castVector<CK_VectorSplat>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1272
1273 // If we're at the same type after resizing we can stop here.
1274 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1275 return Ctx.getCommonSugaredType(X: LHSType, Y: RHSType);
1276
1277 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1278 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1279
1280 // Handle conversion for floating point vectors.
1281 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1282 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1283 LElTy, RElTy, IsCompAssign);
1284
1285 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1286 "HLSL Vectors can only contain integer or floating point types");
1287 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1288 LElTy, RElTy, IsCompAssign);
1289}
1290
1291void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
1292 BinaryOperatorKind Opc) {
1293 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1294 "Called with non-logical operator");
1295 llvm::SmallVector<char, 256> Buff;
1296 llvm::raw_svector_ostream OS(Buff);
1297 PrintingPolicy PP(SemaRef.getLangOpts());
1298 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1299 OS << NewFnName << "(";
1300 LHS->printPretty(OS, Helper: nullptr, Policy: PP);
1301 OS << ", ";
1302 RHS->printPretty(OS, Helper: nullptr, Policy: PP);
1303 OS << ")";
1304 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1305 SemaRef.Diag(Loc: LHS->getBeginLoc(), DiagID: diag::note_function_suggestion)
1306 << NewFnName << FixItHint::CreateReplacement(RemoveRange: FullRange, Code: OS.str());
1307}
1308
1309std::pair<IdentifierInfo *, bool>
1310SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) {
1311 llvm::hash_code Hash = llvm::hash_value(S: Signature);
1312 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(val: Hash);
1313 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(Name: IdStr));
1314
1315 // Check if we have already found a decl of the same name.
1316 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1317 Sema::LookupOrdinaryName);
1318 bool Found = SemaRef.LookupQualifiedName(R, LookupCtx: SemaRef.CurContext);
1319 return {DeclIdent, Found};
1320}
1321
1322void SemaHLSL::ActOnFinishRootSignatureDecl(
1323 SourceLocation Loc, IdentifierInfo *DeclIdent,
1324 ArrayRef<hlsl::RootSignatureElement> RootElements) {
1325
1326 if (handleRootSignatureElements(Elements: RootElements))
1327 return;
1328
1329 SmallVector<llvm::hlsl::rootsig::RootElement> Elements;
1330 for (auto &RootSigElement : RootElements)
1331 Elements.push_back(Elt: RootSigElement.getElement());
1332
1333 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1334 C&: SemaRef.getASTContext(), /*DeclContext=*/DC: SemaRef.CurContext, Loc,
1335 ID: DeclIdent, Version: SemaRef.getLangOpts().HLSLRootSigVer, RootElements: Elements);
1336
1337 SignatureDecl->setImplicit();
1338 SemaRef.PushOnScopeChains(D: SignatureDecl, S: SemaRef.getCurScope());
1339}
1340
1341HLSLRootSignatureDecl *
1342SemaHLSL::lookupRootSignatureOverrideDecl(DeclContext *DC) const {
1343 if (RootSigOverrideIdent) {
1344 LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(),
1345 Sema::LookupOrdinaryName);
1346 if (SemaRef.LookupQualifiedName(R, LookupCtx: DC))
1347 return dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl());
1348 }
1349
1350 return nullptr;
1351}
1352
1353namespace {
1354
1355struct PerVisibilityBindingChecker {
1356 SemaHLSL *S;
1357 // We need one builder per `llvm::dxbc::ShaderVisibility` value.
1358 std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
1359
1360 struct ElemInfo {
1361 const hlsl::RootSignatureElement *Elem;
1362 llvm::dxbc::ShaderVisibility Vis;
1363 bool Diagnosed;
1364 };
1365 llvm::SmallVector<ElemInfo> ElemInfoMap;
1366
1367 PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
1368
1369 void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
1370 llvm::dxil::ResourceClass RC, uint32_t Space,
1371 uint32_t LowerBound, uint32_t UpperBound,
1372 const hlsl::RootSignatureElement *Elem) {
1373 uint32_t BuilderIndex = llvm::to_underlying(E: Visibility);
1374 assert(BuilderIndex < Builders.size() &&
1375 "Not enough builders for visibility type");
1376 Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
1377 Cookie: static_cast<const void *>(Elem));
1378
1379 static_assert(llvm::to_underlying(E: llvm::dxbc::ShaderVisibility::All) == 0,
1380 "'All' visibility must come first");
1381 if (Visibility == llvm::dxbc::ShaderVisibility::All)
1382 for (size_t I = 1, E = Builders.size(); I < E; ++I)
1383 Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
1384 Cookie: static_cast<const void *>(Elem));
1385
1386 ElemInfoMap.push_back(Elt: {.Elem: Elem, .Vis: Visibility, .Diagnosed: false});
1387 }
1388
1389 ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
1390 auto It = llvm::lower_bound(
1391 Range&: ElemInfoMap, Value&: Elem,
1392 C: [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
1393 assert(It->Elem == Elem && "Element not in map");
1394 return *It;
1395 }
1396
1397 bool checkOverlap() {
1398 llvm::sort(C&: ElemInfoMap, Comp: [](const auto &LHS, const auto &RHS) {
1399 return LHS.Elem < RHS.Elem;
1400 });
1401
1402 bool HadOverlap = false;
1403
1404 using llvm::hlsl::BindingInfoBuilder;
1405 auto ReportOverlap = [this,
1406 &HadOverlap](const BindingInfoBuilder &Builder,
1407 const llvm::hlsl::Binding &Reported) {
1408 HadOverlap = true;
1409
1410 const auto *Elem =
1411 static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
1412 const llvm::hlsl::Binding &Previous = Builder.findOverlapping(ReportedBinding: Reported);
1413 const auto *PrevElem =
1414 static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
1415
1416 ElemInfo &Info = getInfo(Elem);
1417 // We will have already diagnosed this binding if there's overlap in the
1418 // "All" visibility as well as any particular visibility.
1419 if (Info.Diagnosed)
1420 return;
1421 Info.Diagnosed = true;
1422
1423 ElemInfo &PrevInfo = getInfo(Elem: PrevElem);
1424 llvm::dxbc::ShaderVisibility CommonVis =
1425 Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
1426 : Info.Vis;
1427
1428 this->S->Diag(Loc: Elem->getLocation(), DiagID: diag::err_hlsl_resource_range_overlap)
1429 << llvm::to_underlying(E: Reported.RC) << Reported.LowerBound
1430 << Reported.isUnbounded() << Reported.UpperBound
1431 << llvm::to_underlying(E: Previous.RC) << Previous.LowerBound
1432 << Previous.isUnbounded() << Previous.UpperBound << Reported.Space
1433 << CommonVis;
1434
1435 this->S->Diag(Loc: PrevElem->getLocation(),
1436 DiagID: diag::note_hlsl_resource_range_here);
1437 };
1438
1439 for (BindingInfoBuilder &Builder : Builders)
1440 Builder.calculateBindingInfo(ReportOverlap);
1441
1442 return HadOverlap;
1443 }
1444};
1445
1446static CXXMethodDecl *lookupMethod(Sema &S, CXXRecordDecl *RecordDecl,
1447 StringRef Name, SourceLocation Loc) {
1448 DeclarationName DeclName(&S.getASTContext().Idents.get(Name));
1449 LookupResult Result(S, DeclName, Loc, Sema::LookupMemberName);
1450 if (!S.LookupQualifiedName(R&: Result, LookupCtx: static_cast<DeclContext *>(RecordDecl)))
1451 return nullptr;
1452 return cast<CXXMethodDecl>(Val: Result.getFoundDecl());
1453}
1454
1455} // end anonymous namespace
1456
1457static bool hasCounterHandle(const CXXRecordDecl *RD) {
1458 if (RD->field_empty())
1459 return false;
1460 auto It = std::next(x: RD->field_begin());
1461 if (It == RD->field_end())
1462 return false;
1463 const FieldDecl *SecondField = *It;
1464 if (const auto *ResTy =
1465 SecondField->getType()->getAs<HLSLAttributedResourceType>()) {
1466 return ResTy->getAttrs().IsCounter;
1467 }
1468 return false;
1469}
1470
1471bool SemaHLSL::handleRootSignatureElements(
1472 ArrayRef<hlsl::RootSignatureElement> Elements) {
1473 // Define some common error handling functions
1474 bool HadError = false;
1475 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1476 uint32_t UpperBound) {
1477 HadError = true;
1478 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1479 << LowerBound << UpperBound;
1480 };
1481
1482 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1483 float LowerBound,
1484 float UpperBound) {
1485 HadError = true;
1486 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1487 << llvm::formatv(Fmt: "{0:f}", Vals&: LowerBound).sstr<6>()
1488 << llvm::formatv(Fmt: "{0:f}", Vals&: UpperBound).sstr<6>();
1489 };
1490
1491 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1492 if (!llvm::hlsl::rootsig::verifyRegisterValue(RegisterValue: Register))
1493 ReportError(Loc, 0, 0xfffffffe);
1494 };
1495
1496 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1497 if (!llvm::hlsl::rootsig::verifyRegisterSpace(RegisterSpace: Space))
1498 ReportError(Loc, 0, 0xffffffef);
1499 };
1500
1501 const uint32_t Version =
1502 llvm::to_underlying(E: SemaRef.getLangOpts().HLSLRootSigVer);
1503 const uint32_t VersionEnum = Version - 1;
1504 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1505 HadError = true;
1506 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_flag)
1507 << /*version minor*/ VersionEnum;
1508 };
1509
1510 // Iterate through the elements and do basic validations
1511 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1512 SourceLocation Loc = RootSigElem.getLocation();
1513 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1514 if (const auto *Descriptor =
1515 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1516 VerifyRegister(Loc, Descriptor->Reg.Number);
1517 VerifySpace(Loc, Descriptor->Space);
1518
1519 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
1520 Flags: Descriptor->Flags))
1521 ReportFlagError(Loc);
1522 } else if (const auto *Constants =
1523 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1524 VerifyRegister(Loc, Constants->Reg.Number);
1525 VerifySpace(Loc, Constants->Space);
1526 } else if (const auto *Sampler =
1527 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1528 VerifyRegister(Loc, Sampler->Reg.Number);
1529 VerifySpace(Loc, Sampler->Space);
1530
1531 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1532 "By construction, parseFloatParam can't produce a NaN from a "
1533 "float_literal token");
1534
1535 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(MaxAnisotropy: Sampler->MaxAnisotropy))
1536 ReportError(Loc, 0, 16);
1537 if (!llvm::hlsl::rootsig::verifyMipLODBias(MipLODBias: Sampler->MipLODBias))
1538 ReportFloatError(Loc, -16.f, 15.99f);
1539 } else if (const auto *Clause =
1540 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1541 ptr: &Elem)) {
1542 VerifyRegister(Loc, Clause->Reg.Number);
1543 VerifySpace(Loc, Clause->Space);
1544
1545 if (!llvm::hlsl::rootsig::verifyNumDescriptors(NumDescriptors: Clause->NumDescriptors)) {
1546 // NumDescriptor could techincally be ~0u but that is reserved for
1547 // unbounded, so the diagnostic will not report that as a valid int
1548 // value
1549 ReportError(Loc, 1, 0xfffffffe);
1550 }
1551
1552 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Type: Clause->Type,
1553 Flags: Clause->Flags))
1554 ReportFlagError(Loc);
1555 }
1556 }
1557
1558 PerVisibilityBindingChecker BindingChecker(this);
1559 SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
1560 const hlsl::RootSignatureElement *>>
1561 UnboundClauses;
1562
1563 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1564 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1565 if (const auto *Descriptor =
1566 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1567 uint32_t LowerBound(Descriptor->Reg.Number);
1568 uint32_t UpperBound(LowerBound); // inclusive range
1569
1570 BindingChecker.trackBinding(
1571 Visibility: Descriptor->Visibility,
1572 RC: static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
1573 Space: Descriptor->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1574 } else if (const auto *Constants =
1575 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1576 uint32_t LowerBound(Constants->Reg.Number);
1577 uint32_t UpperBound(LowerBound); // inclusive range
1578
1579 BindingChecker.trackBinding(
1580 Visibility: Constants->Visibility, RC: llvm::dxil::ResourceClass::CBuffer,
1581 Space: Constants->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1582 } else if (const auto *Sampler =
1583 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1584 uint32_t LowerBound(Sampler->Reg.Number);
1585 uint32_t UpperBound(LowerBound); // inclusive range
1586
1587 BindingChecker.trackBinding(
1588 Visibility: Sampler->Visibility, RC: llvm::dxil::ResourceClass::Sampler,
1589 Space: Sampler->Space, LowerBound, UpperBound, Elem: &RootSigElem);
1590 } else if (const auto *Clause =
1591 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1592 ptr: &Elem)) {
1593 // We'll process these once we see the table element.
1594 UnboundClauses.emplace_back(Args&: Clause, Args: &RootSigElem);
1595 } else if (const auto *Table =
1596 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(ptr: &Elem)) {
1597 assert(UnboundClauses.size() == Table->NumClauses &&
1598 "Number of unbound elements must match the number of clauses");
1599 bool HasAnySampler = false;
1600 bool HasAnyNonSampler = false;
1601 uint64_t Offset = 0;
1602 bool IsPrevUnbound = false;
1603 for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1604 SourceLocation Loc = ClauseElem->getLocation();
1605 if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
1606 HasAnySampler = true;
1607 else
1608 HasAnyNonSampler = true;
1609
1610 if (HasAnySampler && HasAnyNonSampler)
1611 Diag(Loc, DiagID: diag::err_hlsl_invalid_mixed_resources);
1612
1613 // Relevant error will have already been reported above and needs to be
1614 // fixed before we can conduct further analysis, so shortcut error
1615 // return
1616 if (Clause->NumDescriptors == 0)
1617 return true;
1618
1619 bool IsAppending =
1620 Clause->Offset == llvm::hlsl::rootsig::DescriptorTableOffsetAppend;
1621 if (!IsAppending)
1622 Offset = Clause->Offset;
1623
1624 uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
1625 Offset, Size: Clause->NumDescriptors);
1626
1627 if (IsPrevUnbound && IsAppending)
1628 Diag(Loc, DiagID: diag::err_hlsl_appending_onto_unbound);
1629 else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(Offset: RangeBound))
1630 Diag(Loc, DiagID: diag::err_hlsl_offset_overflow) << Offset << RangeBound;
1631
1632 // Update offset to be 1 past this range's bound
1633 Offset = RangeBound + 1;
1634 IsPrevUnbound = Clause->NumDescriptors ==
1635 llvm::hlsl::rootsig::NumDescriptorsUnbounded;
1636
1637 // Compute the register bounds and track resource binding
1638 uint32_t LowerBound(Clause->Reg.Number);
1639 uint32_t UpperBound = llvm::hlsl::rootsig::computeRangeBound(
1640 Offset: LowerBound, Size: Clause->NumDescriptors);
1641
1642 BindingChecker.trackBinding(
1643 Visibility: Table->Visibility,
1644 RC: static_cast<llvm::dxil::ResourceClass>(Clause->Type), Space: Clause->Space,
1645 LowerBound, UpperBound, Elem: ClauseElem);
1646 }
1647 UnboundClauses.clear();
1648 }
1649 }
1650
1651 return BindingChecker.checkOverlap();
1652}
1653
1654void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
1655 if (AL.getNumArgs() != 1) {
1656 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
1657 return;
1658 }
1659
1660 IdentifierInfo *Ident = AL.getArgAsIdent(Arg: 0)->getIdentifierInfo();
1661 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1662 if (RS->getSignatureIdent() != Ident) {
1663 Diag(Loc: AL.getLoc(), DiagID: diag::err_disallowed_duplicate_attribute) << RS;
1664 return;
1665 }
1666
1667 Diag(Loc: AL.getLoc(), DiagID: diag::warn_duplicate_attribute_exact) << RS;
1668 return;
1669 }
1670
1671 LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName);
1672 if (SemaRef.LookupQualifiedName(R, LookupCtx: D->getDeclContext()))
1673 if (auto *SignatureDecl =
1674 dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl())) {
1675 D->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
1676 getASTContext(), AL, Ident, SignatureDecl));
1677 }
1678}
1679
1680void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
1681 llvm::VersionTuple SMVersion =
1682 getASTContext().getTargetInfo().getTriple().getOSVersion();
1683 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1684 llvm::Triple::dxil;
1685
1686 uint32_t ZMax = 1024;
1687 uint32_t ThreadMax = 1024;
1688 if (IsDXIL && SMVersion.getMajor() <= 4) {
1689 ZMax = 1;
1690 ThreadMax = 768;
1691 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1692 ZMax = 64;
1693 ThreadMax = 1024;
1694 }
1695
1696 uint32_t X;
1697 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: X))
1698 return;
1699 if (X > 1024) {
1700 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1701 DiagID: diag::err_hlsl_numthreads_argument_oor)
1702 << 0 << 1024;
1703 return;
1704 }
1705 uint32_t Y;
1706 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Y))
1707 return;
1708 if (Y > 1024) {
1709 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1710 DiagID: diag::err_hlsl_numthreads_argument_oor)
1711 << 1 << 1024;
1712 return;
1713 }
1714 uint32_t Z;
1715 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Z))
1716 return;
1717 if (Z > ZMax) {
1718 SemaRef.Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1719 DiagID: diag::err_hlsl_numthreads_argument_oor)
1720 << 2 << ZMax;
1721 return;
1722 }
1723
1724 if (X * Y * Z > ThreadMax) {
1725 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_numthreads_invalid) << ThreadMax;
1726 return;
1727 }
1728
1729 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1730 if (NewAttr)
1731 D->addAttr(A: NewAttr);
1732}
1733
1734static bool isValidWaveSizeValue(unsigned Value) {
1735 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1736}
1737
1738void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
1739 // validate that the wavesize argument is a power of 2 between 4 and 128
1740 // inclusive
1741 unsigned SpelledArgsCount = AL.getNumArgs();
1742 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1743 return;
1744
1745 uint32_t Min;
1746 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Min))
1747 return;
1748
1749 uint32_t Max = 0;
1750 if (SpelledArgsCount > 1 &&
1751 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Max))
1752 return;
1753
1754 uint32_t Preferred = 0;
1755 if (SpelledArgsCount > 2 &&
1756 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Preferred))
1757 return;
1758
1759 if (SpelledArgsCount > 2) {
1760 if (!isValidWaveSizeValue(Value: Preferred)) {
1761 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1762 DiagID: diag::err_attribute_power_of_two_in_range)
1763 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1764 << Preferred;
1765 return;
1766 }
1767 // Preferred not in range.
1768 if (Preferred < Min || Preferred > Max) {
1769 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1770 DiagID: diag::err_attribute_power_of_two_in_range)
1771 << AL << Min << Max << Preferred;
1772 return;
1773 }
1774 } else if (SpelledArgsCount > 1) {
1775 if (!isValidWaveSizeValue(Value: Max)) {
1776 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1777 DiagID: diag::err_attribute_power_of_two_in_range)
1778 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1779 return;
1780 }
1781 if (Max < Min) {
1782 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_invalid) << AL << 1;
1783 return;
1784 } else if (Max == Min) {
1785 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attr_min_eq_max) << AL;
1786 }
1787 } else {
1788 if (!isValidWaveSizeValue(Value: Min)) {
1789 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1790 DiagID: diag::err_attribute_power_of_two_in_range)
1791 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1792 return;
1793 }
1794 }
1795
1796 HLSLWaveSizeAttr *NewAttr =
1797 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1798 if (NewAttr)
1799 D->addAttr(A: NewAttr);
1800}
1801
1802void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
1803 uint32_t ID;
1804 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: ID))
1805 return;
1806 D->addAttr(A: ::new (getASTContext())
1807 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1808}
1809
1810void SemaHLSL::handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL) {
1811 D->addAttr(A: ::new (getASTContext())
1812 HLSLVkPushConstantAttr(getASTContext(), AL));
1813}
1814
1815void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1816 uint32_t Id;
1817 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Id))
1818 return;
1819 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1820 if (NewAttr)
1821 D->addAttr(A: NewAttr);
1822}
1823
1824void SemaHLSL::handleVkBindingAttr(Decl *D, const ParsedAttr &AL) {
1825 uint32_t Binding = 0;
1826 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Binding))
1827 return;
1828 uint32_t Set = 0;
1829 if (AL.getNumArgs() > 1 &&
1830 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Set))
1831 return;
1832
1833 D->addAttr(A: ::new (getASTContext())
1834 HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
1835}
1836
1837void SemaHLSL::handleVkLocationAttr(Decl *D, const ParsedAttr &AL) {
1838 uint32_t Location;
1839 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Location))
1840 return;
1841
1842 D->addAttr(A: ::new (getASTContext())
1843 HLSLVkLocationAttr(getASTContext(), AL, Location));
1844}
1845
1846bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
1847 const auto *VT = T->getAs<VectorType>();
1848
1849 if (!T->hasUnsignedIntegerRepresentation() ||
1850 (VT && VT->getNumElements() > 3)) {
1851 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1852 << AL << "uint/uint2/uint3";
1853 return false;
1854 }
1855
1856 return true;
1857}
1858
1859bool SemaHLSL::diagnosePositionType(QualType T, const ParsedAttr &AL) {
1860 const auto *VT = T->getAs<VectorType>();
1861 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1862 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1863 << AL << "float/float1/float2/float3/float4";
1864 return false;
1865 }
1866
1867 return true;
1868}
1869
1870void SemaHLSL::diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL,
1871 std::optional<unsigned> Index) {
1872 std::string SemanticName = AL.getAttrName()->getName().upper();
1873
1874 auto *VD = cast<ValueDecl>(Val: D);
1875 QualType ValueType = VD->getType();
1876 if (auto *FD = dyn_cast<FunctionDecl>(Val: D))
1877 ValueType = FD->getReturnType();
1878
1879 bool IsOutput = false;
1880 if (HLSLParamModifierAttr *MA = D->getAttr<HLSLParamModifierAttr>()) {
1881 if (MA->isOut()) {
1882 IsOutput = true;
1883 ValueType = cast<ReferenceType>(Val&: ValueType)->getPointeeType();
1884 }
1885 }
1886
1887 if (SemanticName == "SV_DISPATCHTHREADID") {
1888 diagnoseInputIDType(T: ValueType, AL);
1889 if (IsOutput)
1890 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1891 if (Index.has_value())
1892 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1893 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1894 return;
1895 }
1896
1897 if (SemanticName == "SV_GROUPINDEX") {
1898 if (IsOutput)
1899 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1900 if (Index.has_value())
1901 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1902 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1903 return;
1904 }
1905
1906 if (SemanticName == "SV_GROUPTHREADID") {
1907 diagnoseInputIDType(T: ValueType, AL);
1908 if (IsOutput)
1909 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1910 if (Index.has_value())
1911 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1912 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1913 return;
1914 }
1915
1916 if (SemanticName == "SV_GROUPID") {
1917 diagnoseInputIDType(T: ValueType, AL);
1918 if (IsOutput)
1919 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_output_not_supported) << AL;
1920 if (Index.has_value())
1921 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_semantic_indexing_not_supported) << AL;
1922 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1923 return;
1924 }
1925
1926 if (SemanticName == "SV_POSITION") {
1927 const auto *VT = ValueType->getAs<VectorType>();
1928 if (!ValueType->hasFloatingRepresentation() ||
1929 (VT && VT->getNumElements() > 4))
1930 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1931 << AL << "float/float1/float2/float3/float4";
1932 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1933 return;
1934 }
1935
1936 if (SemanticName == "SV_VERTEXID") {
1937 uint64_t SizeInBits = SemaRef.Context.getTypeSize(T: ValueType);
1938 if (!ValueType->isUnsignedIntegerType() || SizeInBits != 32)
1939 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type) << AL << "uint";
1940 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1941 return;
1942 }
1943
1944 if (SemanticName == "SV_TARGET") {
1945 const auto *VT = ValueType->getAs<VectorType>();
1946 if (!ValueType->hasFloatingRepresentation() ||
1947 (VT && VT->getNumElements() > 4))
1948 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1949 << AL << "float/float1/float2/float3/float4";
1950 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1951 return;
1952 }
1953
1954 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_unknown_semantic) << AL;
1955}
1956
1957void SemaHLSL::handleSemanticAttr(Decl *D, const ParsedAttr &AL) {
1958 uint32_t IndexValue(0), ExplicitIndex(0);
1959 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: IndexValue) ||
1960 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: ExplicitIndex)) {
1961 assert(0 && "HLSLUnparsedSemantic is expected to have 2 int arguments.");
1962 }
1963 assert(IndexValue > 0 ? ExplicitIndex : true);
1964 std::optional<unsigned> Index =
1965 ExplicitIndex ? std::optional<unsigned>(IndexValue) : std::nullopt;
1966
1967 if (AL.getAttrName()->getName().starts_with_insensitive(Prefix: "SV_"))
1968 diagnoseSystemSemanticAttr(D, AL, Index);
1969 else
1970 D->addAttr(A: createSemanticAttr<HLSLParsedSemanticAttr>(ACI: AL, Location: Index));
1971}
1972
1973void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
1974 if (!isa<VarDecl>(Val: D) || !isa<HLSLBufferDecl>(Val: D->getDeclContext())) {
1975 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_ast_node)
1976 << AL << "shader constant in a constant buffer";
1977 return;
1978 }
1979
1980 uint32_t SubComponent;
1981 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: SubComponent))
1982 return;
1983 uint32_t Component;
1984 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Component))
1985 return;
1986
1987 QualType T = cast<VarDecl>(Val: D)->getType().getCanonicalType();
1988 // Check if T is an array or struct type.
1989 // TODO: mark matrix type as aggregate type.
1990 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
1991
1992 // Check Component is valid for T.
1993 if (Component) {
1994 unsigned Size = getASTContext().getTypeSize(T);
1995 if (IsAggregateTy) {
1996 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_invalid_register_or_packoffset);
1997 return;
1998 } else {
1999 // Make sure Component + sizeof(T) <= 4.
2000 if ((Component * 32 + Size) > 128) {
2001 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_cross_reg_boundary);
2002 return;
2003 }
2004 QualType EltTy = T;
2005 if (const auto *VT = T->getAs<VectorType>())
2006 EltTy = VT->getElementType();
2007 unsigned Align = getASTContext().getTypeAlign(T: EltTy);
2008 if (Align > 32 && Component == 1) {
2009 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
2010 // So we only need to check Component 1 here.
2011 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_alignment_mismatch)
2012 << Align << EltTy;
2013 return;
2014 }
2015 }
2016 }
2017
2018 D->addAttr(A: ::new (getASTContext()) HLSLPackOffsetAttr(
2019 getASTContext(), AL, SubComponent, Component));
2020}
2021
2022void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
2023 StringRef Str;
2024 SourceLocation ArgLoc;
2025 if (!SemaRef.checkStringLiteralArgumentAttr(Attr: AL, ArgNum: 0, Str, ArgLocation: &ArgLoc))
2026 return;
2027
2028 llvm::Triple::EnvironmentType ShaderType;
2029 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Val: Str, Out&: ShaderType)) {
2030 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_type_not_supported)
2031 << AL << Str << ArgLoc;
2032 return;
2033 }
2034
2035 // FIXME: check function match the shader stage.
2036
2037 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
2038 if (NewAttr)
2039 D->addAttr(A: NewAttr);
2040}
2041
2042bool clang::CreateHLSLAttributedResourceType(
2043 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
2044 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
2045 assert(AttrList.size() && "expected list of resource attributes");
2046
2047 QualType ContainedTy = QualType();
2048 TypeSourceInfo *ContainedTyInfo = nullptr;
2049 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
2050 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
2051
2052 HLSLAttributedResourceType::Attributes ResAttrs;
2053
2054 bool HasResourceClass = false;
2055 bool HasResourceDimension = false;
2056 for (const Attr *A : AttrList) {
2057 if (!A)
2058 continue;
2059 LocEnd = A->getRange().getEnd();
2060 switch (A->getKind()) {
2061 case attr::HLSLResourceClass: {
2062 ResourceClass RC = cast<HLSLResourceClassAttr>(Val: A)->getResourceClass();
2063 if (HasResourceClass) {
2064 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceClass == RC
2065 ? diag::warn_duplicate_attribute_exact
2066 : diag::warn_duplicate_attribute)
2067 << A;
2068 return false;
2069 }
2070 ResAttrs.ResourceClass = RC;
2071 HasResourceClass = true;
2072 break;
2073 }
2074 case attr::HLSLResourceDimension: {
2075 llvm::dxil::ResourceDimension RD =
2076 cast<HLSLResourceDimensionAttr>(Val: A)->getDimension();
2077 if (HasResourceDimension) {
2078 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceDimension == RD
2079 ? diag::warn_duplicate_attribute_exact
2080 : diag::warn_duplicate_attribute)
2081 << A;
2082 return false;
2083 }
2084 ResAttrs.ResourceDimension = RD;
2085 HasResourceDimension = true;
2086 break;
2087 }
2088 case attr::HLSLROV:
2089 if (ResAttrs.IsROV) {
2090 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2091 return false;
2092 }
2093 ResAttrs.IsROV = true;
2094 break;
2095 case attr::HLSLRawBuffer:
2096 if (ResAttrs.RawBuffer) {
2097 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2098 return false;
2099 }
2100 ResAttrs.RawBuffer = true;
2101 break;
2102 case attr::HLSLIsCounter:
2103 if (ResAttrs.IsCounter) {
2104 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
2105 return false;
2106 }
2107 ResAttrs.IsCounter = true;
2108 break;
2109 case attr::HLSLContainedType: {
2110 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(Val: A);
2111 QualType Ty = CTAttr->getType();
2112 if (!ContainedTy.isNull()) {
2113 S.Diag(Loc: A->getLocation(), DiagID: ContainedTy == Ty
2114 ? diag::warn_duplicate_attribute_exact
2115 : diag::warn_duplicate_attribute)
2116 << A;
2117 return false;
2118 }
2119 ContainedTy = Ty;
2120 ContainedTyInfo = CTAttr->getTypeLoc();
2121 break;
2122 }
2123 default:
2124 llvm_unreachable("unhandled resource attribute type");
2125 }
2126 }
2127
2128 if (!HasResourceClass) {
2129 S.Diag(Loc: AttrList.back()->getRange().getEnd(),
2130 DiagID: diag::err_hlsl_missing_resource_class);
2131 return false;
2132 }
2133
2134 ResType = S.getASTContext().getHLSLAttributedResourceType(
2135 Wrapped, Contained: ContainedTy, Attrs: ResAttrs);
2136
2137 if (LocInfo && ContainedTyInfo) {
2138 LocInfo->Range = SourceRange(LocBegin, LocEnd);
2139 LocInfo->ContainedTyInfo = ContainedTyInfo;
2140 }
2141 return true;
2142}
2143
2144// Validates and creates an HLSL attribute that is applied as type attribute on
2145// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
2146// the end of the declaration they are applied to the declaration type by
2147// wrapping it in HLSLAttributedResourceType.
2148bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) {
2149 // only allow resource type attributes on intangible types
2150 if (!T->isHLSLResourceType()) {
2151 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attribute_needs_intangible_type)
2152 << AL << getASTContext().HLSLResourceTy;
2153 return false;
2154 }
2155
2156 // validate number of arguments
2157 if (!AL.checkExactlyNumArgs(S&: SemaRef, Num: AL.getMinArgs()))
2158 return false;
2159
2160 Attr *A = nullptr;
2161
2162 AttributeCommonInfo ACI(
2163 AL.getLoc(), AttributeScopeInfo(AL.getScopeName(), AL.getScopeLoc()),
2164 AttributeCommonInfo::NoSemaHandlerAttribute,
2165 {
2166 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
2167 false /*IsRegularKeywordAttribute*/
2168 });
2169
2170 switch (AL.getKind()) {
2171 case ParsedAttr::AT_HLSLResourceClass: {
2172 if (!AL.isArgIdent(Arg: 0)) {
2173 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2174 << AL << AANT_ArgumentIdentifier;
2175 return false;
2176 }
2177
2178 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
2179 StringRef Identifier = Loc->getIdentifierInfo()->getName();
2180 SourceLocation ArgLoc = Loc->getLoc();
2181
2182 // Validate resource class value
2183 ResourceClass RC;
2184 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Val: Identifier, Out&: RC)) {
2185 Diag(Loc: ArgLoc, DiagID: diag::warn_attribute_type_not_supported)
2186 << "ResourceClass" << Identifier;
2187 return false;
2188 }
2189 A = HLSLResourceClassAttr::Create(Ctx&: getASTContext(), ResourceClass: RC, CommonInfo: ACI);
2190 break;
2191 }
2192
2193 case ParsedAttr::AT_HLSLResourceDimension: {
2194 StringRef Identifier;
2195 SourceLocation ArgLoc;
2196 if (!SemaRef.checkStringLiteralArgumentAttr(Attr: AL, ArgNum: 0, Str&: Identifier, ArgLocation: &ArgLoc))
2197 return false;
2198
2199 // Validate resource dimension value
2200 llvm::dxil::ResourceDimension RD;
2201 if (!HLSLResourceDimensionAttr::ConvertStrToResourceDimension(Val: Identifier,
2202 Out&: RD)) {
2203 Diag(Loc: ArgLoc, DiagID: diag::warn_attribute_type_not_supported)
2204 << "ResourceDimension" << Identifier;
2205 return false;
2206 }
2207 A = HLSLResourceDimensionAttr::Create(Ctx&: getASTContext(), Dimension: RD, CommonInfo: ACI);
2208 break;
2209 }
2210
2211 case ParsedAttr::AT_HLSLROV:
2212 A = HLSLROVAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2213 break;
2214
2215 case ParsedAttr::AT_HLSLRawBuffer:
2216 A = HLSLRawBufferAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2217 break;
2218
2219 case ParsedAttr::AT_HLSLIsCounter:
2220 A = HLSLIsCounterAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
2221 break;
2222
2223 case ParsedAttr::AT_HLSLContainedType: {
2224 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
2225 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
2226 return false;
2227 }
2228
2229 TypeSourceInfo *TSI = nullptr;
2230 QualType QT = SemaRef.GetTypeFromParser(Ty: AL.getTypeArg(), TInfo: &TSI);
2231 assert(TSI && "no type source info for attribute argument");
2232 if (SemaRef.RequireCompleteType(Loc: TSI->getTypeLoc().getBeginLoc(), T: QT,
2233 DiagID: diag::err_incomplete_type))
2234 return false;
2235 A = HLSLContainedTypeAttr::Create(Ctx&: getASTContext(), Type: TSI, CommonInfo: ACI);
2236 break;
2237 }
2238
2239 default:
2240 llvm_unreachable("unhandled HLSL attribute");
2241 }
2242
2243 HLSLResourcesTypeAttrs.emplace_back(Args&: A);
2244 return true;
2245}
2246
2247// Combines all resource type attributes and creates HLSLAttributedResourceType.
2248QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) {
2249 if (!HLSLResourcesTypeAttrs.size())
2250 return CurrentType;
2251
2252 QualType QT = CurrentType;
2253 HLSLAttributedResourceLocInfo LocInfo;
2254 if (CreateHLSLAttributedResourceType(S&: SemaRef, Wrapped: CurrentType,
2255 AttrList: HLSLResourcesTypeAttrs, ResType&: QT, LocInfo: &LocInfo)) {
2256 const HLSLAttributedResourceType *RT =
2257 cast<HLSLAttributedResourceType>(Val: QT.getTypePtr());
2258
2259 // Temporarily store TypeLoc information for the new type.
2260 // It will be transferred to HLSLAttributesResourceTypeLoc
2261 // shortly after the type is created by TypeSpecLocFiller which
2262 // will call the TakeLocForHLSLAttribute method below.
2263 LocsForHLSLAttributedResources.insert(KV: std::pair(RT, LocInfo));
2264 }
2265 HLSLResourcesTypeAttrs.clear();
2266 return QT;
2267}
2268
2269// Returns source location for the HLSLAttributedResourceType
2270HLSLAttributedResourceLocInfo
2271SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
2272 HLSLAttributedResourceLocInfo LocInfo = {};
2273 auto I = LocsForHLSLAttributedResources.find(Val: RT);
2274 if (I != LocsForHLSLAttributedResources.end()) {
2275 LocInfo = I->second;
2276 LocsForHLSLAttributedResources.erase(I);
2277 return LocInfo;
2278 }
2279 LocInfo.Range = SourceRange();
2280 return LocInfo;
2281}
2282
2283// Walks though the global variable declaration, collects all resource binding
2284// requirements and adds them to Bindings
2285void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
2286 const RecordType *RT) {
2287 const RecordDecl *RD = RT->getDecl()->getDefinitionOrSelf();
2288 for (FieldDecl *FD : RD->fields()) {
2289 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
2290
2291 // Unwrap arrays
2292 // FIXME: Calculate array size while unwrapping
2293 assert(!Ty->isIncompleteArrayType() &&
2294 "incomplete arrays inside user defined types are not supported");
2295 while (Ty->isConstantArrayType()) {
2296 const ConstantArrayType *CAT = cast<ConstantArrayType>(Val: Ty);
2297 Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
2298 }
2299
2300 if (!Ty->isRecordType())
2301 continue;
2302
2303 if (const HLSLAttributedResourceType *AttrResType =
2304 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
2305 // Add a new DeclBindingInfo to Bindings if it does not already exist
2306 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
2307 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, ResClass: RC);
2308 if (!DBI)
2309 Bindings.addDeclBindingInfo(VD, ResClass: RC);
2310 } else if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty)) {
2311 // Recursively scan embedded struct or class; it would be nice to do this
2312 // without recursion, but tricky to correctly calculate the size of the
2313 // binding, which is something we are probably going to need to do later
2314 // on. Hopefully nesting of structs in structs too many levels is
2315 // unlikely.
2316 collectResourceBindingsOnUserRecordDecl(VD, RT);
2317 }
2318 }
2319}
2320
2321// Diagnose localized register binding errors for a single binding; does not
2322// diagnose resource binding on user record types, that will be done later
2323// in processResourceBindingOnDecl based on the information collected in
2324// collectResourceBindingsOnVarDecl.
2325// Returns false if the register binding is not valid.
2326static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
2327 Decl *D, RegisterType RegType,
2328 bool SpecifiedSpace) {
2329 int RegTypeNum = static_cast<int>(RegType);
2330
2331 // check if the decl type is groupshared
2332 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
2333 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2334 return false;
2335 }
2336
2337 // Cbuffers and Tbuffers are HLSLBufferDecl types
2338 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: D)) {
2339 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
2340 : ResourceClass::SRV;
2341 if (RegType == getRegisterType(RC))
2342 return true;
2343
2344 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
2345 << RegTypeNum;
2346 return false;
2347 }
2348
2349 // Samplers, UAVs, and SRVs are VarDecl types
2350 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
2351 VarDecl *VD = cast<VarDecl>(Val: D);
2352
2353 // Resource
2354 if (const HLSLAttributedResourceType *AttrResType =
2355 HLSLAttributedResourceType::findHandleTypeOnResource(
2356 RT: VD->getType().getTypePtr())) {
2357 if (RegType == getRegisterType(ResTy: AttrResType))
2358 return true;
2359
2360 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
2361 << RegTypeNum;
2362 return false;
2363 }
2364
2365 const clang::Type *Ty = VD->getType().getTypePtr();
2366 while (Ty->isArrayType())
2367 Ty = Ty->getArrayElementTypeNoTypeQual();
2368
2369 // Basic types
2370 if (Ty->isArithmeticType() || Ty->isVectorType()) {
2371 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(Val: D->getDeclContext());
2372 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
2373 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_space_on_global_constant);
2374
2375 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(Ctx: S.getASTContext()) ||
2376 Ty->isFloatingType() || Ty->isVectorType())) {
2377 // Register annotation on default constant buffer declaration ($Globals)
2378 if (RegType == RegisterType::CBuffer)
2379 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_deprecated_register_type_b);
2380 else if (RegType != RegisterType::C)
2381 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2382 else
2383 return true;
2384 } else {
2385 if (RegType == RegisterType::C)
2386 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_register_type_c_packoffset);
2387 else
2388 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2389 }
2390 return false;
2391 }
2392 if (Ty->isRecordType())
2393 // RecordTypes will be diagnosed in processResourceBindingOnDecl
2394 // that is called from ActOnVariableDeclarator
2395 return true;
2396
2397 // Anything else is an error
2398 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
2399 return false;
2400}
2401
2402static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
2403 RegisterType regType) {
2404 // make sure that there are no two register annotations
2405 // applied to the decl with the same register type
2406 bool RegisterTypesDetected[5] = {false};
2407 RegisterTypesDetected[static_cast<int>(regType)] = true;
2408
2409 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
2410 if (HLSLResourceBindingAttr *attr =
2411 dyn_cast<HLSLResourceBindingAttr>(Val: *it)) {
2412
2413 RegisterType otherRegType = attr->getRegisterType();
2414 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
2415 int otherRegTypeNum = static_cast<int>(otherRegType);
2416 S.Diag(Loc: TheDecl->getLocation(),
2417 DiagID: diag::err_hlsl_duplicate_register_annotation)
2418 << otherRegTypeNum;
2419 return false;
2420 }
2421 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
2422 }
2423 }
2424 return true;
2425}
2426
2427static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
2428 Decl *D, RegisterType RegType,
2429 bool SpecifiedSpace) {
2430
2431 // exactly one of these two types should be set
2432 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
2433 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
2434 "expecting VarDecl or HLSLBufferDecl");
2435
2436 // check if the declaration contains resource matching the register type
2437 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
2438 return false;
2439
2440 // next, if multiple register annotations exist, check that none conflict.
2441 return ValidateMultipleRegisterAnnotations(S, TheDecl: D, regType: RegType);
2442}
2443
2444// return false if the slot count exceeds the limit, true otherwise
2445static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot,
2446 const uint64_t &Limit,
2447 const ResourceClass ResClass,
2448 ASTContext &Ctx,
2449 uint64_t ArrayCount = 1) {
2450 Ty = Ty.getCanonicalType();
2451 const Type *T = Ty.getTypePtr();
2452
2453 // Early exit if already overflowed
2454 if (StartSlot > Limit)
2455 return false;
2456
2457 // Case 1: array type
2458 if (const auto *AT = dyn_cast<ArrayType>(Val: T)) {
2459 uint64_t Count = 1;
2460
2461 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: AT))
2462 Count = CAT->getSize().getZExtValue();
2463
2464 QualType ElemTy = AT->getElementType();
2465 return AccumulateHLSLResourceSlots(Ty: ElemTy, StartSlot, Limit, ResClass, Ctx,
2466 ArrayCount: ArrayCount * Count);
2467 }
2468
2469 // Case 2: resource leaf
2470 if (auto ResTy = dyn_cast<HLSLAttributedResourceType>(Val: T)) {
2471 // First ensure this resource counts towards the corresponding
2472 // register type limit.
2473 if (ResTy->getAttrs().ResourceClass != ResClass)
2474 return true;
2475
2476 // Validate highest slot used
2477 uint64_t EndSlot = StartSlot + ArrayCount - 1;
2478 if (EndSlot > Limit)
2479 return false;
2480
2481 // Advance SlotCount past the consumed range
2482 StartSlot = EndSlot + 1;
2483 return true;
2484 }
2485
2486 // Case 3: struct / record
2487 if (const auto *RT = dyn_cast<RecordType>(Val: T)) {
2488 const RecordDecl *RD = RT->getDecl();
2489
2490 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
2491 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
2492 if (!AccumulateHLSLResourceSlots(Ty: Base.getType(), StartSlot, Limit,
2493 ResClass, Ctx, ArrayCount))
2494 return false;
2495 }
2496 }
2497
2498 for (const FieldDecl *Field : RD->fields()) {
2499 if (!AccumulateHLSLResourceSlots(Ty: Field->getType(), StartSlot, Limit,
2500 ResClass, Ctx, ArrayCount))
2501 return false;
2502 }
2503
2504 return true;
2505 }
2506
2507 // Case 4: everything else
2508 return true;
2509}
2510
2511// return true if there is something invalid, false otherwise
2512static bool ValidateRegisterNumber(uint64_t SlotNum, Decl *TheDecl,
2513 ASTContext &Ctx, RegisterType RegTy) {
2514 const uint64_t Limit = UINT32_MAX;
2515 if (SlotNum > Limit)
2516 return true;
2517
2518 // after verifying the number doesn't exceed uint32max, we don't need
2519 // to look further into c or i register types
2520 if (RegTy == RegisterType::C || RegTy == RegisterType::I)
2521 return false;
2522
2523 if (VarDecl *VD = dyn_cast<VarDecl>(Val: TheDecl)) {
2524 uint64_t BaseSlot = SlotNum;
2525
2526 if (!AccumulateHLSLResourceSlots(Ty: VD->getType(), StartSlot&: SlotNum, Limit,
2527 ResClass: getResourceClass(RT: RegTy), Ctx))
2528 return true;
2529
2530 // After AccumulateHLSLResourceSlots runs, SlotNum is now
2531 // the first free slot; last used was SlotNum - 1
2532 return (BaseSlot > Limit);
2533 }
2534 // handle the cbuffer/tbuffer case
2535 if (isa<HLSLBufferDecl>(Val: TheDecl))
2536 // resources cannot be put within a cbuffer, so no need
2537 // to analyze the structure since the register number
2538 // won't be pushed any higher.
2539 return (SlotNum > Limit);
2540
2541 // we don't expect any other decl type, so fail
2542 llvm_unreachable("unexpected decl type");
2543}
2544
2545void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
2546 if (VarDecl *VD = dyn_cast<VarDecl>(Val: TheDecl)) {
2547 QualType Ty = VD->getType();
2548 if (const auto *IAT = dyn_cast<IncompleteArrayType>(Val&: Ty))
2549 Ty = IAT->getElementType();
2550 if (SemaRef.RequireCompleteType(Loc: TheDecl->getBeginLoc(), T: Ty,
2551 DiagID: diag::err_incomplete_type))
2552 return;
2553 }
2554
2555 StringRef Slot = "";
2556 StringRef Space = "";
2557 SourceLocation SlotLoc, SpaceLoc;
2558
2559 if (!AL.isArgIdent(Arg: 0)) {
2560 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2561 << AL << AANT_ArgumentIdentifier;
2562 return;
2563 }
2564 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
2565
2566 if (AL.getNumArgs() == 2) {
2567 Slot = Loc->getIdentifierInfo()->getName();
2568 SlotLoc = Loc->getLoc();
2569 if (!AL.isArgIdent(Arg: 1)) {
2570 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
2571 << AL << AANT_ArgumentIdentifier;
2572 return;
2573 }
2574 Loc = AL.getArgAsIdent(Arg: 1);
2575 Space = Loc->getIdentifierInfo()->getName();
2576 SpaceLoc = Loc->getLoc();
2577 } else {
2578 StringRef Str = Loc->getIdentifierInfo()->getName();
2579 if (Str.starts_with(Prefix: "space")) {
2580 Space = Str;
2581 SpaceLoc = Loc->getLoc();
2582 } else {
2583 Slot = Str;
2584 SlotLoc = Loc->getLoc();
2585 Space = "space0";
2586 }
2587 }
2588
2589 RegisterType RegType = RegisterType::SRV;
2590 std::optional<unsigned> SlotNum;
2591 unsigned SpaceNum = 0;
2592
2593 // Validate slot
2594 if (!Slot.empty()) {
2595 if (!convertToRegisterType(Slot, RT: &RegType)) {
2596 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_binding_type_invalid) << Slot.substr(Start: 0, N: 1);
2597 return;
2598 }
2599 if (RegType == RegisterType::I) {
2600 Diag(Loc: SlotLoc, DiagID: diag::warn_hlsl_deprecated_register_type_i);
2601 return;
2602 }
2603 const StringRef SlotNumStr = Slot.substr(Start: 1);
2604
2605 uint64_t N;
2606
2607 // validate that the slot number is a non-empty number
2608 if (SlotNumStr.getAsInteger(Radix: 10, Result&: N)) {
2609 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_unsupported_register_number);
2610 return;
2611 }
2612
2613 // Validate register number. It should not exceed UINT32_MAX,
2614 // including if the resource type is an array that starts
2615 // before UINT32_MAX, but ends afterwards.
2616 if (ValidateRegisterNumber(SlotNum: N, TheDecl, Ctx&: getASTContext(), RegTy: RegType)) {
2617 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_register_number_too_large);
2618 return;
2619 }
2620
2621 // the slot number has been validated and does not exceed UINT32_MAX
2622 SlotNum = (unsigned)N;
2623 }
2624
2625 // Validate space
2626 if (!Space.starts_with(Prefix: "space")) {
2627 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2628 return;
2629 }
2630 StringRef SpaceNumStr = Space.substr(Start: 5);
2631 if (SpaceNumStr.getAsInteger(Radix: 10, Result&: SpaceNum)) {
2632 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2633 return;
2634 }
2635
2636 // If we have slot, diagnose it is the right register type for the decl
2637 if (SlotNum.has_value())
2638 if (!DiagnoseHLSLRegisterAttribute(S&: SemaRef, ArgLoc&: SlotLoc, D: TheDecl, RegType,
2639 SpecifiedSpace: !SpaceLoc.isInvalid()))
2640 return;
2641
2642 HLSLResourceBindingAttr *NewAttr =
2643 HLSLResourceBindingAttr::Create(Ctx&: getASTContext(), Slot, Space, CommonInfo: AL);
2644 if (NewAttr) {
2645 NewAttr->setBinding(RT: RegType, SlotNum, SpaceNum);
2646 TheDecl->addAttr(A: NewAttr);
2647 }
2648}
2649
2650void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
2651 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2652 D, AL,
2653 Spelling: static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2654 if (NewAttr)
2655 D->addAttr(A: NewAttr);
2656}
2657
2658namespace {
2659
2660/// This class implements HLSL availability diagnostics for default
2661/// and relaxed mode
2662///
2663/// The goal of this diagnostic is to emit an error or warning when an
2664/// unavailable API is found in code that is reachable from the shader
2665/// entry function or from an exported function (when compiling a shader
2666/// library).
2667///
2668/// This is done by traversing the AST of all shader entry point functions
2669/// and of all exported functions, and any functions that are referenced
2670/// from this AST. In other words, any functions that are reachable from
2671/// the entry points.
2672class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2673 Sema &SemaRef;
2674
2675 // Stack of functions to be scaned
2676 llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;
2677
2678 // Tracks which environments functions have been scanned in.
2679 //
2680 // Maps FunctionDecl to an unsigned number that represents the set of shader
2681 // environments the function has been scanned for.
2682 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2683 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2684 // (verified by static_asserts in Triple.cpp), we can use it to index
2685 // individual bits in the set, as long as we shift the values to start with 0
2686 // by subtracting the value of llvm::Triple::Pixel first.
2687 //
2688 // The N'th bit in the set will be set if the function has been scanned
2689 // in shader environment whose llvm::Triple::EnvironmentType integer value
2690 // equals (llvm::Triple::Pixel + N).
2691 //
2692 // For example, if a function has been scanned in compute and pixel stage
2693 // environment, the value will be 0x21 (100001 binary) because:
2694 //
2695 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2696 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2697 //
2698 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2699 // been scanned in any environment.
2700 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2701
2702 // Do not access these directly, use the get/set methods below to make
2703 // sure the values are in sync
2704 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2705 unsigned CurrentShaderStageBit;
2706
2707 // True if scanning a function that was already scanned in a different
2708 // shader stage context, and therefore we should not report issues that
2709 // depend only on shader model version because they would be duplicate.
2710 bool ReportOnlyShaderStageIssues;
2711
2712 // Helper methods for dealing with current stage context / environment
2713 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2714 static_assert(sizeof(unsigned) >= 4);
2715 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2716 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2717 "ShaderType is too big for this bitmap"); // 31 is reserved for
2718 // "unknown"
2719
2720 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2721 CurrentShaderEnvironment = ShaderType;
2722 CurrentShaderStageBit = (1 << bitmapIndex);
2723 }
2724
2725 void SetUnknownShaderStageContext() {
2726 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2727 CurrentShaderStageBit = (1 << 31);
2728 }
2729
2730 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2731 return CurrentShaderEnvironment;
2732 }
2733
2734 bool InUnknownShaderStageContext() const {
2735 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2736 }
2737
2738 // Helper methods for dealing with shader stage bitmap
2739 void AddToScannedFunctions(const FunctionDecl *FD) {
2740 unsigned &ScannedStages = ScannedDecls[FD];
2741 ScannedStages |= CurrentShaderStageBit;
2742 }
2743
2744 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2745
2746 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2747 return WasAlreadyScannedInCurrentStage(ScannerStages: GetScannedStages(FD));
2748 }
2749
2750 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2751 return ScannerStages & CurrentShaderStageBit;
2752 }
2753
2754 static bool NeverBeenScanned(unsigned ScannedStages) {
2755 return ScannedStages == 0;
2756 }
2757
2758 // Scanning methods
2759 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2760 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2761 SourceRange Range);
2762 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2763 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2764
2765public:
2766 DiagnoseHLSLAvailability(Sema &SemaRef)
2767 : SemaRef(SemaRef),
2768 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2769 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2770
2771 // AST traversal methods
2772 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2773 void RunOnFunction(const FunctionDecl *FD);
2774
2775 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2776 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: DRE->getDecl());
2777 if (FD)
2778 HandleFunctionOrMethodRef(FD, RefExpr: DRE);
2779 return true;
2780 }
2781
2782 bool VisitMemberExpr(MemberExpr *ME) override {
2783 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: ME->getMemberDecl());
2784 if (FD)
2785 HandleFunctionOrMethodRef(FD, RefExpr: ME);
2786 return true;
2787 }
2788};
2789
2790void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2791 Expr *RefExpr) {
2792 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2793 "expected DeclRefExpr or MemberExpr");
2794
2795 // has a definition -> add to stack to be scanned
2796 const FunctionDecl *FDWithBody = nullptr;
2797 if (FD->hasBody(Definition&: FDWithBody)) {
2798 if (!WasAlreadyScannedInCurrentStage(FD: FDWithBody))
2799 DeclsToScan.push_back(Elt: FDWithBody);
2800 return;
2801 }
2802
2803 // no body -> diagnose availability
2804 const AvailabilityAttr *AA = FindAvailabilityAttr(D: FD);
2805 if (AA)
2806 CheckDeclAvailability(
2807 D: FD, AA, Range: SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2808}
2809
2810void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2811 const TranslationUnitDecl *TU) {
2812
2813 // Iterate over all shader entry functions and library exports, and for those
2814 // that have a body (definiton), run diag scan on each, setting appropriate
2815 // shader environment context based on whether it is a shader entry function
2816 // or an exported function. Exported functions can be in namespaces and in
2817 // export declarations so we need to scan those declaration contexts as well.
2818 llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;
2819 DeclContextsToScan.push_back(Elt: TU);
2820
2821 while (!DeclContextsToScan.empty()) {
2822 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2823 for (auto &D : DC->decls()) {
2824 // do not scan implicit declaration generated by the implementation
2825 if (D->isImplicit())
2826 continue;
2827
2828 // for namespace or export declaration add the context to the list to be
2829 // scanned later
2830 if (llvm::dyn_cast<NamespaceDecl>(Val: D) || llvm::dyn_cast<ExportDecl>(Val: D)) {
2831 DeclContextsToScan.push_back(Elt: llvm::dyn_cast<DeclContext>(Val: D));
2832 continue;
2833 }
2834
2835 // skip over other decls or function decls without body
2836 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: D);
2837 if (!FD || !FD->isThisDeclarationADefinition())
2838 continue;
2839
2840 // shader entry point
2841 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2842 SetShaderStageContext(ShaderAttr->getType());
2843 RunOnFunction(FD);
2844 continue;
2845 }
2846 // exported library function
2847 // FIXME: replace this loop with external linkage check once issue #92071
2848 // is resolved
2849 bool isExport = FD->isInExportDeclContext();
2850 if (!isExport) {
2851 for (const auto *Redecl : FD->redecls()) {
2852 if (Redecl->isInExportDeclContext()) {
2853 isExport = true;
2854 break;
2855 }
2856 }
2857 }
2858 if (isExport) {
2859 SetUnknownShaderStageContext();
2860 RunOnFunction(FD);
2861 continue;
2862 }
2863 }
2864 }
2865}
2866
2867void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2868 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2869 DeclsToScan.push_back(Elt: FD);
2870
2871 while (!DeclsToScan.empty()) {
2872 // Take one decl from the stack and check it by traversing its AST.
2873 // For any CallExpr found during the traversal add it's callee to the top of
2874 // the stack to be processed next. Functions already processed are stored in
2875 // ScannedDecls.
2876 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2877
2878 // Decl was already scanned
2879 const unsigned ScannedStages = GetScannedStages(FD);
2880 if (WasAlreadyScannedInCurrentStage(ScannerStages: ScannedStages))
2881 continue;
2882
2883 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2884
2885 AddToScannedFunctions(FD);
2886 TraverseStmt(S: FD->getBody());
2887 }
2888}
2889
2890bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2891 const AvailabilityAttr *AA) {
2892 const IdentifierInfo *IIEnvironment = AA->getEnvironment();
2893 if (!IIEnvironment)
2894 return true;
2895
2896 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2897 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2898 return false;
2899
2900 llvm::Triple::EnvironmentType AttrEnv =
2901 AvailabilityAttr::getEnvironmentType(Environment: IIEnvironment->getName());
2902
2903 return CurrentEnv == AttrEnv;
2904}
2905
2906const AvailabilityAttr *
2907DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2908 AvailabilityAttr const *PartialMatch = nullptr;
2909 // Check each AvailabilityAttr to find the one for this platform.
2910 // For multiple attributes with the same platform try to find one for this
2911 // environment.
2912 for (const auto *A : D->attrs()) {
2913 if (const auto *Avail = dyn_cast<AvailabilityAttr>(Val: A)) {
2914 StringRef AttrPlatform = Avail->getPlatform()->getName();
2915 StringRef TargetPlatform =
2916 SemaRef.getASTContext().getTargetInfo().getPlatformName();
2917
2918 // Match the platform name.
2919 if (AttrPlatform == TargetPlatform) {
2920 // Find the best matching attribute for this environment
2921 if (HasMatchingEnvironmentOrNone(AA: Avail))
2922 return Avail;
2923 PartialMatch = Avail;
2924 }
2925 }
2926 }
2927 return PartialMatch;
2928}
2929
2930// Check availability against target shader model version and current shader
2931// stage and emit diagnostic
2932void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2933 const AvailabilityAttr *AA,
2934 SourceRange Range) {
2935
2936 const IdentifierInfo *IIEnv = AA->getEnvironment();
2937
2938 if (!IIEnv) {
2939 // The availability attribute does not have environment -> it depends only
2940 // on shader model version and not on specific the shader stage.
2941
2942 // Skip emitting the diagnostics if the diagnostic mode is set to
2943 // strict (-fhlsl-strict-availability) because all relevant diagnostics
2944 // were already emitted in the DiagnoseUnguardedAvailability scan
2945 // (SemaAvailability.cpp).
2946 if (SemaRef.getLangOpts().HLSLStrictAvailability)
2947 return;
2948
2949 // Do not report shader-stage-independent issues if scanning a function
2950 // that was already scanned in a different shader stage context (they would
2951 // be duplicate)
2952 if (ReportOnlyShaderStageIssues)
2953 return;
2954
2955 } else {
2956 // The availability attribute has environment -> we need to know
2957 // the current stage context to property diagnose it.
2958 if (InUnknownShaderStageContext())
2959 return;
2960 }
2961
2962 // Check introduced version and if environment matches
2963 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
2964 VersionTuple Introduced = AA->getIntroduced();
2965 VersionTuple TargetVersion =
2966 SemaRef.Context.getTargetInfo().getPlatformMinVersion();
2967
2968 if (TargetVersion >= Introduced && EnvironmentMatches)
2969 return;
2970
2971 // Emit diagnostic message
2972 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2973 llvm::StringRef PlatformName(
2974 AvailabilityAttr::getPrettyPlatformName(Platform: TI.getPlatformName()));
2975
2976 llvm::StringRef CurrentEnvStr =
2977 llvm::Triple::getEnvironmentTypeName(Kind: GetCurrentShaderEnvironment());
2978
2979 llvm::StringRef AttrEnvStr =
2980 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
2981 bool UseEnvironment = !AttrEnvStr.empty();
2982
2983 if (EnvironmentMatches) {
2984 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability)
2985 << Range << D << PlatformName << Introduced.getAsString()
2986 << UseEnvironment << CurrentEnvStr;
2987 } else {
2988 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability_unavailable)
2989 << Range << D;
2990 }
2991
2992 SemaRef.Diag(Loc: D->getLocation(), DiagID: diag::note_partial_availability_specified_here)
2993 << D << PlatformName << Introduced.getAsString()
2994 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
2995 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
2996}
2997
2998} // namespace
2999
3000void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) {
3001 // process default CBuffer - create buffer layout struct and invoke codegenCGH
3002 if (!DefaultCBufferDecls.empty()) {
3003 HLSLBufferDecl *DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer(
3004 C&: SemaRef.getASTContext(), LexicalParent: SemaRef.getCurLexicalContext(),
3005 DefaultCBufferDecls);
3006 addImplicitBindingAttrToDecl(S&: SemaRef, D: DefaultCBuffer, RT: RegisterType::CBuffer,
3007 ImplicitBindingOrderID: getNextImplicitBindingOrderID());
3008 SemaRef.getCurLexicalContext()->addDecl(D: DefaultCBuffer);
3009 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl: DefaultCBuffer);
3010
3011 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
3012 for (const Decl *VD : DefaultCBufferDecls) {
3013 const HLSLResourceBindingAttr *RBA =
3014 VD->getAttr<HLSLResourceBindingAttr>();
3015 if (RBA && RBA->hasRegisterSlot() &&
3016 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
3017 DefaultCBuffer->setHasValidPackoffset(true);
3018 break;
3019 }
3020 }
3021
3022 DeclGroupRef DG(DefaultCBuffer);
3023 SemaRef.Consumer.HandleTopLevelDecl(D: DG);
3024 }
3025 diagnoseAvailabilityViolations(TU);
3026}
3027
3028void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
3029 // Skip running the diagnostics scan if the diagnostic mode is
3030 // strict (-fhlsl-strict-availability) and the target shader stage is known
3031 // because all relevant diagnostics were already emitted in the
3032 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
3033 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
3034 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
3035 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
3036 return;
3037
3038 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
3039}
3040
3041static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
3042 assert(TheCall->getNumArgs() > 1);
3043 QualType ArgTy0 = TheCall->getArg(Arg: 0)->getType();
3044
3045 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
3046 if (!S->getASTContext().hasSameUnqualifiedType(
3047 T1: ArgTy0, T2: TheCall->getArg(Arg: I)->getType())) {
3048 S->Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vec_builtin_incompatible_vector)
3049 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
3050 << SourceRange(TheCall->getArg(Arg: 0)->getBeginLoc(),
3051 TheCall->getArg(Arg: N - 1)->getEndLoc());
3052 return true;
3053 }
3054 }
3055 return false;
3056}
3057
3058static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
3059 QualType ArgType = Arg->getType();
3060 if (!S->getASTContext().hasSameUnqualifiedType(T1: ArgType, T2: ExpectedType)) {
3061 S->Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_typecheck_convert_incompatible)
3062 << ArgType << ExpectedType << 1 << 0 << 0;
3063 return true;
3064 }
3065 return false;
3066}
3067
3068static bool CheckAllArgTypesAreCorrect(
3069 Sema *S, CallExpr *TheCall,
3070 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
3071 clang::QualType PassedType)>
3072 Check) {
3073 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
3074 Expr *Arg = TheCall->getArg(Arg: I);
3075 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
3076 return true;
3077 }
3078 return false;
3079}
3080
3081static bool CheckFloatRepresentation(Sema *S, SourceLocation Loc,
3082 int ArgOrdinal,
3083 clang::QualType PassedType) {
3084 clang::QualType BaseType =
3085 PassedType->isVectorType()
3086 ? PassedType->castAs<clang::VectorType>()->getElementType()
3087 : PassedType;
3088 if (!BaseType->isFloat32Type())
3089 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3090 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3091 << /* float */ 1 << PassedType;
3092 return false;
3093}
3094
3095static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
3096 int ArgOrdinal,
3097 clang::QualType PassedType) {
3098 clang::QualType BaseType =
3099 PassedType->isVectorType()
3100 ? PassedType->castAs<clang::VectorType>()->getElementType()
3101 : PassedType;
3102 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
3103 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3104 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
3105 << /* half or float */ 2 << PassedType;
3106 return false;
3107}
3108
3109static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
3110 unsigned ArgIndex) {
3111 auto *Arg = TheCall->getArg(Arg: ArgIndex);
3112 SourceLocation OrigLoc = Arg->getExprLoc();
3113 if (Arg->IgnoreCasts()->isModifiableLvalue(Ctx&: S->Context, Loc: &OrigLoc) ==
3114 Expr::MLV_Valid)
3115 return false;
3116 S->Diag(Loc: OrigLoc, DiagID: diag::error_hlsl_inout_lvalue) << Arg << 0;
3117 return true;
3118}
3119
3120static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
3121 clang::QualType PassedType) {
3122 const auto *VecTy = PassedType->getAs<VectorType>();
3123 if (!VecTy)
3124 return false;
3125
3126 if (VecTy->getElementType()->isDoubleType())
3127 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3128 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
3129 << PassedType;
3130 return false;
3131}
3132
3133static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
3134 int ArgOrdinal,
3135 clang::QualType PassedType) {
3136 if (!PassedType->hasIntegerRepresentation() &&
3137 !PassedType->hasFloatingRepresentation())
3138 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3139 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
3140 << /* fp */ 1 << PassedType;
3141 return false;
3142}
3143
3144static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
3145 int ArgOrdinal,
3146 clang::QualType PassedType) {
3147 if (auto *VecTy = PassedType->getAs<VectorType>())
3148 if (VecTy->getElementType()->isUnsignedIntegerType())
3149 return false;
3150
3151 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3152 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
3153 << PassedType;
3154}
3155
3156// checks for unsigned ints of all sizes
3157static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
3158 int ArgOrdinal,
3159 clang::QualType PassedType) {
3160 if (!PassedType->hasUnsignedIntegerRepresentation())
3161 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
3162 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
3163 << /* no fp */ 0 << PassedType;
3164 return false;
3165}
3166
3167static bool CheckExpectedBitWidth(Sema *S, CallExpr *TheCall,
3168 unsigned ArgOrdinal, unsigned Width) {
3169 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
3170 if (auto *VTy = ArgTy->getAs<VectorType>())
3171 ArgTy = VTy->getElementType();
3172 // ensure arg type has expected bit width
3173 uint64_t ElementBitCount =
3174 S->getASTContext().getTypeSizeInChars(T: ArgTy).getQuantity() * 8;
3175 if (ElementBitCount != Width) {
3176 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3177 DiagID: diag::err_integer_incorrect_bit_count)
3178 << Width << ElementBitCount;
3179 return true;
3180 }
3181 return false;
3182}
3183
3184static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
3185 QualType ReturnType) {
3186 auto *VecTyA = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
3187 if (VecTyA)
3188 ReturnType =
3189 S->Context.getExtVectorType(VectorType: ReturnType, NumElts: VecTyA->getNumElements());
3190
3191 TheCall->setType(ReturnType);
3192}
3193
3194static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
3195 unsigned ArgIndex) {
3196 assert(TheCall->getNumArgs() >= ArgIndex);
3197 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3198 auto *VTy = ArgType->getAs<VectorType>();
3199 // not the scalar or vector<scalar>
3200 if (!(S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar) ||
3201 (VTy &&
3202 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar)))) {
3203 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3204 DiagID: diag::err_typecheck_expect_scalar_or_vector)
3205 << ArgType << Scalar;
3206 return true;
3207 }
3208 return false;
3209}
3210
3211static bool CheckScalarOrVectorOrMatrix(Sema *S, CallExpr *TheCall,
3212 QualType Scalar, unsigned ArgIndex) {
3213 assert(TheCall->getNumArgs() > ArgIndex);
3214
3215 Expr *Arg = TheCall->getArg(Arg: ArgIndex);
3216 QualType ArgType = Arg->getType();
3217
3218 // Scalar: T
3219 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar))
3220 return false;
3221
3222 // Vector: vector<T>
3223 if (const auto *VTy = ArgType->getAs<VectorType>()) {
3224 if (S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar))
3225 return false;
3226 }
3227
3228 // Matrix: ConstantMatrixType with element type T
3229 if (const auto *MTy = ArgType->getAs<ConstantMatrixType>()) {
3230 if (S->Context.hasSameUnqualifiedType(T1: MTy->getElementType(), T2: Scalar))
3231 return false;
3232 }
3233
3234 // Not a scalar/vector/matrix-of-scalar
3235 S->Diag(Loc: Arg->getBeginLoc(),
3236 DiagID: diag::err_typecheck_expect_scalar_or_vector_or_matrix)
3237 << ArgType << Scalar;
3238 return true;
3239}
3240
3241static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
3242 unsigned ArgIndex) {
3243 assert(TheCall->getNumArgs() >= ArgIndex);
3244 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3245 auto *VTy = ArgType->getAs<VectorType>();
3246 // not the scalar or vector<scalar>
3247 if (!(ArgType->isScalarType() ||
3248 (VTy && VTy->getElementType()->isScalarType()))) {
3249 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3250 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3251 << ArgType << 1;
3252 return true;
3253 }
3254 return false;
3255}
3256
3257// Check that the argument is not a bool or vector<bool>
3258// Returns true on error
3259static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall,
3260 unsigned ArgIndex) {
3261 QualType BoolType = S->getASTContext().BoolTy;
3262 assert(ArgIndex < TheCall->getNumArgs());
3263 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3264 auto *VTy = ArgType->getAs<VectorType>();
3265 // is the bool or vector<bool>
3266 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: BoolType) ||
3267 (VTy &&
3268 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: BoolType))) {
3269 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3270 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
3271 << ArgType << 0;
3272 return true;
3273 }
3274 return false;
3275}
3276
3277static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
3278 if (CheckNotBoolScalarOrVector(S, TheCall, ArgIndex: 0))
3279 return true;
3280 return false;
3281}
3282
3283static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
3284 if (CheckNotBoolScalarOrVector(S, TheCall, ArgIndex: 0))
3285 return true;
3286 return false;
3287}
3288
3289static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
3290 assert(TheCall->getNumArgs() == 3);
3291 Expr *Arg1 = TheCall->getArg(Arg: 1);
3292 Expr *Arg2 = TheCall->getArg(Arg: 2);
3293 if (!S->Context.hasSameUnqualifiedType(T1: Arg1->getType(), T2: Arg2->getType())) {
3294 S->Diag(Loc: TheCall->getBeginLoc(),
3295 DiagID: diag::err_typecheck_call_different_arg_types)
3296 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
3297 << Arg2->getSourceRange();
3298 return true;
3299 }
3300
3301 TheCall->setType(Arg1->getType());
3302 return false;
3303}
3304
3305static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
3306 assert(TheCall->getNumArgs() == 3);
3307 Expr *Arg1 = TheCall->getArg(Arg: 1);
3308 QualType Arg1Ty = Arg1->getType();
3309 Expr *Arg2 = TheCall->getArg(Arg: 2);
3310 QualType Arg2Ty = Arg2->getType();
3311
3312 QualType Arg1ScalarTy = Arg1Ty;
3313 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
3314 Arg1ScalarTy = VTy->getElementType();
3315
3316 QualType Arg2ScalarTy = Arg2Ty;
3317 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
3318 Arg2ScalarTy = VTy->getElementType();
3319
3320 if (!S->Context.hasSameUnqualifiedType(T1: Arg1ScalarTy, T2: Arg2ScalarTy))
3321 S->Diag(Loc: Arg1->getBeginLoc(), DiagID: diag::err_hlsl_builtin_scalar_vector_mismatch)
3322 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
3323
3324 QualType Arg0Ty = TheCall->getArg(Arg: 0)->getType();
3325 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
3326 unsigned Arg1Length = Arg1Ty->isVectorType()
3327 ? Arg1Ty->getAs<VectorType>()->getNumElements()
3328 : 0;
3329 unsigned Arg2Length = Arg2Ty->isVectorType()
3330 ? Arg2Ty->getAs<VectorType>()->getNumElements()
3331 : 0;
3332 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
3333 S->Diag(Loc: TheCall->getBeginLoc(),
3334 DiagID: diag::err_typecheck_vector_lengths_not_equal)
3335 << Arg0Ty << Arg1Ty << TheCall->getArg(Arg: 0)->getSourceRange()
3336 << Arg1->getSourceRange();
3337 return true;
3338 }
3339
3340 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
3341 S->Diag(Loc: TheCall->getBeginLoc(),
3342 DiagID: diag::err_typecheck_vector_lengths_not_equal)
3343 << Arg0Ty << Arg2Ty << TheCall->getArg(Arg: 0)->getSourceRange()
3344 << Arg2->getSourceRange();
3345 return true;
3346 }
3347
3348 TheCall->setType(
3349 S->getASTContext().getExtVectorType(VectorType: Arg1ScalarTy, NumElts: Arg0Length));
3350 return false;
3351}
3352
3353static bool CheckResourceHandle(
3354 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
3355 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
3356 nullptr) {
3357 assert(TheCall->getNumArgs() >= ArgIndex);
3358 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
3359 const HLSLAttributedResourceType *ResTy =
3360 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
3361 if (!ResTy) {
3362 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getBeginLoc(),
3363 DiagID: diag::err_typecheck_expect_hlsl_resource)
3364 << ArgType;
3365 return true;
3366 }
3367 if (Check && Check(ResTy)) {
3368 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getExprLoc(),
3369 DiagID: diag::err_invalid_hlsl_resource_type)
3370 << ArgType;
3371 return true;
3372 }
3373 return false;
3374}
3375
3376static bool CheckVectorElementCount(Sema *S, QualType PassedType,
3377 QualType BaseType, unsigned ExpectedCount,
3378 SourceLocation Loc) {
3379 unsigned PassedCount = 1;
3380 if (const auto *VecTy = PassedType->getAs<VectorType>())
3381 PassedCount = VecTy->getNumElements();
3382
3383 if (PassedCount != ExpectedCount) {
3384 QualType ExpectedType =
3385 S->Context.getExtVectorType(VectorType: BaseType, NumElts: ExpectedCount);
3386 S->Diag(Loc, DiagID: diag::err_typecheck_convert_incompatible)
3387 << PassedType << ExpectedType << 1 << 0 << 0;
3388 return true;
3389 }
3390 return false;
3391}
3392
3393enum class SampleKind { Sample, Bias, Grad, Level, Cmp, CmpLevelZero };
3394
3395static bool CheckTextureSamplerAndLocation(Sema &S, CallExpr *TheCall) {
3396 // Check the texture handle.
3397 if (CheckResourceHandle(S: &S, TheCall, ArgIndex: 0,
3398 Check: [](const HLSLAttributedResourceType *ResType) {
3399 return ResType->getAttrs().ResourceDimension ==
3400 llvm::dxil::ResourceDimension::Unknown;
3401 }))
3402 return true;
3403
3404 // Check the sampler handle.
3405 if (CheckResourceHandle(S: &S, TheCall, ArgIndex: 1,
3406 Check: [](const HLSLAttributedResourceType *ResType) {
3407 return ResType->getAttrs().ResourceClass !=
3408 llvm::hlsl::ResourceClass::Sampler;
3409 }))
3410 return true;
3411
3412 auto *ResourceTy =
3413 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3414
3415 // Check the location.
3416 unsigned ExpectedDim =
3417 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3418 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: 2)->getType(),
3419 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3420 Loc: TheCall->getBeginLoc()))
3421 return true;
3422
3423 return false;
3424}
3425
3426static bool CheckGatherBuiltin(Sema &S, CallExpr *TheCall, bool IsCmp) {
3427 if (S.checkArgCountRange(Call: TheCall, MinArgCount: IsCmp ? 5 : 4, MaxArgCount: IsCmp ? 6 : 5))
3428 return true;
3429
3430 if (CheckTextureSamplerAndLocation(S, TheCall))
3431 return true;
3432
3433 unsigned NextIdx = 3;
3434 if (IsCmp) {
3435 // Check the compare value.
3436 QualType CmpTy = TheCall->getArg(Arg: NextIdx)->getType();
3437 if (!CmpTy->isFloatingType() || CmpTy->isVectorType()) {
3438 S.Diag(Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc(),
3439 DiagID: diag::err_typecheck_convert_incompatible)
3440 << CmpTy << S.Context.FloatTy << 1 << 0 << 0;
3441 return true;
3442 }
3443 NextIdx++;
3444 }
3445
3446 // Check the component operand.
3447 Expr *ComponentArg = TheCall->getArg(Arg: NextIdx);
3448 QualType ComponentTy = ComponentArg->getType();
3449 if (!ComponentTy->isIntegerType() || ComponentTy->isVectorType()) {
3450 S.Diag(Loc: ComponentArg->getBeginLoc(),
3451 DiagID: diag::err_typecheck_convert_incompatible)
3452 << ComponentTy << S.Context.UnsignedIntTy << 1 << 0 << 0;
3453 return true;
3454 }
3455
3456 // GatherCmp operations on Vulkan target must use component 0 (Red).
3457 if (IsCmp && S.getASTContext().getTargetInfo().getTriple().isSPIRV()) {
3458 std::optional<llvm::APSInt> ComponentOpt =
3459 ComponentArg->getIntegerConstantExpr(Ctx: S.getASTContext());
3460 if (ComponentOpt) {
3461 int64_t ComponentVal = ComponentOpt->getSExtValue();
3462 if (ComponentVal != 0) {
3463 // Issue an error if the component is not 0 (Red).
3464 // 0 -> Red, 1 -> Green, 2 -> Blue, 3 -> Alpha
3465 assert(ComponentVal >= 0 && ComponentVal <= 3 &&
3466 "The component is not in the expected range.");
3467 S.Diag(Loc: ComponentArg->getBeginLoc(),
3468 DiagID: diag::err_hlsl_gathercmp_invalid_component)
3469 << ComponentVal;
3470 return true;
3471 }
3472 }
3473 }
3474
3475 NextIdx++;
3476
3477 // Check the offset operand.
3478 const HLSLAttributedResourceType *ResourceTy =
3479 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3480 if (TheCall->getNumArgs() > NextIdx) {
3481 unsigned ExpectedDim =
3482 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3483 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx)->getType(),
3484 BaseType: S.Context.IntTy, ExpectedCount: ExpectedDim,
3485 Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc()))
3486 return true;
3487 NextIdx++;
3488 }
3489
3490 assert(ResourceTy->hasContainedType() &&
3491 "Expecting a contained type for resource with a dimension "
3492 "attribute.");
3493 QualType ReturnType = ResourceTy->getContainedType();
3494
3495 if (IsCmp) {
3496 if (!ReturnType->hasFloatingRepresentation()) {
3497 S.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_hlsl_samplecmp_requires_float);
3498 return true;
3499 }
3500 }
3501
3502 if (const auto *VecTy = ReturnType->getAs<VectorType>())
3503 ReturnType = VecTy->getElementType();
3504 ReturnType = S.Context.getExtVectorType(VectorType: ReturnType, NumElts: 4);
3505
3506 TheCall->setType(ReturnType);
3507
3508 return false;
3509}
3510static bool CheckLoadLevelBuiltin(Sema &S, CallExpr *TheCall) {
3511 if (S.checkArgCountRange(Call: TheCall, MinArgCount: 2, MaxArgCount: 3))
3512 return true;
3513
3514 // Check the texture handle.
3515 if (CheckResourceHandle(S: &S, TheCall, ArgIndex: 0,
3516 Check: [](const HLSLAttributedResourceType *ResType) {
3517 return ResType->getAttrs().ResourceDimension ==
3518 llvm::dxil::ResourceDimension::Unknown;
3519 }))
3520 return true;
3521
3522 auto *ResourceTy =
3523 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3524
3525 // Check the location + lod (int3 for Texture2D).
3526 unsigned ExpectedDim =
3527 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3528 QualType CoordLODTy = TheCall->getArg(Arg: 1)->getType();
3529 if (CheckVectorElementCount(S: &S, PassedType: CoordLODTy, BaseType: S.Context.IntTy, ExpectedCount: ExpectedDim + 1,
3530 Loc: TheCall->getArg(Arg: 1)->getBeginLoc()))
3531 return true;
3532
3533 QualType EltTy = CoordLODTy;
3534 if (const auto *VTy = EltTy->getAs<VectorType>())
3535 EltTy = VTy->getElementType();
3536 if (!EltTy->isIntegerType()) {
3537 S.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(), DiagID: diag::err_typecheck_expect_int)
3538 << CoordLODTy;
3539 return true;
3540 }
3541
3542 // Check the offset operand.
3543 if (TheCall->getNumArgs() > 2) {
3544 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: 2)->getType(),
3545 BaseType: S.Context.IntTy, ExpectedCount: ExpectedDim,
3546 Loc: TheCall->getArg(Arg: 2)->getBeginLoc()))
3547 return true;
3548 }
3549
3550 TheCall->setType(ResourceTy->getContainedType());
3551 return false;
3552}
3553
3554static bool CheckSamplingBuiltin(Sema &S, CallExpr *TheCall, SampleKind Kind) {
3555 unsigned MinArgs, MaxArgs;
3556 if (Kind == SampleKind::Sample) {
3557 MinArgs = 3;
3558 MaxArgs = 5;
3559 } else if (Kind == SampleKind::Bias) {
3560 MinArgs = 4;
3561 MaxArgs = 6;
3562 } else if (Kind == SampleKind::Grad) {
3563 MinArgs = 5;
3564 MaxArgs = 7;
3565 } else if (Kind == SampleKind::Level) {
3566 MinArgs = 4;
3567 MaxArgs = 5;
3568 } else if (Kind == SampleKind::Cmp) {
3569 MinArgs = 4;
3570 MaxArgs = 6;
3571 } else {
3572 assert(Kind == SampleKind::CmpLevelZero);
3573 MinArgs = 4;
3574 MaxArgs = 5;
3575 }
3576
3577 if (S.checkArgCountRange(Call: TheCall, MinArgCount: MinArgs, MaxArgCount: MaxArgs))
3578 return true;
3579
3580 if (CheckTextureSamplerAndLocation(S, TheCall))
3581 return true;
3582
3583 const HLSLAttributedResourceType *ResourceTy =
3584 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3585 unsigned ExpectedDim =
3586 getResourceDimensions(Dim: ResourceTy->getAttrs().ResourceDimension);
3587
3588 unsigned NextIdx = 3;
3589 if (Kind == SampleKind::Bias || Kind == SampleKind::Level ||
3590 Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3591 // Check the bias, lod level, or compare value, depending on the kind.
3592 // All of them must be a scalar float value.
3593 QualType BiasOrLODOrCmpTy = TheCall->getArg(Arg: NextIdx)->getType();
3594 if (!BiasOrLODOrCmpTy->isFloatingType() ||
3595 BiasOrLODOrCmpTy->isVectorType()) {
3596 S.Diag(Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc(),
3597 DiagID: diag::err_typecheck_convert_incompatible)
3598 << BiasOrLODOrCmpTy << S.Context.FloatTy << 1 << 0 << 0;
3599 return true;
3600 }
3601 NextIdx++;
3602 } else if (Kind == SampleKind::Grad) {
3603 // Check the DDX operand.
3604 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx)->getType(),
3605 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3606 Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc()))
3607 return true;
3608
3609 // Check the DDY operand.
3610 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx + 1)->getType(),
3611 BaseType: S.Context.FloatTy, ExpectedCount: ExpectedDim,
3612 Loc: TheCall->getArg(Arg: NextIdx + 1)->getBeginLoc()))
3613 return true;
3614 NextIdx += 2;
3615 }
3616
3617 // Check the offset operand.
3618 if (TheCall->getNumArgs() > NextIdx) {
3619 if (CheckVectorElementCount(S: &S, PassedType: TheCall->getArg(Arg: NextIdx)->getType(),
3620 BaseType: S.Context.IntTy, ExpectedCount: ExpectedDim,
3621 Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc()))
3622 return true;
3623 NextIdx++;
3624 }
3625
3626 // Check the clamp operand.
3627 if (Kind != SampleKind::Level && Kind != SampleKind::CmpLevelZero &&
3628 TheCall->getNumArgs() > NextIdx) {
3629 QualType ClampTy = TheCall->getArg(Arg: NextIdx)->getType();
3630 if (!ClampTy->isFloatingType() || ClampTy->isVectorType()) {
3631 S.Diag(Loc: TheCall->getArg(Arg: NextIdx)->getBeginLoc(),
3632 DiagID: diag::err_typecheck_convert_incompatible)
3633 << ClampTy << S.Context.FloatTy << 1 << 0 << 0;
3634 return true;
3635 }
3636 }
3637
3638 assert(ResourceTy->hasContainedType() &&
3639 "Expecting a contained type for resource with a dimension "
3640 "attribute.");
3641 QualType ReturnType = ResourceTy->getContainedType();
3642 if (Kind == SampleKind::Cmp || Kind == SampleKind::CmpLevelZero) {
3643 if (!ReturnType->hasFloatingRepresentation()) {
3644 S.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_hlsl_samplecmp_requires_float);
3645 return true;
3646 }
3647 ReturnType = S.Context.FloatTy;
3648 }
3649 TheCall->setType(ReturnType);
3650
3651 return false;
3652}
3653
3654// Note: returning true in this case results in CheckBuiltinFunctionCall
3655// returning an ExprError
3656bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
3657 switch (BuiltinID) {
3658 case Builtin::BI__builtin_hlsl_adduint64: {
3659 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3660 return true;
3661
3662 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3663 Check: CheckUnsignedIntVecRepresentation))
3664 return true;
3665
3666 // ensure arg integers are 32-bits
3667 if (CheckExpectedBitWidth(S: &SemaRef, TheCall, ArgOrdinal: 0, Width: 32))
3668 return true;
3669
3670 // ensure both args are vectors of total bit size of a multiple of 64
3671 auto *VTy = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
3672 int NumElementsArg = VTy->getNumElements();
3673 if (NumElementsArg != 2 && NumElementsArg != 4) {
3674 SemaRef.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vector_incorrect_bit_count)
3675 << 1 /*a multiple of*/ << 64 << NumElementsArg * 32;
3676 return true;
3677 }
3678
3679 // ensure first arg and second arg have the same type
3680 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3681 return true;
3682
3683 ExprResult A = TheCall->getArg(Arg: 0);
3684 QualType ArgTyA = A.get()->getType();
3685 // return type is the same as the input type
3686 TheCall->setType(ArgTyA);
3687 break;
3688 }
3689 case Builtin::BI__builtin_hlsl_resource_getpointer: {
3690 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2) ||
3691 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3692 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3693 ExpectedType: SemaRef.getASTContext().UnsignedIntTy))
3694 return true;
3695
3696 auto *ResourceTy =
3697 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3698 QualType ContainedTy = ResourceTy->getContainedType();
3699 auto ReturnType =
3700 SemaRef.Context.getAddrSpaceQualType(T: ContainedTy, AddressSpace: LangAS::hlsl_device);
3701 ReturnType = SemaRef.Context.getPointerType(T: ReturnType);
3702 TheCall->setType(ReturnType);
3703 TheCall->setValueKind(VK_LValue);
3704
3705 break;
3706 }
3707 case Builtin::BI__builtin_hlsl_resource_getpointer_typed: {
3708 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3) ||
3709 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3710 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3711 ExpectedType: SemaRef.getASTContext().UnsignedIntTy))
3712 return true;
3713
3714 QualType ElementTy = TheCall->getArg(Arg: 2)->getType();
3715 assert(ElementTy->isPointerType() &&
3716 "expected pointer type for second argument");
3717 ElementTy = ElementTy->getPointeeType();
3718
3719 // Reject array types
3720 if (ElementTy->isArrayType())
3721 return SemaRef.Diag(
3722 Loc: cast<FunctionDecl>(Val: SemaRef.CurContext)->getPointOfInstantiation(),
3723 DiagID: diag::err_invalid_use_of_array_type);
3724
3725 auto ReturnType =
3726 SemaRef.Context.getAddrSpaceQualType(T: ElementTy, AddressSpace: LangAS::hlsl_device);
3727 ReturnType = SemaRef.Context.getPointerType(T: ReturnType);
3728 TheCall->setType(ReturnType);
3729
3730 break;
3731 }
3732 case Builtin::BI__builtin_hlsl_resource_load_with_status: {
3733 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3) ||
3734 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3735 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3736 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3737 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2),
3738 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3739 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
3740 return true;
3741
3742 auto *ResourceTy =
3743 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
3744 QualType ReturnType = ResourceTy->getContainedType();
3745 TheCall->setType(ReturnType);
3746
3747 break;
3748 }
3749 case Builtin::BI__builtin_hlsl_resource_load_with_status_typed: {
3750 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 4) ||
3751 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
3752 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3753 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3754 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2),
3755 ExpectedType: SemaRef.getASTContext().UnsignedIntTy) ||
3756 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
3757 return true;
3758
3759 QualType ReturnType = TheCall->getArg(Arg: 3)->getType();
3760 assert(ReturnType->isPointerType() &&
3761 "expected pointer type for second argument");
3762 ReturnType = ReturnType->getPointeeType();
3763
3764 // Reject array types
3765 if (ReturnType->isArrayType())
3766 return SemaRef.Diag(
3767 Loc: cast<FunctionDecl>(Val: SemaRef.CurContext)->getPointOfInstantiation(),
3768 DiagID: diag::err_invalid_use_of_array_type);
3769
3770 TheCall->setType(ReturnType);
3771
3772 break;
3773 }
3774 case Builtin::BI__builtin_hlsl_resource_load_level:
3775 return CheckLoadLevelBuiltin(S&: SemaRef, TheCall);
3776 case Builtin::BI__builtin_hlsl_resource_sample:
3777 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Sample);
3778 case Builtin::BI__builtin_hlsl_resource_sample_bias:
3779 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Bias);
3780 case Builtin::BI__builtin_hlsl_resource_sample_grad:
3781 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Grad);
3782 case Builtin::BI__builtin_hlsl_resource_sample_level:
3783 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Level);
3784 case Builtin::BI__builtin_hlsl_resource_sample_cmp:
3785 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::Cmp);
3786 case Builtin::BI__builtin_hlsl_resource_sample_cmp_level_zero:
3787 return CheckSamplingBuiltin(S&: SemaRef, TheCall, Kind: SampleKind::CmpLevelZero);
3788 case Builtin::BI__builtin_hlsl_resource_gather:
3789 return CheckGatherBuiltin(S&: SemaRef, TheCall, /*IsCmp=*/false);
3790 case Builtin::BI__builtin_hlsl_resource_gather_cmp:
3791 return CheckGatherBuiltin(S&: SemaRef, TheCall, /*IsCmp=*/true);
3792 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
3793 assert(TheCall->getNumArgs() == 1 && "expected 1 arg");
3794 // Update return type to be the attributed resource type from arg0.
3795 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3796 TheCall->setType(ResourceTy);
3797 break;
3798 }
3799 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
3800 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3801 // Update return type to be the attributed resource type from arg0.
3802 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3803 TheCall->setType(ResourceTy);
3804 break;
3805 }
3806 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
3807 assert(TheCall->getNumArgs() == 6 && "expected 6 args");
3808 // Update return type to be the attributed resource type from arg0.
3809 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
3810 TheCall->setType(ResourceTy);
3811 break;
3812 }
3813 case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: {
3814 assert(TheCall->getNumArgs() == 3 && "expected 3 args");
3815 ASTContext &AST = SemaRef.getASTContext();
3816 QualType MainHandleTy = TheCall->getArg(Arg: 0)->getType();
3817 auto *MainResType = MainHandleTy->getAs<HLSLAttributedResourceType>();
3818 auto MainAttrs = MainResType->getAttrs();
3819 assert(!MainAttrs.IsCounter && "cannot create a counter from a counter");
3820 MainAttrs.IsCounter = true;
3821 QualType CounterHandleTy = AST.getHLSLAttributedResourceType(
3822 Wrapped: MainResType->getWrappedType(), Contained: MainResType->getContainedType(),
3823 Attrs: MainAttrs);
3824 // Update return type to be the attributed resource type from arg0
3825 // with added IsCounter flag.
3826 TheCall->setType(CounterHandleTy);
3827 break;
3828 }
3829 case Builtin::BI__builtin_hlsl_and:
3830 case Builtin::BI__builtin_hlsl_or: {
3831 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3832 return true;
3833 if (CheckScalarOrVectorOrMatrix(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy,
3834 ArgIndex: 0))
3835 return true;
3836 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3837 return true;
3838
3839 ExprResult A = TheCall->getArg(Arg: 0);
3840 QualType ArgTyA = A.get()->getType();
3841 // return type is the same as the input type
3842 TheCall->setType(ArgTyA);
3843 break;
3844 }
3845 case Builtin::BI__builtin_hlsl_all:
3846 case Builtin::BI__builtin_hlsl_any: {
3847 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3848 return true;
3849 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3850 return true;
3851 break;
3852 }
3853 case Builtin::BI__builtin_hlsl_asdouble: {
3854 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3855 return true;
3856 if (CheckScalarOrVector(
3857 S: &SemaRef, TheCall,
3858 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
3859 /* arg index */ ArgIndex: 0))
3860 return true;
3861 if (CheckScalarOrVector(
3862 S: &SemaRef, TheCall,
3863 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
3864 /* arg index */ ArgIndex: 1))
3865 return true;
3866 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3867 return true;
3868
3869 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().DoubleTy);
3870 break;
3871 }
3872 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
3873 if (SemaRef.BuiltinElementwiseTernaryMath(
3874 TheCall, /*ArgTyRestr=*/
3875 Sema::EltwiseBuiltinArgTyRestriction::None))
3876 return true;
3877 break;
3878 }
3879 case Builtin::BI__builtin_hlsl_dot: {
3880 // arg count is checked by BuiltinVectorToScalarMath
3881 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
3882 return true;
3883 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckNoDoubleVectors))
3884 return true;
3885 break;
3886 }
3887 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
3888 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
3889 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3890 return true;
3891
3892 const Expr *Arg = TheCall->getArg(Arg: 0);
3893 QualType ArgTy = Arg->getType();
3894 QualType EltTy = ArgTy;
3895
3896 QualType ResTy = SemaRef.Context.UnsignedIntTy;
3897
3898 if (auto *VecTy = EltTy->getAs<VectorType>()) {
3899 EltTy = VecTy->getElementType();
3900 ResTy = SemaRef.Context.getExtVectorType(VectorType: ResTy, NumElts: VecTy->getNumElements());
3901 }
3902
3903 if (!EltTy->isIntegerType()) {
3904 Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
3905 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
3906 << /* no fp */ 0 << ArgTy;
3907 return true;
3908 }
3909
3910 TheCall->setType(ResTy);
3911 break;
3912 }
3913 case Builtin::BI__builtin_hlsl_select: {
3914 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
3915 return true;
3916 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy, ArgIndex: 0))
3917 return true;
3918 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
3919 if (ArgTy->isBooleanType() && CheckBoolSelect(S: &SemaRef, TheCall))
3920 return true;
3921 auto *VTy = ArgTy->getAs<VectorType>();
3922 if (VTy && VTy->getElementType()->isBooleanType() &&
3923 CheckVectorSelect(S: &SemaRef, TheCall))
3924 return true;
3925 break;
3926 }
3927 case Builtin::BI__builtin_hlsl_elementwise_saturate:
3928 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
3929 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3930 return true;
3931 if (!TheCall->getArg(Arg: 0)
3932 ->getType()
3933 ->hasFloatingRepresentation()) // half or float or double
3934 return SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
3935 DiagID: diag::err_builtin_invalid_arg_type)
3936 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
3937 << /* fp */ 1 << TheCall->getArg(Arg: 0)->getType();
3938 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3939 return true;
3940 break;
3941 }
3942 case Builtin::BI__builtin_hlsl_elementwise_degrees:
3943 case Builtin::BI__builtin_hlsl_elementwise_radians:
3944 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
3945 case Builtin::BI__builtin_hlsl_elementwise_frac:
3946 case Builtin::BI__builtin_hlsl_elementwise_ddx_coarse:
3947 case Builtin::BI__builtin_hlsl_elementwise_ddy_coarse:
3948 case Builtin::BI__builtin_hlsl_elementwise_ddx_fine:
3949 case Builtin::BI__builtin_hlsl_elementwise_ddy_fine: {
3950 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3951 return true;
3952 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3953 Check: CheckFloatOrHalfRepresentation))
3954 return true;
3955 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3956 return true;
3957 break;
3958 }
3959 case Builtin::BI__builtin_hlsl_elementwise_isinf:
3960 case Builtin::BI__builtin_hlsl_elementwise_isnan: {
3961 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3962 return true;
3963 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3964 Check: CheckFloatOrHalfRepresentation))
3965 return true;
3966 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
3967 return true;
3968 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().BoolTy);
3969 break;
3970 }
3971 case Builtin::BI__builtin_hlsl_lerp: {
3972 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
3973 return true;
3974 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3975 Check: CheckFloatOrHalfRepresentation))
3976 return true;
3977 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
3978 return true;
3979 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
3980 return true;
3981 break;
3982 }
3983 case Builtin::BI__builtin_hlsl_mad: {
3984 if (SemaRef.BuiltinElementwiseTernaryMath(
3985 TheCall, /*ArgTyRestr=*/
3986 Sema::EltwiseBuiltinArgTyRestriction::None))
3987 return true;
3988 break;
3989 }
3990 case Builtin::BI__builtin_hlsl_mul: {
3991 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3992 return true;
3993
3994 Expr *Arg0 = TheCall->getArg(Arg: 0);
3995 Expr *Arg1 = TheCall->getArg(Arg: 1);
3996 QualType Ty0 = Arg0->getType();
3997 QualType Ty1 = Arg1->getType();
3998
3999 auto getElemType = [](QualType T) -> QualType {
4000 if (const auto *VTy = T->getAs<VectorType>())
4001 return VTy->getElementType();
4002 if (const auto *MTy = T->getAs<ConstantMatrixType>())
4003 return MTy->getElementType();
4004 return T;
4005 };
4006
4007 QualType EltTy0 = getElemType(Ty0);
4008
4009 bool IsVec0 = Ty0->isVectorType();
4010 bool IsMat0 = Ty0->isConstantMatrixType();
4011 bool IsVec1 = Ty1->isVectorType();
4012 bool IsMat1 = Ty1->isConstantMatrixType();
4013
4014 QualType RetTy;
4015
4016 if (IsVec0 && IsMat1) {
4017 auto *MatTy = Ty1->castAs<ConstantMatrixType>();
4018 RetTy = getASTContext().getExtVectorType(VectorType: EltTy0, NumElts: MatTy->getNumColumns());
4019 } else if (IsMat0 && IsVec1) {
4020 auto *MatTy = Ty0->castAs<ConstantMatrixType>();
4021 RetTy = getASTContext().getExtVectorType(VectorType: EltTy0, NumElts: MatTy->getNumRows());
4022 } else {
4023 assert(IsMat0 && IsMat1);
4024 auto *MatTy0 = Ty0->castAs<ConstantMatrixType>();
4025 auto *MatTy1 = Ty1->castAs<ConstantMatrixType>();
4026 RetTy = getASTContext().getConstantMatrixType(
4027 ElementType: EltTy0, NumRows: MatTy0->getNumRows(), NumColumns: MatTy1->getNumColumns());
4028 }
4029
4030 TheCall->setType(RetTy);
4031 break;
4032 }
4033 case Builtin::BI__builtin_hlsl_normalize: {
4034 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4035 return true;
4036 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4037 Check: CheckFloatOrHalfRepresentation))
4038 return true;
4039 ExprResult A = TheCall->getArg(Arg: 0);
4040 QualType ArgTyA = A.get()->getType();
4041 // return type is the same as the input type
4042 TheCall->setType(ArgTyA);
4043 break;
4044 }
4045 case Builtin::BI__builtin_hlsl_transpose: {
4046 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4047 return true;
4048
4049 Expr *Arg = TheCall->getArg(Arg: 0);
4050 QualType ArgTy = Arg->getType();
4051
4052 const auto *MatTy = ArgTy->getAs<ConstantMatrixType>();
4053 if (!MatTy) {
4054 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
4055 << 1 << /* matrix */ 3 << /* no int */ 0 << /* no fp */ 0 << ArgTy;
4056 return true;
4057 }
4058
4059 QualType RetTy = getASTContext().getConstantMatrixType(
4060 ElementType: MatTy->getElementType(), NumRows: MatTy->getNumColumns(), NumColumns: MatTy->getNumRows());
4061 TheCall->setType(RetTy);
4062 break;
4063 }
4064 case Builtin::BI__builtin_hlsl_elementwise_sign: {
4065 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
4066 return true;
4067 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4068 Check: CheckFloatingOrIntRepresentation))
4069 return true;
4070 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().IntTy);
4071 break;
4072 }
4073 case Builtin::BI__builtin_hlsl_step: {
4074 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
4075 return true;
4076 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4077 Check: CheckFloatOrHalfRepresentation))
4078 return true;
4079
4080 ExprResult A = TheCall->getArg(Arg: 0);
4081 QualType ArgTyA = A.get()->getType();
4082 // return type is the same as the input type
4083 TheCall->setType(ArgTyA);
4084 break;
4085 }
4086 case Builtin::BI__builtin_hlsl_wave_active_all_equal: {
4087 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4088 return true;
4089
4090 // Ensure input expr type is a scalar/vector
4091 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4092 return true;
4093
4094 QualType InputTy = TheCall->getArg(Arg: 0)->getType();
4095 ASTContext &Ctx = getASTContext();
4096
4097 QualType RetTy;
4098
4099 // If vector, construct bool vector of same size
4100 if (const auto *VecTy = InputTy->getAs<ExtVectorType>()) {
4101 unsigned NumElts = VecTy->getNumElements();
4102 RetTy = Ctx.getExtVectorType(VectorType: Ctx.BoolTy, NumElts);
4103 } else {
4104 // Scalar case
4105 RetTy = Ctx.BoolTy;
4106 }
4107
4108 TheCall->setType(RetTy);
4109 break;
4110 }
4111 case Builtin::BI__builtin_hlsl_wave_active_max:
4112 case Builtin::BI__builtin_hlsl_wave_active_min:
4113 case Builtin::BI__builtin_hlsl_wave_active_sum:
4114 case Builtin::BI__builtin_hlsl_wave_active_product: {
4115 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4116 return true;
4117
4118 // Ensure input expr type is a scalar/vector and the same as the return type
4119 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4120 return true;
4121 if (CheckWaveActive(S: &SemaRef, TheCall))
4122 return true;
4123 ExprResult Expr = TheCall->getArg(Arg: 0);
4124 QualType ArgTyExpr = Expr.get()->getType();
4125 TheCall->setType(ArgTyExpr);
4126 break;
4127 }
4128 case Builtin::BI__builtin_hlsl_wave_active_bit_xor:
4129 case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
4130 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4131 return true;
4132
4133 // Ensure input expr type is a scalar/vector
4134 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4135 return true;
4136
4137 if (CheckWaveActive(S: &SemaRef, TheCall))
4138 return true;
4139
4140 // Ensure the expr type is interpretable as a uint or vector<uint>
4141 ExprResult Expr = TheCall->getArg(Arg: 0);
4142 QualType ArgTyExpr = Expr.get()->getType();
4143 auto *VTy = ArgTyExpr->getAs<VectorType>();
4144 if (!(ArgTyExpr->isIntegerType() ||
4145 (VTy && VTy->getElementType()->isIntegerType()))) {
4146 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4147 DiagID: diag::err_builtin_invalid_arg_type)
4148 << ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4149 return true;
4150 }
4151
4152 // Ensure input expr type is the same as the return type
4153 TheCall->setType(ArgTyExpr);
4154 break;
4155 }
4156 // Note these are llvm builtins that we want to catch invalid intrinsic
4157 // generation. Normal handling of these builtins will occur elsewhere.
4158 case Builtin::BI__builtin_elementwise_bitreverse: {
4159 // does not include a check for number of arguments
4160 // because that is done previously
4161 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4162 Check: CheckUnsignedIntRepresentation))
4163 return true;
4164 break;
4165 }
4166 case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: {
4167 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4168 return true;
4169
4170 QualType ArgType = TheCall->getArg(Arg: 0)->getType();
4171
4172 if (!(ArgType->isScalarType())) {
4173 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4174 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
4175 << ArgType << 0;
4176 return true;
4177 }
4178
4179 if (!(ArgType->isBooleanType())) {
4180 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4181 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
4182 << ArgType << 0;
4183 return true;
4184 }
4185
4186 break;
4187 }
4188 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
4189 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
4190 return true;
4191
4192 // Ensure index parameter type can be interpreted as a uint
4193 ExprResult Index = TheCall->getArg(Arg: 1);
4194 QualType ArgTyIndex = Index.get()->getType();
4195 if (!ArgTyIndex->isIntegerType()) {
4196 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
4197 DiagID: diag::err_typecheck_convert_incompatible)
4198 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
4199 return true;
4200 }
4201
4202 // Ensure input expr type is a scalar/vector and the same as the return type
4203 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4204 return true;
4205
4206 ExprResult Expr = TheCall->getArg(Arg: 0);
4207 QualType ArgTyExpr = Expr.get()->getType();
4208 TheCall->setType(ArgTyExpr);
4209 break;
4210 }
4211 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
4212 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 0))
4213 return true;
4214 break;
4215 }
4216 case Builtin::BI__builtin_hlsl_wave_prefix_sum:
4217 case Builtin::BI__builtin_hlsl_wave_prefix_product: {
4218 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4219 return true;
4220
4221 // Ensure input expr type is a scalar/vector and the same as the return type
4222 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
4223 return true;
4224 if (CheckWavePrefix(S: &SemaRef, TheCall))
4225 return true;
4226 ExprResult Expr = TheCall->getArg(Arg: 0);
4227 QualType ArgTyExpr = Expr.get()->getType();
4228 TheCall->setType(ArgTyExpr);
4229 break;
4230 }
4231 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
4232 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
4233 return true;
4234
4235 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.DoubleTy, ArgIndex: 0) ||
4236 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
4237 ArgIndex: 1) ||
4238 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
4239 ArgIndex: 2))
4240 return true;
4241
4242 if (CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 1) ||
4243 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
4244 return true;
4245 break;
4246 }
4247 case Builtin::BI__builtin_hlsl_elementwise_clip: {
4248 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4249 return true;
4250
4251 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.FloatTy, ArgIndex: 0))
4252 return true;
4253 break;
4254 }
4255 case Builtin::BI__builtin_elementwise_acos:
4256 case Builtin::BI__builtin_elementwise_asin:
4257 case Builtin::BI__builtin_elementwise_atan:
4258 case Builtin::BI__builtin_elementwise_atan2:
4259 case Builtin::BI__builtin_elementwise_ceil:
4260 case Builtin::BI__builtin_elementwise_cos:
4261 case Builtin::BI__builtin_elementwise_cosh:
4262 case Builtin::BI__builtin_elementwise_exp:
4263 case Builtin::BI__builtin_elementwise_exp2:
4264 case Builtin::BI__builtin_elementwise_exp10:
4265 case Builtin::BI__builtin_elementwise_floor:
4266 case Builtin::BI__builtin_elementwise_fmod:
4267 case Builtin::BI__builtin_elementwise_log:
4268 case Builtin::BI__builtin_elementwise_log2:
4269 case Builtin::BI__builtin_elementwise_log10:
4270 case Builtin::BI__builtin_elementwise_pow:
4271 case Builtin::BI__builtin_elementwise_roundeven:
4272 case Builtin::BI__builtin_elementwise_sin:
4273 case Builtin::BI__builtin_elementwise_sinh:
4274 case Builtin::BI__builtin_elementwise_sqrt:
4275 case Builtin::BI__builtin_elementwise_tan:
4276 case Builtin::BI__builtin_elementwise_tanh:
4277 case Builtin::BI__builtin_elementwise_trunc: {
4278 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4279 Check: CheckFloatOrHalfRepresentation))
4280 return true;
4281 break;
4282 }
4283 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
4284 assert(TheCall->getNumArgs() == 2 && "expected 2 args");
4285 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
4286 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
4287 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
4288 };
4289 if (CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0, Check: checkResTy))
4290 return true;
4291 Expr *OffsetExpr = TheCall->getArg(Arg: 1);
4292 std::optional<llvm::APSInt> Offset =
4293 OffsetExpr->getIntegerConstantExpr(Ctx: SemaRef.getASTContext());
4294 if (!Offset.has_value() || std::abs(i: Offset->getExtValue()) != 1) {
4295 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
4296 DiagID: diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
4297 << 1;
4298 return true;
4299 }
4300 break;
4301 }
4302 case Builtin::BI__builtin_hlsl_elementwise_f16tof32: {
4303 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4304 return true;
4305 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
4306 Check: CheckUnsignedIntRepresentation))
4307 return true;
4308 // ensure arg integers are 32 bits
4309 if (CheckExpectedBitWidth(S: &SemaRef, TheCall, ArgOrdinal: 0, Width: 32))
4310 return true;
4311 // check it wasn't a bool type
4312 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
4313 if (auto *VTy = ArgTy->getAs<VectorType>())
4314 ArgTy = VTy->getElementType();
4315 if (ArgTy->isBooleanType()) {
4316 SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
4317 DiagID: diag::err_builtin_invalid_arg_type)
4318 << 1 << /* scalar or vector of */ 5 << /* unsigned int */ 3
4319 << /* no fp */ 0 << TheCall->getArg(Arg: 0)->getType();
4320 return true;
4321 }
4322
4323 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().FloatTy);
4324 break;
4325 }
4326 case Builtin::BI__builtin_hlsl_elementwise_f32tof16: {
4327 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
4328 return true;
4329 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckFloatRepresentation))
4330 return true;
4331 SetElementTypeAsReturnType(S: &SemaRef, TheCall,
4332 ReturnType: getASTContext().UnsignedIntTy);
4333 break;
4334 }
4335 }
4336 return false;
4337}
4338
4339static void BuildFlattenedTypeList(QualType BaseTy,
4340 llvm::SmallVectorImpl<QualType> &List) {
4341 llvm::SmallVector<QualType, 16> WorkList;
4342 WorkList.push_back(Elt: BaseTy);
4343 while (!WorkList.empty()) {
4344 QualType T = WorkList.pop_back_val();
4345 T = T.getCanonicalType().getUnqualifiedType();
4346 if (const auto *AT = dyn_cast<ConstantArrayType>(Val&: T)) {
4347 llvm::SmallVector<QualType, 16> ElementFields;
4348 // Generally I've avoided recursion in this algorithm, but arrays of
4349 // structs could be time-consuming to flatten and churn through on the
4350 // work list. Hopefully nesting arrays of structs containing arrays
4351 // of structs too many levels deep is unlikely.
4352 BuildFlattenedTypeList(BaseTy: AT->getElementType(), List&: ElementFields);
4353 // Repeat the element's field list n times.
4354 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
4355 llvm::append_range(C&: List, R&: ElementFields);
4356 continue;
4357 }
4358 // Vectors can only have element types that are builtin types, so this can
4359 // add directly to the list instead of to the WorkList.
4360 if (const auto *VT = dyn_cast<VectorType>(Val&: T)) {
4361 List.insert(I: List.end(), NumToInsert: VT->getNumElements(), Elt: VT->getElementType());
4362 continue;
4363 }
4364 if (const auto *MT = dyn_cast<ConstantMatrixType>(Val&: T)) {
4365 List.insert(I: List.end(), NumToInsert: MT->getNumElementsFlattened(),
4366 Elt: MT->getElementType());
4367 continue;
4368 }
4369 if (const auto *RD = T->getAsCXXRecordDecl()) {
4370 if (RD->isStandardLayout())
4371 RD = RD->getStandardLayoutBaseWithFields();
4372
4373 // For types that we shouldn't decompose (unions and non-aggregates), just
4374 // add the type itself to the list.
4375 if (RD->isUnion() || !RD->isAggregate()) {
4376 List.push_back(Elt: T);
4377 continue;
4378 }
4379
4380 llvm::SmallVector<QualType, 16> FieldTypes;
4381 for (const auto *FD : RD->fields())
4382 if (!FD->isUnnamedBitField())
4383 FieldTypes.push_back(Elt: FD->getType());
4384 // Reverse the newly added sub-range.
4385 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
4386 llvm::append_range(C&: WorkList, R&: FieldTypes);
4387
4388 // If this wasn't a standard layout type we may also have some base
4389 // classes to deal with.
4390 if (!RD->isStandardLayout()) {
4391 FieldTypes.clear();
4392 for (const auto &Base : RD->bases())
4393 FieldTypes.push_back(Elt: Base.getType());
4394 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
4395 llvm::append_range(C&: WorkList, R&: FieldTypes);
4396 }
4397 continue;
4398 }
4399 List.push_back(Elt: T);
4400 }
4401}
4402
4403bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) {
4404 // null and array types are not allowed.
4405 if (QT.isNull() || QT->isArrayType())
4406 return false;
4407
4408 // UDT types are not allowed
4409 if (QT->isRecordType())
4410 return false;
4411
4412 if (QT->isBooleanType() || QT->isEnumeralType())
4413 return false;
4414
4415 // the only other valid builtin types are scalars or vectors
4416 if (QT->isArithmeticType()) {
4417 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
4418 return false;
4419 return true;
4420 }
4421
4422 if (const VectorType *VT = QT->getAs<VectorType>()) {
4423 int ArraySize = VT->getNumElements();
4424
4425 if (ArraySize > 4)
4426 return false;
4427
4428 QualType ElTy = VT->getElementType();
4429 if (ElTy->isBooleanType())
4430 return false;
4431
4432 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
4433 return false;
4434 return true;
4435 }
4436
4437 return false;
4438}
4439
4440bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const {
4441 if (T1.isNull() || T2.isNull())
4442 return false;
4443
4444 T1 = T1.getCanonicalType().getUnqualifiedType();
4445 T2 = T2.getCanonicalType().getUnqualifiedType();
4446
4447 // If both types are the same canonical type, they're obviously compatible.
4448 if (SemaRef.getASTContext().hasSameType(T1, T2))
4449 return true;
4450
4451 llvm::SmallVector<QualType, 16> T1Types;
4452 BuildFlattenedTypeList(BaseTy: T1, List&: T1Types);
4453 llvm::SmallVector<QualType, 16> T2Types;
4454 BuildFlattenedTypeList(BaseTy: T2, List&: T2Types);
4455
4456 // Check the flattened type list
4457 return llvm::equal(LRange&: T1Types, RRange&: T2Types,
4458 P: [this](QualType LHS, QualType RHS) -> bool {
4459 return SemaRef.IsLayoutCompatible(T1: LHS, T2: RHS);
4460 });
4461}
4462
4463bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
4464 FunctionDecl *Old) {
4465 if (New->getNumParams() != Old->getNumParams())
4466 return true;
4467
4468 bool HadError = false;
4469
4470 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
4471 ParmVarDecl *NewParam = New->getParamDecl(i);
4472 ParmVarDecl *OldParam = Old->getParamDecl(i);
4473
4474 // HLSL parameter declarations for inout and out must match between
4475 // declarations. In HLSL inout and out are ambiguous at the call site,
4476 // but have different calling behavior, so you cannot overload a
4477 // method based on a difference between inout and out annotations.
4478 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
4479 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
4480 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
4481 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
4482
4483 if (NSpellingIdx != OSpellingIdx) {
4484 SemaRef.Diag(Loc: NewParam->getLocation(),
4485 DiagID: diag::err_hlsl_param_qualifier_mismatch)
4486 << NDAttr << NewParam;
4487 SemaRef.Diag(Loc: OldParam->getLocation(), DiagID: diag::note_previous_declaration_as)
4488 << ODAttr;
4489 HadError = true;
4490 }
4491 }
4492 return HadError;
4493}
4494
4495// Generally follows PerformScalarCast, with cases reordered for
4496// clarity of what types are supported
4497bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
4498
4499 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
4500 return false;
4501
4502 if (SemaRef.getASTContext().hasSameUnqualifiedType(T1: SrcTy, T2: DestTy))
4503 return true;
4504
4505 switch (SrcTy->getScalarTypeKind()) {
4506 case Type::STK_Bool: // casting from bool is like casting from an integer
4507 case Type::STK_Integral:
4508 switch (DestTy->getScalarTypeKind()) {
4509 case Type::STK_Bool:
4510 case Type::STK_Integral:
4511 case Type::STK_Floating:
4512 return true;
4513 case Type::STK_CPointer:
4514 case Type::STK_ObjCObjectPointer:
4515 case Type::STK_BlockPointer:
4516 case Type::STK_MemberPointer:
4517 llvm_unreachable("HLSL doesn't support pointers.");
4518 case Type::STK_IntegralComplex:
4519 case Type::STK_FloatingComplex:
4520 llvm_unreachable("HLSL doesn't support complex types.");
4521 case Type::STK_FixedPoint:
4522 llvm_unreachable("HLSL doesn't support fixed point types.");
4523 }
4524 llvm_unreachable("Should have returned before this");
4525
4526 case Type::STK_Floating:
4527 switch (DestTy->getScalarTypeKind()) {
4528 case Type::STK_Floating:
4529 case Type::STK_Bool:
4530 case Type::STK_Integral:
4531 return true;
4532 case Type::STK_FloatingComplex:
4533 case Type::STK_IntegralComplex:
4534 llvm_unreachable("HLSL doesn't support complex types.");
4535 case Type::STK_FixedPoint:
4536 llvm_unreachable("HLSL doesn't support fixed point types.");
4537 case Type::STK_CPointer:
4538 case Type::STK_ObjCObjectPointer:
4539 case Type::STK_BlockPointer:
4540 case Type::STK_MemberPointer:
4541 llvm_unreachable("HLSL doesn't support pointers.");
4542 }
4543 llvm_unreachable("Should have returned before this");
4544
4545 case Type::STK_MemberPointer:
4546 case Type::STK_CPointer:
4547 case Type::STK_BlockPointer:
4548 case Type::STK_ObjCObjectPointer:
4549 llvm_unreachable("HLSL doesn't support pointers.");
4550
4551 case Type::STK_FixedPoint:
4552 llvm_unreachable("HLSL doesn't support fixed point types.");
4553
4554 case Type::STK_FloatingComplex:
4555 case Type::STK_IntegralComplex:
4556 llvm_unreachable("HLSL doesn't support complex types.");
4557 }
4558
4559 llvm_unreachable("Unhandled scalar cast");
4560}
4561
4562// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
4563// Src is a scalar or a vector of length 1
4564// Or if Dest is a vector and Src is a vector of length 1
4565bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
4566
4567 QualType SrcTy = Src->getType();
4568 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
4569 // going to be a vector splat from a scalar.
4570 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
4571 DestTy->isScalarType())
4572 return false;
4573
4574 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
4575
4576 // Src isn't a scalar or a vector of length 1
4577 if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
4578 return false;
4579
4580 if (SrcVecTy)
4581 SrcTy = SrcVecTy->getElementType();
4582
4583 llvm::SmallVector<QualType> DestTypes;
4584 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
4585
4586 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
4587 if (DestTypes[I]->isUnionType())
4588 return false;
4589 if (!CanPerformScalarCast(SrcTy, DestTy: DestTypes[I]))
4590 return false;
4591 }
4592 return true;
4593}
4594
4595// Can we perform an HLSL Elementwise cast?
4596bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
4597
4598 // Don't handle casts where LHS and RHS are any combination of scalar/vector
4599 // There must be an aggregate somewhere
4600 QualType SrcTy = Src->getType();
4601 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
4602 return false;
4603
4604 if (SrcTy->isVectorType() &&
4605 (DestTy->isScalarType() || DestTy->isVectorType()))
4606 return false;
4607
4608 if (SrcTy->isConstantMatrixType() &&
4609 (DestTy->isScalarType() || DestTy->isConstantMatrixType()))
4610 return false;
4611
4612 llvm::SmallVector<QualType> DestTypes;
4613 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
4614 llvm::SmallVector<QualType> SrcTypes;
4615 BuildFlattenedTypeList(BaseTy: SrcTy, List&: SrcTypes);
4616
4617 // Usually the size of SrcTypes must be greater than or equal to the size of
4618 // DestTypes.
4619 if (SrcTypes.size() < DestTypes.size())
4620 return false;
4621
4622 unsigned SrcSize = SrcTypes.size();
4623 unsigned DstSize = DestTypes.size();
4624 unsigned I;
4625 for (I = 0; I < DstSize && I < SrcSize; I++) {
4626 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
4627 return false;
4628 if (!CanPerformScalarCast(SrcTy: SrcTypes[I], DestTy: DestTypes[I])) {
4629 return false;
4630 }
4631 }
4632
4633 // check the rest of the source type for unions.
4634 for (; I < SrcSize; I++) {
4635 if (SrcTypes[I]->isUnionType())
4636 return false;
4637 }
4638 return true;
4639}
4640
4641ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) {
4642 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
4643 "We should not get here without a parameter modifier expression");
4644 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
4645 if (Attr->getABI() == ParameterABI::Ordinary)
4646 return ExprResult(Arg);
4647
4648 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
4649 if (!Arg->isLValue()) {
4650 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_lvalue)
4651 << Arg << (IsInOut ? 1 : 0);
4652 return ExprError();
4653 }
4654
4655 ASTContext &Ctx = SemaRef.getASTContext();
4656
4657 QualType Ty = Param->getType().getNonLValueExprType(Context: Ctx);
4658
4659 // HLSL allows implicit conversions from scalars to vectors, but not the
4660 // inverse, so we need to disallow `inout` with scalar->vector or
4661 // scalar->matrix conversions.
4662 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
4663 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_scalar_extension)
4664 << Arg << (IsInOut ? 1 : 0);
4665 return ExprError();
4666 }
4667
4668 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
4669 VK_LValue, OK_Ordinary, Arg);
4670
4671 // Parameters are initialized via copy initialization. This allows for
4672 // overload resolution of argument constructors.
4673 InitializedEntity Entity =
4674 InitializedEntity::InitializeParameter(Context&: Ctx, Type: Ty, Consumed: false);
4675 ExprResult Res =
4676 SemaRef.PerformCopyInitialization(Entity, EqualLoc: Param->getBeginLoc(), Init: ArgOpV);
4677 if (Res.isInvalid())
4678 return ExprError();
4679 Expr *Base = Res.get();
4680 // After the cast, drop the reference type when creating the exprs.
4681 Ty = Ty.getNonLValueExprType(Context: Ctx);
4682 auto *OpV = new (Ctx)
4683 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
4684
4685 // Writebacks are performed with `=` binary operator, which allows for
4686 // overload resolution on writeback result expressions.
4687 Res = SemaRef.ActOnBinOp(S: SemaRef.getCurScope(), TokLoc: Param->getBeginLoc(),
4688 Kind: tok::equal, LHSExpr: ArgOpV, RHSExpr: OpV);
4689
4690 if (Res.isInvalid())
4691 return ExprError();
4692 Expr *Writeback = Res.get();
4693 auto *OutExpr =
4694 HLSLOutArgExpr::Create(C: Ctx, Ty, Base: ArgOpV, OpV, WB: Writeback, IsInOut);
4695
4696 return ExprResult(OutExpr);
4697}
4698
4699QualType SemaHLSL::getInoutParameterType(QualType Ty) {
4700 // If HLSL gains support for references, all the cites that use this will need
4701 // to be updated with semantic checking to produce errors for
4702 // pointers/references.
4703 assert(!Ty->isReferenceType() &&
4704 "Pointer and reference types cannot be inout or out parameters");
4705 Ty = SemaRef.getASTContext().getLValueReferenceType(T: Ty);
4706 Ty.addRestrict();
4707 return Ty;
4708}
4709
4710// Returns true if the type has a non-empty constant buffer layout (if it is
4711// scalar, vector or matrix, or if it contains any of these.
4712static bool hasConstantBufferLayout(QualType QT) {
4713 const Type *Ty = QT->getUnqualifiedDesugaredType();
4714 if (Ty->isScalarType() || Ty->isVectorType() || Ty->isMatrixType())
4715 return true;
4716
4717 if (Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray())
4718 return false;
4719
4720 if (const auto *RD = Ty->getAsCXXRecordDecl()) {
4721 for (const auto *FD : RD->fields()) {
4722 if (hasConstantBufferLayout(QT: FD->getType()))
4723 return true;
4724 }
4725 assert(RD->getNumBases() <= 1 &&
4726 "HLSL doesn't support multiple inheritance");
4727 return RD->getNumBases()
4728 ? hasConstantBufferLayout(QT: RD->bases_begin()->getType())
4729 : false;
4730 }
4731
4732 if (const auto *AT = dyn_cast<ArrayType>(Val: Ty)) {
4733 if (const auto *CAT = dyn_cast<ConstantArrayType>(Val: AT))
4734 if (isZeroSizedArray(CAT))
4735 return false;
4736 return hasConstantBufferLayout(QT: AT->getElementType());
4737 }
4738
4739 return false;
4740}
4741
4742static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD) {
4743 bool IsVulkan =
4744 Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::Vulkan;
4745 bool IsVKPushConstant = IsVulkan && VD->hasAttr<HLSLVkPushConstantAttr>();
4746 QualType QT = VD->getType();
4747 return VD->getDeclContext()->isTranslationUnit() &&
4748 QT.getAddressSpace() == LangAS::Default &&
4749 VD->getStorageClass() != SC_Static &&
4750 !VD->hasAttr<HLSLVkConstantIdAttr>() && !IsVKPushConstant &&
4751 hasConstantBufferLayout(QT);
4752}
4753
4754void SemaHLSL::deduceAddressSpace(VarDecl *Decl) {
4755 // The variable already has an address space (groupshared for ex).
4756 if (Decl->getType().hasAddressSpace())
4757 return;
4758
4759 if (Decl->getType()->isDependentType())
4760 return;
4761
4762 QualType Type = Decl->getType();
4763
4764 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
4765 LangAS ImplAS = LangAS::hlsl_input;
4766 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4767 Decl->setType(Type);
4768 return;
4769 }
4770
4771 bool IsVulkan = getASTContext().getTargetInfo().getTriple().getOS() ==
4772 llvm::Triple::Vulkan;
4773 if (IsVulkan && Decl->hasAttr<HLSLVkPushConstantAttr>()) {
4774 if (HasDeclaredAPushConstant)
4775 SemaRef.Diag(Loc: Decl->getLocation(), DiagID: diag::err_hlsl_push_constant_unique);
4776
4777 LangAS ImplAS = LangAS::hlsl_push_constant;
4778 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4779 Decl->setType(Type);
4780 HasDeclaredAPushConstant = true;
4781 return;
4782 }
4783
4784 if (Type->isSamplerT() || Type->isVoidType())
4785 return;
4786
4787 // Resource handles.
4788 if (Type->isHLSLResourceRecord() || Type->isHLSLResourceRecordArray())
4789 return;
4790
4791 // Only static globals belong to the Private address space.
4792 // Non-static globals belongs to the cbuffer.
4793 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
4794 return;
4795
4796 LangAS ImplAS = LangAS::hlsl_private;
4797 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
4798 Decl->setType(Type);
4799}
4800
4801void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
4802 if (VD->hasGlobalStorage()) {
4803 // make sure the declaration has a complete type
4804 if (SemaRef.RequireCompleteType(
4805 Loc: VD->getLocation(),
4806 T: SemaRef.getASTContext().getBaseElementType(QT: VD->getType()),
4807 DiagID: diag::err_typecheck_decl_incomplete_type)) {
4808 VD->setInvalidDecl();
4809 deduceAddressSpace(Decl: VD);
4810 return;
4811 }
4812
4813 // Global variables outside a cbuffer block that are not a resource, static,
4814 // groupshared, or an empty array or struct belong to the default constant
4815 // buffer $Globals (to be created at the end of the translation unit).
4816 if (IsDefaultBufferConstantDecl(Ctx: getASTContext(), VD)) {
4817 // update address space to hlsl_constant
4818 QualType NewTy = getASTContext().getAddrSpaceQualType(
4819 T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
4820 VD->setType(NewTy);
4821 DefaultCBufferDecls.push_back(Elt: VD);
4822 }
4823
4824 // find all resources bindings on decl
4825 if (VD->getType()->isHLSLIntangibleType())
4826 collectResourceBindingsOnVarDecl(D: VD);
4827
4828 if (VD->hasAttr<HLSLVkConstantIdAttr>())
4829 VD->setStorageClass(StorageClass::SC_Static);
4830
4831 if (isResourceRecordTypeOrArrayOf(VD) &&
4832 VD->getStorageClass() != SC_Static) {
4833 // Add internal linkage attribute to non-static resource variables. The
4834 // global externally visible storage is accessed through the handle, which
4835 // is a member. The variable itself is not externally visible.
4836 VD->addAttr(A: InternalLinkageAttr::CreateImplicit(Ctx&: getASTContext()));
4837 }
4838
4839 // process explicit bindings
4840 processExplicitBindingsOnDecl(D: VD);
4841
4842 // Add implicit binding attribute to non-static resource arrays.
4843 if (VD->getType()->isHLSLResourceRecordArray() &&
4844 VD->getStorageClass() != SC_Static) {
4845 // If the resource array does not have an explicit binding attribute,
4846 // create an implicit one. It will be used to transfer implicit binding
4847 // order_ID to codegen.
4848 ResourceBindingAttrs Binding(VD);
4849 if (!Binding.isExplicit()) {
4850 uint32_t OrderID = getNextImplicitBindingOrderID();
4851 if (Binding.hasBinding())
4852 Binding.setImplicitOrderID(OrderID);
4853 else {
4854 addImplicitBindingAttrToDecl(
4855 S&: SemaRef, D: VD, RT: getRegisterType(ResTy: getResourceArrayHandleType(VD)),
4856 ImplicitBindingOrderID: OrderID);
4857 // Re-create the binding object to pick up the new attribute.
4858 Binding = ResourceBindingAttrs(VD);
4859 }
4860 }
4861
4862 // Get to the base type of a potentially multi-dimensional array.
4863 QualType Ty = getASTContext().getBaseElementType(QT: VD->getType());
4864
4865 const CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
4866 if (hasCounterHandle(RD)) {
4867 if (!Binding.hasCounterImplicitOrderID()) {
4868 uint32_t OrderID = getNextImplicitBindingOrderID();
4869 Binding.setCounterImplicitOrderID(OrderID);
4870 }
4871 }
4872 }
4873
4874 // Mark groupshared variables as extern so they will have
4875 // external storage and won't be default initialized
4876 if (VD->hasAttr<HLSLGroupSharedAddressSpaceAttr>())
4877 VD->setStorageClass(StorageClass::SC_Extern);
4878 }
4879
4880 deduceAddressSpace(Decl: VD);
4881}
4882
4883bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
4884 assert(VD->getType()->isHLSLResourceRecord() &&
4885 "expected resource record type");
4886
4887 ASTContext &AST = SemaRef.getASTContext();
4888 uint64_t UIntTySize = AST.getTypeSize(T: AST.UnsignedIntTy);
4889 uint64_t IntTySize = AST.getTypeSize(T: AST.IntTy);
4890
4891 // Gather resource binding attributes.
4892 ResourceBindingAttrs Binding(VD);
4893
4894 // Find correct initialization method and create its arguments.
4895 QualType ResourceTy = VD->getType();
4896 CXXRecordDecl *ResourceDecl = ResourceTy->getAsCXXRecordDecl();
4897 CXXMethodDecl *CreateMethod = nullptr;
4898 llvm::SmallVector<Expr *> Args;
4899
4900 bool HasCounter = hasCounterHandle(RD: ResourceDecl);
4901 const char *CreateMethodName;
4902 if (Binding.isExplicit())
4903 CreateMethodName = HasCounter ? "__createFromBindingWithImplicitCounter"
4904 : "__createFromBinding";
4905 else
4906 CreateMethodName = HasCounter
4907 ? "__createFromImplicitBindingWithImplicitCounter"
4908 : "__createFromImplicitBinding";
4909
4910 CreateMethod =
4911 lookupMethod(S&: SemaRef, RecordDecl: ResourceDecl, Name: CreateMethodName, Loc: VD->getLocation());
4912
4913 if (!CreateMethod)
4914 // This can happen if someone creates a struct that looks like an HLSL
4915 // resource record but does not have the required static create method.
4916 // No binding will be generated for it.
4917 return false;
4918
4919 if (Binding.isExplicit()) {
4920 IntegerLiteral *RegSlot =
4921 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, Binding.getSlot()),
4922 type: AST.UnsignedIntTy, l: SourceLocation());
4923 Args.push_back(Elt: RegSlot);
4924 } else {
4925 uint32_t OrderID = (Binding.hasImplicitOrderID())
4926 ? Binding.getImplicitOrderID()
4927 : getNextImplicitBindingOrderID();
4928 IntegerLiteral *OrderId =
4929 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, OrderID),
4930 type: AST.UnsignedIntTy, l: SourceLocation());
4931 Args.push_back(Elt: OrderId);
4932 }
4933
4934 IntegerLiteral *Space =
4935 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, Binding.getSpace()),
4936 type: AST.UnsignedIntTy, l: SourceLocation());
4937 Args.push_back(Elt: Space);
4938
4939 IntegerLiteral *RangeSize = IntegerLiteral::Create(
4940 C: AST, V: llvm::APInt(IntTySize, 1), type: AST.IntTy, l: SourceLocation());
4941 Args.push_back(Elt: RangeSize);
4942
4943 IntegerLiteral *Index = IntegerLiteral::Create(
4944 C: AST, V: llvm::APInt(UIntTySize, 0), type: AST.UnsignedIntTy, l: SourceLocation());
4945 Args.push_back(Elt: Index);
4946
4947 StringRef VarName = VD->getName();
4948 StringLiteral *Name = StringLiteral::Create(
4949 Ctx: AST, Str: VarName, Kind: StringLiteralKind::Ordinary, Pascal: false,
4950 Ty: AST.getStringLiteralArrayType(EltTy: AST.CharTy.withConst(), Length: VarName.size()),
4951 Locs: SourceLocation());
4952 ImplicitCastExpr *NameCast = ImplicitCastExpr::Create(
4953 Context: AST, T: AST.getPointerType(T: AST.CharTy.withConst()), Kind: CK_ArrayToPointerDecay,
4954 Operand: Name, BasePath: nullptr, Cat: VK_PRValue, FPO: FPOptionsOverride());
4955 Args.push_back(Elt: NameCast);
4956
4957 if (HasCounter) {
4958 // Will this be in the correct order?
4959 uint32_t CounterOrderID = getNextImplicitBindingOrderID();
4960 IntegerLiteral *CounterId =
4961 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, CounterOrderID),
4962 type: AST.UnsignedIntTy, l: SourceLocation());
4963 Args.push_back(Elt: CounterId);
4964 }
4965
4966 // Make sure the create method template is instantiated and emitted.
4967 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
4968 SemaRef.InstantiateFunctionDefinition(PointOfInstantiation: VD->getLocation(), Function: CreateMethod,
4969 Recursive: true);
4970
4971 // Create CallExpr with a call to the static method and set it as the decl
4972 // initialization.
4973 DeclRefExpr *DRE = DeclRefExpr::Create(
4974 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: CreateMethod, RefersToEnclosingVariableOrCapture: false,
4975 NameInfo: CreateMethod->getNameInfo(), T: CreateMethod->getType(), VK: VK_PRValue);
4976
4977 auto *ImpCast = ImplicitCastExpr::Create(
4978 Context: AST, T: AST.getPointerType(T: CreateMethod->getType()),
4979 Kind: CK_FunctionToPointerDecay, Operand: DRE, BasePath: nullptr, Cat: VK_PRValue, FPO: FPOptionsOverride());
4980
4981 CallExpr *InitExpr =
4982 CallExpr::Create(Ctx: AST, Fn: ImpCast, Args, Ty: ResourceTy, VK: VK_PRValue,
4983 RParenLoc: SourceLocation(), FPFeatures: FPOptionsOverride());
4984 VD->setInit(InitExpr);
4985 VD->setInitStyle(VarDecl::CallInit);
4986 SemaRef.CheckCompleteVariableDeclaration(VD);
4987 return true;
4988}
4989
4990bool SemaHLSL::initGlobalResourceArrayDecl(VarDecl *VD) {
4991 assert(VD->getType()->isHLSLResourceRecordArray() &&
4992 "expected array of resource records");
4993
4994 // Individual resources in a resource array are not initialized here. They
4995 // are initialized later on during codegen when the individual resources are
4996 // accessed. Codegen will emit a call to the resource initialization method
4997 // with the specified array index. We need to make sure though that the method
4998 // for the specific resource type is instantiated, so codegen can emit a call
4999 // to it when the array element is accessed.
5000
5001 // Find correct initialization method based on the resource binding
5002 // information.
5003 ASTContext &AST = SemaRef.getASTContext();
5004 QualType ResElementTy = AST.getBaseElementType(QT: VD->getType());
5005 CXXRecordDecl *ResourceDecl = ResElementTy->getAsCXXRecordDecl();
5006 CXXMethodDecl *CreateMethod = nullptr;
5007
5008 bool HasCounter = hasCounterHandle(RD: ResourceDecl);
5009 ResourceBindingAttrs ResourceAttrs(VD);
5010 if (ResourceAttrs.isExplicit())
5011 // Resource has explicit binding.
5012 CreateMethod =
5013 lookupMethod(S&: SemaRef, RecordDecl: ResourceDecl,
5014 Name: HasCounter ? "__createFromBindingWithImplicitCounter"
5015 : "__createFromBinding",
5016 Loc: VD->getLocation());
5017 else
5018 // Resource has implicit binding.
5019 CreateMethod = lookupMethod(
5020 S&: SemaRef, RecordDecl: ResourceDecl,
5021 Name: HasCounter ? "__createFromImplicitBindingWithImplicitCounter"
5022 : "__createFromImplicitBinding",
5023 Loc: VD->getLocation());
5024
5025 if (!CreateMethod)
5026 return false;
5027
5028 // Make sure the create method template is instantiated and emitted.
5029 if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation())
5030 SemaRef.InstantiateFunctionDefinition(PointOfInstantiation: VD->getLocation(), Function: CreateMethod,
5031 Recursive: true);
5032 return true;
5033}
5034
5035// Returns true if the initialization has been handled.
5036// Returns false to use default initialization.
5037bool SemaHLSL::ActOnUninitializedVarDecl(VarDecl *VD) {
5038 // Objects in the hlsl_constant address space are initialized
5039 // externally, so don't synthesize an implicit initializer.
5040 if (VD->getType().getAddressSpace() == LangAS::hlsl_constant)
5041 return true;
5042
5043 // Initialize non-static resources at the global scope.
5044 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5045 const Type *Ty = VD->getType().getTypePtr();
5046 if (Ty->isHLSLResourceRecord())
5047 return initGlobalResourceDecl(VD);
5048 if (Ty->isHLSLResourceRecordArray())
5049 return initGlobalResourceArrayDecl(VD);
5050 }
5051 return false;
5052}
5053
5054std::optional<const DeclBindingInfo *> SemaHLSL::inferGlobalBinding(Expr *E) {
5055 if (auto *Ternary = dyn_cast<ConditionalOperator>(Val: E)) {
5056 auto TrueInfo = inferGlobalBinding(E: Ternary->getTrueExpr());
5057 auto FalseInfo = inferGlobalBinding(E: Ternary->getFalseExpr());
5058 if (!TrueInfo || !FalseInfo)
5059 return std::nullopt;
5060 if (*TrueInfo != *FalseInfo)
5061 return std::nullopt;
5062 return TrueInfo;
5063 }
5064
5065 if (auto *ASE = dyn_cast<ArraySubscriptExpr>(Val: E))
5066 E = ASE->getBase()->IgnoreParenImpCasts();
5067
5068 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Val: E->IgnoreParens()))
5069 if (VarDecl *VD = dyn_cast<VarDecl>(Val: DRE->getDecl())) {
5070 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5071 if (Ty->isArrayType())
5072 Ty = Ty->getArrayElementTypeNoTypeQual();
5073
5074 if (const auto *AttrResType =
5075 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
5076 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
5077 return Bindings.getDeclBindingInfo(VD, ResClass: RC);
5078 }
5079 }
5080
5081 return nullptr;
5082}
5083
5084void SemaHLSL::trackLocalResource(VarDecl *VD, Expr *E) {
5085 std::optional<const DeclBindingInfo *> ExprBinding = inferGlobalBinding(E);
5086 if (!ExprBinding) {
5087 SemaRef.Diag(Loc: E->getBeginLoc(),
5088 DiagID: diag::warn_hlsl_assigning_local_resource_is_not_unique)
5089 << E << VD;
5090 return; // Expr use multiple resources
5091 }
5092
5093 if (*ExprBinding == nullptr)
5094 return; // No binding could be inferred to track, return without error
5095
5096 auto PrevBinding = Assigns.find(Val: VD);
5097 if (PrevBinding == Assigns.end()) {
5098 // No previous binding recorded, simply record the new assignment
5099 Assigns.insert(KV: {VD, *ExprBinding});
5100 return;
5101 }
5102
5103 // Otherwise, warn if the assignment implies different resource bindings
5104 if (*ExprBinding != PrevBinding->second) {
5105 SemaRef.Diag(Loc: E->getBeginLoc(),
5106 DiagID: diag::warn_hlsl_assigning_local_resource_is_not_unique)
5107 << E << VD;
5108 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::note_var_declared_here) << VD;
5109 return;
5110 }
5111
5112 return;
5113}
5114
5115bool SemaHLSL::CheckResourceBinOp(BinaryOperatorKind Opc, Expr *LHSExpr,
5116 Expr *RHSExpr, SourceLocation Loc) {
5117 assert((LHSExpr->getType()->isHLSLResourceRecord() ||
5118 LHSExpr->getType()->isHLSLResourceRecordArray()) &&
5119 "expected LHS to be a resource record or array of resource records");
5120 if (Opc != BO_Assign)
5121 return true;
5122
5123 // If LHS is an array subscript, get the underlying declaration.
5124 Expr *E = LHSExpr;
5125 while (auto *ASE = dyn_cast<ArraySubscriptExpr>(Val: E))
5126 E = ASE->getBase()->IgnoreParenImpCasts();
5127
5128 // Report error if LHS is a non-static resource declared at a global scope.
5129 if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Val: E->IgnoreParens())) {
5130 if (VarDecl *VD = dyn_cast<VarDecl>(Val: DRE->getDecl())) {
5131 if (VD->hasGlobalStorage() && VD->getStorageClass() != SC_Static) {
5132 // assignment to global resource is not allowed
5133 SemaRef.Diag(Loc, DiagID: diag::err_hlsl_assign_to_global_resource) << VD;
5134 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::note_var_declared_here) << VD;
5135 return false;
5136 }
5137
5138 trackLocalResource(VD, E: RHSExpr);
5139 }
5140 }
5141 return true;
5142}
5143
5144// Walks though the global variable declaration, collects all resource binding
5145// requirements and adds them to Bindings
5146void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
5147 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
5148 "expected global variable that contains HLSL resource");
5149
5150 // Cbuffers and Tbuffers are HLSLBufferDecl types
5151 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: VD)) {
5152 Bindings.addDeclBindingInfo(VD, ResClass: CBufferOrTBuffer->isCBuffer()
5153 ? ResourceClass::CBuffer
5154 : ResourceClass::SRV);
5155 return;
5156 }
5157
5158 // Unwrap arrays
5159 // FIXME: Calculate array size while unwrapping
5160 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
5161 while (Ty->isArrayType()) {
5162 const ArrayType *AT = cast<ArrayType>(Val: Ty);
5163 Ty = AT->getElementType()->getUnqualifiedDesugaredType();
5164 }
5165
5166 // Resource (or array of resources)
5167 if (const HLSLAttributedResourceType *AttrResType =
5168 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
5169 Bindings.addDeclBindingInfo(VD, ResClass: AttrResType->getAttrs().ResourceClass);
5170 return;
5171 }
5172
5173 // User defined record type
5174 if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty))
5175 collectResourceBindingsOnUserRecordDecl(VD, RT);
5176}
5177
5178// Walks though the explicit resource binding attributes on the declaration,
5179// and makes sure there is a resource that matched the binding and updates
5180// DeclBindingInfoLists
5181void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
5182 assert(VD->hasGlobalStorage() && "expected global variable");
5183
5184 bool HasBinding = false;
5185 for (Attr *A : VD->attrs()) {
5186 if (isa<HLSLVkBindingAttr>(Val: A)) {
5187 HasBinding = true;
5188 if (auto PA = VD->getAttr<HLSLVkPushConstantAttr>())
5189 Diag(Loc: PA->getLoc(), DiagID: diag::err_hlsl_attr_incompatible) << A << PA;
5190 }
5191
5192 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: A);
5193 if (!RBA || !RBA->hasRegisterSlot())
5194 continue;
5195 HasBinding = true;
5196
5197 RegisterType RT = RBA->getRegisterType();
5198 assert(RT != RegisterType::I && "invalid or obsolete register type should "
5199 "never have an attribute created");
5200
5201 if (RT == RegisterType::C) {
5202 if (Bindings.hasBindingInfoForDecl(VD))
5203 SemaRef.Diag(Loc: VD->getLocation(),
5204 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
5205 << static_cast<int>(RT);
5206 continue;
5207 }
5208
5209 // Find DeclBindingInfo for this binding and update it, or report error
5210 // if it does not exist (user type does to contain resources with the
5211 // expected resource class).
5212 ResourceClass RC = getResourceClass(RT);
5213 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, ResClass: RC)) {
5214 // update binding info
5215 BI->setBindingAttribute(A: RBA, BT: BindingType::Explicit);
5216 } else {
5217 SemaRef.Diag(Loc: VD->getLocation(),
5218 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
5219 << static_cast<int>(RT);
5220 }
5221 }
5222
5223 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
5224 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
5225}
5226namespace {
5227class InitListTransformer {
5228 Sema &S;
5229 ASTContext &Ctx;
5230 QualType InitTy;
5231 QualType *DstIt = nullptr;
5232 Expr **ArgIt = nullptr;
5233 // Is wrapping the destination type iterator required? This is only used for
5234 // incomplete array types where we loop over the destination type since we
5235 // don't know the full number of elements from the declaration.
5236 bool Wrap;
5237
5238 bool castInitializer(Expr *E) {
5239 assert(DstIt && "This should always be something!");
5240 if (DstIt == DestTypes.end()) {
5241 if (!Wrap) {
5242 ArgExprs.push_back(Elt: E);
5243 // This is odd, but it isn't technically a failure due to conversion, we
5244 // handle mismatched counts of arguments differently.
5245 return true;
5246 }
5247 DstIt = DestTypes.begin();
5248 }
5249 InitializedEntity Entity = InitializedEntity::InitializeParameter(
5250 Context&: Ctx, Type: *DstIt, /* Consumed (ObjC) */ Consumed: false);
5251 ExprResult Res = S.PerformCopyInitialization(Entity, EqualLoc: E->getBeginLoc(), Init: E);
5252 if (Res.isInvalid())
5253 return false;
5254 Expr *Init = Res.get();
5255 ArgExprs.push_back(Elt: Init);
5256 DstIt++;
5257 return true;
5258 }
5259
5260 bool buildInitializerListImpl(Expr *E) {
5261 // If this is an initialization list, traverse the sub initializers.
5262 if (auto *Init = dyn_cast<InitListExpr>(Val: E)) {
5263 for (auto *SubInit : Init->inits())
5264 if (!buildInitializerListImpl(E: SubInit))
5265 return false;
5266 return true;
5267 }
5268
5269 // If this is a scalar type, just enqueue the expression.
5270 QualType Ty = E->getType();
5271
5272 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
5273 return castInitializer(E);
5274
5275 if (auto *VecTy = Ty->getAs<VectorType>()) {
5276 uint64_t Size = VecTy->getNumElements();
5277
5278 QualType SizeTy = Ctx.getSizeType();
5279 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
5280 for (uint64_t I = 0; I < Size; ++I) {
5281 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
5282 type: SizeTy, l: SourceLocation());
5283
5284 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
5285 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
5286 if (ElExpr.isInvalid())
5287 return false;
5288 if (!castInitializer(E: ElExpr.get()))
5289 return false;
5290 }
5291 return true;
5292 }
5293 if (auto *MTy = Ty->getAs<ConstantMatrixType>()) {
5294 unsigned Rows = MTy->getNumRows();
5295 unsigned Cols = MTy->getNumColumns();
5296 QualType ElemTy = MTy->getElementType();
5297
5298 for (unsigned R = 0; R < Rows; ++R) {
5299 for (unsigned C = 0; C < Cols; ++C) {
5300 // row index literal
5301 Expr *RowIdx = IntegerLiteral::Create(
5302 C: Ctx, V: llvm::APInt(Ctx.getIntWidth(T: Ctx.IntTy), R), type: Ctx.IntTy,
5303 l: E->getBeginLoc());
5304 // column index literal
5305 Expr *ColIdx = IntegerLiteral::Create(
5306 C: Ctx, V: llvm::APInt(Ctx.getIntWidth(T: Ctx.IntTy), C), type: Ctx.IntTy,
5307 l: E->getBeginLoc());
5308 ExprResult ElExpr = S.CreateBuiltinMatrixSubscriptExpr(
5309 Base: E, RowIdx, ColumnIdx: ColIdx, RBLoc: E->getEndLoc());
5310 if (ElExpr.isInvalid())
5311 return false;
5312 if (!castInitializer(E: ElExpr.get()))
5313 return false;
5314 ElExpr.get()->setType(ElemTy);
5315 }
5316 }
5317 return true;
5318 }
5319
5320 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Val: Ty.getTypePtr())) {
5321 uint64_t Size = ArrTy->getZExtSize();
5322 QualType SizeTy = Ctx.getSizeType();
5323 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
5324 for (uint64_t I = 0; I < Size; ++I) {
5325 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
5326 type: SizeTy, l: SourceLocation());
5327 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
5328 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
5329 if (ElExpr.isInvalid())
5330 return false;
5331 if (!buildInitializerListImpl(E: ElExpr.get()))
5332 return false;
5333 }
5334 return true;
5335 }
5336
5337 if (auto *RD = Ty->getAsCXXRecordDecl()) {
5338 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
5339 RecordDecls.push_back(Elt: RD);
5340 while (RecordDecls.back()->getNumBases()) {
5341 CXXRecordDecl *D = RecordDecls.back();
5342 assert(D->getNumBases() == 1 &&
5343 "HLSL doesn't support multiple inheritance");
5344 RecordDecls.push_back(
5345 Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
5346 }
5347 while (!RecordDecls.empty()) {
5348 CXXRecordDecl *RD = RecordDecls.pop_back_val();
5349 for (auto *FD : RD->fields()) {
5350 if (FD->isUnnamedBitField())
5351 continue;
5352 DeclAccessPair Found = DeclAccessPair::make(D: FD, AS: FD->getAccess());
5353 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
5354 ExprResult Res = S.BuildFieldReferenceExpr(
5355 BaseExpr: E, IsArrow: false, OpLoc: E->getBeginLoc(), SS: CXXScopeSpec(), Field: FD, FoundDecl: Found, MemberNameInfo: NameInfo);
5356 if (Res.isInvalid())
5357 return false;
5358 if (!buildInitializerListImpl(E: Res.get()))
5359 return false;
5360 }
5361 }
5362 }
5363 return true;
5364 }
5365
5366 Expr *generateInitListsImpl(QualType Ty) {
5367 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
5368 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
5369 return *(ArgIt++);
5370
5371 llvm::SmallVector<Expr *> Inits;
5372 Ty = Ty.getDesugaredType(Context: Ctx);
5373 if (Ty->isVectorType() || Ty->isConstantArrayType() ||
5374 Ty->isConstantMatrixType()) {
5375 QualType ElTy;
5376 uint64_t Size = 0;
5377 if (auto *ATy = Ty->getAs<VectorType>()) {
5378 ElTy = ATy->getElementType();
5379 Size = ATy->getNumElements();
5380 } else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) {
5381 ElTy = CMTy->getElementType();
5382 Size = CMTy->getNumElementsFlattened();
5383 } else {
5384 auto *VTy = cast<ConstantArrayType>(Val: Ty.getTypePtr());
5385 ElTy = VTy->getElementType();
5386 Size = VTy->getZExtSize();
5387 }
5388 for (uint64_t I = 0; I < Size; ++I)
5389 Inits.push_back(Elt: generateInitListsImpl(Ty: ElTy));
5390 }
5391 if (auto *RD = Ty->getAsCXXRecordDecl()) {
5392 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
5393 RecordDecls.push_back(Elt: RD);
5394 while (RecordDecls.back()->getNumBases()) {
5395 CXXRecordDecl *D = RecordDecls.back();
5396 assert(D->getNumBases() == 1 &&
5397 "HLSL doesn't support multiple inheritance");
5398 RecordDecls.push_back(
5399 Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
5400 }
5401 while (!RecordDecls.empty()) {
5402 CXXRecordDecl *RD = RecordDecls.pop_back_val();
5403 for (auto *FD : RD->fields())
5404 if (!FD->isUnnamedBitField())
5405 Inits.push_back(Elt: generateInitListsImpl(Ty: FD->getType()));
5406 }
5407 }
5408 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
5409 Inits, Inits.back()->getEndLoc());
5410 NewInit->setType(Ty);
5411 return NewInit;
5412 }
5413
5414public:
5415 llvm::SmallVector<QualType, 16> DestTypes;
5416 llvm::SmallVector<Expr *, 16> ArgExprs;
5417 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
5418 : S(SemaRef), Ctx(SemaRef.getASTContext()),
5419 Wrap(Entity.getType()->isIncompleteArrayType()) {
5420 InitTy = Entity.getType().getNonReferenceType();
5421 // When we're generating initializer lists for incomplete array types we
5422 // need to wrap around both when building the initializers and when
5423 // generating the final initializer lists.
5424 if (Wrap) {
5425 assert(InitTy->isIncompleteArrayType());
5426 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(T: InitTy);
5427 InitTy = IAT->getElementType();
5428 }
5429 BuildFlattenedTypeList(BaseTy: InitTy, List&: DestTypes);
5430 DstIt = DestTypes.begin();
5431 }
5432
5433 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
5434
5435 Expr *generateInitLists() {
5436 assert(!ArgExprs.empty() &&
5437 "Call buildInitializerList to generate argument expressions.");
5438 ArgIt = ArgExprs.begin();
5439 if (!Wrap)
5440 return generateInitListsImpl(Ty: InitTy);
5441 llvm::SmallVector<Expr *> Inits;
5442 while (ArgIt != ArgExprs.end())
5443 Inits.push_back(Elt: generateInitListsImpl(Ty: InitTy));
5444
5445 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
5446 Inits, Inits.back()->getEndLoc());
5447 llvm::APInt ArySize(64, Inits.size());
5448 NewInit->setType(Ctx.getConstantArrayType(EltTy: InitTy, ArySize, SizeExpr: nullptr,
5449 ASM: ArraySizeModifier::Normal, IndexTypeQuals: 0));
5450 return NewInit;
5451 }
5452};
5453} // namespace
5454
5455// Recursively detect any incomplete array anywhere in the type graph,
5456// including arrays, struct fields, and base classes.
5457static bool containsIncompleteArrayType(QualType Ty) {
5458 Ty = Ty.getCanonicalType();
5459
5460 // Array types
5461 if (const ArrayType *AT = dyn_cast<ArrayType>(Val&: Ty)) {
5462 if (isa<IncompleteArrayType>(Val: AT))
5463 return true;
5464 return containsIncompleteArrayType(Ty: AT->getElementType());
5465 }
5466
5467 // Record (struct/class) types
5468 if (const auto *RT = Ty->getAs<RecordType>()) {
5469 const RecordDecl *RD = RT->getDecl();
5470
5471 // Walk base classes (for C++ / HLSL structs with inheritance)
5472 if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
5473 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
5474 if (containsIncompleteArrayType(Ty: Base.getType()))
5475 return true;
5476 }
5477 }
5478
5479 // Walk fields
5480 for (const FieldDecl *F : RD->fields()) {
5481 if (containsIncompleteArrayType(Ty: F->getType()))
5482 return true;
5483 }
5484 }
5485
5486 return false;
5487}
5488
5489bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
5490 InitListExpr *Init) {
5491 // If the initializer is a scalar, just return it.
5492 if (Init->getType()->isScalarType())
5493 return true;
5494 ASTContext &Ctx = SemaRef.getASTContext();
5495 InitListTransformer ILT(SemaRef, Entity);
5496
5497 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
5498 Expr *E = Init->getInit(Init: I);
5499 if (E->HasSideEffects(Ctx)) {
5500 QualType Ty = E->getType();
5501 if (Ty->isRecordType())
5502 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
5503 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
5504 E->getObjectKind(), E);
5505 Init->setInit(Init: I, expr: E);
5506 }
5507 if (!ILT.buildInitializerList(E))
5508 return false;
5509 }
5510 size_t ExpectedSize = ILT.DestTypes.size();
5511 size_t ActualSize = ILT.ArgExprs.size();
5512 if (ExpectedSize == 0 && ActualSize == 0)
5513 return true;
5514
5515 // Reject empty initializer if *any* incomplete array exists structurally
5516 if (ActualSize == 0 && containsIncompleteArrayType(Ty: Entity.getType())) {
5517 QualType InitTy = Entity.getType().getNonReferenceType();
5518 if (InitTy.hasAddressSpace())
5519 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
5520
5521 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
5522 << /*TooManyOrFew=*/(int)(ExpectedSize < ActualSize) << InitTy
5523 << /*ExpectedSize=*/ExpectedSize << /*ActualSize=*/ActualSize;
5524 return false;
5525 }
5526
5527 // We infer size after validating legality.
5528 // For incomplete arrays it is completely arbitrary to choose whether we think
5529 // the user intended fewer or more elements. This implementation assumes that
5530 // the user intended more, and errors that there are too few initializers to
5531 // complete the final element.
5532 if (Entity.getType()->isIncompleteArrayType()) {
5533 assert(ExpectedSize > 0 &&
5534 "The expected size of an incomplete array type must be at least 1.");
5535 ExpectedSize =
5536 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
5537 }
5538
5539 // An initializer list might be attempting to initialize a reference or
5540 // rvalue-reference. When checking the initializer we should look through
5541 // the reference.
5542 QualType InitTy = Entity.getType().getNonReferenceType();
5543 if (InitTy.hasAddressSpace())
5544 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
5545 if (ExpectedSize != ActualSize) {
5546 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
5547 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
5548 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
5549 return false;
5550 }
5551
5552 // generateInitListsImpl will always return an InitListExpr here, because the
5553 // scalar case is handled above.
5554 auto *NewInit = cast<InitListExpr>(Val: ILT.generateInitLists());
5555 Init->resizeInits(Context: Ctx, NumInits: NewInit->getNumInits());
5556 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
5557 Init->updateInit(C: Ctx, Init: I, expr: NewInit->getInit(Init: I));
5558 return true;
5559}
5560
5561static QualType ReportMatrixInvalidMember(Sema &S, StringRef Name,
5562 StringRef Expected,
5563 SourceLocation OpLoc,
5564 SourceLocation CompLoc) {
5565 S.Diag(Loc: OpLoc, DiagID: diag::err_builtin_matrix_invalid_member)
5566 << Name << Expected << SourceRange(CompLoc);
5567 return QualType();
5568}
5569
5570QualType SemaHLSL::checkMatrixComponent(Sema &S, QualType baseType,
5571 ExprValueKind &VK, SourceLocation OpLoc,
5572 const IdentifierInfo *CompName,
5573 SourceLocation CompLoc) {
5574 const auto *MT = baseType->castAs<ConstantMatrixType>();
5575 StringRef AccessorName = CompName->getName();
5576 assert(!AccessorName.empty() && "Matrix Accessor must have a name");
5577
5578 unsigned Rows = MT->getNumRows();
5579 unsigned Cols = MT->getNumColumns();
5580 bool IsZeroBasedAccessor = false;
5581 unsigned ChunkLen = 0;
5582 if (AccessorName.size() < 2)
5583 return ReportMatrixInvalidMember(S, Name: AccessorName,
5584 Expected: "length 4 for zero based: \'_mRC\' or "
5585 "length 3 for one-based: \'_RC\' accessor",
5586 OpLoc, CompLoc);
5587
5588 if (AccessorName[0] == '_') {
5589 if (AccessorName[1] == 'm') {
5590 IsZeroBasedAccessor = true;
5591 ChunkLen = 4; // zero-based: "_mRC"
5592 } else {
5593 ChunkLen = 3; // one-based: "_RC"
5594 }
5595 } else
5596 return ReportMatrixInvalidMember(
5597 S, Name: AccessorName, Expected: "zero based: \'_mRC\' or one-based: \'_RC\' accessor",
5598 OpLoc, CompLoc);
5599
5600 if (AccessorName.size() % ChunkLen != 0) {
5601 const llvm::StringRef Expected = IsZeroBasedAccessor
5602 ? "zero based: '_mRC' accessor"
5603 : "one-based: '_RC' accessor";
5604
5605 return ReportMatrixInvalidMember(S, Name: AccessorName, Expected, OpLoc, CompLoc);
5606 }
5607
5608 auto isDigit = [](char c) { return c >= '0' && c <= '9'; };
5609 auto isZeroBasedIndex = [](unsigned i) { return i <= 3; };
5610 auto isOneBasedIndex = [](unsigned i) { return i >= 1 && i <= 4; };
5611
5612 bool HasRepeated = false;
5613 SmallVector<bool, 16> Seen(Rows * Cols, false);
5614 unsigned NumComponents = 0;
5615 const char *Begin = AccessorName.data();
5616
5617 for (unsigned I = 0, E = AccessorName.size(); I < E; I += ChunkLen) {
5618 const char *Chunk = Begin + I;
5619 char RowChar = 0, ColChar = 0;
5620 if (IsZeroBasedAccessor) {
5621 // Zero-based: "_mRC"
5622 if (Chunk[0] != '_' || Chunk[1] != 'm') {
5623 char Bad = (Chunk[0] != '_') ? Chunk[0] : Chunk[1];
5624 return ReportMatrixInvalidMember(
5625 S, Name: StringRef(&Bad, 1), Expected: "\'_m\' prefix",
5626 OpLoc: OpLoc.getLocWithOffset(Offset: I + (Bad == Chunk[0] ? 1 : 2)), CompLoc);
5627 }
5628 RowChar = Chunk[2];
5629 ColChar = Chunk[3];
5630 } else {
5631 // One-based: "_RC"
5632 if (Chunk[0] != '_')
5633 return ReportMatrixInvalidMember(
5634 S, Name: StringRef(&Chunk[0], 1), Expected: "\'_\' prefix",
5635 OpLoc: OpLoc.getLocWithOffset(Offset: I + 1), CompLoc);
5636 RowChar = Chunk[1];
5637 ColChar = Chunk[2];
5638 }
5639
5640 // Must be digits.
5641 bool IsDigitsError = false;
5642 if (!isDigit(RowChar)) {
5643 unsigned BadPos = IsZeroBasedAccessor ? 2 : 1;
5644 ReportMatrixInvalidMember(S, Name: StringRef(&RowChar, 1), Expected: "row as integer",
5645 OpLoc: OpLoc.getLocWithOffset(Offset: I + BadPos + 1),
5646 CompLoc);
5647 IsDigitsError = true;
5648 }
5649
5650 if (!isDigit(ColChar)) {
5651 unsigned BadPos = IsZeroBasedAccessor ? 3 : 2;
5652 ReportMatrixInvalidMember(S, Name: StringRef(&ColChar, 1), Expected: "column as integer",
5653 OpLoc: OpLoc.getLocWithOffset(Offset: I + BadPos + 1),
5654 CompLoc);
5655 IsDigitsError = true;
5656 }
5657 if (IsDigitsError)
5658 return QualType();
5659
5660 unsigned Row = RowChar - '0';
5661 unsigned Col = ColChar - '0';
5662
5663 bool HasIndexingError = false;
5664 if (IsZeroBasedAccessor) {
5665 // 0-based [0..3]
5666 if (!isZeroBasedIndex(Row)) {
5667 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
5668 << /*row*/ 0 << /*zero-based*/ 0 << SourceRange(CompLoc);
5669 HasIndexingError = true;
5670 }
5671 if (!isZeroBasedIndex(Col)) {
5672 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
5673 << /*col*/ 1 << /*zero-based*/ 0 << SourceRange(CompLoc);
5674 HasIndexingError = true;
5675 }
5676 } else {
5677 // 1-based [1..4]
5678 if (!isOneBasedIndex(Row)) {
5679 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
5680 << /*row*/ 0 << /*one-based*/ 1 << SourceRange(CompLoc);
5681 HasIndexingError = true;
5682 }
5683 if (!isOneBasedIndex(Col)) {
5684 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_element_not_in_bounds)
5685 << /*col*/ 1 << /*one-based*/ 1 << SourceRange(CompLoc);
5686 HasIndexingError = true;
5687 }
5688 // Convert to 0-based after range checking.
5689 --Row;
5690 --Col;
5691 }
5692
5693 if (HasIndexingError)
5694 return QualType();
5695
5696 // Note: matrix swizzle index is hard coded. That means Row and Col can
5697 // potentially be larger than Rows and Cols if matrix size is less than
5698 // the max index size.
5699 bool HasBoundsError = false;
5700 if (Row >= Rows) {
5701 Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_index_out_of_bounds)
5702 << /*Row*/ 0 << Row << Rows << SourceRange(CompLoc);
5703 HasBoundsError = true;
5704 }
5705 if (Col >= Cols) {
5706 Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_index_out_of_bounds)
5707 << /*Col*/ 1 << Col << Cols << SourceRange(CompLoc);
5708 HasBoundsError = true;
5709 }
5710 if (HasBoundsError)
5711 return QualType();
5712
5713 unsigned FlatIndex = Row * Cols + Col;
5714 if (Seen[FlatIndex])
5715 HasRepeated = true;
5716 Seen[FlatIndex] = true;
5717 ++NumComponents;
5718 }
5719 if (NumComponents == 0 || NumComponents > 4) {
5720 S.Diag(Loc: OpLoc, DiagID: diag::err_hlsl_matrix_swizzle_invalid_length)
5721 << NumComponents << SourceRange(CompLoc);
5722 return QualType();
5723 }
5724
5725 QualType ElemTy = MT->getElementType();
5726 if (NumComponents == 1)
5727 return ElemTy;
5728 QualType VT = S.Context.getExtVectorType(VectorType: ElemTy, NumElts: NumComponents);
5729 if (HasRepeated)
5730 VK = VK_PRValue;
5731
5732 for (Sema::ExtVectorDeclsType::iterator
5733 I = S.ExtVectorDecls.begin(source: S.getExternalSource()),
5734 E = S.ExtVectorDecls.end();
5735 I != E; ++I) {
5736 if ((*I)->getUnderlyingType() == VT)
5737 return S.Context.getTypedefType(Keyword: ElaboratedTypeKeyword::None,
5738 /*Qualifier=*/std::nullopt, Decl: *I);
5739 }
5740
5741 return VT;
5742}
5743
5744bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
5745 // If initializing a local resource, track the resource binding it is using
5746 if (VDecl->getType()->isHLSLResourceRecord() && !VDecl->hasGlobalStorage())
5747 trackLocalResource(VD: VDecl, E: Init);
5748
5749 const HLSLVkConstantIdAttr *ConstIdAttr =
5750 VDecl->getAttr<HLSLVkConstantIdAttr>();
5751 if (!ConstIdAttr)
5752 return true;
5753
5754 ASTContext &Context = SemaRef.getASTContext();
5755
5756 APValue InitValue;
5757 if (!Init->isCXX11ConstantExpr(Ctx: Context, Result: &InitValue)) {
5758 Diag(Loc: VDecl->getLocation(), DiagID: diag::err_specialization_const);
5759 VDecl->setInvalidDecl();
5760 return false;
5761 }
5762
5763 Builtin::ID BID =
5764 getSpecConstBuiltinId(Type: VDecl->getType()->getUnqualifiedDesugaredType());
5765
5766 // Argument 1: The ID from the attribute
5767 int ConstantID = ConstIdAttr->getId();
5768 llvm::APInt IDVal(Context.getIntWidth(T: Context.IntTy), ConstantID);
5769 Expr *IdExpr = IntegerLiteral::Create(C: Context, V: IDVal, type: Context.IntTy,
5770 l: ConstIdAttr->getLocation());
5771
5772 SmallVector<Expr *, 2> Args = {IdExpr, Init};
5773 Expr *C = SemaRef.BuildBuiltinCallExpr(Loc: Init->getExprLoc(), Id: BID, CallArgs: Args);
5774 if (C->getType()->getCanonicalTypeUnqualified() !=
5775 VDecl->getType()->getCanonicalTypeUnqualified()) {
5776 C = SemaRef
5777 .BuildCStyleCastExpr(LParenLoc: SourceLocation(),
5778 Ty: Context.getTrivialTypeSourceInfo(
5779 T: Init->getType(), Loc: Init->getExprLoc()),
5780 RParenLoc: SourceLocation(), Op: C)
5781 .get();
5782 }
5783 Init = C;
5784 return true;
5785}
5786
5787QualType SemaHLSL::ActOnTemplateShorthand(TemplateDecl *Template,
5788 SourceLocation NameLoc) {
5789 if (!Template)
5790 return QualType();
5791
5792 DeclContext *DC = Template->getDeclContext();
5793 if (!DC->isNamespace() || !cast<NamespaceDecl>(Val: DC)->getIdentifier() ||
5794 cast<NamespaceDecl>(Val: DC)->getName() != "hlsl")
5795 return QualType();
5796
5797 TemplateParameterList *Params = Template->getTemplateParameters();
5798 if (!Params || Params->size() != 1)
5799 return QualType();
5800
5801 if (!Template->isImplicit())
5802 return QualType();
5803
5804 // We manually extract default arguments here instead of letting
5805 // CheckTemplateIdType handle it. This ensures that for resource types that
5806 // lack a default argument (like Buffer), we return a null QualType, which
5807 // triggers the "requires template arguments" error rather than a less
5808 // descriptive "too few template arguments" error.
5809 TemplateArgumentListInfo TemplateArgs(NameLoc, NameLoc);
5810 for (NamedDecl *P : *Params) {
5811 if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(Val: P)) {
5812 if (TTP->hasDefaultArgument()) {
5813 TemplateArgs.addArgument(Loc: TTP->getDefaultArgument());
5814 continue;
5815 }
5816 } else if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(Val: P)) {
5817 if (NTTP->hasDefaultArgument()) {
5818 TemplateArgs.addArgument(Loc: NTTP->getDefaultArgument());
5819 continue;
5820 }
5821 } else if (auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(Val: P)) {
5822 if (TTPD->hasDefaultArgument()) {
5823 TemplateArgs.addArgument(Loc: TTPD->getDefaultArgument());
5824 continue;
5825 }
5826 }
5827 return QualType();
5828 }
5829
5830 return SemaRef.CheckTemplateIdType(
5831 Keyword: ElaboratedTypeKeyword::None, Template: TemplateName(Template), TemplateLoc: NameLoc,
5832 TemplateArgs, Scope: nullptr, /*ForNestedNameSpecifier=*/false);
5833}
5834