1//===--- HLSLExternalSemaSource.cpp - HLSL Sema Source --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//
10//===----------------------------------------------------------------------===//
11
12#include "clang/Sema/HLSLExternalSemaSource.h"
13#include "HLSLBuiltinTypeDeclBuilder.h"
14#include "clang/AST/ASTContext.h"
15#include "clang/AST/Attr.h"
16#include "clang/AST/Decl.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/DeclTemplate.h"
19#include "clang/AST/Expr.h"
20#include "clang/AST/Type.h"
21#include "clang/Basic/AddressSpaces.h"
22#include "clang/Basic/SourceLocation.h"
23#include "clang/Lex/Preprocessor.h"
24#include "clang/Sema/Lookup.h"
25#include "clang/Sema/Sema.h"
26#include "clang/Sema/SemaHLSL.h"
27#include "clang/Sema/TemplateDeduction.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallVector.h"
30
31using namespace clang;
32using namespace llvm::hlsl;
33
34using clang::hlsl::BuiltinTypeDeclBuilder;
35
36void HLSLExternalSemaSource::InitializeSema(Sema &S) {
37 SemaPtr = &S;
38 ASTContext &AST = SemaPtr->getASTContext();
39 // If the translation unit has external storage force external decls to load.
40 if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage())
41 (void)AST.getTranslationUnitDecl()->decls_begin();
42
43 IdentifierInfo &HLSL = AST.Idents.get(Name: "hlsl", TokenCode: tok::TokenKind::identifier);
44 LookupResult Result(S, &HLSL, SourceLocation(), Sema::LookupNamespaceName);
45 NamespaceDecl *PrevDecl = nullptr;
46 if (S.LookupQualifiedName(R&: Result, LookupCtx: AST.getTranslationUnitDecl()))
47 PrevDecl = Result.getAsSingle<NamespaceDecl>();
48 HLSLNamespace = NamespaceDecl::Create(
49 C&: AST, DC: AST.getTranslationUnitDecl(), /*Inline=*/false, StartLoc: SourceLocation(),
50 IdLoc: SourceLocation(), Id: &HLSL, PrevDecl, /*Nested=*/false);
51 HLSLNamespace->setImplicit(true);
52 HLSLNamespace->setHasExternalLexicalStorage();
53 AST.getTranslationUnitDecl()->addDecl(D: HLSLNamespace);
54
55 // Force external decls in the HLSL namespace to load from the PCH.
56 (void)HLSLNamespace->getCanonicalDecl()->decls_begin();
57 defineTrivialHLSLTypes();
58 defineHLSLTypesWithForwardDeclarations();
59 defineHLSLAtomicIntrinsics();
60
61 // This adds a `using namespace hlsl` directive. In DXC, we don't put HLSL's
62 // built in types inside a namespace, but we are planning to change that in
63 // the near future. In order to be source compatible older versions of HLSL
64 // will need to implicitly use the hlsl namespace. For now in clang everything
65 // will get added to the namespace, and we can remove the using directive for
66 // future language versions to match HLSL's evolution.
67 auto *UsingDecl = UsingDirectiveDecl::Create(
68 C&: AST, DC: AST.getTranslationUnitDecl(), UsingLoc: SourceLocation(), NamespaceLoc: SourceLocation(),
69 QualifierLoc: NestedNameSpecifierLoc(), IdentLoc: SourceLocation(), Nominated: HLSLNamespace,
70 CommonAncestor: AST.getTranslationUnitDecl());
71
72 AST.getTranslationUnitDecl()->addDecl(D: UsingDecl);
73}
74
75void HLSLExternalSemaSource::defineHLSLVectorAlias() {
76 ASTContext &AST = SemaPtr->getASTContext();
77
78 llvm::SmallVector<NamedDecl *> TemplateParams;
79
80 auto *TypeParam = TemplateTypeParmDecl::Create(
81 C: AST, DC: HLSLNamespace, KeyLoc: SourceLocation(), NameLoc: SourceLocation(), D: 0, P: 0,
82 Id: &AST.Idents.get(Name: "element", TokenCode: tok::TokenKind::identifier), Typename: false, ParameterPack: false);
83 TypeParam->setDefaultArgument(
84 C: AST, DefArg: SemaPtr->getTrivialTemplateArgumentLoc(
85 Arg: TemplateArgument(AST.FloatTy), NTTPType: QualType(), Loc: SourceLocation()));
86
87 TemplateParams.emplace_back(Args&: TypeParam);
88
89 auto *SizeParam = NonTypeTemplateParmDecl::Create(
90 C: AST, DC: HLSLNamespace, StartLoc: SourceLocation(), IdLoc: SourceLocation(), D: 0, P: 1,
91 Id: &AST.Idents.get(Name: "element_count", TokenCode: tok::TokenKind::identifier), T: AST.IntTy,
92 ParameterPack: false, TInfo: AST.getTrivialTypeSourceInfo(T: AST.IntTy));
93 llvm::APInt Val(AST.getIntWidth(T: AST.IntTy), 4);
94 TemplateArgument Default(AST, llvm::APSInt(std::move(Val)), AST.IntTy,
95 /*IsDefaulted=*/true);
96 SizeParam->setDefaultArgument(C: AST, DefArg: SemaPtr->getTrivialTemplateArgumentLoc(
97 Arg: Default, NTTPType: AST.IntTy, Loc: SourceLocation()));
98 TemplateParams.emplace_back(Args&: SizeParam);
99
100 auto *ParamList =
101 TemplateParameterList::Create(C: AST, TemplateLoc: SourceLocation(), LAngleLoc: SourceLocation(),
102 Params: TemplateParams, RAngleLoc: SourceLocation(), RequiresClause: nullptr);
103
104 IdentifierInfo &II = AST.Idents.get(Name: "vector", TokenCode: tok::TokenKind::identifier);
105
106 QualType AliasType = AST.getDependentSizedExtVectorType(
107 VectorType: AST.getTemplateTypeParmType(Depth: 0, Index: 0, ParameterPack: false, ParmDecl: TypeParam),
108 SizeExpr: DeclRefExpr::Create(
109 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: SizeParam, RefersToEnclosingVariableOrCapture: false,
110 NameInfo: DeclarationNameInfo(SizeParam->getDeclName(), SourceLocation()),
111 T: AST.IntTy, VK: VK_LValue),
112 AttrLoc: SourceLocation());
113
114 auto *Record = TypeAliasDecl::Create(C&: AST, DC: HLSLNamespace, StartLoc: SourceLocation(),
115 IdLoc: SourceLocation(), Id: &II,
116 TInfo: AST.getTrivialTypeSourceInfo(T: AliasType));
117 Record->setImplicit(true);
118
119 auto *Template =
120 TypeAliasTemplateDecl::Create(C&: AST, DC: HLSLNamespace, L: SourceLocation(),
121 Name: Record->getIdentifier(), Params: ParamList, Decl: Record);
122
123 Record->setDescribedAliasTemplate(Template);
124 Template->setImplicit(true);
125 Template->setLexicalDeclContext(Record->getDeclContext());
126 HLSLNamespace->addDecl(D: Template);
127}
128
129void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
130 ASTContext &AST = SemaPtr->getASTContext();
131 llvm::SmallVector<NamedDecl *> TemplateParams;
132
133 auto *TypeParam = TemplateTypeParmDecl::Create(
134 C: AST, DC: HLSLNamespace, KeyLoc: SourceLocation(), NameLoc: SourceLocation(), D: 0, P: 0,
135 Id: &AST.Idents.get(Name: "element", TokenCode: tok::TokenKind::identifier), Typename: false, ParameterPack: false);
136 TypeParam->setDefaultArgument(
137 C: AST, DefArg: SemaPtr->getTrivialTemplateArgumentLoc(
138 Arg: TemplateArgument(AST.FloatTy), NTTPType: QualType(), Loc: SourceLocation()));
139
140 TemplateParams.emplace_back(Args&: TypeParam);
141
142 // these should be 64 bit to be consistent with other clang matrices.
143 auto *RowsParam = NonTypeTemplateParmDecl::Create(
144 C: AST, DC: HLSLNamespace, StartLoc: SourceLocation(), IdLoc: SourceLocation(), D: 0, P: 1,
145 Id: &AST.Idents.get(Name: "rows_count", TokenCode: tok::TokenKind::identifier), T: AST.IntTy,
146 ParameterPack: false, TInfo: AST.getTrivialTypeSourceInfo(T: AST.IntTy));
147 llvm::APInt RVal(AST.getIntWidth(T: AST.IntTy), 4);
148 TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy,
149 /*IsDefaulted=*/true);
150 RowsParam->setDefaultArgument(
151 C: AST, DefArg: SemaPtr->getTrivialTemplateArgumentLoc(Arg: RDefault, NTTPType: AST.IntTy,
152 Loc: SourceLocation()));
153 TemplateParams.emplace_back(Args&: RowsParam);
154
155 auto *ColsParam = NonTypeTemplateParmDecl::Create(
156 C: AST, DC: HLSLNamespace, StartLoc: SourceLocation(), IdLoc: SourceLocation(), D: 0, P: 2,
157 Id: &AST.Idents.get(Name: "cols_count", TokenCode: tok::TokenKind::identifier), T: AST.IntTy,
158 ParameterPack: false, TInfo: AST.getTrivialTypeSourceInfo(T: AST.IntTy));
159 llvm::APInt CVal(AST.getIntWidth(T: AST.IntTy), 4);
160 TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy,
161 /*IsDefaulted=*/true);
162 ColsParam->setDefaultArgument(
163 C: AST, DefArg: SemaPtr->getTrivialTemplateArgumentLoc(Arg: CDefault, NTTPType: AST.IntTy,
164 Loc: SourceLocation()));
165 TemplateParams.emplace_back(Args&: ColsParam);
166
167 const unsigned MaxMatDim = SemaPtr->getLangOpts().MaxMatrixDimension;
168
169 auto *MaxRow = IntegerLiteral::Create(
170 C: AST, V: llvm::APInt(AST.getIntWidth(T: AST.IntTy), MaxMatDim), type: AST.IntTy,
171 l: SourceLocation());
172 auto *MaxCol = IntegerLiteral::Create(
173 C: AST, V: llvm::APInt(AST.getIntWidth(T: AST.IntTy), MaxMatDim), type: AST.IntTy,
174 l: SourceLocation());
175
176 auto *RowsRef = DeclRefExpr::Create(
177 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: RowsParam,
178 /*RefersToEnclosingVariableOrCapture*/ false,
179 NameInfo: DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
180 T: AST.IntTy, VK: VK_LValue);
181 auto *ColsRef = DeclRefExpr::Create(
182 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: ColsParam,
183 /*RefersToEnclosingVariableOrCapture*/ false,
184 NameInfo: DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
185 T: AST.IntTy, VK: VK_LValue);
186
187 auto *RowsLE = BinaryOperator::Create(C: AST, lhs: RowsRef, rhs: MaxRow, opc: BO_LE, ResTy: AST.BoolTy,
188 VK: VK_PRValue, OK: OK_Ordinary,
189 opLoc: SourceLocation(), FPFeatures: FPOptionsOverride());
190 auto *ColsLE = BinaryOperator::Create(C: AST, lhs: ColsRef, rhs: MaxCol, opc: BO_LE, ResTy: AST.BoolTy,
191 VK: VK_PRValue, OK: OK_Ordinary,
192 opLoc: SourceLocation(), FPFeatures: FPOptionsOverride());
193
194 auto *RequiresExpr = BinaryOperator::Create(
195 C: AST, lhs: RowsLE, rhs: ColsLE, opc: BO_LAnd, ResTy: AST.BoolTy, VK: VK_PRValue, OK: OK_Ordinary,
196 opLoc: SourceLocation(), FPFeatures: FPOptionsOverride());
197
198 auto *ParamList = TemplateParameterList::Create(
199 C: AST, TemplateLoc: SourceLocation(), LAngleLoc: SourceLocation(), Params: TemplateParams, RAngleLoc: SourceLocation(),
200 RequiresClause: RequiresExpr);
201
202 IdentifierInfo &II = AST.Idents.get(Name: "matrix", TokenCode: tok::TokenKind::identifier);
203
204 QualType AliasType = AST.getDependentSizedMatrixType(
205 ElementType: AST.getTemplateTypeParmType(Depth: 0, Index: 0, ParameterPack: false, ParmDecl: TypeParam),
206 RowExpr: DeclRefExpr::Create(
207 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: RowsParam, RefersToEnclosingVariableOrCapture: false,
208 NameInfo: DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
209 T: AST.IntTy, VK: VK_LValue),
210 ColumnExpr: DeclRefExpr::Create(
211 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: ColsParam, RefersToEnclosingVariableOrCapture: false,
212 NameInfo: DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
213 T: AST.IntTy, VK: VK_LValue),
214 AttrLoc: SourceLocation());
215
216 auto *Record = TypeAliasDecl::Create(C&: AST, DC: HLSLNamespace, StartLoc: SourceLocation(),
217 IdLoc: SourceLocation(), Id: &II,
218 TInfo: AST.getTrivialTypeSourceInfo(T: AliasType));
219 Record->setImplicit(true);
220
221 auto *Template =
222 TypeAliasTemplateDecl::Create(C&: AST, DC: HLSLNamespace, L: SourceLocation(),
223 Name: Record->getIdentifier(), Params: ParamList, Decl: Record);
224
225 Record->setDescribedAliasTemplate(Template);
226 Template->setImplicit(true);
227 Template->setLexicalDeclContext(Record->getDeclContext());
228 HLSLNamespace->addDecl(D: Template);
229}
230
231void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
232 defineHLSLVectorAlias();
233 defineHLSLMatrixAlias();
234}
235
236/// Set up common members and attributes for buffer types
237static BuiltinTypeDeclBuilder setupBufferType(CXXRecordDecl *Decl, Sema &S,
238 ResourceClass RC, bool IsROV,
239 bool RawBuffer, bool HasCounter) {
240 return BuiltinTypeDeclBuilder(S, Decl)
241 .addBufferHandles(RC, IsROV, RawBuffer, HasCounter)
242 .addDefaultHandleConstructor()
243 .addCopyConstructor()
244 .addCopyAssignmentOperator()
245 .addStaticInitializationFunctions(HasCounter);
246}
247
248/// Set up common members and attributes for sampler types
249static BuiltinTypeDeclBuilder setupSamplerType(CXXRecordDecl *Decl, Sema &S) {
250 return BuiltinTypeDeclBuilder(S, Decl)
251 .addSamplerHandle()
252 .addDefaultHandleConstructor()
253 .addCopyConstructor()
254 .addCopyAssignmentOperator()
255 .addStaticInitializationFunctions(HasCounter: false);
256}
257
258/// Set up common members and attributes for texture types
259static BuiltinTypeDeclBuilder setupTextureType(CXXRecordDecl *Decl, Sema &S,
260 ResourceClass RC, bool IsROV,
261 bool IsArray,
262 ResourceDimension Dim) {
263 return BuiltinTypeDeclBuilder(S, Decl)
264 .addTextureHandle(RC, IsROV, IsArray, RD: Dim)
265 .addTextureLoadMethods(Dim, IsArray)
266 .addArraySubscriptOperators(Dim, IsArray)
267 .addMipsMember(Dim)
268 .addDefaultHandleConstructor()
269 .addCopyConstructor()
270 .addCopyAssignmentOperator()
271 .addStaticInitializationFunctions(HasCounter: false)
272 .addSampleMethods(Dim, IsArray)
273 .addSampleBiasMethods(Dim, IsArray)
274 .addSampleGradMethods(Dim, IsArray)
275 .addSampleLevelMethods(Dim, IsArray)
276 .addSampleCmpMethods(Dim, IsArray)
277 .addSampleCmpLevelZeroMethods(Dim, IsArray)
278 .addCalculateLodMethods(Dim)
279 .addGetDimensionsMethods(Dim)
280 .addGatherMethods(Dim, IsArray)
281 .addGatherCmpMethods(Dim, IsArray);
282}
283
284// Add a partial specialization for a template. The `TextureTemplate` is
285// `Texture<element_type>`, and it will be specialized for vectors:
286// `Texture<vector<element_type, element_count>>`.
287static ClassTemplatePartialSpecializationDecl *
288addVectorTexturePartialSpecialization(Sema &S, NamespaceDecl *HLSLNamespace,
289 ClassTemplateDecl *TextureTemplate) {
290 ASTContext &AST = S.getASTContext();
291
292 // Create the template parameters: element_type and element_count.
293 auto *ElementType = TemplateTypeParmDecl::Create(
294 C: AST, DC: HLSLNamespace, KeyLoc: SourceLocation(), NameLoc: SourceLocation(), D: 0, P: 0,
295 Id: &AST.Idents.get(Name: "element_type"), Typename: false, ParameterPack: false);
296 auto *ElementCount = NonTypeTemplateParmDecl::Create(
297 C: AST, DC: HLSLNamespace, StartLoc: SourceLocation(), IdLoc: SourceLocation(), D: 0, P: 1,
298 Id: &AST.Idents.get(Name: "element_count"), T: AST.IntTy, ParameterPack: false,
299 TInfo: AST.getTrivialTypeSourceInfo(T: AST.IntTy));
300
301 auto *TemplateParams = TemplateParameterList::Create(
302 C: AST, TemplateLoc: SourceLocation(), LAngleLoc: SourceLocation(), Params: {ElementType, ElementCount},
303 RAngleLoc: SourceLocation(), RequiresClause: nullptr);
304
305 // Create the dependent vector type: vector<element_type, element_count>.
306 QualType VectorType = AST.getDependentSizedExtVectorType(
307 VectorType: AST.getTemplateTypeParmType(Depth: 0, Index: 0, ParameterPack: false, ParmDecl: ElementType),
308 SizeExpr: DeclRefExpr::Create(
309 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: ElementCount, RefersToEnclosingVariableOrCapture: false,
310 NameInfo: DeclarationNameInfo(ElementCount->getDeclName(), SourceLocation()),
311 T: AST.IntTy, VK: VK_LValue),
312 AttrLoc: SourceLocation());
313
314 // Create the partial specialization declaration.
315 QualType CanonInjectedTST =
316 AST.getCanonicalType(T: AST.getTemplateSpecializationType(
317 Keyword: ElaboratedTypeKeyword::Class, T: TemplateName(TextureTemplate),
318 SpecifiedArgs: {TemplateArgument(VectorType)}, CanonicalArgs: {}));
319
320 auto *PartialSpec = ClassTemplatePartialSpecializationDecl::Create(
321 Context&: AST, TK: TagDecl::TagKind::Class, DC: HLSLNamespace, StartLoc: SourceLocation(),
322 IdLoc: SourceLocation(), Params: TemplateParams, SpecializedTemplate: TextureTemplate,
323 Args: {TemplateArgument(VectorType)},
324 CanonInjectedTST: CanQualType::CreateUnsafe(Other: CanonInjectedTST), PrevDecl: nullptr);
325
326 // Set the template arguments as written.
327 TemplateArgument Arg(VectorType);
328 TemplateArgumentLoc ArgLoc =
329 S.getTrivialTemplateArgumentLoc(Arg, NTTPType: QualType(), Loc: SourceLocation());
330 TemplateArgumentListInfo ArgsInfo =
331 TemplateArgumentListInfo(SourceLocation(), SourceLocation());
332 ArgsInfo.addArgument(Loc: ArgLoc);
333 PartialSpec->setTemplateArgsAsWritten(
334 ASTTemplateArgumentListInfo::Create(C: AST, List: ArgsInfo));
335
336 PartialSpec->setImplicit(true);
337 PartialSpec->setLexicalDeclContext(HLSLNamespace);
338 PartialSpec->setHasExternalLexicalStorage();
339
340 // Add the partial specialization to the namespace and the class template.
341 HLSLNamespace->addDecl(D: PartialSpec);
342 TextureTemplate->AddPartialSpecialization(D: PartialSpec, InsertPos: nullptr);
343
344 return PartialSpec;
345}
346
347// This function is responsible for constructing the constraint expression for
348// this concept:
349// template<typename T> concept is_typed_resource_element_compatible =
350// __is_typed_resource_element_compatible<T>;
351static Expr *constructTypedBufferConstraintExpr(Sema &S, SourceLocation NameLoc,
352 TemplateTypeParmDecl *T) {
353 ASTContext &Context = S.getASTContext();
354
355 // Obtain the QualType for 'bool'
356 QualType BoolTy = Context.BoolTy;
357
358 // Create a QualType that points to this TemplateTypeParmDecl
359 QualType TType = Context.getTypeDeclType(Decl: T);
360
361 // Create a TypeSourceInfo for the template type parameter 'T'
362 TypeSourceInfo *TTypeSourceInfo =
363 Context.getTrivialTypeSourceInfo(T: TType, Loc: NameLoc);
364
365 TypeTraitExpr *TypedResExpr = TypeTraitExpr::Create(
366 C: Context, T: BoolTy, Loc: NameLoc, Kind: UTT_IsTypedResourceElementCompatible,
367 Args: {TTypeSourceInfo}, RParenLoc: NameLoc, Value: true);
368
369 return TypedResExpr;
370}
371
372// This function is responsible for constructing the constraint expression for
373// this concept:
374// template<typename T> concept is_constant_buffer_element_compatible =
375// std::is_class_v<T> && !__is_intangible(T);
376static Expr *constructConstantBufferConstraintExpr(Sema &S,
377 SourceLocation NameLoc,
378 TemplateTypeParmDecl *T) {
379 ASTContext &Context = S.getASTContext();
380
381 // Obtain the QualType for 'bool'
382 QualType BoolTy = Context.BoolTy;
383
384 // Create a QualType that points to this TemplateTypeParmDecl
385 QualType TType = Context.getTypeDeclType(Decl: T);
386
387 // Create a TypeSourceInfo for the template type parameter 'T'
388 TypeSourceInfo *TTypeSourceInfo =
389 Context.getTrivialTypeSourceInfo(T: TType, Loc: NameLoc);
390
391 TypeTraitExpr *ResExpr = TypeTraitExpr::Create(
392 C: Context, T: BoolTy, Loc: NameLoc, Kind: UTT_IsConstantBufferElementCompatible,
393 Args: {TTypeSourceInfo}, RParenLoc: NameLoc, Value: true);
394
395 return ResExpr;
396}
397
398// This function is responsible for constructing the constraint expression for
399// this concept:
400// template<typename T> concept is_structured_resource_element_compatible =
401// !__is_intangible<T> && sizeof(T) >= 1;
402static Expr *constructStructuredBufferConstraintExpr(Sema &S,
403 SourceLocation NameLoc,
404 TemplateTypeParmDecl *T) {
405 ASTContext &Context = S.getASTContext();
406
407 // Obtain the QualType for 'bool'
408 QualType BoolTy = Context.BoolTy;
409
410 // Create a QualType that points to this TemplateTypeParmDecl
411 QualType TType = Context.getTypeDeclType(Decl: T);
412
413 // Create a TypeSourceInfo for the template type parameter 'T'
414 TypeSourceInfo *TTypeSourceInfo =
415 Context.getTrivialTypeSourceInfo(T: TType, Loc: NameLoc);
416
417 TypeTraitExpr *IsIntangibleExpr =
418 TypeTraitExpr::Create(C: Context, T: BoolTy, Loc: NameLoc, Kind: UTT_IsIntangibleType,
419 Args: {TTypeSourceInfo}, RParenLoc: NameLoc, Value: true);
420
421 // negate IsIntangibleExpr
422 UnaryOperator *NotIntangibleExpr = UnaryOperator::Create(
423 C: Context, input: IsIntangibleExpr, opc: UO_LNot, type: BoolTy, VK: VK_LValue, OK: OK_Ordinary,
424 l: NameLoc, CanOverflow: false, FPFeatures: FPOptionsOverride());
425
426 // element types also may not be of 0 size
427 UnaryExprOrTypeTraitExpr *SizeOfExpr = new (Context) UnaryExprOrTypeTraitExpr(
428 UETT_SizeOf, TTypeSourceInfo, BoolTy, NameLoc, NameLoc);
429
430 // Create a BinaryOperator that checks if the size of the type is not equal to
431 // 1 Empty structs have a size of 1 in HLSL, so we need to check for that
432 IntegerLiteral *rhs = IntegerLiteral::Create(
433 C: Context, V: llvm::APInt(Context.getTypeSize(T: Context.getSizeType()), 1, true),
434 type: Context.getSizeType(), l: NameLoc);
435
436 BinaryOperator *SizeGEQOneExpr =
437 BinaryOperator::Create(C: Context, lhs: SizeOfExpr, rhs, opc: BO_GE, ResTy: BoolTy, VK: VK_LValue,
438 OK: OK_Ordinary, opLoc: NameLoc, FPFeatures: FPOptionsOverride());
439
440 // Combine the two constraints
441 BinaryOperator *CombinedExpr = BinaryOperator::Create(
442 C: Context, lhs: NotIntangibleExpr, rhs: SizeGEQOneExpr, opc: BO_LAnd, ResTy: BoolTy, VK: VK_LValue,
443 OK: OK_Ordinary, opLoc: NameLoc, FPFeatures: FPOptionsOverride());
444
445 return CombinedExpr;
446}
447
448enum class HLSLBufferType { Typed, Structured, Constant };
449
450static ConceptDecl *constructBufferConceptDecl(Sema &S, NamespaceDecl *NSD,
451 HLSLBufferType BT) {
452 ASTContext &Context = S.getASTContext();
453 DeclContext *DC = NSD->getDeclContext();
454 SourceLocation DeclLoc = SourceLocation();
455
456 IdentifierInfo &ElementTypeII = Context.Idents.get(Name: "element_type");
457 TemplateTypeParmDecl *T = TemplateTypeParmDecl::Create(
458 C: Context, DC: NSD->getDeclContext(), KeyLoc: DeclLoc, NameLoc: DeclLoc,
459 /*D=*/0,
460 /*P=*/0,
461 /*Id=*/&ElementTypeII,
462 /*Typename=*/true,
463 /*ParameterPack=*/false);
464
465 T->setDeclContext(DC);
466 T->setReferenced();
467
468 // Create and Attach Template Parameter List to ConceptDecl
469 TemplateParameterList *ConceptParams = TemplateParameterList::Create(
470 C: Context, TemplateLoc: DeclLoc, LAngleLoc: DeclLoc, Params: {T}, RAngleLoc: DeclLoc, RequiresClause: nullptr);
471
472 DeclarationName DeclName;
473 Expr *ConstraintExpr = nullptr;
474
475 switch (BT) {
476 case HLSLBufferType::Typed:
477 DeclName = DeclarationName(
478 &Context.Idents.get(Name: "__is_typed_resource_element_compatible"));
479 ConstraintExpr = constructTypedBufferConstraintExpr(S, NameLoc: DeclLoc, T);
480 break;
481 case HLSLBufferType::Structured:
482 DeclName = DeclarationName(
483 &Context.Idents.get(Name: "__is_structured_resource_element_compatible"));
484 ConstraintExpr = constructStructuredBufferConstraintExpr(S, NameLoc: DeclLoc, T);
485 break;
486 case HLSLBufferType::Constant:
487 DeclName = DeclarationName(
488 &Context.Idents.get(Name: "__is_constant_buffer_element_compatible"));
489 ConstraintExpr = constructConstantBufferConstraintExpr(S, NameLoc: DeclLoc, T);
490 break;
491 }
492
493 // Create a ConceptDecl
494 ConceptDecl *CD =
495 ConceptDecl::Create(C&: Context, DC: NSD->getDeclContext(), L: DeclLoc, Name: DeclName,
496 Params: ConceptParams, ConstraintExpr);
497
498 // Attach the template parameter list to the ConceptDecl
499 CD->setTemplateParameters(ConceptParams);
500
501 // Add the concept declaration to the Translation Unit Decl
502 NSD->getDeclContext()->addDecl(D: CD);
503
504 return CD;
505}
506
507void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
508 ASTContext &AST = SemaPtr->getASTContext();
509 CXXRecordDecl *Decl;
510 ConceptDecl *TypedBufferConcept = constructBufferConceptDecl(
511 S&: *SemaPtr, NSD: HLSLNamespace, BT: HLSLBufferType::Typed);
512 ConceptDecl *StructuredBufferConcept = constructBufferConceptDecl(
513 S&: *SemaPtr, NSD: HLSLNamespace, BT: HLSLBufferType::Structured);
514 ConceptDecl *ConstantBufferConcept = constructBufferConceptDecl(
515 S&: *SemaPtr, NSD: HLSLNamespace, BT: HLSLBufferType::Constant);
516
517 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "ConstantBuffer")
518 .addSimpleTemplateParams(Names: {"element_type"}, CD: ConstantBufferConcept)
519 .finalizeForwardDeclaration();
520
521 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
522 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::CBuffer, /*IsROV=*/false,
523 /*RawBuffer=*/false, /*HasCounter=*/false)
524 .addConstantBufferConversionToType()
525 .completeDefinition();
526 });
527
528 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Buffer")
529 .addSimpleTemplateParams(Names: {"element_type"}, CD: TypedBufferConcept)
530 .finalizeForwardDeclaration();
531
532 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
533 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
534 /*RawBuffer=*/false, /*HasCounter=*/false)
535 .addArraySubscriptOperators()
536 .addLoadMethods()
537 .addGetDimensionsMethodForBuffer()
538 .completeDefinition();
539 });
540
541 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
542 .addSimpleTemplateParams(Names: {"element_type"}, CD: TypedBufferConcept)
543 .finalizeForwardDeclaration();
544
545 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
546 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/false,
547 /*RawBuffer=*/false, /*HasCounter=*/false)
548 .addArraySubscriptOperators()
549 .addLoadMethods()
550 .addGetDimensionsMethodForBuffer()
551 .completeDefinition();
552 });
553
554 Decl =
555 BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RasterizerOrderedBuffer")
556 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
557 .finalizeForwardDeclaration();
558 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
559 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/true,
560 /*RawBuffer=*/false, /*HasCounter=*/false)
561 .addArraySubscriptOperators()
562 .addLoadMethods()
563 .addGetDimensionsMethodForBuffer()
564 .completeDefinition();
565 });
566
567 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "StructuredBuffer")
568 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
569 .finalizeForwardDeclaration();
570 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
571 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
572 /*RawBuffer=*/true, /*HasCounter=*/false)
573 .addArraySubscriptOperators()
574 .addLoadMethods()
575 .addGetDimensionsMethodForBuffer()
576 .completeDefinition();
577 });
578
579 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWStructuredBuffer")
580 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
581 .finalizeForwardDeclaration();
582 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
583 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/false,
584 /*RawBuffer=*/true, /*HasCounter=*/true)
585 .addArraySubscriptOperators()
586 .addLoadMethods()
587 .addIncrementCounterMethod()
588 .addDecrementCounterMethod()
589 .addGetDimensionsMethodForBuffer()
590 .completeDefinition();
591 });
592
593 Decl =
594 BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "AppendStructuredBuffer")
595 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
596 .finalizeForwardDeclaration();
597 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
598 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/false,
599 /*RawBuffer=*/true, /*HasCounter=*/true)
600 .addAppendMethod()
601 .addGetDimensionsMethodForBuffer()
602 .completeDefinition();
603 });
604
605 Decl =
606 BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "ConsumeStructuredBuffer")
607 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
608 .finalizeForwardDeclaration();
609 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
610 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/false,
611 /*RawBuffer=*/true, /*HasCounter=*/true)
612 .addConsumeMethod()
613 .addGetDimensionsMethodForBuffer()
614 .completeDefinition();
615 });
616
617 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace,
618 "RasterizerOrderedStructuredBuffer")
619 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
620 .finalizeForwardDeclaration();
621 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
622 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/true,
623 /*RawBuffer=*/true, /*HasCounter=*/true)
624 .addArraySubscriptOperators()
625 .addLoadMethods()
626 .addIncrementCounterMethod()
627 .addDecrementCounterMethod()
628 .addGetDimensionsMethodForBuffer()
629 .completeDefinition();
630 });
631
632 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "ByteAddressBuffer")
633 .finalizeForwardDeclaration();
634 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
635 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
636 /*RawBuffer=*/true, /*HasCounter=*/false)
637 .addByteAddressBufferLoadMethods()
638 .addGetDimensionsMethodForBuffer()
639 .completeDefinition();
640 });
641 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWByteAddressBuffer")
642 .finalizeForwardDeclaration();
643 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
644 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/false,
645 /*RawBuffer=*/true, /*HasCounter=*/false)
646 .addByteAddressBufferLoadMethods()
647 .addByteAddressBufferStoreMethods()
648 .addGetDimensionsMethodForBuffer()
649 .completeDefinition();
650 });
651 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace,
652 "RasterizerOrderedByteAddressBuffer")
653 .finalizeForwardDeclaration();
654 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
655 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/true,
656 /*RawBuffer=*/true, /*HasCounter=*/false)
657 .addGetDimensionsMethodForBuffer()
658 .completeDefinition();
659 });
660
661 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "SamplerState")
662 .finalizeForwardDeclaration();
663 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
664 setupSamplerType(Decl, S&: *SemaPtr).completeDefinition();
665 });
666
667 Decl =
668 BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "SamplerComparisonState")
669 .finalizeForwardDeclaration();
670 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
671 setupSamplerType(Decl, S&: *SemaPtr).completeDefinition();
672 });
673
674 QualType Float4Ty = AST.getExtVectorType(VectorType: AST.FloatTy, NumElts: 4);
675 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Texture2D")
676 .addSimpleTemplateParams(Names: {"element_type"}, DefaultTypes: {Float4Ty},
677 CD: TypedBufferConcept)
678 .finalizeForwardDeclaration();
679
680 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
681 setupTextureType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
682 /*IsArray=*/false, Dim: ResourceDimension::Dim2D)
683 .completeDefinition();
684 });
685
686 auto *PartialSpec = addVectorTexturePartialSpecialization(
687 S&: *SemaPtr, HLSLNamespace, TextureTemplate: Decl->getDescribedClassTemplate());
688 onCompletion(Record: PartialSpec, Fn: [this](CXXRecordDecl *Decl) {
689 setupTextureType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
690 /*IsArray=*/false, Dim: ResourceDimension::Dim2D)
691 .completeDefinition();
692 });
693
694 // Texture2DArray — same as Texture2D but IsArray=true
695 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Texture2DArray")
696 .addSimpleTemplateParams(Names: {"element_type"}, DefaultTypes: {Float4Ty},
697 CD: TypedBufferConcept)
698 .finalizeForwardDeclaration();
699
700 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
701 setupTextureType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
702 /*IsArray=*/true, Dim: ResourceDimension::Dim2D)
703 .completeDefinition();
704 });
705
706 auto *PartialSpec2DA = addVectorTexturePartialSpecialization(
707 S&: *SemaPtr, HLSLNamespace, TextureTemplate: Decl->getDescribedClassTemplate());
708 onCompletion(Record: PartialSpec2DA, Fn: [this](CXXRecordDecl *Decl) {
709 setupTextureType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
710 /*IsArray=*/true, Dim: ResourceDimension::Dim2D)
711 .completeDefinition();
712 });
713}
714
715// Build a single overload of an HLSL atomic intrinsic in the hlsl namespace.
716// `dest` is an address-space-qualified reference; `original_value` (when
717// present) is a plain reference. The synthesized FunctionDecl aliases the
718// underlying clang builtin via BuiltinAliasAttr.
719static void buildAtomicOverload(Sema &S, NamespaceDecl *NS, StringRef FuncName,
720 StringRef BuiltinName, QualType ElemTy,
721 LangAS DestAS, bool ThreeArg) {
722 ASTContext &AST = S.getASTContext();
723
724 QualType DestTy =
725 AST.getLValueReferenceType(T: AST.getAddrSpaceQualType(T: ElemTy, AddressSpace: DestAS));
726 QualType OrigRefTy = AST.getLValueReferenceType(T: ElemTy);
727
728 SmallVector<QualType, 3> ParamTypes;
729 ParamTypes.push_back(Elt: DestTy);
730 ParamTypes.push_back(Elt: ElemTy);
731 if (ThreeArg)
732 ParamTypes.push_back(Elt: OrigRefTy);
733
734 FunctionProtoType::ExtProtoInfo EPI;
735 QualType FuncTy = AST.getFunctionType(ResultTy: AST.VoidTy, Args: ParamTypes, EPI);
736 auto *TSInfo = AST.getTrivialTypeSourceInfo(T: FuncTy, Loc: SourceLocation());
737
738 IdentifierInfo &FuncII = AST.Idents.get(Name: FuncName, TokenCode: tok::TokenKind::identifier);
739 DeclarationName FuncDeclName(&FuncII);
740
741 FunctionDecl *FD = FunctionDecl::Create(
742 C&: AST, DC: NS, StartLoc: SourceLocation(), NLoc: SourceLocation(), N: FuncDeclName, T: FuncTy, TInfo: TSInfo,
743 SC: SC_Extern, /*UsesFPIntrin=*/false, /*isInlineSpecified=*/false,
744 /*hasWrittenPrototype=*/true);
745
746 constexpr const char *ParamNames[] = {"dest", "value", "original_value"};
747 SmallVector<ParmVarDecl *, 3> ParmDecls;
748 unsigned I = 0;
749 for (auto [ParamType, ParamName] : llvm::zip(t&: ParamTypes, u: ParamNames)) {
750 IdentifierInfo &PII = AST.Idents.get(Name: ParamName, TokenCode: tok::TokenKind::identifier);
751 ParmVarDecl *Parm = ParmVarDecl::Create(
752 C&: AST, DC: FD, StartLoc: SourceLocation(), IdLoc: SourceLocation(), Id: &PII, T: ParamType,
753 TInfo: AST.getTrivialTypeSourceInfo(T: ParamType, Loc: SourceLocation()), S: SC_None,
754 DefArg: nullptr);
755 Parm->setScopeInfo(scopeDepth: 0, parameterIndex: I++);
756 ParmDecls.push_back(Elt: Parm);
757 }
758 FD->setParams(ParmDecls);
759
760 IdentifierInfo &BuiltinII =
761 S.getPreprocessor().getIdentifierTable().get(Name: BuiltinName);
762 FD->addAttr(A: BuiltinAliasAttr::CreateImplicit(Ctx&: AST, BuiltinName: &BuiltinII));
763 FD->setImplicit();
764 NS->addDecl(D: FD);
765}
766
767// Synthesize the InterlockedAdd overload set: {int, uint, int64_t, uint64_t}
768// x {groupshared, device} x {2-arg, 3-arg}.
769static void defineHLSLInterlockedAdd(Sema &S, NamespaceDecl *NS) {
770 ASTContext &AST = S.getASTContext();
771 // HLSL: int64_t == long, uint64_t == unsigned long (see hlsl_basic_types.h).
772 QualType Elems[] = {AST.IntTy, AST.UnsignedIntTy, AST.LongTy,
773 AST.UnsignedLongTy};
774 LangAS AddrSpaces[] = {LangAS::hlsl_groupshared, LangAS::hlsl_device};
775
776 for (QualType ElemTy : Elems)
777 for (LangAS AS : AddrSpaces)
778 for (bool ThreeArg : {false, true})
779 buildAtomicOverload(S, NS, FuncName: "InterlockedAdd",
780 BuiltinName: "__builtin_hlsl_interlocked_add", ElemTy, DestAS: AS,
781 ThreeArg);
782}
783
784void HLSLExternalSemaSource::defineHLSLAtomicIntrinsics() {
785 defineHLSLInterlockedAdd(S&: *SemaPtr, NS: HLSLNamespace);
786}
787
788void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
789 CompletionFunction Fn) {
790 if (!Record->isCompleteDefinition())
791 Completions.insert(KV: std::make_pair(x: Record->getCanonicalDecl(), y&: Fn));
792}
793
794void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
795 if (!isa<CXXRecordDecl>(Val: Tag))
796 return;
797 auto Record = cast<CXXRecordDecl>(Val: Tag);
798
799 // If this is a specialization, we need to get the underlying templated
800 // declaration and complete that.
801 if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Val: Record)) {
802 if (!isa<ClassTemplatePartialSpecializationDecl>(Val: TDecl)) {
803 ClassTemplateDecl *Template = TDecl->getSpecializedTemplate();
804 llvm::SmallVector<ClassTemplatePartialSpecializationDecl *, 4> Partials;
805 Template->getPartialSpecializations(PS&: Partials);
806 ClassTemplatePartialSpecializationDecl *MatchedPartial = nullptr;
807 for (auto *Partial : Partials) {
808 sema::TemplateDeductionInfo Info(TDecl->getLocation());
809 if (SemaPtr->DeduceTemplateArguments(Partial, TemplateArgs: TDecl->getTemplateArgs(),
810 Info) ==
811 TemplateDeductionResult::Success) {
812 MatchedPartial = Partial;
813 break;
814 }
815 }
816 if (MatchedPartial)
817 Record = MatchedPartial;
818 else
819 Record = Template->getTemplatedDecl();
820 }
821 }
822 Record = Record->getCanonicalDecl();
823 auto It = Completions.find(Val: Record);
824 if (It == Completions.end())
825 return;
826 // Move out the callback and erase before invoking it: the callback can
827 // re-enter CompleteType and mutate Completions, which invalidates It under
828 // backward-shift deletion.
829 CompletionFunction Fn = std::move(It->second);
830 Completions.erase(I: It);
831 Fn(Record);
832}
833