1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
12#include "clang/AST/ASTConsumer.h"
13#include "clang/AST/ASTContext.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Attrs.inc"
16#include "clang/AST/Decl.h"
17#include "clang/AST/DeclBase.h"
18#include "clang/AST/DeclCXX.h"
19#include "clang/AST/DeclarationName.h"
20#include "clang/AST/DynamicRecursiveASTVisitor.h"
21#include "clang/AST/Expr.h"
22#include "clang/AST/Type.h"
23#include "clang/AST/TypeLoc.h"
24#include "clang/Basic/Builtins.h"
25#include "clang/Basic/DiagnosticSema.h"
26#include "clang/Basic/IdentifierTable.h"
27#include "clang/Basic/LLVM.h"
28#include "clang/Basic/SourceLocation.h"
29#include "clang/Basic/Specifiers.h"
30#include "clang/Basic/TargetInfo.h"
31#include "clang/Sema/Initialization.h"
32#include "clang/Sema/Lookup.h"
33#include "clang/Sema/ParsedAttr.h"
34#include "clang/Sema/Sema.h"
35#include "clang/Sema/Template.h"
36#include "llvm/ADT/ArrayRef.h"
37#include "llvm/ADT/STLExtras.h"
38#include "llvm/ADT/SmallVector.h"
39#include "llvm/ADT/StringExtras.h"
40#include "llvm/ADT/StringRef.h"
41#include "llvm/ADT/Twine.h"
42#include "llvm/Frontend/HLSL/HLSLRootSignatureUtils.h"
43#include "llvm/Support/Casting.h"
44#include "llvm/Support/DXILABI.h"
45#include "llvm/Support/ErrorHandling.h"
46#include "llvm/TargetParser/Triple.h"
47#include <cstddef>
48#include <iterator>
49#include <utility>
50
51using namespace clang;
52using RegisterType = HLSLResourceBindingAttr::RegisterType;
53
54static CXXRecordDecl *createHostLayoutStruct(Sema &S,
55 CXXRecordDecl *StructDecl);
56
57static RegisterType getRegisterType(ResourceClass RC) {
58 switch (RC) {
59 case ResourceClass::SRV:
60 return RegisterType::SRV;
61 case ResourceClass::UAV:
62 return RegisterType::UAV;
63 case ResourceClass::CBuffer:
64 return RegisterType::CBuffer;
65 case ResourceClass::Sampler:
66 return RegisterType::Sampler;
67 }
68 llvm_unreachable("unexpected ResourceClass value");
69}
70
71// Converts the first letter of string Slot to RegisterType.
72// Returns false if the letter does not correspond to a valid register type.
73static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
74 assert(RT != nullptr);
75 switch (Slot[0]) {
76 case 't':
77 case 'T':
78 *RT = RegisterType::SRV;
79 return true;
80 case 'u':
81 case 'U':
82 *RT = RegisterType::UAV;
83 return true;
84 case 'b':
85 case 'B':
86 *RT = RegisterType::CBuffer;
87 return true;
88 case 's':
89 case 'S':
90 *RT = RegisterType::Sampler;
91 return true;
92 case 'c':
93 case 'C':
94 *RT = RegisterType::C;
95 return true;
96 case 'i':
97 case 'I':
98 *RT = RegisterType::I;
99 return true;
100 default:
101 return false;
102 }
103}
104
105static ResourceClass getResourceClass(RegisterType RT) {
106 switch (RT) {
107 case RegisterType::SRV:
108 return ResourceClass::SRV;
109 case RegisterType::UAV:
110 return ResourceClass::UAV;
111 case RegisterType::CBuffer:
112 return ResourceClass::CBuffer;
113 case RegisterType::Sampler:
114 return ResourceClass::Sampler;
115 case RegisterType::C:
116 case RegisterType::I:
117 // Deliberately falling through to the unreachable below.
118 break;
119 }
120 llvm_unreachable("unexpected RegisterType value");
121}
122
123static Builtin::ID getSpecConstBuiltinId(const Type *Type) {
124 const auto *BT = dyn_cast<BuiltinType>(Val: Type);
125 if (!BT) {
126 if (!Type->isEnumeralType())
127 return Builtin::NotBuiltin;
128 return Builtin::BI__builtin_get_spirv_spec_constant_int;
129 }
130
131 switch (BT->getKind()) {
132 case BuiltinType::Bool:
133 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
134 case BuiltinType::Short:
135 return Builtin::BI__builtin_get_spirv_spec_constant_short;
136 case BuiltinType::Int:
137 return Builtin::BI__builtin_get_spirv_spec_constant_int;
138 case BuiltinType::LongLong:
139 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
140 case BuiltinType::UShort:
141 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
142 case BuiltinType::UInt:
143 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
144 case BuiltinType::ULongLong:
145 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
146 case BuiltinType::Half:
147 return Builtin::BI__builtin_get_spirv_spec_constant_half;
148 case BuiltinType::Float:
149 return Builtin::BI__builtin_get_spirv_spec_constant_float;
150 case BuiltinType::Double:
151 return Builtin::BI__builtin_get_spirv_spec_constant_double;
152 default:
153 return Builtin::NotBuiltin;
154 }
155}
156
157DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
158 ResourceClass ResClass) {
159 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
160 "DeclBindingInfo already added");
161 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
162 // VarDecl may have multiple entries for different resource classes.
163 // DeclToBindingListIndex stores the index of the first binding we saw
164 // for this decl. If there are any additional ones then that index
165 // shouldn't be updated.
166 DeclToBindingListIndex.try_emplace(Key: VD, Args: BindingsList.size());
167 return &BindingsList.emplace_back(Args&: VD, Args&: ResClass);
168}
169
170DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
171 ResourceClass ResClass) {
172 auto Entry = DeclToBindingListIndex.find(Val: VD);
173 if (Entry != DeclToBindingListIndex.end()) {
174 for (unsigned Index = Entry->getSecond();
175 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
176 ++Index) {
177 if (BindingsList[Index].ResClass == ResClass)
178 return &BindingsList[Index];
179 }
180 }
181 return nullptr;
182}
183
184bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
185 return DeclToBindingListIndex.contains(Val: VD);
186}
187
188SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
189
190Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
191 SourceLocation KwLoc, IdentifierInfo *Ident,
192 SourceLocation IdentLoc,
193 SourceLocation LBrace) {
194 // For anonymous namespace, take the location of the left brace.
195 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
196 HLSLBufferDecl *Result = HLSLBufferDecl::Create(
197 C&: getASTContext(), LexicalParent, CBuffer, KwLoc, ID: Ident, IDLoc: IdentLoc, LBrace);
198
199 // if CBuffer is false, then it's a TBuffer
200 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
201 : llvm::hlsl::ResourceClass::SRV;
202 Result->addAttr(A: HLSLResourceClassAttr::CreateImplicit(Ctx&: getASTContext(), ResourceClass: RC));
203
204 SemaRef.PushOnScopeChains(D: Result, S: BufferScope);
205 SemaRef.PushDeclContext(S: BufferScope, DC: Result);
206
207 return Result;
208}
209
210static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
211 QualType T) {
212 // Arrays and Structs are always aligned to new buffer rows
213 if (T->isArrayType() || T->isStructureType())
214 return 16;
215
216 // Vectors are aligned to the type they contain
217 if (const VectorType *VT = T->getAs<VectorType>())
218 return calculateLegacyCbufferFieldAlign(Context, T: VT->getElementType());
219
220 assert(Context.getTypeSize(T) <= 64 &&
221 "Scalar bit widths larger than 64 not supported");
222
223 // Scalar types are aligned to their byte width
224 return Context.getTypeSize(T) / 8;
225}
226
227// Calculate the size of a legacy cbuffer type in bytes based on
228// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
229static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
230 QualType T) {
231 constexpr unsigned CBufferAlign = 16;
232 if (const RecordType *RT = T->getAs<RecordType>()) {
233 unsigned Size = 0;
234 const RecordDecl *RD = RT->getDecl();
235 for (const FieldDecl *Field : RD->fields()) {
236 QualType Ty = Field->getType();
237 unsigned FieldSize = calculateLegacyCbufferSize(Context, T: Ty);
238 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, T: Ty);
239
240 // If the field crosses the row boundary after alignment it drops to the
241 // next row
242 unsigned AlignSize = llvm::alignTo(Value: Size, Align: FieldAlign);
243 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
244 FieldAlign = CBufferAlign;
245 }
246
247 Size = llvm::alignTo(Value: Size, Align: FieldAlign);
248 Size += FieldSize;
249 }
250 return Size;
251 }
252
253 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
254 unsigned ElementCount = AT->getSize().getZExtValue();
255 if (ElementCount == 0)
256 return 0;
257
258 unsigned ElementSize =
259 calculateLegacyCbufferSize(Context, T: AT->getElementType());
260 unsigned AlignedElementSize = llvm::alignTo(Value: ElementSize, Align: CBufferAlign);
261 return AlignedElementSize * (ElementCount - 1) + ElementSize;
262 }
263
264 if (const VectorType *VT = T->getAs<VectorType>()) {
265 unsigned ElementCount = VT->getNumElements();
266 unsigned ElementSize =
267 calculateLegacyCbufferSize(Context, T: VT->getElementType());
268 return ElementSize * ElementCount;
269 }
270
271 return Context.getTypeSize(T) / 8;
272}
273
274// Validate packoffset:
275// - if packoffset it used it must be set on all declarations inside the buffer
276// - packoffset ranges must not overlap
277static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
278 llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
279
280 // Make sure the packoffset annotations are either on all declarations
281 // or on none.
282 bool HasPackOffset = false;
283 bool HasNonPackOffset = false;
284 for (auto *Field : BufDecl->buffer_decls()) {
285 VarDecl *Var = dyn_cast<VarDecl>(Val: Field);
286 if (!Var)
287 continue;
288 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
289 PackOffsetVec.emplace_back(Args&: Var, Args: Field->getAttr<HLSLPackOffsetAttr>());
290 HasPackOffset = true;
291 } else {
292 HasNonPackOffset = true;
293 }
294 }
295
296 if (!HasPackOffset)
297 return;
298
299 if (HasNonPackOffset)
300 S.Diag(Loc: BufDecl->getLocation(), DiagID: diag::warn_hlsl_packoffset_mix);
301
302 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
303 // and compare adjacent values.
304 bool IsValid = true;
305 ASTContext &Context = S.getASTContext();
306 std::sort(first: PackOffsetVec.begin(), last: PackOffsetVec.end(),
307 comp: [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
308 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
309 return LHS.second->getOffsetInBytes() <
310 RHS.second->getOffsetInBytes();
311 });
312 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
313 VarDecl *Var = PackOffsetVec[i].first;
314 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
315 unsigned Size = calculateLegacyCbufferSize(Context, T: Var->getType());
316 unsigned Begin = Attr->getOffsetInBytes();
317 unsigned End = Begin + Size;
318 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
319 if (End > NextBegin) {
320 VarDecl *NextVar = PackOffsetVec[i + 1].first;
321 S.Diag(Loc: NextVar->getLocation(), DiagID: diag::err_hlsl_packoffset_overlap)
322 << NextVar << Var;
323 IsValid = false;
324 }
325 }
326 BufDecl->setHasValidPackoffset(IsValid);
327}
328
329// Returns true if the array has a zero size = if any of the dimensions is 0
330static bool isZeroSizedArray(const ConstantArrayType *CAT) {
331 while (CAT && !CAT->isZeroSize())
332 CAT = dyn_cast<ConstantArrayType>(
333 Val: CAT->getElementType()->getUnqualifiedDesugaredType());
334 return CAT != nullptr;
335}
336
337// Returns true if the record type is an HLSL resource class or an array of
338// resource classes
339static bool isResourceRecordTypeOrArrayOf(const Type *Ty) {
340 while (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(Val: Ty))
341 Ty = CAT->getArrayElementTypeNoTypeQual();
342 return HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty) != nullptr;
343}
344
345static bool isResourceRecordTypeOrArrayOf(VarDecl *VD) {
346 return isResourceRecordTypeOrArrayOf(Ty: VD->getType().getTypePtr());
347}
348
349// Returns true if the type is a leaf element type that is not valid to be
350// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
351// array, or a builtin intangible type. Returns false it is a valid leaf element
352// type or if it is a record type that needs to be inspected further.
353static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
354 Ty = Ty->getUnqualifiedDesugaredType();
355 if (isResourceRecordTypeOrArrayOf(Ty))
356 return true;
357 if (Ty->isRecordType())
358 return Ty->getAsCXXRecordDecl()->isEmpty();
359 if (Ty->isConstantArrayType() &&
360 isZeroSizedArray(CAT: cast<ConstantArrayType>(Val: Ty)))
361 return true;
362 if (Ty->isHLSLBuiltinIntangibleType() || Ty->isHLSLAttributedResourceType())
363 return true;
364 return false;
365}
366
367// Returns true if the struct contains at least one element that prevents it
368// from being included inside HLSL Buffer as is, such as an intangible type,
369// empty struct, or zero-sized array. If it does, a new implicit layout struct
370// needs to be created for HLSL Buffer use that will exclude these unwanted
371// declarations (see createHostLayoutStruct function).
372static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
373 if (RD->getTypeForDecl()->isHLSLIntangibleType() || RD->isEmpty())
374 return true;
375 // check fields
376 for (const FieldDecl *Field : RD->fields()) {
377 QualType Ty = Field->getType();
378 if (isInvalidConstantBufferLeafElementType(Ty: Ty.getTypePtr()))
379 return true;
380 if (Ty->isRecordType() &&
381 requiresImplicitBufferLayoutStructure(RD: Ty->getAsCXXRecordDecl()))
382 return true;
383 }
384 // check bases
385 for (const CXXBaseSpecifier &Base : RD->bases())
386 if (requiresImplicitBufferLayoutStructure(
387 RD: Base.getType()->getAsCXXRecordDecl()))
388 return true;
389 return false;
390}
391
392static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II,
393 DeclContext *DC) {
394 CXXRecordDecl *RD = nullptr;
395 for (NamedDecl *Decl :
396 DC->getNonTransparentContext()->lookup(Name: DeclarationName(II))) {
397 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Val: Decl)) {
398 assert(RD == nullptr &&
399 "there should be at most 1 record by a given name in a scope");
400 RD = FoundRD;
401 }
402 }
403 return RD;
404}
405
406// Creates a name for buffer layout struct using the provide name base.
407// If the name must be unique (not previously defined), a suffix is added
408// until a unique name is found.
409static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl,
410 bool MustBeUnique) {
411 ASTContext &AST = S.getASTContext();
412
413 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
414 llvm::SmallString<64> Name("__cblayout_");
415 if (NameBaseII) {
416 Name.append(RHS: NameBaseII->getName());
417 } else {
418 // anonymous struct
419 Name.append(RHS: "anon");
420 MustBeUnique = true;
421 }
422
423 size_t NameLength = Name.size();
424 IdentifierInfo *II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
425 if (!MustBeUnique)
426 return II;
427
428 unsigned suffix = 0;
429 while (true) {
430 if (suffix != 0) {
431 Name.append(RHS: "_");
432 Name.append(RHS: llvm::Twine(suffix).str());
433 II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
434 }
435 if (!findRecordDeclInContext(II, DC: BaseDecl->getDeclContext()))
436 return II;
437 // declaration with that name already exists - increment suffix and try
438 // again until unique name is found
439 suffix++;
440 Name.truncate(N: NameLength);
441 };
442}
443
444// Creates a field declaration of given name and type for HLSL buffer layout
445// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
446static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
447 IdentifierInfo *II,
448 CXXRecordDecl *LayoutStruct) {
449 if (isInvalidConstantBufferLeafElementType(Ty))
450 return nullptr;
451
452 if (Ty->isRecordType()) {
453 CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
454 if (requiresImplicitBufferLayoutStructure(RD)) {
455 RD = createHostLayoutStruct(S, StructDecl: RD);
456 if (!RD)
457 return nullptr;
458 Ty = RD->getTypeForDecl();
459 }
460 }
461
462 QualType QT = QualType(Ty, 0);
463 ASTContext &AST = S.getASTContext();
464 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(T: QT, Loc: SourceLocation());
465 auto *Field = FieldDecl::Create(C: AST, DC: LayoutStruct, StartLoc: SourceLocation(),
466 IdLoc: SourceLocation(), Id: II, T: QT, TInfo: TSI, BW: nullptr, Mutable: false,
467 InitStyle: InClassInitStyle::ICIS_NoInit);
468 Field->setAccess(AccessSpecifier::AS_public);
469 return Field;
470}
471
472// Creates host layout struct for a struct included in HLSL Buffer.
473// The layout struct will include only fields that are allowed in HLSL buffer.
474// These fields will be filtered out:
475// - resource classes
476// - empty structs
477// - zero-sized arrays
478// Returns nullptr if the resulting layout struct would be empty.
479static CXXRecordDecl *createHostLayoutStruct(Sema &S,
480 CXXRecordDecl *StructDecl) {
481 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
482 "struct is already HLSL buffer compatible");
483
484 ASTContext &AST = S.getASTContext();
485 DeclContext *DC = StructDecl->getDeclContext();
486 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: StructDecl, MustBeUnique: false);
487
488 // reuse existing if the layout struct if it already exists
489 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
490 return RD;
491
492 CXXRecordDecl *LS =
493 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC, StartLoc: SourceLocation(),
494 IdLoc: SourceLocation(), Id: II);
495 LS->setImplicit(true);
496 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
497 LS->startDefinition();
498
499 // copy base struct, create HLSL Buffer compatible version if needed
500 if (unsigned NumBases = StructDecl->getNumBases()) {
501 assert(NumBases == 1 && "HLSL supports only one base type");
502 (void)NumBases;
503 CXXBaseSpecifier Base = *StructDecl->bases_begin();
504 CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
505 if (requiresImplicitBufferLayoutStructure(RD: BaseDecl)) {
506 BaseDecl = createHostLayoutStruct(S, StructDecl: BaseDecl);
507 if (BaseDecl) {
508 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(
509 T: QualType(BaseDecl->getTypeForDecl(), 0));
510 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
511 AS_none, TSI, SourceLocation());
512 }
513 }
514 if (BaseDecl) {
515 const CXXBaseSpecifier *BasesArray[1] = {&Base};
516 LS->setBases(Bases: BasesArray, NumBases: 1);
517 }
518 }
519
520 // filter struct fields
521 for (const FieldDecl *FD : StructDecl->fields()) {
522 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
523 if (FieldDecl *NewFD =
524 createFieldForHostLayoutStruct(S, Ty, II: FD->getIdentifier(), LayoutStruct: LS))
525 LS->addDecl(D: NewFD);
526 }
527 LS->completeDefinition();
528
529 if (LS->field_empty() && LS->getNumBases() == 0)
530 return nullptr;
531
532 DC->addDecl(D: LS);
533 return LS;
534}
535
536// Creates host layout struct for HLSL Buffer. The struct will include only
537// fields of types that are allowed in HLSL buffer and it will filter out:
538// - static or groupshared variable declarations
539// - resource classes
540// - empty structs
541// - zero-sized arrays
542// - non-variable declarations
543// The layout struct will be added to the HLSLBufferDecl declarations.
544void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
545 ASTContext &AST = S.getASTContext();
546 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: BufDecl, MustBeUnique: true);
547
548 CXXRecordDecl *LS =
549 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC: BufDecl,
550 StartLoc: SourceLocation(), IdLoc: SourceLocation(), Id: II);
551 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
552 LS->setImplicit(true);
553 LS->startDefinition();
554
555 for (Decl *D : BufDecl->buffer_decls()) {
556 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
557 if (!VD || VD->getStorageClass() == SC_Static ||
558 VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
559 continue;
560 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
561 if (FieldDecl *FD =
562 createFieldForHostLayoutStruct(S, Ty, II: VD->getIdentifier(), LayoutStruct: LS)) {
563 // add the field decl to the layout struct
564 LS->addDecl(D: FD);
565 // update address space of the original decl to hlsl_constant
566 QualType NewTy =
567 AST.getAddrSpaceQualType(T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
568 VD->setType(NewTy);
569 }
570 }
571 LS->completeDefinition();
572 BufDecl->addLayoutStruct(LS);
573}
574
575static void addImplicitBindingAttrToBuffer(Sema &S, HLSLBufferDecl *BufDecl,
576 uint32_t ImplicitBindingOrderID) {
577 RegisterType RT =
578 BufDecl->isCBuffer() ? RegisterType::CBuffer : RegisterType::SRV;
579 auto *Attr =
580 HLSLResourceBindingAttr::CreateImplicit(Ctx&: S.getASTContext(), Slot: "", Space: "0", Range: {});
581 std::optional<unsigned> RegSlot;
582 Attr->setBinding(RT, SlotNum: RegSlot, SpaceNum: 0);
583 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
584 BufDecl->addAttr(A: Attr);
585}
586
587// Handle end of cbuffer/tbuffer declaration
588void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
589 auto *BufDecl = cast<HLSLBufferDecl>(Val: Dcl);
590 BufDecl->setRBraceLoc(RBrace);
591
592 validatePackoffset(S&: SemaRef, BufDecl);
593
594 // create buffer layout struct
595 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl);
596
597 HLSLResourceBindingAttr *RBA = Dcl->getAttr<HLSLResourceBindingAttr>();
598 if (!RBA || !RBA->hasRegisterSlot()) {
599 SemaRef.Diag(Loc: Dcl->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
600 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
601 // to codegen. If it does not exist, create an implicit attribute.
602 uint32_t OrderID = getNextImplicitBindingOrderID();
603 if (RBA)
604 RBA->setImplicitBindingOrderID(OrderID);
605 else
606 addImplicitBindingAttrToBuffer(S&: SemaRef, BufDecl, ImplicitBindingOrderID: OrderID);
607 }
608
609 SemaRef.PopDeclContext();
610}
611
612HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
613 const AttributeCommonInfo &AL,
614 int X, int Y, int Z) {
615 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
616 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
617 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
618 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
619 }
620 return nullptr;
621 }
622 return ::new (getASTContext())
623 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
624}
625
626HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
627 const AttributeCommonInfo &AL,
628 int Min, int Max, int Preferred,
629 int SpelledArgsCount) {
630 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
631 if (WS->getMin() != Min || WS->getMax() != Max ||
632 WS->getPreferred() != Preferred ||
633 WS->getSpelledArgsCount() != SpelledArgsCount) {
634 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
635 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
636 }
637 return nullptr;
638 }
639 HLSLWaveSizeAttr *Result = ::new (getASTContext())
640 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
641 Result->setSpelledArgsCount(SpelledArgsCount);
642 return Result;
643}
644
645HLSLVkConstantIdAttr *
646SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
647 int Id) {
648
649 auto &TargetInfo = getASTContext().getTargetInfo();
650 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
651 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_ignored) << AL;
652 return nullptr;
653 }
654
655 auto *VD = cast<VarDecl>(Val: D);
656
657 if (getSpecConstBuiltinId(Type: VD->getType()->getUnqualifiedDesugaredType()) ==
658 Builtin::NotBuiltin) {
659 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
660 return nullptr;
661 }
662
663 if (!VD->getType().isConstQualified()) {
664 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
665 return nullptr;
666 }
667
668 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
669 if (CI->getId() != Id) {
670 Diag(Loc: CI->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
671 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
672 }
673 return nullptr;
674 }
675
676 HLSLVkConstantIdAttr *Result =
677 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
678 return Result;
679}
680
681HLSLShaderAttr *
682SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
683 llvm::Triple::EnvironmentType ShaderType) {
684 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
685 if (NT->getType() != ShaderType) {
686 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
687 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
688 }
689 return nullptr;
690 }
691 return HLSLShaderAttr::Create(Ctx&: getASTContext(), Type: ShaderType, CommonInfo: AL);
692}
693
694HLSLParamModifierAttr *
695SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
696 HLSLParamModifierAttr::Spelling Spelling) {
697 // We can only merge an `in` attribute with an `out` attribute. All other
698 // combinations of duplicated attributes are ill-formed.
699 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
700 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
701 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
702 D->dropAttr<HLSLParamModifierAttr>();
703 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
704 return HLSLParamModifierAttr::Create(
705 Ctx&: getASTContext(), /*MergedSpelling=*/true, Range: AdjustedRange,
706 S: HLSLParamModifierAttr::Keyword_inout);
707 }
708 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_duplicate_parameter_modifier) << AL;
709 Diag(Loc: PA->getLocation(), DiagID: diag::note_conflicting_attribute);
710 return nullptr;
711 }
712 return HLSLParamModifierAttr::Create(Ctx&: getASTContext(), CommonInfo: AL);
713}
714
715void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
716 auto &TargetInfo = getASTContext().getTargetInfo();
717
718 if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
719 return;
720
721 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
722 if (HLSLShaderAttr::isValidShaderType(ShaderType: Env) && Env != llvm::Triple::Library) {
723 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
724 // The entry point is already annotated - check that it matches the
725 // triple.
726 if (Shader->getType() != Env) {
727 Diag(Loc: Shader->getLocation(), DiagID: diag::err_hlsl_entry_shader_attr_mismatch)
728 << Shader;
729 FD->setInvalidDecl();
730 }
731 } else {
732 // Implicitly add the shader attribute if the entry function isn't
733 // explicitly annotated.
734 FD->addAttr(A: HLSLShaderAttr::CreateImplicit(Ctx&: getASTContext(), Type: Env,
735 Range: FD->getBeginLoc()));
736 }
737 } else {
738 switch (Env) {
739 case llvm::Triple::UnknownEnvironment:
740 case llvm::Triple::Library:
741 break;
742 default:
743 llvm_unreachable("Unhandled environment in triple");
744 }
745 }
746}
747
748void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
749 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
750 assert(ShaderAttr && "Entry point has no shader attribute");
751 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
752 auto &TargetInfo = getASTContext().getTargetInfo();
753 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
754 switch (ST) {
755 case llvm::Triple::Pixel:
756 case llvm::Triple::Vertex:
757 case llvm::Triple::Geometry:
758 case llvm::Triple::Hull:
759 case llvm::Triple::Domain:
760 case llvm::Triple::RayGeneration:
761 case llvm::Triple::Intersection:
762 case llvm::Triple::AnyHit:
763 case llvm::Triple::ClosestHit:
764 case llvm::Triple::Miss:
765 case llvm::Triple::Callable:
766 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
767 DiagnoseAttrStageMismatch(A: NT, Stage: ST,
768 AllowedStages: {llvm::Triple::Compute,
769 llvm::Triple::Amplification,
770 llvm::Triple::Mesh});
771 FD->setInvalidDecl();
772 }
773 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
774 DiagnoseAttrStageMismatch(A: WS, Stage: ST,
775 AllowedStages: {llvm::Triple::Compute,
776 llvm::Triple::Amplification,
777 llvm::Triple::Mesh});
778 FD->setInvalidDecl();
779 }
780 break;
781
782 case llvm::Triple::Compute:
783 case llvm::Triple::Amplification:
784 case llvm::Triple::Mesh:
785 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
786 Diag(Loc: FD->getLocation(), DiagID: diag::err_hlsl_missing_numthreads)
787 << llvm::Triple::getEnvironmentTypeName(Kind: ST);
788 FD->setInvalidDecl();
789 }
790 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
791 if (Ver < VersionTuple(6, 6)) {
792 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_in_wrong_shader_model)
793 << WS << "6.6";
794 FD->setInvalidDecl();
795 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
796 Diag(
797 Loc: WS->getLocation(),
798 DiagID: diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
799 << WS << WS->getSpelledArgsCount() << "6.8";
800 FD->setInvalidDecl();
801 }
802 }
803 break;
804 default:
805 llvm_unreachable("Unhandled environment in triple");
806 }
807
808 for (ParmVarDecl *Param : FD->parameters()) {
809 if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
810 CheckSemanticAnnotation(EntryPoint: FD, Param, AnnotationAttr);
811 } else {
812 // FIXME: Handle struct parameters where annotations are on struct fields.
813 // See: https://github.com/llvm/llvm-project/issues/57875
814 Diag(Loc: FD->getLocation(), DiagID: diag::err_hlsl_missing_semantic_annotation);
815 Diag(Loc: Param->getLocation(), DiagID: diag::note_previous_decl) << Param;
816 FD->setInvalidDecl();
817 }
818 }
819 // FIXME: Verify return type semantic annotation.
820}
821
822void SemaHLSL::CheckSemanticAnnotation(
823 FunctionDecl *EntryPoint, const Decl *Param,
824 const HLSLAnnotationAttr *AnnotationAttr) {
825 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
826 assert(ShaderAttr && "Entry point has no shader attribute");
827 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
828
829 switch (AnnotationAttr->getKind()) {
830 case attr::HLSLSV_DispatchThreadID:
831 case attr::HLSLSV_GroupIndex:
832 case attr::HLSLSV_GroupThreadID:
833 case attr::HLSLSV_GroupID:
834 if (ST == llvm::Triple::Compute)
835 return;
836 DiagnoseAttrStageMismatch(A: AnnotationAttr, Stage: ST, AllowedStages: {llvm::Triple::Compute});
837 break;
838 case attr::HLSLSV_Position:
839 // TODO(#143523): allow use on other shader types & output once the overall
840 // semantic logic is implemented.
841 if (ST == llvm::Triple::Pixel)
842 return;
843 DiagnoseAttrStageMismatch(A: AnnotationAttr, Stage: ST, AllowedStages: {llvm::Triple::Pixel});
844 break;
845 default:
846 llvm_unreachable("Unknown HLSLAnnotationAttr");
847 }
848}
849
850void SemaHLSL::DiagnoseAttrStageMismatch(
851 const Attr *A, llvm::Triple::EnvironmentType Stage,
852 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
853 SmallVector<StringRef, 8> StageStrings;
854 llvm::transform(Range&: AllowedStages, d_first: std::back_inserter(x&: StageStrings),
855 F: [](llvm::Triple::EnvironmentType ST) {
856 return StringRef(
857 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: ST));
858 });
859 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
860 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
861 << (AllowedStages.size() != 1) << join(R&: StageStrings, Separator: ", ");
862}
863
864template <CastKind Kind>
865static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
866 if (const auto *VTy = Ty->getAs<VectorType>())
867 Ty = VTy->getElementType();
868 Ty = S.getASTContext().getExtVectorType(VectorType: Ty, NumElts: Sz);
869 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
870}
871
872template <CastKind Kind>
873static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
874 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
875 return Ty;
876}
877
878static QualType handleFloatVectorBinOpConversion(
879 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
880 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
881 bool LHSFloat = LElTy->isRealFloatingType();
882 bool RHSFloat = RElTy->isRealFloatingType();
883
884 if (LHSFloat && RHSFloat) {
885 if (IsCompAssign ||
886 SemaRef.getASTContext().getFloatingTypeOrder(LHS: LElTy, RHS: RElTy) > 0)
887 return castElement<CK_FloatingCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
888
889 return castElement<CK_FloatingCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
890 }
891
892 if (LHSFloat)
893 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: RHS, Ty: LHSType);
894
895 assert(RHSFloat);
896 if (IsCompAssign)
897 return castElement<clang::CK_FloatingToIntegral>(S&: SemaRef, E&: RHS, Ty: LHSType);
898
899 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: LHS, Ty: RHSType);
900}
901
902static QualType handleIntegerVectorBinOpConversion(
903 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
904 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
905
906 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LHS: LElTy, RHS: RElTy);
907 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
908 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
909 auto &Ctx = SemaRef.getASTContext();
910
911 // If both types have the same signedness, use the higher ranked type.
912 if (LHSSigned == RHSSigned) {
913 if (IsCompAssign || IntOrder >= 0)
914 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
915
916 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
917 }
918
919 // If the unsigned type has greater than or equal rank of the signed type, use
920 // the unsigned type.
921 if (IntOrder != (LHSSigned ? 1 : -1)) {
922 if (IsCompAssign || RHSSigned)
923 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
924 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
925 }
926
927 // At this point the signed type has higher rank than the unsigned type, which
928 // means it will be the same size or bigger. If the signed type is bigger, it
929 // can represent all the values of the unsigned type, so select it.
930 if (Ctx.getIntWidth(T: LElTy) != Ctx.getIntWidth(T: RElTy)) {
931 if (IsCompAssign || LHSSigned)
932 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
933 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
934 }
935
936 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
937 // to C/C++ leaking through. The place this happens today is long vs long
938 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
939 // the long long has higher rank than long even though they are the same size.
940
941 // If this is a compound assignment cast the right hand side to the left hand
942 // side's type.
943 if (IsCompAssign)
944 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
945
946 // If this isn't a compound assignment we convert to unsigned long long.
947 QualType ElTy = Ctx.getCorrespondingUnsignedType(T: LHSSigned ? LElTy : RElTy);
948 QualType NewTy = Ctx.getExtVectorType(
949 VectorType: ElTy, NumElts: RHSType->castAs<VectorType>()->getNumElements());
950 (void)castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: NewTy);
951
952 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: NewTy);
953}
954
955static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
956 QualType SrcTy) {
957 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
958 return CK_FloatingCast;
959 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
960 return CK_IntegralCast;
961 if (DestTy->isRealFloatingType())
962 return CK_IntegralToFloating;
963 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
964 return CK_FloatingToIntegral;
965}
966
967QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
968 QualType LHSType,
969 QualType RHSType,
970 bool IsCompAssign) {
971 const auto *LVecTy = LHSType->getAs<VectorType>();
972 const auto *RVecTy = RHSType->getAs<VectorType>();
973 auto &Ctx = getASTContext();
974
975 // If the LHS is not a vector and this is a compound assignment, we truncate
976 // the argument to a scalar then convert it to the LHS's type.
977 if (!LVecTy && IsCompAssign) {
978 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
979 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: RElTy, CK: CK_HLSLVectorTruncation);
980 RHSType = RHS.get()->getType();
981 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
982 return LHSType;
983 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: LHSType,
984 CK: getScalarCastKind(Ctx, DestTy: LHSType, SrcTy: RHSType));
985 return LHSType;
986 }
987
988 unsigned EndSz = std::numeric_limits<unsigned>::max();
989 unsigned LSz = 0;
990 if (LVecTy)
991 LSz = EndSz = LVecTy->getNumElements();
992 if (RVecTy)
993 EndSz = std::min(a: RVecTy->getNumElements(), b: EndSz);
994 assert(EndSz != std::numeric_limits<unsigned>::max() &&
995 "one of the above should have had a value");
996
997 // In a compound assignment, the left operand does not change type, the right
998 // operand is converted to the type of the left operand.
999 if (IsCompAssign && LSz != EndSz) {
1000 Diag(Loc: LHS.get()->getBeginLoc(),
1001 DiagID: diag::err_hlsl_vector_compound_assignment_truncation)
1002 << LHSType << RHSType;
1003 return QualType();
1004 }
1005
1006 if (RVecTy && RVecTy->getNumElements() > EndSz)
1007 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1008 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1009 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1010
1011 if (!RVecTy)
1012 castVector<CK_VectorSplat>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1013 if (!IsCompAssign && !LVecTy)
1014 castVector<CK_VectorSplat>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1015
1016 // If we're at the same type after resizing we can stop here.
1017 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1018 return Ctx.getCommonSugaredType(X: LHSType, Y: RHSType);
1019
1020 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1021 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1022
1023 // Handle conversion for floating point vectors.
1024 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1025 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1026 LElTy, RElTy, IsCompAssign);
1027
1028 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1029 "HLSL Vectors can only contain integer or floating point types");
1030 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1031 LElTy, RElTy, IsCompAssign);
1032}
1033
1034void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
1035 BinaryOperatorKind Opc) {
1036 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1037 "Called with non-logical operator");
1038 llvm::SmallVector<char, 256> Buff;
1039 llvm::raw_svector_ostream OS(Buff);
1040 PrintingPolicy PP(SemaRef.getLangOpts());
1041 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1042 OS << NewFnName << "(";
1043 LHS->printPretty(OS, Helper: nullptr, Policy: PP);
1044 OS << ", ";
1045 RHS->printPretty(OS, Helper: nullptr, Policy: PP);
1046 OS << ")";
1047 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1048 SemaRef.Diag(Loc: LHS->getBeginLoc(), DiagID: diag::note_function_suggestion)
1049 << NewFnName << FixItHint::CreateReplacement(RemoveRange: FullRange, Code: OS.str());
1050}
1051
1052std::pair<IdentifierInfo *, bool>
1053SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) {
1054 llvm::hash_code Hash = llvm::hash_value(S: Signature);
1055 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(val: Hash);
1056 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(Name: IdStr));
1057
1058 // Check if we have already found a decl of the same name.
1059 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1060 Sema::LookupOrdinaryName);
1061 bool Found = SemaRef.LookupQualifiedName(R, LookupCtx: SemaRef.CurContext);
1062 return {DeclIdent, Found};
1063}
1064
1065void SemaHLSL::ActOnFinishRootSignatureDecl(
1066 SourceLocation Loc, IdentifierInfo *DeclIdent,
1067 SmallVector<llvm::hlsl::rootsig::RootElement> &Elements) {
1068
1069 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1070 C&: SemaRef.getASTContext(), /*DeclContext=*/DC: SemaRef.CurContext, Loc,
1071 ID: DeclIdent, Version: SemaRef.getLangOpts().HLSLRootSigVer, RootElements: Elements);
1072
1073 if (handleRootSignatureDecl(D: SignatureDecl, Loc))
1074 return;
1075
1076 SignatureDecl->setImplicit();
1077 SemaRef.PushOnScopeChains(D: SignatureDecl, S: SemaRef.getCurScope());
1078}
1079
1080bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D,
1081 SourceLocation Loc) {
1082 // The following conducts analysis on resource ranges to detect and report
1083 // any overlaps in resource ranges.
1084 //
1085 // A resource range overlaps with another resource range if they have:
1086 // - equivalent ResourceClass (SRV, UAV, CBuffer, Sampler)
1087 // - equivalent resource space
1088 // - overlapping visbility
1089 //
1090 // The following algorithm is implemented in the following steps:
1091 //
1092 // 1. Collect RangeInfo from relevant RootElements:
1093 // - RangeInfo will retain the interval, ResourceClass, Space and Visibility
1094 // 2. Sort the RangeInfo's such that they are grouped together by
1095 // ResourceClass and Space (GroupT defined below)
1096 // 3. Iterate through the collected RangeInfos by their groups
1097 // - For each group we will have a ResourceRange for each visibility
1098 // - As we iterate through we will:
1099 // A: Insert the current RangeInfo into the corresponding Visibility
1100 // ResourceRange
1101 // B: Check for overlap with any overlapping Visibility ResourceRange
1102 using RangeInfo = llvm::hlsl::rootsig::RangeInfo;
1103 using ResourceRange = llvm::hlsl::rootsig::ResourceRange;
1104 using GroupT = std::pair<ResourceClass, /*Space*/ uint32_t>;
1105
1106 // 1. Collect RangeInfos
1107 llvm::SmallVector<RangeInfo> Infos;
1108 for (const llvm::hlsl::rootsig::RootElement &Elem : D->getRootElements()) {
1109 if (const auto *Descriptor =
1110 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1111 RangeInfo Info;
1112 Info.LowerBound = Descriptor->Reg.Number;
1113 Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1114
1115 Info.Class =
1116 llvm::dxil::ResourceClass(llvm::to_underlying(E: Descriptor->Type));
1117 Info.Space = Descriptor->Space;
1118 Info.Visibility = Descriptor->Visibility;
1119 Infos.push_back(Elt: Info);
1120 } else if (const auto *Constants =
1121 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1122 RangeInfo Info;
1123 Info.LowerBound = Constants->Reg.Number;
1124 Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1125
1126 Info.Class = llvm::dxil::ResourceClass::CBuffer;
1127 Info.Space = Constants->Space;
1128 Info.Visibility = Constants->Visibility;
1129 Infos.push_back(Elt: Info);
1130 } else if (const auto *Sampler =
1131 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1132 RangeInfo Info;
1133 Info.LowerBound = Sampler->Reg.Number;
1134 Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1135
1136 Info.Class = llvm::dxil::ResourceClass::Sampler;
1137 Info.Space = Sampler->Space;
1138 Info.Visibility = Sampler->Visibility;
1139 Infos.push_back(Elt: Info);
1140 } else if (const auto *Clause =
1141 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1142 ptr: &Elem)) {
1143 RangeInfo Info;
1144 Info.LowerBound = Clause->Reg.Number;
1145 assert(0 < Clause->NumDescriptors && "Verified as part of TODO(#129940)");
1146 Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded
1147 ? RangeInfo::Unbounded
1148 : Info.LowerBound + Clause->NumDescriptors -
1149 1; // use inclusive ranges []
1150
1151 Info.Class = Clause->Type;
1152 Info.Space = Clause->Space;
1153 // Note: Clause does not hold the visibility this will need to
1154 Infos.push_back(Elt: Info);
1155 } else if (const auto *Table =
1156 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(ptr: &Elem)) {
1157 // Table holds the Visibility of all owned Clauses in Table, so iterate
1158 // owned Clauses and update their corresponding RangeInfo
1159 assert(Table->NumClauses <= Infos.size() && "RootElement");
1160 // The last Table->NumClauses elements of Infos are the owned Clauses
1161 // generated RangeInfo
1162 auto TableInfos =
1163 MutableArrayRef<RangeInfo>(Infos).take_back(N: Table->NumClauses);
1164 for (RangeInfo &Info : TableInfos)
1165 Info.Visibility = Table->Visibility;
1166 }
1167 }
1168
1169 // 2. Sort the RangeInfo's by their GroupT to form groupings
1170 std::sort(first: Infos.begin(), last: Infos.end(), comp: [](RangeInfo A, RangeInfo B) {
1171 return std::tie(args&: A.Class, args&: A.Space) < std::tie(args&: B.Class, args&: B.Space);
1172 });
1173
1174 // 3. First we will init our state to track:
1175 if (Infos.size() == 0)
1176 return false; // No ranges to overlap
1177 GroupT CurGroup = {Infos[0].Class, Infos[0].Space};
1178 bool HadOverlap = false;
1179
1180 // Create a ResourceRange for each Visibility
1181 ResourceRange::MapT::Allocator Allocator;
1182 std::array<ResourceRange, 8> Ranges = {
1183 ResourceRange(Allocator), // All
1184 ResourceRange(Allocator), // Vertex
1185 ResourceRange(Allocator), // Hull
1186 ResourceRange(Allocator), // Domain
1187 ResourceRange(Allocator), // Geometry
1188 ResourceRange(Allocator), // Pixel
1189 ResourceRange(Allocator), // Amplification
1190 ResourceRange(Allocator), // Mesh
1191 };
1192
1193 // Reset the ResourceRanges for when we iterate through a new group
1194 auto ClearRanges = [&Ranges]() {
1195 for (ResourceRange &Range : Ranges)
1196 Range.clear();
1197 };
1198
1199 // Helper to report diagnostics
1200 auto ReportOverlap = [this, Loc, &HadOverlap](const RangeInfo *Info,
1201 const RangeInfo *OInfo) {
1202 HadOverlap = true;
1203 auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All
1204 ? OInfo->Visibility
1205 : Info->Visibility;
1206 this->Diag(Loc, DiagID: diag::err_hlsl_resource_range_overlap)
1207 << llvm::to_underlying(E: Info->Class) << Info->LowerBound
1208 << /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded)
1209 << Info->UpperBound << llvm::to_underlying(E: OInfo->Class)
1210 << OInfo->LowerBound
1211 << /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded)
1212 << OInfo->UpperBound << Info->Space << CommonVis;
1213 };
1214
1215 // 3: Iterate through collected RangeInfos
1216 for (const RangeInfo &Info : Infos) {
1217 GroupT InfoGroup = {Info.Class, Info.Space};
1218 // Reset our ResourceRanges when we enter a new group
1219 if (CurGroup != InfoGroup) {
1220 ClearRanges();
1221 CurGroup = InfoGroup;
1222 }
1223
1224 // 3A: Insert range info into corresponding Visibility ResourceRange
1225 ResourceRange &VisRange = Ranges[llvm::to_underlying(E: Info.Visibility)];
1226 if (std::optional<const RangeInfo *> Overlapping = VisRange.insert(Info))
1227 ReportOverlap(&Info, Overlapping.value());
1228
1229 // 3B: Check for overlap in all overlapping Visibility ResourceRanges
1230 //
1231 // If the range that we are inserting has ShaderVisiblity::All it needs to
1232 // check for an overlap in all other visibility types as well.
1233 // Otherwise, the range that is inserted needs to check that it does not
1234 // overlap with ShaderVisibility::All.
1235 //
1236 // OverlapRanges will be an ArrayRef to all non-all visibility
1237 // ResourceRanges in the former case and it will be an ArrayRef to just the
1238 // all visiblity ResourceRange in the latter case.
1239 ArrayRef<ResourceRange> OverlapRanges =
1240 Info.Visibility == llvm::dxbc::ShaderVisibility::All
1241 ? ArrayRef<ResourceRange>{Ranges}.drop_front()
1242 : ArrayRef<ResourceRange>{Ranges}.take_front();
1243
1244 for (const ResourceRange &Range : OverlapRanges)
1245 if (std::optional<const RangeInfo *> Overlapping =
1246 Range.getOverlapping(Info))
1247 ReportOverlap(&Info, Overlapping.value());
1248 }
1249
1250 return HadOverlap;
1251}
1252
1253void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
1254 if (AL.getNumArgs() != 1) {
1255 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
1256 return;
1257 }
1258
1259 IdentifierInfo *Ident = AL.getArgAsIdent(Arg: 0)->getIdentifierInfo();
1260 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1261 if (RS->getSignatureIdent() != Ident) {
1262 Diag(Loc: AL.getLoc(), DiagID: diag::err_disallowed_duplicate_attribute) << RS;
1263 return;
1264 }
1265
1266 Diag(Loc: AL.getLoc(), DiagID: diag::warn_duplicate_attribute_exact) << RS;
1267 return;
1268 }
1269
1270 LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName);
1271 if (SemaRef.LookupQualifiedName(R, LookupCtx: D->getDeclContext()))
1272 if (auto *SignatureDecl =
1273 dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl())) {
1274 D->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
1275 getASTContext(), AL, Ident, SignatureDecl));
1276 }
1277}
1278
1279void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
1280 llvm::VersionTuple SMVersion =
1281 getASTContext().getTargetInfo().getTriple().getOSVersion();
1282 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1283 llvm::Triple::dxil;
1284
1285 uint32_t ZMax = 1024;
1286 uint32_t ThreadMax = 1024;
1287 if (IsDXIL && SMVersion.getMajor() <= 4) {
1288 ZMax = 1;
1289 ThreadMax = 768;
1290 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1291 ZMax = 64;
1292 ThreadMax = 1024;
1293 }
1294
1295 uint32_t X;
1296 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: X))
1297 return;
1298 if (X > 1024) {
1299 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1300 DiagID: diag::err_hlsl_numthreads_argument_oor)
1301 << 0 << 1024;
1302 return;
1303 }
1304 uint32_t Y;
1305 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Y))
1306 return;
1307 if (Y > 1024) {
1308 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1309 DiagID: diag::err_hlsl_numthreads_argument_oor)
1310 << 1 << 1024;
1311 return;
1312 }
1313 uint32_t Z;
1314 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Z))
1315 return;
1316 if (Z > ZMax) {
1317 SemaRef.Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1318 DiagID: diag::err_hlsl_numthreads_argument_oor)
1319 << 2 << ZMax;
1320 return;
1321 }
1322
1323 if (X * Y * Z > ThreadMax) {
1324 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_numthreads_invalid) << ThreadMax;
1325 return;
1326 }
1327
1328 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1329 if (NewAttr)
1330 D->addAttr(A: NewAttr);
1331}
1332
1333static bool isValidWaveSizeValue(unsigned Value) {
1334 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1335}
1336
1337void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
1338 // validate that the wavesize argument is a power of 2 between 4 and 128
1339 // inclusive
1340 unsigned SpelledArgsCount = AL.getNumArgs();
1341 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1342 return;
1343
1344 uint32_t Min;
1345 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Min))
1346 return;
1347
1348 uint32_t Max = 0;
1349 if (SpelledArgsCount > 1 &&
1350 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Max))
1351 return;
1352
1353 uint32_t Preferred = 0;
1354 if (SpelledArgsCount > 2 &&
1355 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Preferred))
1356 return;
1357
1358 if (SpelledArgsCount > 2) {
1359 if (!isValidWaveSizeValue(Value: Preferred)) {
1360 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1361 DiagID: diag::err_attribute_power_of_two_in_range)
1362 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1363 << Preferred;
1364 return;
1365 }
1366 // Preferred not in range.
1367 if (Preferred < Min || Preferred > Max) {
1368 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1369 DiagID: diag::err_attribute_power_of_two_in_range)
1370 << AL << Min << Max << Preferred;
1371 return;
1372 }
1373 } else if (SpelledArgsCount > 1) {
1374 if (!isValidWaveSizeValue(Value: Max)) {
1375 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1376 DiagID: diag::err_attribute_power_of_two_in_range)
1377 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1378 return;
1379 }
1380 if (Max < Min) {
1381 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_invalid) << AL << 1;
1382 return;
1383 } else if (Max == Min) {
1384 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attr_min_eq_max) << AL;
1385 }
1386 } else {
1387 if (!isValidWaveSizeValue(Value: Min)) {
1388 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1389 DiagID: diag::err_attribute_power_of_two_in_range)
1390 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1391 return;
1392 }
1393 }
1394
1395 HLSLWaveSizeAttr *NewAttr =
1396 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1397 if (NewAttr)
1398 D->addAttr(A: NewAttr);
1399}
1400
1401void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
1402 uint32_t ID;
1403 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: ID))
1404 return;
1405 D->addAttr(A: ::new (getASTContext())
1406 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1407}
1408
1409void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1410 uint32_t Id;
1411 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Id))
1412 return;
1413 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1414 if (NewAttr)
1415 D->addAttr(A: NewAttr);
1416}
1417
1418bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
1419 const auto *VT = T->getAs<VectorType>();
1420
1421 if (!T->hasUnsignedIntegerRepresentation() ||
1422 (VT && VT->getNumElements() > 3)) {
1423 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1424 << AL << "uint/uint2/uint3";
1425 return false;
1426 }
1427
1428 return true;
1429}
1430
1431void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
1432 auto *VD = cast<ValueDecl>(Val: D);
1433 if (!diagnoseInputIDType(T: VD->getType(), AL))
1434 return;
1435
1436 D->addAttr(A: ::new (getASTContext())
1437 HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
1438}
1439
1440bool SemaHLSL::diagnosePositionType(QualType T, const ParsedAttr &AL) {
1441 const auto *VT = T->getAs<VectorType>();
1442
1443 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1444 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1445 << AL << "float/float1/float2/float3/float4";
1446 return false;
1447 }
1448
1449 return true;
1450}
1451
1452void SemaHLSL::handleSV_PositionAttr(Decl *D, const ParsedAttr &AL) {
1453 auto *VD = cast<ValueDecl>(Val: D);
1454 if (!diagnosePositionType(T: VD->getType(), AL))
1455 return;
1456
1457 D->addAttr(A: ::new (getASTContext()) HLSLSV_PositionAttr(getASTContext(), AL));
1458}
1459
1460void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) {
1461 auto *VD = cast<ValueDecl>(Val: D);
1462 if (!diagnoseInputIDType(T: VD->getType(), AL))
1463 return;
1464
1465 D->addAttr(A: ::new (getASTContext())
1466 HLSLSV_GroupThreadIDAttr(getASTContext(), AL));
1467}
1468
1469void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
1470 auto *VD = cast<ValueDecl>(Val: D);
1471 if (!diagnoseInputIDType(T: VD->getType(), AL))
1472 return;
1473
1474 D->addAttr(A: ::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL));
1475}
1476
1477void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
1478 if (!isa<VarDecl>(Val: D) || !isa<HLSLBufferDecl>(Val: D->getDeclContext())) {
1479 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_ast_node)
1480 << AL << "shader constant in a constant buffer";
1481 return;
1482 }
1483
1484 uint32_t SubComponent;
1485 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: SubComponent))
1486 return;
1487 uint32_t Component;
1488 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Component))
1489 return;
1490
1491 QualType T = cast<VarDecl>(Val: D)->getType().getCanonicalType();
1492 // Check if T is an array or struct type.
1493 // TODO: mark matrix type as aggregate type.
1494 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
1495
1496 // Check Component is valid for T.
1497 if (Component) {
1498 unsigned Size = getASTContext().getTypeSize(T);
1499 if (IsAggregateTy || Size > 128) {
1500 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_cross_reg_boundary);
1501 return;
1502 } else {
1503 // Make sure Component + sizeof(T) <= 4.
1504 if ((Component * 32 + Size) > 128) {
1505 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_cross_reg_boundary);
1506 return;
1507 }
1508 QualType EltTy = T;
1509 if (const auto *VT = T->getAs<VectorType>())
1510 EltTy = VT->getElementType();
1511 unsigned Align = getASTContext().getTypeAlign(T: EltTy);
1512 if (Align > 32 && Component == 1) {
1513 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
1514 // So we only need to check Component 1 here.
1515 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_alignment_mismatch)
1516 << Align << EltTy;
1517 return;
1518 }
1519 }
1520 }
1521
1522 D->addAttr(A: ::new (getASTContext()) HLSLPackOffsetAttr(
1523 getASTContext(), AL, SubComponent, Component));
1524}
1525
1526void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
1527 StringRef Str;
1528 SourceLocation ArgLoc;
1529 if (!SemaRef.checkStringLiteralArgumentAttr(Attr: AL, ArgNum: 0, Str, ArgLocation: &ArgLoc))
1530 return;
1531
1532 llvm::Triple::EnvironmentType ShaderType;
1533 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Val: Str, Out&: ShaderType)) {
1534 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_type_not_supported)
1535 << AL << Str << ArgLoc;
1536 return;
1537 }
1538
1539 // FIXME: check function match the shader stage.
1540
1541 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
1542 if (NewAttr)
1543 D->addAttr(A: NewAttr);
1544}
1545
1546bool clang::CreateHLSLAttributedResourceType(
1547 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
1548 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
1549 assert(AttrList.size() && "expected list of resource attributes");
1550
1551 QualType ContainedTy = QualType();
1552 TypeSourceInfo *ContainedTyInfo = nullptr;
1553 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
1554 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
1555
1556 HLSLAttributedResourceType::Attributes ResAttrs;
1557
1558 bool HasResourceClass = false;
1559 for (const Attr *A : AttrList) {
1560 if (!A)
1561 continue;
1562 LocEnd = A->getRange().getEnd();
1563 switch (A->getKind()) {
1564 case attr::HLSLResourceClass: {
1565 ResourceClass RC = cast<HLSLResourceClassAttr>(Val: A)->getResourceClass();
1566 if (HasResourceClass) {
1567 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceClass == RC
1568 ? diag::warn_duplicate_attribute_exact
1569 : diag::warn_duplicate_attribute)
1570 << A;
1571 return false;
1572 }
1573 ResAttrs.ResourceClass = RC;
1574 HasResourceClass = true;
1575 break;
1576 }
1577 case attr::HLSLROV:
1578 if (ResAttrs.IsROV) {
1579 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
1580 return false;
1581 }
1582 ResAttrs.IsROV = true;
1583 break;
1584 case attr::HLSLRawBuffer:
1585 if (ResAttrs.RawBuffer) {
1586 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
1587 return false;
1588 }
1589 ResAttrs.RawBuffer = true;
1590 break;
1591 case attr::HLSLContainedType: {
1592 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(Val: A);
1593 QualType Ty = CTAttr->getType();
1594 if (!ContainedTy.isNull()) {
1595 S.Diag(Loc: A->getLocation(), DiagID: ContainedTy == Ty
1596 ? diag::warn_duplicate_attribute_exact
1597 : diag::warn_duplicate_attribute)
1598 << A;
1599 return false;
1600 }
1601 ContainedTy = Ty;
1602 ContainedTyInfo = CTAttr->getTypeLoc();
1603 break;
1604 }
1605 default:
1606 llvm_unreachable("unhandled resource attribute type");
1607 }
1608 }
1609
1610 if (!HasResourceClass) {
1611 S.Diag(Loc: AttrList.back()->getRange().getEnd(),
1612 DiagID: diag::err_hlsl_missing_resource_class);
1613 return false;
1614 }
1615
1616 ResType = S.getASTContext().getHLSLAttributedResourceType(
1617 Wrapped, Contained: ContainedTy, Attrs: ResAttrs);
1618
1619 if (LocInfo && ContainedTyInfo) {
1620 LocInfo->Range = SourceRange(LocBegin, LocEnd);
1621 LocInfo->ContainedTyInfo = ContainedTyInfo;
1622 }
1623 return true;
1624}
1625
1626// Validates and creates an HLSL attribute that is applied as type attribute on
1627// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
1628// the end of the declaration they are applied to the declaration type by
1629// wrapping it in HLSLAttributedResourceType.
1630bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) {
1631 // only allow resource type attributes on intangible types
1632 if (!T->isHLSLResourceType()) {
1633 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attribute_needs_intangible_type)
1634 << AL << getASTContext().HLSLResourceTy;
1635 return false;
1636 }
1637
1638 // validate number of arguments
1639 if (!AL.checkExactlyNumArgs(S&: SemaRef, Num: AL.getMinArgs()))
1640 return false;
1641
1642 Attr *A = nullptr;
1643 switch (AL.getKind()) {
1644 case ParsedAttr::AT_HLSLResourceClass: {
1645 if (!AL.isArgIdent(Arg: 0)) {
1646 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
1647 << AL << AANT_ArgumentIdentifier;
1648 return false;
1649 }
1650
1651 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
1652 StringRef Identifier = Loc->getIdentifierInfo()->getName();
1653 SourceLocation ArgLoc = Loc->getLoc();
1654
1655 // Validate resource class value
1656 ResourceClass RC;
1657 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Val: Identifier, Out&: RC)) {
1658 Diag(Loc: ArgLoc, DiagID: diag::warn_attribute_type_not_supported)
1659 << "ResourceClass" << Identifier;
1660 return false;
1661 }
1662 A = HLSLResourceClassAttr::Create(Ctx&: getASTContext(), ResourceClass: RC, Range: AL.getLoc());
1663 break;
1664 }
1665
1666 case ParsedAttr::AT_HLSLROV:
1667 A = HLSLROVAttr::Create(Ctx&: getASTContext(), Range: AL.getLoc());
1668 break;
1669
1670 case ParsedAttr::AT_HLSLRawBuffer:
1671 A = HLSLRawBufferAttr::Create(Ctx&: getASTContext(), Range: AL.getLoc());
1672 break;
1673
1674 case ParsedAttr::AT_HLSLContainedType: {
1675 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
1676 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
1677 return false;
1678 }
1679
1680 TypeSourceInfo *TSI = nullptr;
1681 QualType QT = SemaRef.GetTypeFromParser(Ty: AL.getTypeArg(), TInfo: &TSI);
1682 assert(TSI && "no type source info for attribute argument");
1683 if (SemaRef.RequireCompleteType(Loc: TSI->getTypeLoc().getBeginLoc(), T: QT,
1684 DiagID: diag::err_incomplete_type))
1685 return false;
1686 A = HLSLContainedTypeAttr::Create(Ctx&: getASTContext(), Type: TSI, Range: AL.getLoc());
1687 break;
1688 }
1689
1690 default:
1691 llvm_unreachable("unhandled HLSL attribute");
1692 }
1693
1694 HLSLResourcesTypeAttrs.emplace_back(Args&: A);
1695 return true;
1696}
1697
1698// Combines all resource type attributes and creates HLSLAttributedResourceType.
1699QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) {
1700 if (!HLSLResourcesTypeAttrs.size())
1701 return CurrentType;
1702
1703 QualType QT = CurrentType;
1704 HLSLAttributedResourceLocInfo LocInfo;
1705 if (CreateHLSLAttributedResourceType(S&: SemaRef, Wrapped: CurrentType,
1706 AttrList: HLSLResourcesTypeAttrs, ResType&: QT, LocInfo: &LocInfo)) {
1707 const HLSLAttributedResourceType *RT =
1708 cast<HLSLAttributedResourceType>(Val: QT.getTypePtr());
1709
1710 // Temporarily store TypeLoc information for the new type.
1711 // It will be transferred to HLSLAttributesResourceTypeLoc
1712 // shortly after the type is created by TypeSpecLocFiller which
1713 // will call the TakeLocForHLSLAttribute method below.
1714 LocsForHLSLAttributedResources.insert(KV: std::pair(RT, LocInfo));
1715 }
1716 HLSLResourcesTypeAttrs.clear();
1717 return QT;
1718}
1719
1720// Returns source location for the HLSLAttributedResourceType
1721HLSLAttributedResourceLocInfo
1722SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
1723 HLSLAttributedResourceLocInfo LocInfo = {};
1724 auto I = LocsForHLSLAttributedResources.find(Val: RT);
1725 if (I != LocsForHLSLAttributedResources.end()) {
1726 LocInfo = I->second;
1727 LocsForHLSLAttributedResources.erase(I);
1728 return LocInfo;
1729 }
1730 LocInfo.Range = SourceRange();
1731 return LocInfo;
1732}
1733
1734// Walks though the global variable declaration, collects all resource binding
1735// requirements and adds them to Bindings
1736void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
1737 const RecordType *RT) {
1738 const RecordDecl *RD = RT->getDecl();
1739 for (FieldDecl *FD : RD->fields()) {
1740 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
1741
1742 // Unwrap arrays
1743 // FIXME: Calculate array size while unwrapping
1744 assert(!Ty->isIncompleteArrayType() &&
1745 "incomplete arrays inside user defined types are not supported");
1746 while (Ty->isConstantArrayType()) {
1747 const ConstantArrayType *CAT = cast<ConstantArrayType>(Val: Ty);
1748 Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
1749 }
1750
1751 if (!Ty->isRecordType())
1752 continue;
1753
1754 if (const HLSLAttributedResourceType *AttrResType =
1755 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
1756 // Add a new DeclBindingInfo to Bindings if it does not already exist
1757 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
1758 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, ResClass: RC);
1759 if (!DBI)
1760 Bindings.addDeclBindingInfo(VD, ResClass: RC);
1761 } else if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty)) {
1762 // Recursively scan embedded struct or class; it would be nice to do this
1763 // without recursion, but tricky to correctly calculate the size of the
1764 // binding, which is something we are probably going to need to do later
1765 // on. Hopefully nesting of structs in structs too many levels is
1766 // unlikely.
1767 collectResourceBindingsOnUserRecordDecl(VD, RT);
1768 }
1769 }
1770}
1771
1772// Diagnose localized register binding errors for a single binding; does not
1773// diagnose resource binding on user record types, that will be done later
1774// in processResourceBindingOnDecl based on the information collected in
1775// collectResourceBindingsOnVarDecl.
1776// Returns false if the register binding is not valid.
1777static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
1778 Decl *D, RegisterType RegType,
1779 bool SpecifiedSpace) {
1780 int RegTypeNum = static_cast<int>(RegType);
1781
1782 // check if the decl type is groupshared
1783 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
1784 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1785 return false;
1786 }
1787
1788 // Cbuffers and Tbuffers are HLSLBufferDecl types
1789 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: D)) {
1790 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
1791 : ResourceClass::SRV;
1792 if (RegType == getRegisterType(RC))
1793 return true;
1794
1795 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
1796 << RegTypeNum;
1797 return false;
1798 }
1799
1800 // Samplers, UAVs, and SRVs are VarDecl types
1801 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
1802 VarDecl *VD = cast<VarDecl>(Val: D);
1803
1804 // Resource
1805 if (const HLSLAttributedResourceType *AttrResType =
1806 HLSLAttributedResourceType::findHandleTypeOnResource(
1807 RT: VD->getType().getTypePtr())) {
1808 if (RegType == getRegisterType(RC: AttrResType->getAttrs().ResourceClass))
1809 return true;
1810
1811 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
1812 << RegTypeNum;
1813 return false;
1814 }
1815
1816 const clang::Type *Ty = VD->getType().getTypePtr();
1817 while (Ty->isArrayType())
1818 Ty = Ty->getArrayElementTypeNoTypeQual();
1819
1820 // Basic types
1821 if (Ty->isArithmeticType() || Ty->isVectorType()) {
1822 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(Val: D->getDeclContext());
1823 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
1824 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_space_on_global_constant);
1825
1826 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(Ctx: S.getASTContext()) ||
1827 Ty->isFloatingType() || Ty->isVectorType())) {
1828 // Register annotation on default constant buffer declaration ($Globals)
1829 if (RegType == RegisterType::CBuffer)
1830 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_deprecated_register_type_b);
1831 else if (RegType != RegisterType::C)
1832 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1833 else
1834 return true;
1835 } else {
1836 if (RegType == RegisterType::C)
1837 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_register_type_c_packoffset);
1838 else
1839 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1840 }
1841 return false;
1842 }
1843 if (Ty->isRecordType())
1844 // RecordTypes will be diagnosed in processResourceBindingOnDecl
1845 // that is called from ActOnVariableDeclarator
1846 return true;
1847
1848 // Anything else is an error
1849 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1850 return false;
1851}
1852
1853static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
1854 RegisterType regType) {
1855 // make sure that there are no two register annotations
1856 // applied to the decl with the same register type
1857 bool RegisterTypesDetected[5] = {false};
1858 RegisterTypesDetected[static_cast<int>(regType)] = true;
1859
1860 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
1861 if (HLSLResourceBindingAttr *attr =
1862 dyn_cast<HLSLResourceBindingAttr>(Val: *it)) {
1863
1864 RegisterType otherRegType = attr->getRegisterType();
1865 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
1866 int otherRegTypeNum = static_cast<int>(otherRegType);
1867 S.Diag(Loc: TheDecl->getLocation(),
1868 DiagID: diag::err_hlsl_duplicate_register_annotation)
1869 << otherRegTypeNum;
1870 return false;
1871 }
1872 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
1873 }
1874 }
1875 return true;
1876}
1877
1878static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
1879 Decl *D, RegisterType RegType,
1880 bool SpecifiedSpace) {
1881
1882 // exactly one of these two types should be set
1883 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
1884 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
1885 "expecting VarDecl or HLSLBufferDecl");
1886
1887 // check if the declaration contains resource matching the register type
1888 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
1889 return false;
1890
1891 // next, if multiple register annotations exist, check that none conflict.
1892 return ValidateMultipleRegisterAnnotations(S, TheDecl: D, regType: RegType);
1893}
1894
1895void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
1896 if (isa<VarDecl>(Val: TheDecl)) {
1897 if (SemaRef.RequireCompleteType(Loc: TheDecl->getBeginLoc(),
1898 T: cast<ValueDecl>(Val: TheDecl)->getType(),
1899 DiagID: diag::err_incomplete_type))
1900 return;
1901 }
1902
1903 StringRef Slot = "";
1904 StringRef Space = "";
1905 SourceLocation SlotLoc, SpaceLoc;
1906
1907 if (!AL.isArgIdent(Arg: 0)) {
1908 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
1909 << AL << AANT_ArgumentIdentifier;
1910 return;
1911 }
1912 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
1913
1914 if (AL.getNumArgs() == 2) {
1915 Slot = Loc->getIdentifierInfo()->getName();
1916 SlotLoc = Loc->getLoc();
1917 if (!AL.isArgIdent(Arg: 1)) {
1918 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
1919 << AL << AANT_ArgumentIdentifier;
1920 return;
1921 }
1922 Loc = AL.getArgAsIdent(Arg: 1);
1923 Space = Loc->getIdentifierInfo()->getName();
1924 SpaceLoc = Loc->getLoc();
1925 } else {
1926 StringRef Str = Loc->getIdentifierInfo()->getName();
1927 if (Str.starts_with(Prefix: "space")) {
1928 Space = Str;
1929 SpaceLoc = Loc->getLoc();
1930 } else {
1931 Slot = Str;
1932 SlotLoc = Loc->getLoc();
1933 Space = "space0";
1934 }
1935 }
1936
1937 RegisterType RegType = RegisterType::SRV;
1938 std::optional<unsigned> SlotNum;
1939 unsigned SpaceNum = 0;
1940
1941 // Validate slot
1942 if (!Slot.empty()) {
1943 if (!convertToRegisterType(Slot, RT: &RegType)) {
1944 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_binding_type_invalid) << Slot.substr(Start: 0, N: 1);
1945 return;
1946 }
1947 if (RegType == RegisterType::I) {
1948 Diag(Loc: SlotLoc, DiagID: diag::warn_hlsl_deprecated_register_type_i);
1949 return;
1950 }
1951 StringRef SlotNumStr = Slot.substr(Start: 1);
1952 unsigned N;
1953 if (SlotNumStr.getAsInteger(Radix: 10, Result&: N)) {
1954 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_unsupported_register_number);
1955 return;
1956 }
1957 SlotNum = N;
1958 }
1959
1960 // Validate space
1961 if (!Space.starts_with(Prefix: "space")) {
1962 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
1963 return;
1964 }
1965 StringRef SpaceNumStr = Space.substr(Start: 5);
1966 if (SpaceNumStr.getAsInteger(Radix: 10, Result&: SpaceNum)) {
1967 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
1968 return;
1969 }
1970
1971 // If we have slot, diagnose it is the right register type for the decl
1972 if (SlotNum.has_value())
1973 if (!DiagnoseHLSLRegisterAttribute(S&: SemaRef, ArgLoc&: SlotLoc, D: TheDecl, RegType,
1974 SpecifiedSpace: !SpaceLoc.isInvalid()))
1975 return;
1976
1977 HLSLResourceBindingAttr *NewAttr =
1978 HLSLResourceBindingAttr::Create(Ctx&: getASTContext(), Slot, Space, CommonInfo: AL);
1979 if (NewAttr) {
1980 NewAttr->setBinding(RT: RegType, SlotNum, SpaceNum);
1981 TheDecl->addAttr(A: NewAttr);
1982 }
1983}
1984
1985void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
1986 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
1987 D, AL,
1988 Spelling: static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
1989 if (NewAttr)
1990 D->addAttr(A: NewAttr);
1991}
1992
1993namespace {
1994
1995/// This class implements HLSL availability diagnostics for default
1996/// and relaxed mode
1997///
1998/// The goal of this diagnostic is to emit an error or warning when an
1999/// unavailable API is found in code that is reachable from the shader
2000/// entry function or from an exported function (when compiling a shader
2001/// library).
2002///
2003/// This is done by traversing the AST of all shader entry point functions
2004/// and of all exported functions, and any functions that are referenced
2005/// from this AST. In other words, any functions that are reachable from
2006/// the entry points.
2007class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2008 Sema &SemaRef;
2009
2010 // Stack of functions to be scaned
2011 llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;
2012
2013 // Tracks which environments functions have been scanned in.
2014 //
2015 // Maps FunctionDecl to an unsigned number that represents the set of shader
2016 // environments the function has been scanned for.
2017 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2018 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2019 // (verified by static_asserts in Triple.cpp), we can use it to index
2020 // individual bits in the set, as long as we shift the values to start with 0
2021 // by subtracting the value of llvm::Triple::Pixel first.
2022 //
2023 // The N'th bit in the set will be set if the function has been scanned
2024 // in shader environment whose llvm::Triple::EnvironmentType integer value
2025 // equals (llvm::Triple::Pixel + N).
2026 //
2027 // For example, if a function has been scanned in compute and pixel stage
2028 // environment, the value will be 0x21 (100001 binary) because:
2029 //
2030 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2031 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2032 //
2033 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2034 // been scanned in any environment.
2035 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2036
2037 // Do not access these directly, use the get/set methods below to make
2038 // sure the values are in sync
2039 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2040 unsigned CurrentShaderStageBit;
2041
2042 // True if scanning a function that was already scanned in a different
2043 // shader stage context, and therefore we should not report issues that
2044 // depend only on shader model version because they would be duplicate.
2045 bool ReportOnlyShaderStageIssues;
2046
2047 // Helper methods for dealing with current stage context / environment
2048 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2049 static_assert(sizeof(unsigned) >= 4);
2050 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2051 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2052 "ShaderType is too big for this bitmap"); // 31 is reserved for
2053 // "unknown"
2054
2055 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2056 CurrentShaderEnvironment = ShaderType;
2057 CurrentShaderStageBit = (1 << bitmapIndex);
2058 }
2059
2060 void SetUnknownShaderStageContext() {
2061 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2062 CurrentShaderStageBit = (1 << 31);
2063 }
2064
2065 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2066 return CurrentShaderEnvironment;
2067 }
2068
2069 bool InUnknownShaderStageContext() const {
2070 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2071 }
2072
2073 // Helper methods for dealing with shader stage bitmap
2074 void AddToScannedFunctions(const FunctionDecl *FD) {
2075 unsigned &ScannedStages = ScannedDecls[FD];
2076 ScannedStages |= CurrentShaderStageBit;
2077 }
2078
2079 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2080
2081 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2082 return WasAlreadyScannedInCurrentStage(ScannerStages: GetScannedStages(FD));
2083 }
2084
2085 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2086 return ScannerStages & CurrentShaderStageBit;
2087 }
2088
2089 static bool NeverBeenScanned(unsigned ScannedStages) {
2090 return ScannedStages == 0;
2091 }
2092
2093 // Scanning methods
2094 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2095 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2096 SourceRange Range);
2097 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2098 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2099
2100public:
2101 DiagnoseHLSLAvailability(Sema &SemaRef)
2102 : SemaRef(SemaRef),
2103 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2104 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2105
2106 // AST traversal methods
2107 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2108 void RunOnFunction(const FunctionDecl *FD);
2109
2110 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2111 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: DRE->getDecl());
2112 if (FD)
2113 HandleFunctionOrMethodRef(FD, RefExpr: DRE);
2114 return true;
2115 }
2116
2117 bool VisitMemberExpr(MemberExpr *ME) override {
2118 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: ME->getMemberDecl());
2119 if (FD)
2120 HandleFunctionOrMethodRef(FD, RefExpr: ME);
2121 return true;
2122 }
2123};
2124
2125void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2126 Expr *RefExpr) {
2127 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2128 "expected DeclRefExpr or MemberExpr");
2129
2130 // has a definition -> add to stack to be scanned
2131 const FunctionDecl *FDWithBody = nullptr;
2132 if (FD->hasBody(Definition&: FDWithBody)) {
2133 if (!WasAlreadyScannedInCurrentStage(FD: FDWithBody))
2134 DeclsToScan.push_back(Elt: FDWithBody);
2135 return;
2136 }
2137
2138 // no body -> diagnose availability
2139 const AvailabilityAttr *AA = FindAvailabilityAttr(D: FD);
2140 if (AA)
2141 CheckDeclAvailability(
2142 D: FD, AA, Range: SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2143}
2144
2145void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2146 const TranslationUnitDecl *TU) {
2147
2148 // Iterate over all shader entry functions and library exports, and for those
2149 // that have a body (definiton), run diag scan on each, setting appropriate
2150 // shader environment context based on whether it is a shader entry function
2151 // or an exported function. Exported functions can be in namespaces and in
2152 // export declarations so we need to scan those declaration contexts as well.
2153 llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;
2154 DeclContextsToScan.push_back(Elt: TU);
2155
2156 while (!DeclContextsToScan.empty()) {
2157 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2158 for (auto &D : DC->decls()) {
2159 // do not scan implicit declaration generated by the implementation
2160 if (D->isImplicit())
2161 continue;
2162
2163 // for namespace or export declaration add the context to the list to be
2164 // scanned later
2165 if (llvm::dyn_cast<NamespaceDecl>(Val: D) || llvm::dyn_cast<ExportDecl>(Val: D)) {
2166 DeclContextsToScan.push_back(Elt: llvm::dyn_cast<DeclContext>(Val: D));
2167 continue;
2168 }
2169
2170 // skip over other decls or function decls without body
2171 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: D);
2172 if (!FD || !FD->isThisDeclarationADefinition())
2173 continue;
2174
2175 // shader entry point
2176 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2177 SetShaderStageContext(ShaderAttr->getType());
2178 RunOnFunction(FD);
2179 continue;
2180 }
2181 // exported library function
2182 // FIXME: replace this loop with external linkage check once issue #92071
2183 // is resolved
2184 bool isExport = FD->isInExportDeclContext();
2185 if (!isExport) {
2186 for (const auto *Redecl : FD->redecls()) {
2187 if (Redecl->isInExportDeclContext()) {
2188 isExport = true;
2189 break;
2190 }
2191 }
2192 }
2193 if (isExport) {
2194 SetUnknownShaderStageContext();
2195 RunOnFunction(FD);
2196 continue;
2197 }
2198 }
2199 }
2200}
2201
2202void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2203 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2204 DeclsToScan.push_back(Elt: FD);
2205
2206 while (!DeclsToScan.empty()) {
2207 // Take one decl from the stack and check it by traversing its AST.
2208 // For any CallExpr found during the traversal add it's callee to the top of
2209 // the stack to be processed next. Functions already processed are stored in
2210 // ScannedDecls.
2211 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2212
2213 // Decl was already scanned
2214 const unsigned ScannedStages = GetScannedStages(FD);
2215 if (WasAlreadyScannedInCurrentStage(ScannerStages: ScannedStages))
2216 continue;
2217
2218 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2219
2220 AddToScannedFunctions(FD);
2221 TraverseStmt(S: FD->getBody());
2222 }
2223}
2224
2225bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2226 const AvailabilityAttr *AA) {
2227 IdentifierInfo *IIEnvironment = AA->getEnvironment();
2228 if (!IIEnvironment)
2229 return true;
2230
2231 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2232 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2233 return false;
2234
2235 llvm::Triple::EnvironmentType AttrEnv =
2236 AvailabilityAttr::getEnvironmentType(Environment: IIEnvironment->getName());
2237
2238 return CurrentEnv == AttrEnv;
2239}
2240
2241const AvailabilityAttr *
2242DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2243 AvailabilityAttr const *PartialMatch = nullptr;
2244 // Check each AvailabilityAttr to find the one for this platform.
2245 // For multiple attributes with the same platform try to find one for this
2246 // environment.
2247 for (const auto *A : D->attrs()) {
2248 if (const auto *Avail = dyn_cast<AvailabilityAttr>(Val: A)) {
2249 StringRef AttrPlatform = Avail->getPlatform()->getName();
2250 StringRef TargetPlatform =
2251 SemaRef.getASTContext().getTargetInfo().getPlatformName();
2252
2253 // Match the platform name.
2254 if (AttrPlatform == TargetPlatform) {
2255 // Find the best matching attribute for this environment
2256 if (HasMatchingEnvironmentOrNone(AA: Avail))
2257 return Avail;
2258 PartialMatch = Avail;
2259 }
2260 }
2261 }
2262 return PartialMatch;
2263}
2264
2265// Check availability against target shader model version and current shader
2266// stage and emit diagnostic
2267void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2268 const AvailabilityAttr *AA,
2269 SourceRange Range) {
2270
2271 IdentifierInfo *IIEnv = AA->getEnvironment();
2272
2273 if (!IIEnv) {
2274 // The availability attribute does not have environment -> it depends only
2275 // on shader model version and not on specific the shader stage.
2276
2277 // Skip emitting the diagnostics if the diagnostic mode is set to
2278 // strict (-fhlsl-strict-availability) because all relevant diagnostics
2279 // were already emitted in the DiagnoseUnguardedAvailability scan
2280 // (SemaAvailability.cpp).
2281 if (SemaRef.getLangOpts().HLSLStrictAvailability)
2282 return;
2283
2284 // Do not report shader-stage-independent issues if scanning a function
2285 // that was already scanned in a different shader stage context (they would
2286 // be duplicate)
2287 if (ReportOnlyShaderStageIssues)
2288 return;
2289
2290 } else {
2291 // The availability attribute has environment -> we need to know
2292 // the current stage context to property diagnose it.
2293 if (InUnknownShaderStageContext())
2294 return;
2295 }
2296
2297 // Check introduced version and if environment matches
2298 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
2299 VersionTuple Introduced = AA->getIntroduced();
2300 VersionTuple TargetVersion =
2301 SemaRef.Context.getTargetInfo().getPlatformMinVersion();
2302
2303 if (TargetVersion >= Introduced && EnvironmentMatches)
2304 return;
2305
2306 // Emit diagnostic message
2307 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2308 llvm::StringRef PlatformName(
2309 AvailabilityAttr::getPrettyPlatformName(Platform: TI.getPlatformName()));
2310
2311 llvm::StringRef CurrentEnvStr =
2312 llvm::Triple::getEnvironmentTypeName(Kind: GetCurrentShaderEnvironment());
2313
2314 llvm::StringRef AttrEnvStr =
2315 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
2316 bool UseEnvironment = !AttrEnvStr.empty();
2317
2318 if (EnvironmentMatches) {
2319 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability)
2320 << Range << D << PlatformName << Introduced.getAsString()
2321 << UseEnvironment << CurrentEnvStr;
2322 } else {
2323 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability_unavailable)
2324 << Range << D;
2325 }
2326
2327 SemaRef.Diag(Loc: D->getLocation(), DiagID: diag::note_partial_availability_specified_here)
2328 << D << PlatformName << Introduced.getAsString()
2329 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
2330 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
2331}
2332
2333} // namespace
2334
2335void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) {
2336 // process default CBuffer - create buffer layout struct and invoke codegenCGH
2337 if (!DefaultCBufferDecls.empty()) {
2338 HLSLBufferDecl *DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer(
2339 C&: SemaRef.getASTContext(), LexicalParent: SemaRef.getCurLexicalContext(),
2340 DefaultCBufferDecls);
2341 addImplicitBindingAttrToBuffer(S&: SemaRef, BufDecl: DefaultCBuffer,
2342 ImplicitBindingOrderID: getNextImplicitBindingOrderID());
2343 SemaRef.getCurLexicalContext()->addDecl(D: DefaultCBuffer);
2344 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl: DefaultCBuffer);
2345
2346 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
2347 for (const Decl *VD : DefaultCBufferDecls) {
2348 const HLSLResourceBindingAttr *RBA =
2349 VD->getAttr<HLSLResourceBindingAttr>();
2350 if (RBA && RBA->hasRegisterSlot() &&
2351 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
2352 DefaultCBuffer->setHasValidPackoffset(true);
2353 break;
2354 }
2355 }
2356
2357 DeclGroupRef DG(DefaultCBuffer);
2358 SemaRef.Consumer.HandleTopLevelDecl(D: DG);
2359 }
2360 diagnoseAvailabilityViolations(TU);
2361}
2362
2363void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
2364 // Skip running the diagnostics scan if the diagnostic mode is
2365 // strict (-fhlsl-strict-availability) and the target shader stage is known
2366 // because all relevant diagnostics were already emitted in the
2367 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
2368 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2369 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
2370 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
2371 return;
2372
2373 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
2374}
2375
2376static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
2377 assert(TheCall->getNumArgs() > 1);
2378 QualType ArgTy0 = TheCall->getArg(Arg: 0)->getType();
2379
2380 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
2381 if (!S->getASTContext().hasSameUnqualifiedType(
2382 T1: ArgTy0, T2: TheCall->getArg(Arg: I)->getType())) {
2383 S->Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vec_builtin_incompatible_vector)
2384 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
2385 << SourceRange(TheCall->getArg(Arg: 0)->getBeginLoc(),
2386 TheCall->getArg(Arg: N - 1)->getEndLoc());
2387 return true;
2388 }
2389 }
2390 return false;
2391}
2392
2393static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
2394 QualType ArgType = Arg->getType();
2395 if (!S->getASTContext().hasSameUnqualifiedType(T1: ArgType, T2: ExpectedType)) {
2396 S->Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_typecheck_convert_incompatible)
2397 << ArgType << ExpectedType << 1 << 0 << 0;
2398 return true;
2399 }
2400 return false;
2401}
2402
2403static bool CheckAllArgTypesAreCorrect(
2404 Sema *S, CallExpr *TheCall,
2405 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
2406 clang::QualType PassedType)>
2407 Check) {
2408 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
2409 Expr *Arg = TheCall->getArg(Arg: I);
2410 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
2411 return true;
2412 }
2413 return false;
2414}
2415
2416static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
2417 int ArgOrdinal,
2418 clang::QualType PassedType) {
2419 clang::QualType BaseType =
2420 PassedType->isVectorType()
2421 ? PassedType->castAs<clang::VectorType>()->getElementType()
2422 : PassedType;
2423 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
2424 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
2425 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
2426 << /* half or float */ 2 << PassedType;
2427 return false;
2428}
2429
2430static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
2431 unsigned ArgIndex) {
2432 auto *Arg = TheCall->getArg(Arg: ArgIndex);
2433 SourceLocation OrigLoc = Arg->getExprLoc();
2434 if (Arg->IgnoreCasts()->isModifiableLvalue(Ctx&: S->Context, Loc: &OrigLoc) ==
2435 Expr::MLV_Valid)
2436 return false;
2437 S->Diag(Loc: OrigLoc, DiagID: diag::error_hlsl_inout_lvalue) << Arg << 0;
2438 return true;
2439}
2440
2441static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
2442 clang::QualType PassedType) {
2443 const auto *VecTy = PassedType->getAs<VectorType>();
2444 if (!VecTy)
2445 return false;
2446
2447 if (VecTy->getElementType()->isDoubleType())
2448 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
2449 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
2450 << PassedType;
2451 return false;
2452}
2453
2454static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
2455 int ArgOrdinal,
2456 clang::QualType PassedType) {
2457 if (!PassedType->hasIntegerRepresentation() &&
2458 !PassedType->hasFloatingRepresentation())
2459 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
2460 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
2461 << /* fp */ 1 << PassedType;
2462 return false;
2463}
2464
2465static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
2466 int ArgOrdinal,
2467 clang::QualType PassedType) {
2468 if (auto *VecTy = PassedType->getAs<VectorType>())
2469 if (VecTy->getElementType()->isUnsignedIntegerType())
2470 return false;
2471
2472 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
2473 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
2474 << PassedType;
2475}
2476
2477// checks for unsigned ints of all sizes
2478static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
2479 int ArgOrdinal,
2480 clang::QualType PassedType) {
2481 if (!PassedType->hasUnsignedIntegerRepresentation())
2482 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
2483 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
2484 << /* no fp */ 0 << PassedType;
2485 return false;
2486}
2487
2488static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
2489 QualType ReturnType) {
2490 auto *VecTyA = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
2491 if (VecTyA)
2492 ReturnType =
2493 S->Context.getExtVectorType(VectorType: ReturnType, NumElts: VecTyA->getNumElements());
2494
2495 TheCall->setType(ReturnType);
2496}
2497
2498static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
2499 unsigned ArgIndex) {
2500 assert(TheCall->getNumArgs() >= ArgIndex);
2501 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
2502 auto *VTy = ArgType->getAs<VectorType>();
2503 // not the scalar or vector<scalar>
2504 if (!(S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar) ||
2505 (VTy &&
2506 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar)))) {
2507 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
2508 DiagID: diag::err_typecheck_expect_scalar_or_vector)
2509 << ArgType << Scalar;
2510 return true;
2511 }
2512 return false;
2513}
2514
2515static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
2516 unsigned ArgIndex) {
2517 assert(TheCall->getNumArgs() >= ArgIndex);
2518 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
2519 auto *VTy = ArgType->getAs<VectorType>();
2520 // not the scalar or vector<scalar>
2521 if (!(ArgType->isScalarType() ||
2522 (VTy && VTy->getElementType()->isScalarType()))) {
2523 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
2524 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
2525 << ArgType << 1;
2526 return true;
2527 }
2528 return false;
2529}
2530
2531static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
2532 QualType BoolType = S->getASTContext().BoolTy;
2533 assert(TheCall->getNumArgs() >= 1);
2534 QualType ArgType = TheCall->getArg(Arg: 0)->getType();
2535 auto *VTy = ArgType->getAs<VectorType>();
2536 // is the bool or vector<bool>
2537 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: BoolType) ||
2538 (VTy &&
2539 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: BoolType))) {
2540 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
2541 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
2542 << ArgType << 0;
2543 return true;
2544 }
2545 return false;
2546}
2547
2548static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
2549 assert(TheCall->getNumArgs() == 3);
2550 Expr *Arg1 = TheCall->getArg(Arg: 1);
2551 Expr *Arg2 = TheCall->getArg(Arg: 2);
2552 if (!S->Context.hasSameUnqualifiedType(T1: Arg1->getType(), T2: Arg2->getType())) {
2553 S->Diag(Loc: TheCall->getBeginLoc(),
2554 DiagID: diag::err_typecheck_call_different_arg_types)
2555 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
2556 << Arg2->getSourceRange();
2557 return true;
2558 }
2559
2560 TheCall->setType(Arg1->getType());
2561 return false;
2562}
2563
2564static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
2565 assert(TheCall->getNumArgs() == 3);
2566 Expr *Arg1 = TheCall->getArg(Arg: 1);
2567 QualType Arg1Ty = Arg1->getType();
2568 Expr *Arg2 = TheCall->getArg(Arg: 2);
2569 QualType Arg2Ty = Arg2->getType();
2570
2571 QualType Arg1ScalarTy = Arg1Ty;
2572 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
2573 Arg1ScalarTy = VTy->getElementType();
2574
2575 QualType Arg2ScalarTy = Arg2Ty;
2576 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
2577 Arg2ScalarTy = VTy->getElementType();
2578
2579 if (!S->Context.hasSameUnqualifiedType(T1: Arg1ScalarTy, T2: Arg2ScalarTy))
2580 S->Diag(Loc: Arg1->getBeginLoc(), DiagID: diag::err_hlsl_builtin_scalar_vector_mismatch)
2581 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
2582
2583 QualType Arg0Ty = TheCall->getArg(Arg: 0)->getType();
2584 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
2585 unsigned Arg1Length = Arg1Ty->isVectorType()
2586 ? Arg1Ty->getAs<VectorType>()->getNumElements()
2587 : 0;
2588 unsigned Arg2Length = Arg2Ty->isVectorType()
2589 ? Arg2Ty->getAs<VectorType>()->getNumElements()
2590 : 0;
2591 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
2592 S->Diag(Loc: TheCall->getBeginLoc(),
2593 DiagID: diag::err_typecheck_vector_lengths_not_equal)
2594 << Arg0Ty << Arg1Ty << TheCall->getArg(Arg: 0)->getSourceRange()
2595 << Arg1->getSourceRange();
2596 return true;
2597 }
2598
2599 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
2600 S->Diag(Loc: TheCall->getBeginLoc(),
2601 DiagID: diag::err_typecheck_vector_lengths_not_equal)
2602 << Arg0Ty << Arg2Ty << TheCall->getArg(Arg: 0)->getSourceRange()
2603 << Arg2->getSourceRange();
2604 return true;
2605 }
2606
2607 TheCall->setType(
2608 S->getASTContext().getExtVectorType(VectorType: Arg1ScalarTy, NumElts: Arg0Length));
2609 return false;
2610}
2611
2612static bool CheckResourceHandle(
2613 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
2614 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
2615 nullptr) {
2616 assert(TheCall->getNumArgs() >= ArgIndex);
2617 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
2618 const HLSLAttributedResourceType *ResTy =
2619 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
2620 if (!ResTy) {
2621 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getBeginLoc(),
2622 DiagID: diag::err_typecheck_expect_hlsl_resource)
2623 << ArgType;
2624 return true;
2625 }
2626 if (Check && Check(ResTy)) {
2627 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getExprLoc(),
2628 DiagID: diag::err_invalid_hlsl_resource_type)
2629 << ArgType;
2630 return true;
2631 }
2632 return false;
2633}
2634
2635// Note: returning true in this case results in CheckBuiltinFunctionCall
2636// returning an ExprError
2637bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
2638 switch (BuiltinID) {
2639 case Builtin::BI__builtin_hlsl_adduint64: {
2640 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
2641 return true;
2642
2643 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2644 Check: CheckUnsignedIntVecRepresentation))
2645 return true;
2646
2647 auto *VTy = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
2648 // ensure arg integers are 32-bits
2649 uint64_t ElementBitCount = getASTContext()
2650 .getTypeSizeInChars(T: VTy->getElementType())
2651 .getQuantity() *
2652 8;
2653 if (ElementBitCount != 32) {
2654 SemaRef.Diag(Loc: TheCall->getBeginLoc(),
2655 DiagID: diag::err_integer_incorrect_bit_count)
2656 << 32 << ElementBitCount;
2657 return true;
2658 }
2659
2660 // ensure both args are vectors of total bit size of a multiple of 64
2661 int NumElementsArg = VTy->getNumElements();
2662 if (NumElementsArg != 2 && NumElementsArg != 4) {
2663 SemaRef.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vector_incorrect_bit_count)
2664 << 1 /*a multiple of*/ << 64 << NumElementsArg * ElementBitCount;
2665 return true;
2666 }
2667
2668 // ensure first arg and second arg have the same type
2669 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
2670 return true;
2671
2672 ExprResult A = TheCall->getArg(Arg: 0);
2673 QualType ArgTyA = A.get()->getType();
2674 // return type is the same as the input type
2675 TheCall->setType(ArgTyA);
2676 break;
2677 }
2678 case Builtin::BI__builtin_hlsl_resource_getpointer: {
2679 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2) ||
2680 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
2681 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
2682 ExpectedType: SemaRef.getASTContext().UnsignedIntTy))
2683 return true;
2684
2685 auto *ResourceTy =
2686 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
2687 QualType ContainedTy = ResourceTy->getContainedType();
2688 auto ReturnType =
2689 SemaRef.Context.getAddrSpaceQualType(T: ContainedTy, AddressSpace: LangAS::hlsl_device);
2690 ReturnType = SemaRef.Context.getPointerType(T: ReturnType);
2691 TheCall->setType(ReturnType);
2692 TheCall->setValueKind(VK_LValue);
2693
2694 break;
2695 }
2696 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
2697 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1) ||
2698 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0))
2699 return true;
2700 // use the type of the handle (arg0) as a return type
2701 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
2702 TheCall->setType(ResourceTy);
2703 break;
2704 }
2705 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
2706 ASTContext &AST = SemaRef.getASTContext();
2707 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 6) ||
2708 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
2709 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1), ExpectedType: AST.UnsignedIntTy) ||
2710 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2), ExpectedType: AST.UnsignedIntTy) ||
2711 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 3), ExpectedType: AST.IntTy) ||
2712 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 4), ExpectedType: AST.UnsignedIntTy) ||
2713 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 5),
2714 ExpectedType: AST.getPointerType(T: AST.CharTy.withConst())))
2715 return true;
2716 // use the type of the handle (arg0) as a return type
2717 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
2718 TheCall->setType(ResourceTy);
2719 break;
2720 }
2721 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
2722 ASTContext &AST = SemaRef.getASTContext();
2723 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 6) ||
2724 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
2725 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1), ExpectedType: AST.UnsignedIntTy) ||
2726 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2), ExpectedType: AST.IntTy) ||
2727 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 3), ExpectedType: AST.UnsignedIntTy) ||
2728 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 4), ExpectedType: AST.UnsignedIntTy) ||
2729 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 5),
2730 ExpectedType: AST.getPointerType(T: AST.CharTy.withConst())))
2731 return true;
2732 // use the type of the handle (arg0) as a return type
2733 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
2734 TheCall->setType(ResourceTy);
2735 break;
2736 }
2737 case Builtin::BI__builtin_hlsl_and:
2738 case Builtin::BI__builtin_hlsl_or: {
2739 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
2740 return true;
2741 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy, ArgIndex: 0))
2742 return true;
2743 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
2744 return true;
2745
2746 ExprResult A = TheCall->getArg(Arg: 0);
2747 QualType ArgTyA = A.get()->getType();
2748 // return type is the same as the input type
2749 TheCall->setType(ArgTyA);
2750 break;
2751 }
2752 case Builtin::BI__builtin_hlsl_all:
2753 case Builtin::BI__builtin_hlsl_any: {
2754 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2755 return true;
2756 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
2757 return true;
2758 break;
2759 }
2760 case Builtin::BI__builtin_hlsl_asdouble: {
2761 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
2762 return true;
2763 if (CheckScalarOrVector(
2764 S: &SemaRef, TheCall,
2765 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
2766 /* arg index */ ArgIndex: 0))
2767 return true;
2768 if (CheckScalarOrVector(
2769 S: &SemaRef, TheCall,
2770 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
2771 /* arg index */ ArgIndex: 1))
2772 return true;
2773 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
2774 return true;
2775
2776 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().DoubleTy);
2777 break;
2778 }
2779 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
2780 if (SemaRef.BuiltinElementwiseTernaryMath(
2781 TheCall, /*ArgTyRestr=*/
2782 Sema::EltwiseBuiltinArgTyRestriction::None))
2783 return true;
2784 break;
2785 }
2786 case Builtin::BI__builtin_hlsl_dot: {
2787 // arg count is checked by BuiltinVectorToScalarMath
2788 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
2789 return true;
2790 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckNoDoubleVectors))
2791 return true;
2792 break;
2793 }
2794 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
2795 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
2796 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2797 return true;
2798
2799 const Expr *Arg = TheCall->getArg(Arg: 0);
2800 QualType ArgTy = Arg->getType();
2801 QualType EltTy = ArgTy;
2802
2803 QualType ResTy = SemaRef.Context.UnsignedIntTy;
2804
2805 if (auto *VecTy = EltTy->getAs<VectorType>()) {
2806 EltTy = VecTy->getElementType();
2807 ResTy = SemaRef.Context.getExtVectorType(VectorType: ResTy, NumElts: VecTy->getNumElements());
2808 }
2809
2810 if (!EltTy->isIntegerType()) {
2811 Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
2812 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
2813 << /* no fp */ 0 << ArgTy;
2814 return true;
2815 }
2816
2817 TheCall->setType(ResTy);
2818 break;
2819 }
2820 case Builtin::BI__builtin_hlsl_select: {
2821 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
2822 return true;
2823 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy, ArgIndex: 0))
2824 return true;
2825 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
2826 if (ArgTy->isBooleanType() && CheckBoolSelect(S: &SemaRef, TheCall))
2827 return true;
2828 auto *VTy = ArgTy->getAs<VectorType>();
2829 if (VTy && VTy->getElementType()->isBooleanType() &&
2830 CheckVectorSelect(S: &SemaRef, TheCall))
2831 return true;
2832 break;
2833 }
2834 case Builtin::BI__builtin_hlsl_elementwise_saturate:
2835 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
2836 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2837 return true;
2838 if (!TheCall->getArg(Arg: 0)
2839 ->getType()
2840 ->hasFloatingRepresentation()) // half or float or double
2841 return SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
2842 DiagID: diag::err_builtin_invalid_arg_type)
2843 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
2844 << /* fp */ 1 << TheCall->getArg(Arg: 0)->getType();
2845 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2846 return true;
2847 break;
2848 }
2849 case Builtin::BI__builtin_hlsl_elementwise_degrees:
2850 case Builtin::BI__builtin_hlsl_elementwise_radians:
2851 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
2852 case Builtin::BI__builtin_hlsl_elementwise_frac: {
2853 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2854 return true;
2855 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2856 Check: CheckFloatOrHalfRepresentation))
2857 return true;
2858 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2859 return true;
2860 break;
2861 }
2862 case Builtin::BI__builtin_hlsl_elementwise_isinf: {
2863 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2864 return true;
2865 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2866 Check: CheckFloatOrHalfRepresentation))
2867 return true;
2868 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2869 return true;
2870 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().BoolTy);
2871 break;
2872 }
2873 case Builtin::BI__builtin_hlsl_lerp: {
2874 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
2875 return true;
2876 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2877 Check: CheckFloatOrHalfRepresentation))
2878 return true;
2879 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
2880 return true;
2881 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
2882 return true;
2883 break;
2884 }
2885 case Builtin::BI__builtin_hlsl_mad: {
2886 if (SemaRef.BuiltinElementwiseTernaryMath(
2887 TheCall, /*ArgTyRestr=*/
2888 Sema::EltwiseBuiltinArgTyRestriction::None))
2889 return true;
2890 break;
2891 }
2892 case Builtin::BI__builtin_hlsl_normalize: {
2893 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2894 return true;
2895 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2896 Check: CheckFloatOrHalfRepresentation))
2897 return true;
2898 ExprResult A = TheCall->getArg(Arg: 0);
2899 QualType ArgTyA = A.get()->getType();
2900 // return type is the same as the input type
2901 TheCall->setType(ArgTyA);
2902 break;
2903 }
2904 case Builtin::BI__builtin_hlsl_elementwise_sign: {
2905 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2906 return true;
2907 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2908 Check: CheckFloatingOrIntRepresentation))
2909 return true;
2910 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().IntTy);
2911 break;
2912 }
2913 case Builtin::BI__builtin_hlsl_step: {
2914 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
2915 return true;
2916 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2917 Check: CheckFloatOrHalfRepresentation))
2918 return true;
2919
2920 ExprResult A = TheCall->getArg(Arg: 0);
2921 QualType ArgTyA = A.get()->getType();
2922 // return type is the same as the input type
2923 TheCall->setType(ArgTyA);
2924 break;
2925 }
2926 case Builtin::BI__builtin_hlsl_wave_active_max:
2927 case Builtin::BI__builtin_hlsl_wave_active_sum: {
2928 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2929 return true;
2930
2931 // Ensure input expr type is a scalar/vector and the same as the return type
2932 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
2933 return true;
2934 if (CheckWaveActive(S: &SemaRef, TheCall))
2935 return true;
2936 ExprResult Expr = TheCall->getArg(Arg: 0);
2937 QualType ArgTyExpr = Expr.get()->getType();
2938 TheCall->setType(ArgTyExpr);
2939 break;
2940 }
2941 // Note these are llvm builtins that we want to catch invalid intrinsic
2942 // generation. Normal handling of these builitns will occur elsewhere.
2943 case Builtin::BI__builtin_elementwise_bitreverse: {
2944 // does not include a check for number of arguments
2945 // because that is done previously
2946 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2947 Check: CheckUnsignedIntRepresentation))
2948 return true;
2949 break;
2950 }
2951 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
2952 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
2953 return true;
2954
2955 // Ensure index parameter type can be interpreted as a uint
2956 ExprResult Index = TheCall->getArg(Arg: 1);
2957 QualType ArgTyIndex = Index.get()->getType();
2958 if (!ArgTyIndex->isIntegerType()) {
2959 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
2960 DiagID: diag::err_typecheck_convert_incompatible)
2961 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
2962 return true;
2963 }
2964
2965 // Ensure input expr type is a scalar/vector and the same as the return type
2966 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
2967 return true;
2968
2969 ExprResult Expr = TheCall->getArg(Arg: 0);
2970 QualType ArgTyExpr = Expr.get()->getType();
2971 TheCall->setType(ArgTyExpr);
2972 break;
2973 }
2974 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
2975 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 0))
2976 return true;
2977 break;
2978 }
2979 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
2980 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
2981 return true;
2982
2983 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.DoubleTy, ArgIndex: 0) ||
2984 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
2985 ArgIndex: 1) ||
2986 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
2987 ArgIndex: 2))
2988 return true;
2989
2990 if (CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 1) ||
2991 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
2992 return true;
2993 break;
2994 }
2995 case Builtin::BI__builtin_hlsl_elementwise_clip: {
2996 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2997 return true;
2998
2999 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.FloatTy, ArgIndex: 0))
3000 return true;
3001 break;
3002 }
3003 case Builtin::BI__builtin_elementwise_acos:
3004 case Builtin::BI__builtin_elementwise_asin:
3005 case Builtin::BI__builtin_elementwise_atan:
3006 case Builtin::BI__builtin_elementwise_atan2:
3007 case Builtin::BI__builtin_elementwise_ceil:
3008 case Builtin::BI__builtin_elementwise_cos:
3009 case Builtin::BI__builtin_elementwise_cosh:
3010 case Builtin::BI__builtin_elementwise_exp:
3011 case Builtin::BI__builtin_elementwise_exp2:
3012 case Builtin::BI__builtin_elementwise_exp10:
3013 case Builtin::BI__builtin_elementwise_floor:
3014 case Builtin::BI__builtin_elementwise_fmod:
3015 case Builtin::BI__builtin_elementwise_log:
3016 case Builtin::BI__builtin_elementwise_log2:
3017 case Builtin::BI__builtin_elementwise_log10:
3018 case Builtin::BI__builtin_elementwise_pow:
3019 case Builtin::BI__builtin_elementwise_roundeven:
3020 case Builtin::BI__builtin_elementwise_sin:
3021 case Builtin::BI__builtin_elementwise_sinh:
3022 case Builtin::BI__builtin_elementwise_sqrt:
3023 case Builtin::BI__builtin_elementwise_tan:
3024 case Builtin::BI__builtin_elementwise_tanh:
3025 case Builtin::BI__builtin_elementwise_trunc: {
3026 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3027 Check: CheckFloatOrHalfRepresentation))
3028 return true;
3029 break;
3030 }
3031 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
3032 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
3033 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
3034 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
3035 };
3036 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2) ||
3037 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0, Check: checkResTy) ||
3038 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3039 ExpectedType: SemaRef.getASTContext().IntTy))
3040 return true;
3041 Expr *OffsetExpr = TheCall->getArg(Arg: 1);
3042 std::optional<llvm::APSInt> Offset =
3043 OffsetExpr->getIntegerConstantExpr(Ctx: SemaRef.getASTContext());
3044 if (!Offset.has_value() || std::abs(i: Offset->getExtValue()) != 1) {
3045 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
3046 DiagID: diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
3047 << 1;
3048 return true;
3049 }
3050 break;
3051 }
3052 }
3053 return false;
3054}
3055
3056static void BuildFlattenedTypeList(QualType BaseTy,
3057 llvm::SmallVectorImpl<QualType> &List) {
3058 llvm::SmallVector<QualType, 16> WorkList;
3059 WorkList.push_back(Elt: BaseTy);
3060 while (!WorkList.empty()) {
3061 QualType T = WorkList.pop_back_val();
3062 T = T.getCanonicalType().getUnqualifiedType();
3063 assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
3064 if (const auto *AT = dyn_cast<ConstantArrayType>(Val&: T)) {
3065 llvm::SmallVector<QualType, 16> ElementFields;
3066 // Generally I've avoided recursion in this algorithm, but arrays of
3067 // structs could be time-consuming to flatten and churn through on the
3068 // work list. Hopefully nesting arrays of structs containing arrays
3069 // of structs too many levels deep is unlikely.
3070 BuildFlattenedTypeList(BaseTy: AT->getElementType(), List&: ElementFields);
3071 // Repeat the element's field list n times.
3072 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
3073 llvm::append_range(C&: List, R&: ElementFields);
3074 continue;
3075 }
3076 // Vectors can only have element types that are builtin types, so this can
3077 // add directly to the list instead of to the WorkList.
3078 if (const auto *VT = dyn_cast<VectorType>(Val&: T)) {
3079 List.insert(I: List.end(), NumToInsert: VT->getNumElements(), Elt: VT->getElementType());
3080 continue;
3081 }
3082 if (const auto *RT = dyn_cast<RecordType>(Val&: T)) {
3083 const CXXRecordDecl *RD = RT->getAsCXXRecordDecl();
3084 assert(RD && "HLSL record types should all be CXXRecordDecls!");
3085
3086 if (RD->isStandardLayout())
3087 RD = RD->getStandardLayoutBaseWithFields();
3088
3089 // For types that we shouldn't decompose (unions and non-aggregates), just
3090 // add the type itself to the list.
3091 if (RD->isUnion() || !RD->isAggregate()) {
3092 List.push_back(Elt: T);
3093 continue;
3094 }
3095
3096 llvm::SmallVector<QualType, 16> FieldTypes;
3097 for (const auto *FD : RD->fields())
3098 FieldTypes.push_back(Elt: FD->getType());
3099 // Reverse the newly added sub-range.
3100 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
3101 llvm::append_range(C&: WorkList, R&: FieldTypes);
3102
3103 // If this wasn't a standard layout type we may also have some base
3104 // classes to deal with.
3105 if (!RD->isStandardLayout()) {
3106 FieldTypes.clear();
3107 for (const auto &Base : RD->bases())
3108 FieldTypes.push_back(Elt: Base.getType());
3109 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
3110 llvm::append_range(C&: WorkList, R&: FieldTypes);
3111 }
3112 continue;
3113 }
3114 List.push_back(Elt: T);
3115 }
3116}
3117
3118bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) {
3119 // null and array types are not allowed.
3120 if (QT.isNull() || QT->isArrayType())
3121 return false;
3122
3123 // UDT types are not allowed
3124 if (QT->isRecordType())
3125 return false;
3126
3127 if (QT->isBooleanType() || QT->isEnumeralType())
3128 return false;
3129
3130 // the only other valid builtin types are scalars or vectors
3131 if (QT->isArithmeticType()) {
3132 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
3133 return false;
3134 return true;
3135 }
3136
3137 if (const VectorType *VT = QT->getAs<VectorType>()) {
3138 int ArraySize = VT->getNumElements();
3139
3140 if (ArraySize > 4)
3141 return false;
3142
3143 QualType ElTy = VT->getElementType();
3144 if (ElTy->isBooleanType())
3145 return false;
3146
3147 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
3148 return false;
3149 return true;
3150 }
3151
3152 return false;
3153}
3154
3155bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const {
3156 if (T1.isNull() || T2.isNull())
3157 return false;
3158
3159 T1 = T1.getCanonicalType().getUnqualifiedType();
3160 T2 = T2.getCanonicalType().getUnqualifiedType();
3161
3162 // If both types are the same canonical type, they're obviously compatible.
3163 if (SemaRef.getASTContext().hasSameType(T1, T2))
3164 return true;
3165
3166 llvm::SmallVector<QualType, 16> T1Types;
3167 BuildFlattenedTypeList(BaseTy: T1, List&: T1Types);
3168 llvm::SmallVector<QualType, 16> T2Types;
3169 BuildFlattenedTypeList(BaseTy: T2, List&: T2Types);
3170
3171 // Check the flattened type list
3172 return llvm::equal(LRange&: T1Types, RRange&: T2Types,
3173 P: [this](QualType LHS, QualType RHS) -> bool {
3174 return SemaRef.IsLayoutCompatible(T1: LHS, T2: RHS);
3175 });
3176}
3177
3178bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
3179 FunctionDecl *Old) {
3180 if (New->getNumParams() != Old->getNumParams())
3181 return true;
3182
3183 bool HadError = false;
3184
3185 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
3186 ParmVarDecl *NewParam = New->getParamDecl(i);
3187 ParmVarDecl *OldParam = Old->getParamDecl(i);
3188
3189 // HLSL parameter declarations for inout and out must match between
3190 // declarations. In HLSL inout and out are ambiguous at the call site,
3191 // but have different calling behavior, so you cannot overload a
3192 // method based on a difference between inout and out annotations.
3193 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
3194 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
3195 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
3196 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
3197
3198 if (NSpellingIdx != OSpellingIdx) {
3199 SemaRef.Diag(Loc: NewParam->getLocation(),
3200 DiagID: diag::err_hlsl_param_qualifier_mismatch)
3201 << NDAttr << NewParam;
3202 SemaRef.Diag(Loc: OldParam->getLocation(), DiagID: diag::note_previous_declaration_as)
3203 << ODAttr;
3204 HadError = true;
3205 }
3206 }
3207 return HadError;
3208}
3209
3210// Generally follows PerformScalarCast, with cases reordered for
3211// clarity of what types are supported
3212bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
3213
3214 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
3215 return false;
3216
3217 if (SemaRef.getASTContext().hasSameUnqualifiedType(T1: SrcTy, T2: DestTy))
3218 return true;
3219
3220 switch (SrcTy->getScalarTypeKind()) {
3221 case Type::STK_Bool: // casting from bool is like casting from an integer
3222 case Type::STK_Integral:
3223 switch (DestTy->getScalarTypeKind()) {
3224 case Type::STK_Bool:
3225 case Type::STK_Integral:
3226 case Type::STK_Floating:
3227 return true;
3228 case Type::STK_CPointer:
3229 case Type::STK_ObjCObjectPointer:
3230 case Type::STK_BlockPointer:
3231 case Type::STK_MemberPointer:
3232 llvm_unreachable("HLSL doesn't support pointers.");
3233 case Type::STK_IntegralComplex:
3234 case Type::STK_FloatingComplex:
3235 llvm_unreachable("HLSL doesn't support complex types.");
3236 case Type::STK_FixedPoint:
3237 llvm_unreachable("HLSL doesn't support fixed point types.");
3238 }
3239 llvm_unreachable("Should have returned before this");
3240
3241 case Type::STK_Floating:
3242 switch (DestTy->getScalarTypeKind()) {
3243 case Type::STK_Floating:
3244 case Type::STK_Bool:
3245 case Type::STK_Integral:
3246 return true;
3247 case Type::STK_FloatingComplex:
3248 case Type::STK_IntegralComplex:
3249 llvm_unreachable("HLSL doesn't support complex types.");
3250 case Type::STK_FixedPoint:
3251 llvm_unreachable("HLSL doesn't support fixed point types.");
3252 case Type::STK_CPointer:
3253 case Type::STK_ObjCObjectPointer:
3254 case Type::STK_BlockPointer:
3255 case Type::STK_MemberPointer:
3256 llvm_unreachable("HLSL doesn't support pointers.");
3257 }
3258 llvm_unreachable("Should have returned before this");
3259
3260 case Type::STK_MemberPointer:
3261 case Type::STK_CPointer:
3262 case Type::STK_BlockPointer:
3263 case Type::STK_ObjCObjectPointer:
3264 llvm_unreachable("HLSL doesn't support pointers.");
3265
3266 case Type::STK_FixedPoint:
3267 llvm_unreachable("HLSL doesn't support fixed point types.");
3268
3269 case Type::STK_FloatingComplex:
3270 case Type::STK_IntegralComplex:
3271 llvm_unreachable("HLSL doesn't support complex types.");
3272 }
3273
3274 llvm_unreachable("Unhandled scalar cast");
3275}
3276
3277// Detect if a type contains a bitfield. Will be removed when
3278// bitfield support is added to HLSLElementwiseCast and HLSLAggregateSplatCast
3279bool SemaHLSL::ContainsBitField(QualType BaseTy) {
3280 llvm::SmallVector<QualType, 16> WorkList;
3281 WorkList.push_back(Elt: BaseTy);
3282 while (!WorkList.empty()) {
3283 QualType T = WorkList.pop_back_val();
3284 T = T.getCanonicalType().getUnqualifiedType();
3285 // only check aggregate types
3286 if (const auto *AT = dyn_cast<ConstantArrayType>(Val&: T)) {
3287 WorkList.push_back(Elt: AT->getElementType());
3288 continue;
3289 }
3290 if (const auto *RT = dyn_cast<RecordType>(Val&: T)) {
3291 const RecordDecl *RD = RT->getDecl();
3292 if (RD->isUnion())
3293 continue;
3294
3295 const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(Val: RD);
3296
3297 if (CXXD && CXXD->isStandardLayout())
3298 RD = CXXD->getStandardLayoutBaseWithFields();
3299
3300 for (const auto *FD : RD->fields()) {
3301 if (FD->isBitField())
3302 return true;
3303 WorkList.push_back(Elt: FD->getType());
3304 }
3305 continue;
3306 }
3307 }
3308 return false;
3309}
3310
3311// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
3312// Src is a scalar or a vector of length 1
3313// Or if Dest is a vector and Src is a vector of length 1
3314bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
3315
3316 QualType SrcTy = Src->getType();
3317 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
3318 // going to be a vector splat from a scalar.
3319 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
3320 DestTy->isScalarType())
3321 return false;
3322
3323 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
3324
3325 // Src isn't a scalar or a vector of length 1
3326 if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
3327 return false;
3328
3329 if (SrcVecTy)
3330 SrcTy = SrcVecTy->getElementType();
3331
3332 if (ContainsBitField(BaseTy: DestTy))
3333 return false;
3334
3335 llvm::SmallVector<QualType> DestTypes;
3336 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
3337
3338 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
3339 if (DestTypes[I]->isUnionType())
3340 return false;
3341 if (!CanPerformScalarCast(SrcTy, DestTy: DestTypes[I]))
3342 return false;
3343 }
3344 return true;
3345}
3346
3347// Can we perform an HLSL Elementwise cast?
3348// TODO: update this code when matrices are added; see issue #88060
3349bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
3350
3351 // Don't handle casts where LHS and RHS are any combination of scalar/vector
3352 // There must be an aggregate somewhere
3353 QualType SrcTy = Src->getType();
3354 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
3355 return false;
3356
3357 if (SrcTy->isVectorType() &&
3358 (DestTy->isScalarType() || DestTy->isVectorType()))
3359 return false;
3360
3361 if (ContainsBitField(BaseTy: DestTy) || ContainsBitField(BaseTy: SrcTy))
3362 return false;
3363
3364 llvm::SmallVector<QualType> DestTypes;
3365 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
3366 llvm::SmallVector<QualType> SrcTypes;
3367 BuildFlattenedTypeList(BaseTy: SrcTy, List&: SrcTypes);
3368
3369 // Usually the size of SrcTypes must be greater than or equal to the size of
3370 // DestTypes.
3371 if (SrcTypes.size() < DestTypes.size())
3372 return false;
3373
3374 unsigned SrcSize = SrcTypes.size();
3375 unsigned DstSize = DestTypes.size();
3376 unsigned I;
3377 for (I = 0; I < DstSize && I < SrcSize; I++) {
3378 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
3379 return false;
3380 if (!CanPerformScalarCast(SrcTy: SrcTypes[I], DestTy: DestTypes[I])) {
3381 return false;
3382 }
3383 }
3384
3385 // check the rest of the source type for unions.
3386 for (; I < SrcSize; I++) {
3387 if (SrcTypes[I]->isUnionType())
3388 return false;
3389 }
3390 return true;
3391}
3392
3393ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) {
3394 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
3395 "We should not get here without a parameter modifier expression");
3396 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
3397 if (Attr->getABI() == ParameterABI::Ordinary)
3398 return ExprResult(Arg);
3399
3400 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
3401 if (!Arg->isLValue()) {
3402 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_lvalue)
3403 << Arg << (IsInOut ? 1 : 0);
3404 return ExprError();
3405 }
3406
3407 ASTContext &Ctx = SemaRef.getASTContext();
3408
3409 QualType Ty = Param->getType().getNonLValueExprType(Context: Ctx);
3410
3411 // HLSL allows implicit conversions from scalars to vectors, but not the
3412 // inverse, so we need to disallow `inout` with scalar->vector or
3413 // scalar->matrix conversions.
3414 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
3415 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_scalar_extension)
3416 << Arg << (IsInOut ? 1 : 0);
3417 return ExprError();
3418 }
3419
3420 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
3421 VK_LValue, OK_Ordinary, Arg);
3422
3423 // Parameters are initialized via copy initialization. This allows for
3424 // overload resolution of argument constructors.
3425 InitializedEntity Entity =
3426 InitializedEntity::InitializeParameter(Context&: Ctx, Type: Ty, Consumed: false);
3427 ExprResult Res =
3428 SemaRef.PerformCopyInitialization(Entity, EqualLoc: Param->getBeginLoc(), Init: ArgOpV);
3429 if (Res.isInvalid())
3430 return ExprError();
3431 Expr *Base = Res.get();
3432 // After the cast, drop the reference type when creating the exprs.
3433 Ty = Ty.getNonLValueExprType(Context: Ctx);
3434 auto *OpV = new (Ctx)
3435 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
3436
3437 // Writebacks are performed with `=` binary operator, which allows for
3438 // overload resolution on writeback result expressions.
3439 Res = SemaRef.ActOnBinOp(S: SemaRef.getCurScope(), TokLoc: Param->getBeginLoc(),
3440 Kind: tok::equal, LHSExpr: ArgOpV, RHSExpr: OpV);
3441
3442 if (Res.isInvalid())
3443 return ExprError();
3444 Expr *Writeback = Res.get();
3445 auto *OutExpr =
3446 HLSLOutArgExpr::Create(C: Ctx, Ty, Base: ArgOpV, OpV, WB: Writeback, IsInOut);
3447
3448 return ExprResult(OutExpr);
3449}
3450
3451QualType SemaHLSL::getInoutParameterType(QualType Ty) {
3452 // If HLSL gains support for references, all the cites that use this will need
3453 // to be updated with semantic checking to produce errors for
3454 // pointers/references.
3455 assert(!Ty->isReferenceType() &&
3456 "Pointer and reference types cannot be inout or out parameters");
3457 Ty = SemaRef.getASTContext().getLValueReferenceType(T: Ty);
3458 Ty.addRestrict();
3459 return Ty;
3460}
3461
3462static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
3463 QualType QT = VD->getType();
3464 return VD->getDeclContext()->isTranslationUnit() &&
3465 QT.getAddressSpace() == LangAS::Default &&
3466 VD->getStorageClass() != SC_Static &&
3467 !VD->hasAttr<HLSLVkConstantIdAttr>() &&
3468 !isInvalidConstantBufferLeafElementType(Ty: QT.getTypePtr());
3469}
3470
3471void SemaHLSL::deduceAddressSpace(VarDecl *Decl) {
3472 // The variable already has an address space (groupshared for ex).
3473 if (Decl->getType().hasAddressSpace())
3474 return;
3475
3476 if (Decl->getType()->isDependentType())
3477 return;
3478
3479 QualType Type = Decl->getType();
3480
3481 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
3482 LangAS ImplAS = LangAS::hlsl_input;
3483 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
3484 Decl->setType(Type);
3485 return;
3486 }
3487
3488 if (Type->isSamplerT() || Type->isVoidType())
3489 return;
3490
3491 // Resource handles.
3492 if (isResourceRecordTypeOrArrayOf(Ty: Type->getUnqualifiedDesugaredType()))
3493 return;
3494
3495 // Only static globals belong to the Private address space.
3496 // Non-static globals belongs to the cbuffer.
3497 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
3498 return;
3499
3500 LangAS ImplAS = LangAS::hlsl_private;
3501 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
3502 Decl->setType(Type);
3503}
3504
3505void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
3506 if (VD->hasGlobalStorage()) {
3507 // make sure the declaration has a complete type
3508 if (SemaRef.RequireCompleteType(
3509 Loc: VD->getLocation(),
3510 T: SemaRef.getASTContext().getBaseElementType(QT: VD->getType()),
3511 DiagID: diag::err_typecheck_decl_incomplete_type)) {
3512 VD->setInvalidDecl();
3513 deduceAddressSpace(Decl: VD);
3514 return;
3515 }
3516
3517 // Global variables outside a cbuffer block that are not a resource, static,
3518 // groupshared, or an empty array or struct belong to the default constant
3519 // buffer $Globals (to be created at the end of the translation unit).
3520 if (IsDefaultBufferConstantDecl(VD)) {
3521 // update address space to hlsl_constant
3522 QualType NewTy = getASTContext().getAddrSpaceQualType(
3523 T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
3524 VD->setType(NewTy);
3525 DefaultCBufferDecls.push_back(Elt: VD);
3526 }
3527
3528 // find all resources bindings on decl
3529 if (VD->getType()->isHLSLIntangibleType())
3530 collectResourceBindingsOnVarDecl(D: VD);
3531
3532 const Type *VarType = VD->getType().getTypePtr();
3533 while (VarType->isArrayType())
3534 VarType = VarType->getArrayElementTypeNoTypeQual();
3535 if (VarType->isHLSLResourceRecord() ||
3536 VD->hasAttr<HLSLVkConstantIdAttr>()) {
3537 // Make the variable for resources static. The global externally visible
3538 // storage is accessed through the handle, which is a member. The variable
3539 // itself is not externally visible.
3540 VD->setStorageClass(StorageClass::SC_Static);
3541 }
3542
3543 // process explicit bindings
3544 processExplicitBindingsOnDecl(D: VD);
3545 }
3546
3547 deduceAddressSpace(Decl: VD);
3548}
3549
3550static bool initVarDeclWithCtor(Sema &S, VarDecl *VD,
3551 MutableArrayRef<Expr *> Args) {
3552 InitializedEntity Entity = InitializedEntity::InitializeVariable(Var: VD);
3553 InitializationKind Kind = InitializationKind::CreateDirect(
3554 InitLoc: VD->getLocation(), LParenLoc: SourceLocation(), RParenLoc: SourceLocation());
3555
3556 InitializationSequence InitSeq(S, Entity, Kind, Args);
3557 if (InitSeq.Failed())
3558 return false;
3559
3560 ExprResult Init = InitSeq.Perform(S, Entity, Kind, Args);
3561 if (!Init.get())
3562 return false;
3563
3564 VD->setInit(S.MaybeCreateExprWithCleanups(SubExpr: Init.get()));
3565 VD->setInitStyle(VarDecl::CallInit);
3566 S.CheckCompleteVariableDeclaration(VD);
3567 return true;
3568}
3569
3570bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
3571 std::optional<uint32_t> RegisterSlot;
3572 uint32_t SpaceNo = 0;
3573 HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>();
3574 if (RBA) {
3575 if (RBA->hasRegisterSlot())
3576 RegisterSlot = RBA->getSlotNumber();
3577 SpaceNo = RBA->getSpaceNumber();
3578 }
3579
3580 ASTContext &AST = SemaRef.getASTContext();
3581 uint64_t UIntTySize = AST.getTypeSize(T: AST.UnsignedIntTy);
3582 uint64_t IntTySize = AST.getTypeSize(T: AST.IntTy);
3583 IntegerLiteral *RangeSize = IntegerLiteral::Create(
3584 C: AST, V: llvm::APInt(IntTySize, 1), type: AST.IntTy, l: SourceLocation());
3585 IntegerLiteral *Index = IntegerLiteral::Create(
3586 C: AST, V: llvm::APInt(UIntTySize, 0), type: AST.UnsignedIntTy, l: SourceLocation());
3587 IntegerLiteral *Space =
3588 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, SpaceNo),
3589 type: AST.UnsignedIntTy, l: SourceLocation());
3590 StringRef VarName = VD->getName();
3591 StringLiteral *Name = StringLiteral::Create(
3592 Ctx: AST, Str: VarName, Kind: StringLiteralKind::Ordinary, Pascal: false,
3593 Ty: AST.getStringLiteralArrayType(EltTy: AST.CharTy.withConst(), Length: VarName.size()),
3594 Locs: SourceLocation());
3595
3596 // resource with explicit binding
3597 if (RegisterSlot.has_value()) {
3598 IntegerLiteral *RegSlot = IntegerLiteral::Create(
3599 C: AST, V: llvm::APInt(UIntTySize, RegisterSlot.value()), type: AST.UnsignedIntTy,
3600 l: SourceLocation());
3601 Expr *Args[] = {RegSlot, Space, RangeSize, Index, Name};
3602 return initVarDeclWithCtor(S&: SemaRef, VD, Args);
3603 }
3604
3605 // resource with implicit binding
3606 IntegerLiteral *OrderId = IntegerLiteral::Create(
3607 C: AST, V: llvm::APInt(UIntTySize, getNextImplicitBindingOrderID()),
3608 type: AST.UnsignedIntTy, l: SourceLocation());
3609 Expr *Args[] = {Space, RangeSize, Index, OrderId, Name};
3610 return initVarDeclWithCtor(S&: SemaRef, VD, Args);
3611}
3612
3613// Returns true if the initialization has been handled.
3614// Returns false to use default initialization.
3615bool SemaHLSL::ActOnUninitializedVarDecl(VarDecl *VD) {
3616 // Objects in the hlsl_constant address space are initialized
3617 // externally, so don't synthesize an implicit initializer.
3618 if (VD->getType().getAddressSpace() == LangAS::hlsl_constant)
3619 return true;
3620
3621 // Initialize resources
3622 if (!isResourceRecordTypeOrArrayOf(VD))
3623 return false;
3624
3625 // FIXME: We currectly support only simple resources - no arrays of resources
3626 // or resources in user defined structs.
3627 // (llvm/llvm-project#133835, llvm/llvm-project#133837)
3628 // Initialize resources at the global scope
3629 if (VD->hasGlobalStorage() && VD->getType()->isHLSLResourceRecord())
3630 return initGlobalResourceDecl(VD);
3631
3632 return false;
3633}
3634
3635// Walks though the global variable declaration, collects all resource binding
3636// requirements and adds them to Bindings
3637void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
3638 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
3639 "expected global variable that contains HLSL resource");
3640
3641 // Cbuffers and Tbuffers are HLSLBufferDecl types
3642 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: VD)) {
3643 Bindings.addDeclBindingInfo(VD, ResClass: CBufferOrTBuffer->isCBuffer()
3644 ? ResourceClass::CBuffer
3645 : ResourceClass::SRV);
3646 return;
3647 }
3648
3649 // Unwrap arrays
3650 // FIXME: Calculate array size while unwrapping
3651 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
3652 while (Ty->isConstantArrayType()) {
3653 const ConstantArrayType *CAT = cast<ConstantArrayType>(Val: Ty);
3654 Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
3655 }
3656
3657 // Resource (or array of resources)
3658 if (const HLSLAttributedResourceType *AttrResType =
3659 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
3660 Bindings.addDeclBindingInfo(VD, ResClass: AttrResType->getAttrs().ResourceClass);
3661 return;
3662 }
3663
3664 // User defined record type
3665 if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty))
3666 collectResourceBindingsOnUserRecordDecl(VD, RT);
3667}
3668
3669// Walks though the explicit resource binding attributes on the declaration,
3670// and makes sure there is a resource that matched the binding and updates
3671// DeclBindingInfoLists
3672void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
3673 assert(VD->hasGlobalStorage() && "expected global variable");
3674
3675 bool HasBinding = false;
3676 for (Attr *A : VD->attrs()) {
3677 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: A);
3678 if (!RBA || !RBA->hasRegisterSlot())
3679 continue;
3680 HasBinding = true;
3681
3682 RegisterType RT = RBA->getRegisterType();
3683 assert(RT != RegisterType::I && "invalid or obsolete register type should "
3684 "never have an attribute created");
3685
3686 if (RT == RegisterType::C) {
3687 if (Bindings.hasBindingInfoForDecl(VD))
3688 SemaRef.Diag(Loc: VD->getLocation(),
3689 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
3690 << static_cast<int>(RT);
3691 continue;
3692 }
3693
3694 // Find DeclBindingInfo for this binding and update it, or report error
3695 // if it does not exist (user type does to contain resources with the
3696 // expected resource class).
3697 ResourceClass RC = getResourceClass(RT);
3698 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, ResClass: RC)) {
3699 // update binding info
3700 BI->setBindingAttribute(A: RBA, BT: BindingType::Explicit);
3701 } else {
3702 SemaRef.Diag(Loc: VD->getLocation(),
3703 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
3704 << static_cast<int>(RT);
3705 }
3706 }
3707
3708 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
3709 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
3710}
3711namespace {
3712class InitListTransformer {
3713 Sema &S;
3714 ASTContext &Ctx;
3715 QualType InitTy;
3716 QualType *DstIt = nullptr;
3717 Expr **ArgIt = nullptr;
3718 // Is wrapping the destination type iterator required? This is only used for
3719 // incomplete array types where we loop over the destination type since we
3720 // don't know the full number of elements from the declaration.
3721 bool Wrap;
3722
3723 bool castInitializer(Expr *E) {
3724 assert(DstIt && "This should always be something!");
3725 if (DstIt == DestTypes.end()) {
3726 if (!Wrap) {
3727 ArgExprs.push_back(Elt: E);
3728 // This is odd, but it isn't technically a failure due to conversion, we
3729 // handle mismatched counts of arguments differently.
3730 return true;
3731 }
3732 DstIt = DestTypes.begin();
3733 }
3734 InitializedEntity Entity = InitializedEntity::InitializeParameter(
3735 Context&: Ctx, Type: *DstIt, /* Consumed (ObjC) */ Consumed: false);
3736 ExprResult Res = S.PerformCopyInitialization(Entity, EqualLoc: E->getBeginLoc(), Init: E);
3737 if (Res.isInvalid())
3738 return false;
3739 Expr *Init = Res.get();
3740 ArgExprs.push_back(Elt: Init);
3741 DstIt++;
3742 return true;
3743 }
3744
3745 bool buildInitializerListImpl(Expr *E) {
3746 // If this is an initialization list, traverse the sub initializers.
3747 if (auto *Init = dyn_cast<InitListExpr>(Val: E)) {
3748 for (auto *SubInit : Init->inits())
3749 if (!buildInitializerListImpl(E: SubInit))
3750 return false;
3751 return true;
3752 }
3753
3754 // If this is a scalar type, just enqueue the expression.
3755 QualType Ty = E->getType();
3756
3757 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
3758 return castInitializer(E);
3759
3760 if (auto *VecTy = Ty->getAs<VectorType>()) {
3761 uint64_t Size = VecTy->getNumElements();
3762
3763 QualType SizeTy = Ctx.getSizeType();
3764 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
3765 for (uint64_t I = 0; I < Size; ++I) {
3766 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
3767 type: SizeTy, l: SourceLocation());
3768
3769 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
3770 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
3771 if (ElExpr.isInvalid())
3772 return false;
3773 if (!castInitializer(E: ElExpr.get()))
3774 return false;
3775 }
3776 return true;
3777 }
3778
3779 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Val: Ty.getTypePtr())) {
3780 uint64_t Size = ArrTy->getZExtSize();
3781 QualType SizeTy = Ctx.getSizeType();
3782 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
3783 for (uint64_t I = 0; I < Size; ++I) {
3784 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
3785 type: SizeTy, l: SourceLocation());
3786 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
3787 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
3788 if (ElExpr.isInvalid())
3789 return false;
3790 if (!buildInitializerListImpl(E: ElExpr.get()))
3791 return false;
3792 }
3793 return true;
3794 }
3795
3796 if (auto *RTy = Ty->getAs<RecordType>()) {
3797 llvm::SmallVector<const RecordType *> RecordTypes;
3798 RecordTypes.push_back(Elt: RTy);
3799 while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
3800 CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
3801 assert(D->getNumBases() == 1 &&
3802 "HLSL doesn't support multiple inheritance");
3803 RecordTypes.push_back(Elt: D->bases_begin()->getType()->getAs<RecordType>());
3804 }
3805 while (!RecordTypes.empty()) {
3806 const RecordType *RT = RecordTypes.pop_back_val();
3807 for (auto *FD : RT->getDecl()->fields()) {
3808 DeclAccessPair Found = DeclAccessPair::make(D: FD, AS: FD->getAccess());
3809 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
3810 ExprResult Res = S.BuildFieldReferenceExpr(
3811 BaseExpr: E, IsArrow: false, OpLoc: E->getBeginLoc(), SS: CXXScopeSpec(), Field: FD, FoundDecl: Found, MemberNameInfo: NameInfo);
3812 if (Res.isInvalid())
3813 return false;
3814 if (!buildInitializerListImpl(E: Res.get()))
3815 return false;
3816 }
3817 }
3818 }
3819 return true;
3820 }
3821
3822 Expr *generateInitListsImpl(QualType Ty) {
3823 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
3824 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
3825 return *(ArgIt++);
3826
3827 llvm::SmallVector<Expr *> Inits;
3828 assert(!isa<MatrixType>(Ty) && "Matrix types not yet supported in HLSL");
3829 Ty = Ty.getDesugaredType(Context: Ctx);
3830 if (Ty->isVectorType() || Ty->isConstantArrayType()) {
3831 QualType ElTy;
3832 uint64_t Size = 0;
3833 if (auto *ATy = Ty->getAs<VectorType>()) {
3834 ElTy = ATy->getElementType();
3835 Size = ATy->getNumElements();
3836 } else {
3837 auto *VTy = cast<ConstantArrayType>(Val: Ty.getTypePtr());
3838 ElTy = VTy->getElementType();
3839 Size = VTy->getZExtSize();
3840 }
3841 for (uint64_t I = 0; I < Size; ++I)
3842 Inits.push_back(Elt: generateInitListsImpl(Ty: ElTy));
3843 }
3844 if (auto *RTy = Ty->getAs<RecordType>()) {
3845 llvm::SmallVector<const RecordType *> RecordTypes;
3846 RecordTypes.push_back(Elt: RTy);
3847 while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
3848 CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
3849 assert(D->getNumBases() == 1 &&
3850 "HLSL doesn't support multiple inheritance");
3851 RecordTypes.push_back(Elt: D->bases_begin()->getType()->getAs<RecordType>());
3852 }
3853 while (!RecordTypes.empty()) {
3854 const RecordType *RT = RecordTypes.pop_back_val();
3855 for (auto *FD : RT->getDecl()->fields()) {
3856 Inits.push_back(Elt: generateInitListsImpl(Ty: FD->getType()));
3857 }
3858 }
3859 }
3860 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
3861 Inits, Inits.back()->getEndLoc());
3862 NewInit->setType(Ty);
3863 return NewInit;
3864 }
3865
3866public:
3867 llvm::SmallVector<QualType, 16> DestTypes;
3868 llvm::SmallVector<Expr *, 16> ArgExprs;
3869 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
3870 : S(SemaRef), Ctx(SemaRef.getASTContext()),
3871 Wrap(Entity.getType()->isIncompleteArrayType()) {
3872 InitTy = Entity.getType().getNonReferenceType();
3873 // When we're generating initializer lists for incomplete array types we
3874 // need to wrap around both when building the initializers and when
3875 // generating the final initializer lists.
3876 if (Wrap) {
3877 assert(InitTy->isIncompleteArrayType());
3878 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(T: InitTy);
3879 InitTy = IAT->getElementType();
3880 }
3881 BuildFlattenedTypeList(BaseTy: InitTy, List&: DestTypes);
3882 DstIt = DestTypes.begin();
3883 }
3884
3885 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
3886
3887 Expr *generateInitLists() {
3888 assert(!ArgExprs.empty() &&
3889 "Call buildInitializerList to generate argument expressions.");
3890 ArgIt = ArgExprs.begin();
3891 if (!Wrap)
3892 return generateInitListsImpl(Ty: InitTy);
3893 llvm::SmallVector<Expr *> Inits;
3894 while (ArgIt != ArgExprs.end())
3895 Inits.push_back(Elt: generateInitListsImpl(Ty: InitTy));
3896
3897 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
3898 Inits, Inits.back()->getEndLoc());
3899 llvm::APInt ArySize(64, Inits.size());
3900 NewInit->setType(Ctx.getConstantArrayType(EltTy: InitTy, ArySize, SizeExpr: nullptr,
3901 ASM: ArraySizeModifier::Normal, IndexTypeQuals: 0));
3902 return NewInit;
3903 }
3904};
3905} // namespace
3906
3907bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
3908 InitListExpr *Init) {
3909 // If the initializer is a scalar, just return it.
3910 if (Init->getType()->isScalarType())
3911 return true;
3912 ASTContext &Ctx = SemaRef.getASTContext();
3913 InitListTransformer ILT(SemaRef, Entity);
3914
3915 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
3916 Expr *E = Init->getInit(Init: I);
3917 if (E->HasSideEffects(Ctx)) {
3918 QualType Ty = E->getType();
3919 if (Ty->isRecordType())
3920 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
3921 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
3922 E->getObjectKind(), E);
3923 Init->setInit(Init: I, expr: E);
3924 }
3925 if (!ILT.buildInitializerList(E))
3926 return false;
3927 }
3928 size_t ExpectedSize = ILT.DestTypes.size();
3929 size_t ActualSize = ILT.ArgExprs.size();
3930 // For incomplete arrays it is completely arbitrary to choose whether we think
3931 // the user intended fewer or more elements. This implementation assumes that
3932 // the user intended more, and errors that there are too few initializers to
3933 // complete the final element.
3934 if (Entity.getType()->isIncompleteArrayType())
3935 ExpectedSize =
3936 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
3937
3938 // An initializer list might be attempting to initialize a reference or
3939 // rvalue-reference. When checking the initializer we should look through
3940 // the reference.
3941 QualType InitTy = Entity.getType().getNonReferenceType();
3942 if (InitTy.hasAddressSpace())
3943 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
3944 if (ExpectedSize != ActualSize) {
3945 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
3946 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
3947 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
3948 return false;
3949 }
3950
3951 // generateInitListsImpl will always return an InitListExpr here, because the
3952 // scalar case is handled above.
3953 auto *NewInit = cast<InitListExpr>(Val: ILT.generateInitLists());
3954 Init->resizeInits(Context: Ctx, NumInits: NewInit->getNumInits());
3955 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
3956 Init->updateInit(C: Ctx, Init: I, expr: NewInit->getInit(Init: I));
3957 return true;
3958}
3959
3960bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
3961 const HLSLVkConstantIdAttr *ConstIdAttr =
3962 VDecl->getAttr<HLSLVkConstantIdAttr>();
3963 if (!ConstIdAttr)
3964 return true;
3965
3966 ASTContext &Context = SemaRef.getASTContext();
3967
3968 APValue InitValue;
3969 if (!Init->isCXX11ConstantExpr(Ctx: Context, Result: &InitValue)) {
3970 Diag(Loc: VDecl->getLocation(), DiagID: diag::err_specialization_const);
3971 VDecl->setInvalidDecl();
3972 return false;
3973 }
3974
3975 Builtin::ID BID =
3976 getSpecConstBuiltinId(Type: VDecl->getType()->getUnqualifiedDesugaredType());
3977
3978 // Argument 1: The ID from the attribute
3979 int ConstantID = ConstIdAttr->getId();
3980 llvm::APInt IDVal(Context.getIntWidth(T: Context.IntTy), ConstantID);
3981 Expr *IdExpr = IntegerLiteral::Create(C: Context, V: IDVal, type: Context.IntTy,
3982 l: ConstIdAttr->getLocation());
3983
3984 SmallVector<Expr *, 2> Args = {IdExpr, Init};
3985 Expr *C = SemaRef.BuildBuiltinCallExpr(Loc: Init->getExprLoc(), Id: BID, CallArgs: Args);
3986 if (C->getType()->getCanonicalTypeUnqualified() !=
3987 VDecl->getType()->getCanonicalTypeUnqualified()) {
3988 C = SemaRef
3989 .BuildCStyleCastExpr(LParenLoc: SourceLocation(),
3990 Ty: Context.getTrivialTypeSourceInfo(
3991 T: Init->getType(), Loc: Init->getExprLoc()),
3992 RParenLoc: SourceLocation(), Op: C)
3993 .get();
3994 }
3995 Init = C;
3996 return true;
3997}
3998