1//===--- HLSLExternalSemaSource.cpp - HLSL Sema Source --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//
10//===----------------------------------------------------------------------===//
11
12#include "clang/Sema/HLSLExternalSemaSource.h"
13#include "HLSLBuiltinTypeDeclBuilder.h"
14#include "clang/AST/ASTContext.h"
15#include "clang/AST/Attr.h"
16#include "clang/AST/Decl.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/DeclTemplate.h"
19#include "clang/AST/Expr.h"
20#include "clang/AST/Type.h"
21#include "clang/Basic/SourceLocation.h"
22#include "clang/Sema/Lookup.h"
23#include "clang/Sema/Sema.h"
24#include "clang/Sema/SemaHLSL.h"
25#include "clang/Sema/TemplateDeduction.h"
26#include "llvm/ADT/SmallVector.h"
27
28using namespace clang;
29using namespace llvm::hlsl;
30
31using clang::hlsl::BuiltinTypeDeclBuilder;
32
33void HLSLExternalSemaSource::InitializeSema(Sema &S) {
34 SemaPtr = &S;
35 ASTContext &AST = SemaPtr->getASTContext();
36 // If the translation unit has external storage force external decls to load.
37 if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage())
38 (void)AST.getTranslationUnitDecl()->decls_begin();
39
40 IdentifierInfo &HLSL = AST.Idents.get(Name: "hlsl", TokenCode: tok::TokenKind::identifier);
41 LookupResult Result(S, &HLSL, SourceLocation(), Sema::LookupNamespaceName);
42 NamespaceDecl *PrevDecl = nullptr;
43 if (S.LookupQualifiedName(R&: Result, LookupCtx: AST.getTranslationUnitDecl()))
44 PrevDecl = Result.getAsSingle<NamespaceDecl>();
45 HLSLNamespace = NamespaceDecl::Create(
46 C&: AST, DC: AST.getTranslationUnitDecl(), /*Inline=*/false, StartLoc: SourceLocation(),
47 IdLoc: SourceLocation(), Id: &HLSL, PrevDecl, /*Nested=*/false);
48 HLSLNamespace->setImplicit(true);
49 HLSLNamespace->setHasExternalLexicalStorage();
50 AST.getTranslationUnitDecl()->addDecl(D: HLSLNamespace);
51
52 // Force external decls in the HLSL namespace to load from the PCH.
53 (void)HLSLNamespace->getCanonicalDecl()->decls_begin();
54 defineTrivialHLSLTypes();
55 defineHLSLTypesWithForwardDeclarations();
56
57 // This adds a `using namespace hlsl` directive. In DXC, we don't put HLSL's
58 // built in types inside a namespace, but we are planning to change that in
59 // the near future. In order to be source compatible older versions of HLSL
60 // will need to implicitly use the hlsl namespace. For now in clang everything
61 // will get added to the namespace, and we can remove the using directive for
62 // future language versions to match HLSL's evolution.
63 auto *UsingDecl = UsingDirectiveDecl::Create(
64 C&: AST, DC: AST.getTranslationUnitDecl(), UsingLoc: SourceLocation(), NamespaceLoc: SourceLocation(),
65 QualifierLoc: NestedNameSpecifierLoc(), IdentLoc: SourceLocation(), Nominated: HLSLNamespace,
66 CommonAncestor: AST.getTranslationUnitDecl());
67
68 AST.getTranslationUnitDecl()->addDecl(D: UsingDecl);
69}
70
71void HLSLExternalSemaSource::defineHLSLVectorAlias() {
72 ASTContext &AST = SemaPtr->getASTContext();
73
74 llvm::SmallVector<NamedDecl *> TemplateParams;
75
76 auto *TypeParam = TemplateTypeParmDecl::Create(
77 C: AST, DC: HLSLNamespace, KeyLoc: SourceLocation(), NameLoc: SourceLocation(), D: 0, P: 0,
78 Id: &AST.Idents.get(Name: "element", TokenCode: tok::TokenKind::identifier), Typename: false, ParameterPack: false);
79 TypeParam->setDefaultArgument(
80 C: AST, DefArg: SemaPtr->getTrivialTemplateArgumentLoc(
81 Arg: TemplateArgument(AST.FloatTy), NTTPType: QualType(), Loc: SourceLocation()));
82
83 TemplateParams.emplace_back(Args&: TypeParam);
84
85 auto *SizeParam = NonTypeTemplateParmDecl::Create(
86 C: AST, DC: HLSLNamespace, StartLoc: SourceLocation(), IdLoc: SourceLocation(), D: 0, P: 1,
87 Id: &AST.Idents.get(Name: "element_count", TokenCode: tok::TokenKind::identifier), T: AST.IntTy,
88 ParameterPack: false, TInfo: AST.getTrivialTypeSourceInfo(T: AST.IntTy));
89 llvm::APInt Val(AST.getIntWidth(T: AST.IntTy), 4);
90 TemplateArgument Default(AST, llvm::APSInt(std::move(Val)), AST.IntTy,
91 /*IsDefaulted=*/true);
92 SizeParam->setDefaultArgument(
93 C: AST, DefArg: SemaPtr->getTrivialTemplateArgumentLoc(Arg: Default, NTTPType: AST.IntTy,
94 Loc: SourceLocation(), TemplateParam: SizeParam));
95 TemplateParams.emplace_back(Args&: SizeParam);
96
97 auto *ParamList =
98 TemplateParameterList::Create(C: AST, TemplateLoc: SourceLocation(), LAngleLoc: SourceLocation(),
99 Params: TemplateParams, RAngleLoc: SourceLocation(), RequiresClause: nullptr);
100
101 IdentifierInfo &II = AST.Idents.get(Name: "vector", TokenCode: tok::TokenKind::identifier);
102
103 QualType AliasType = AST.getDependentSizedExtVectorType(
104 VectorType: AST.getTemplateTypeParmType(Depth: 0, Index: 0, ParameterPack: false, ParmDecl: TypeParam),
105 SizeExpr: DeclRefExpr::Create(
106 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: SizeParam, RefersToEnclosingVariableOrCapture: false,
107 NameInfo: DeclarationNameInfo(SizeParam->getDeclName(), SourceLocation()),
108 T: AST.IntTy, VK: VK_LValue),
109 AttrLoc: SourceLocation());
110
111 auto *Record = TypeAliasDecl::Create(C&: AST, DC: HLSLNamespace, StartLoc: SourceLocation(),
112 IdLoc: SourceLocation(), Id: &II,
113 TInfo: AST.getTrivialTypeSourceInfo(T: AliasType));
114 Record->setImplicit(true);
115
116 auto *Template =
117 TypeAliasTemplateDecl::Create(C&: AST, DC: HLSLNamespace, L: SourceLocation(),
118 Name: Record->getIdentifier(), Params: ParamList, Decl: Record);
119
120 Record->setDescribedAliasTemplate(Template);
121 Template->setImplicit(true);
122 Template->setLexicalDeclContext(Record->getDeclContext());
123 HLSLNamespace->addDecl(D: Template);
124}
125
126void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
127 ASTContext &AST = SemaPtr->getASTContext();
128 llvm::SmallVector<NamedDecl *> TemplateParams;
129
130 auto *TypeParam = TemplateTypeParmDecl::Create(
131 C: AST, DC: HLSLNamespace, KeyLoc: SourceLocation(), NameLoc: SourceLocation(), D: 0, P: 0,
132 Id: &AST.Idents.get(Name: "element", TokenCode: tok::TokenKind::identifier), Typename: false, ParameterPack: false);
133 TypeParam->setDefaultArgument(
134 C: AST, DefArg: SemaPtr->getTrivialTemplateArgumentLoc(
135 Arg: TemplateArgument(AST.FloatTy), NTTPType: QualType(), Loc: SourceLocation()));
136
137 TemplateParams.emplace_back(Args&: TypeParam);
138
139 // these should be 64 bit to be consistent with other clang matrices.
140 auto *RowsParam = NonTypeTemplateParmDecl::Create(
141 C: AST, DC: HLSLNamespace, StartLoc: SourceLocation(), IdLoc: SourceLocation(), D: 0, P: 1,
142 Id: &AST.Idents.get(Name: "rows_count", TokenCode: tok::TokenKind::identifier), T: AST.IntTy,
143 ParameterPack: false, TInfo: AST.getTrivialTypeSourceInfo(T: AST.IntTy));
144 llvm::APInt RVal(AST.getIntWidth(T: AST.IntTy), 4);
145 TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy,
146 /*IsDefaulted=*/true);
147 RowsParam->setDefaultArgument(
148 C: AST, DefArg: SemaPtr->getTrivialTemplateArgumentLoc(Arg: RDefault, NTTPType: AST.IntTy,
149 Loc: SourceLocation(), TemplateParam: RowsParam));
150 TemplateParams.emplace_back(Args&: RowsParam);
151
152 auto *ColsParam = NonTypeTemplateParmDecl::Create(
153 C: AST, DC: HLSLNamespace, StartLoc: SourceLocation(), IdLoc: SourceLocation(), D: 0, P: 2,
154 Id: &AST.Idents.get(Name: "cols_count", TokenCode: tok::TokenKind::identifier), T: AST.IntTy,
155 ParameterPack: false, TInfo: AST.getTrivialTypeSourceInfo(T: AST.IntTy));
156 llvm::APInt CVal(AST.getIntWidth(T: AST.IntTy), 4);
157 TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy,
158 /*IsDefaulted=*/true);
159 ColsParam->setDefaultArgument(
160 C: AST, DefArg: SemaPtr->getTrivialTemplateArgumentLoc(Arg: CDefault, NTTPType: AST.IntTy,
161 Loc: SourceLocation(), TemplateParam: ColsParam));
162 TemplateParams.emplace_back(Args&: ColsParam);
163
164 const unsigned MaxMatDim = SemaPtr->getLangOpts().MaxMatrixDimension;
165
166 auto *MaxRow = IntegerLiteral::Create(
167 C: AST, V: llvm::APInt(AST.getIntWidth(T: AST.IntTy), MaxMatDim), type: AST.IntTy,
168 l: SourceLocation());
169 auto *MaxCol = IntegerLiteral::Create(
170 C: AST, V: llvm::APInt(AST.getIntWidth(T: AST.IntTy), MaxMatDim), type: AST.IntTy,
171 l: SourceLocation());
172
173 auto *RowsRef = DeclRefExpr::Create(
174 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: RowsParam,
175 /*RefersToEnclosingVariableOrCapture*/ false,
176 NameInfo: DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
177 T: AST.IntTy, VK: VK_LValue);
178 auto *ColsRef = DeclRefExpr::Create(
179 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: ColsParam,
180 /*RefersToEnclosingVariableOrCapture*/ false,
181 NameInfo: DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
182 T: AST.IntTy, VK: VK_LValue);
183
184 auto *RowsLE = BinaryOperator::Create(C: AST, lhs: RowsRef, rhs: MaxRow, opc: BO_LE, ResTy: AST.BoolTy,
185 VK: VK_PRValue, OK: OK_Ordinary,
186 opLoc: SourceLocation(), FPFeatures: FPOptionsOverride());
187 auto *ColsLE = BinaryOperator::Create(C: AST, lhs: ColsRef, rhs: MaxCol, opc: BO_LE, ResTy: AST.BoolTy,
188 VK: VK_PRValue, OK: OK_Ordinary,
189 opLoc: SourceLocation(), FPFeatures: FPOptionsOverride());
190
191 auto *RequiresExpr = BinaryOperator::Create(
192 C: AST, lhs: RowsLE, rhs: ColsLE, opc: BO_LAnd, ResTy: AST.BoolTy, VK: VK_PRValue, OK: OK_Ordinary,
193 opLoc: SourceLocation(), FPFeatures: FPOptionsOverride());
194
195 auto *ParamList = TemplateParameterList::Create(
196 C: AST, TemplateLoc: SourceLocation(), LAngleLoc: SourceLocation(), Params: TemplateParams, RAngleLoc: SourceLocation(),
197 RequiresClause: RequiresExpr);
198
199 IdentifierInfo &II = AST.Idents.get(Name: "matrix", TokenCode: tok::TokenKind::identifier);
200
201 QualType AliasType = AST.getDependentSizedMatrixType(
202 ElementType: AST.getTemplateTypeParmType(Depth: 0, Index: 0, ParameterPack: false, ParmDecl: TypeParam),
203 RowExpr: DeclRefExpr::Create(
204 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: RowsParam, RefersToEnclosingVariableOrCapture: false,
205 NameInfo: DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
206 T: AST.IntTy, VK: VK_LValue),
207 ColumnExpr: DeclRefExpr::Create(
208 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: ColsParam, RefersToEnclosingVariableOrCapture: false,
209 NameInfo: DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
210 T: AST.IntTy, VK: VK_LValue),
211 AttrLoc: SourceLocation());
212
213 auto *Record = TypeAliasDecl::Create(C&: AST, DC: HLSLNamespace, StartLoc: SourceLocation(),
214 IdLoc: SourceLocation(), Id: &II,
215 TInfo: AST.getTrivialTypeSourceInfo(T: AliasType));
216 Record->setImplicit(true);
217
218 auto *Template =
219 TypeAliasTemplateDecl::Create(C&: AST, DC: HLSLNamespace, L: SourceLocation(),
220 Name: Record->getIdentifier(), Params: ParamList, Decl: Record);
221
222 Record->setDescribedAliasTemplate(Template);
223 Template->setImplicit(true);
224 Template->setLexicalDeclContext(Record->getDeclContext());
225 HLSLNamespace->addDecl(D: Template);
226}
227
228void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
229 defineHLSLVectorAlias();
230 defineHLSLMatrixAlias();
231}
232
233/// Set up common members and attributes for buffer types
234static BuiltinTypeDeclBuilder setupBufferType(CXXRecordDecl *Decl, Sema &S,
235 ResourceClass RC, bool IsROV,
236 bool RawBuffer, bool HasCounter) {
237 return BuiltinTypeDeclBuilder(S, Decl)
238 .addBufferHandles(RC, IsROV, RawBuffer, HasCounter)
239 .addDefaultHandleConstructor()
240 .addCopyConstructor()
241 .addCopyAssignmentOperator()
242 .addStaticInitializationFunctions(HasCounter);
243}
244
245/// Set up common members and attributes for sampler types
246static BuiltinTypeDeclBuilder setupSamplerType(CXXRecordDecl *Decl, Sema &S) {
247 return BuiltinTypeDeclBuilder(S, Decl)
248 .addSamplerHandle()
249 .addDefaultHandleConstructor()
250 .addCopyConstructor()
251 .addCopyAssignmentOperator()
252 .addStaticInitializationFunctions(HasCounter: false);
253}
254
255/// Set up common members and attributes for texture types
256static BuiltinTypeDeclBuilder setupTextureType(CXXRecordDecl *Decl, Sema &S,
257 ResourceClass RC, bool IsROV,
258 ResourceDimension Dim) {
259 return BuiltinTypeDeclBuilder(S, Decl)
260 .addTextureHandle(RC, IsROV, RD: Dim)
261 .addTextureLoadMethods(Dim)
262 .addArraySubscriptOperators(Dim)
263 .addMipsMember(Dim)
264 .addDefaultHandleConstructor()
265 .addCopyConstructor()
266 .addCopyAssignmentOperator()
267 .addStaticInitializationFunctions(HasCounter: false)
268 .addSampleMethods(Dim)
269 .addSampleBiasMethods(Dim)
270 .addSampleGradMethods(Dim)
271 .addSampleLevelMethods(Dim)
272 .addSampleCmpMethods(Dim)
273 .addSampleCmpLevelZeroMethods(Dim)
274 .addCalculateLodMethods(Dim)
275 .addGetDimensionsMethods(Dim)
276 .addGatherMethods(Dim)
277 .addGatherCmpMethods(Dim);
278}
279
280// Add a partial specialization for a template. The `TextureTemplate` is
281// `Texture<element_type>`, and it will be specialized for vectors:
282// `Texture<vector<element_type, element_count>>`.
283static ClassTemplatePartialSpecializationDecl *
284addVectorTexturePartialSpecialization(Sema &S, NamespaceDecl *HLSLNamespace,
285 ClassTemplateDecl *TextureTemplate) {
286 ASTContext &AST = S.getASTContext();
287
288 // Create the template parameters: element_type and element_count.
289 auto *ElementType = TemplateTypeParmDecl::Create(
290 C: AST, DC: HLSLNamespace, KeyLoc: SourceLocation(), NameLoc: SourceLocation(), D: 0, P: 0,
291 Id: &AST.Idents.get(Name: "element_type"), Typename: false, ParameterPack: false);
292 auto *ElementCount = NonTypeTemplateParmDecl::Create(
293 C: AST, DC: HLSLNamespace, StartLoc: SourceLocation(), IdLoc: SourceLocation(), D: 0, P: 1,
294 Id: &AST.Idents.get(Name: "element_count"), T: AST.IntTy, ParameterPack: false,
295 TInfo: AST.getTrivialTypeSourceInfo(T: AST.IntTy));
296
297 auto *TemplateParams = TemplateParameterList::Create(
298 C: AST, TemplateLoc: SourceLocation(), LAngleLoc: SourceLocation(), Params: {ElementType, ElementCount},
299 RAngleLoc: SourceLocation(), RequiresClause: nullptr);
300
301 // Create the dependent vector type: vector<element_type, element_count>.
302 QualType VectorType = AST.getDependentSizedExtVectorType(
303 VectorType: AST.getTemplateTypeParmType(Depth: 0, Index: 0, ParameterPack: false, ParmDecl: ElementType),
304 SizeExpr: DeclRefExpr::Create(
305 Context: AST, QualifierLoc: NestedNameSpecifierLoc(), TemplateKWLoc: SourceLocation(), D: ElementCount, RefersToEnclosingVariableOrCapture: false,
306 NameInfo: DeclarationNameInfo(ElementCount->getDeclName(), SourceLocation()),
307 T: AST.IntTy, VK: VK_LValue),
308 AttrLoc: SourceLocation());
309
310 // Create the partial specialization declaration.
311 QualType CanonInjectedTST =
312 AST.getCanonicalType(T: AST.getTemplateSpecializationType(
313 Keyword: ElaboratedTypeKeyword::Class, T: TemplateName(TextureTemplate),
314 SpecifiedArgs: {TemplateArgument(VectorType)}, CanonicalArgs: {}));
315
316 auto *PartialSpec = ClassTemplatePartialSpecializationDecl::Create(
317 Context&: AST, TK: TagDecl::TagKind::Class, DC: HLSLNamespace, StartLoc: SourceLocation(),
318 IdLoc: SourceLocation(), Params: TemplateParams, SpecializedTemplate: TextureTemplate,
319 Args: {TemplateArgument(VectorType)},
320 CanonInjectedTST: CanQualType::CreateUnsafe(Other: CanonInjectedTST), PrevDecl: nullptr);
321
322 // Set the template arguments as written.
323 TemplateArgument Arg(VectorType);
324 TemplateArgumentLoc ArgLoc =
325 S.getTrivialTemplateArgumentLoc(Arg, NTTPType: QualType(), Loc: SourceLocation());
326 TemplateArgumentListInfo ArgsInfo =
327 TemplateArgumentListInfo(SourceLocation(), SourceLocation());
328 ArgsInfo.addArgument(Loc: ArgLoc);
329 PartialSpec->setTemplateArgsAsWritten(
330 ASTTemplateArgumentListInfo::Create(C: AST, List: ArgsInfo));
331
332 PartialSpec->setImplicit(true);
333 PartialSpec->setLexicalDeclContext(HLSLNamespace);
334 PartialSpec->setHasExternalLexicalStorage();
335
336 // Add the partial specialization to the namespace and the class template.
337 HLSLNamespace->addDecl(D: PartialSpec);
338 TextureTemplate->AddPartialSpecialization(D: PartialSpec, InsertPos: nullptr);
339
340 return PartialSpec;
341}
342
343// This function is responsible for constructing the constraint expression for
344// this concept:
345// template<typename T> concept is_typed_resource_element_compatible =
346// __is_typed_resource_element_compatible<T>;
347static Expr *constructTypedBufferConstraintExpr(Sema &S, SourceLocation NameLoc,
348 TemplateTypeParmDecl *T) {
349 ASTContext &Context = S.getASTContext();
350
351 // Obtain the QualType for 'bool'
352 QualType BoolTy = Context.BoolTy;
353
354 // Create a QualType that points to this TemplateTypeParmDecl
355 QualType TType = Context.getTypeDeclType(Decl: T);
356
357 // Create a TypeSourceInfo for the template type parameter 'T'
358 TypeSourceInfo *TTypeSourceInfo =
359 Context.getTrivialTypeSourceInfo(T: TType, Loc: NameLoc);
360
361 TypeTraitExpr *TypedResExpr = TypeTraitExpr::Create(
362 C: Context, T: BoolTy, Loc: NameLoc, Kind: UTT_IsTypedResourceElementCompatible,
363 Args: {TTypeSourceInfo}, RParenLoc: NameLoc, Value: true);
364
365 return TypedResExpr;
366}
367
368// This function is responsible for constructing the constraint expression for
369// this concept:
370// template<typename T> concept is_structured_resource_element_compatible =
371// !__is_intangible<T> && sizeof(T) >= 1;
372static Expr *constructStructuredBufferConstraintExpr(Sema &S,
373 SourceLocation NameLoc,
374 TemplateTypeParmDecl *T) {
375 ASTContext &Context = S.getASTContext();
376
377 // Obtain the QualType for 'bool'
378 QualType BoolTy = Context.BoolTy;
379
380 // Create a QualType that points to this TemplateTypeParmDecl
381 QualType TType = Context.getTypeDeclType(Decl: T);
382
383 // Create a TypeSourceInfo for the template type parameter 'T'
384 TypeSourceInfo *TTypeSourceInfo =
385 Context.getTrivialTypeSourceInfo(T: TType, Loc: NameLoc);
386
387 TypeTraitExpr *IsIntangibleExpr =
388 TypeTraitExpr::Create(C: Context, T: BoolTy, Loc: NameLoc, Kind: UTT_IsIntangibleType,
389 Args: {TTypeSourceInfo}, RParenLoc: NameLoc, Value: true);
390
391 // negate IsIntangibleExpr
392 UnaryOperator *NotIntangibleExpr = UnaryOperator::Create(
393 C: Context, input: IsIntangibleExpr, opc: UO_LNot, type: BoolTy, VK: VK_LValue, OK: OK_Ordinary,
394 l: NameLoc, CanOverflow: false, FPFeatures: FPOptionsOverride());
395
396 // element types also may not be of 0 size
397 UnaryExprOrTypeTraitExpr *SizeOfExpr = new (Context) UnaryExprOrTypeTraitExpr(
398 UETT_SizeOf, TTypeSourceInfo, BoolTy, NameLoc, NameLoc);
399
400 // Create a BinaryOperator that checks if the size of the type is not equal to
401 // 1 Empty structs have a size of 1 in HLSL, so we need to check for that
402 IntegerLiteral *rhs = IntegerLiteral::Create(
403 C: Context, V: llvm::APInt(Context.getTypeSize(T: Context.getSizeType()), 1, true),
404 type: Context.getSizeType(), l: NameLoc);
405
406 BinaryOperator *SizeGEQOneExpr =
407 BinaryOperator::Create(C: Context, lhs: SizeOfExpr, rhs, opc: BO_GE, ResTy: BoolTy, VK: VK_LValue,
408 OK: OK_Ordinary, opLoc: NameLoc, FPFeatures: FPOptionsOverride());
409
410 // Combine the two constraints
411 BinaryOperator *CombinedExpr = BinaryOperator::Create(
412 C: Context, lhs: NotIntangibleExpr, rhs: SizeGEQOneExpr, opc: BO_LAnd, ResTy: BoolTy, VK: VK_LValue,
413 OK: OK_Ordinary, opLoc: NameLoc, FPFeatures: FPOptionsOverride());
414
415 return CombinedExpr;
416}
417
418static ConceptDecl *constructBufferConceptDecl(Sema &S, NamespaceDecl *NSD,
419 bool isTypedBuffer) {
420 ASTContext &Context = S.getASTContext();
421 DeclContext *DC = NSD->getDeclContext();
422 SourceLocation DeclLoc = SourceLocation();
423
424 IdentifierInfo &ElementTypeII = Context.Idents.get(Name: "element_type");
425 TemplateTypeParmDecl *T = TemplateTypeParmDecl::Create(
426 C: Context, DC: NSD->getDeclContext(), KeyLoc: DeclLoc, NameLoc: DeclLoc,
427 /*D=*/0,
428 /*P=*/0,
429 /*Id=*/&ElementTypeII,
430 /*Typename=*/true,
431 /*ParameterPack=*/false);
432
433 T->setDeclContext(DC);
434 T->setReferenced();
435
436 // Create and Attach Template Parameter List to ConceptDecl
437 TemplateParameterList *ConceptParams = TemplateParameterList::Create(
438 C: Context, TemplateLoc: DeclLoc, LAngleLoc: DeclLoc, Params: {T}, RAngleLoc: DeclLoc, RequiresClause: nullptr);
439
440 DeclarationName DeclName;
441 Expr *ConstraintExpr = nullptr;
442
443 if (isTypedBuffer) {
444 DeclName = DeclarationName(
445 &Context.Idents.get(Name: "__is_typed_resource_element_compatible"));
446 ConstraintExpr = constructTypedBufferConstraintExpr(S, NameLoc: DeclLoc, T);
447 } else {
448 DeclName = DeclarationName(
449 &Context.Idents.get(Name: "__is_structured_resource_element_compatible"));
450 ConstraintExpr = constructStructuredBufferConstraintExpr(S, NameLoc: DeclLoc, T);
451 }
452
453 // Create a ConceptDecl
454 ConceptDecl *CD =
455 ConceptDecl::Create(C&: Context, DC: NSD->getDeclContext(), L: DeclLoc, Name: DeclName,
456 Params: ConceptParams, ConstraintExpr);
457
458 // Attach the template parameter list to the ConceptDecl
459 CD->setTemplateParameters(ConceptParams);
460
461 // Add the concept declaration to the Translation Unit Decl
462 NSD->getDeclContext()->addDecl(D: CD);
463
464 return CD;
465}
466
467void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
468 ASTContext &AST = SemaPtr->getASTContext();
469 CXXRecordDecl *Decl;
470 ConceptDecl *TypedBufferConcept = constructBufferConceptDecl(
471 S&: *SemaPtr, NSD: HLSLNamespace, /*isTypedBuffer*/ true);
472 ConceptDecl *StructuredBufferConcept = constructBufferConceptDecl(
473 S&: *SemaPtr, NSD: HLSLNamespace, /*isTypedBuffer*/ false);
474
475 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Buffer")
476 .addSimpleTemplateParams(Names: {"element_type"}, CD: TypedBufferConcept)
477 .finalizeForwardDeclaration();
478
479 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
480 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
481 /*RawBuffer=*/false, /*HasCounter=*/false)
482 .addArraySubscriptOperators()
483 .addLoadMethods()
484 .addGetDimensionsMethodForBuffer()
485 .completeDefinition();
486 });
487
488 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
489 .addSimpleTemplateParams(Names: {"element_type"}, CD: TypedBufferConcept)
490 .finalizeForwardDeclaration();
491
492 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
493 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/false,
494 /*RawBuffer=*/false, /*HasCounter=*/false)
495 .addArraySubscriptOperators()
496 .addLoadMethods()
497 .addGetDimensionsMethodForBuffer()
498 .completeDefinition();
499 });
500
501 Decl =
502 BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RasterizerOrderedBuffer")
503 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
504 .finalizeForwardDeclaration();
505 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
506 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/true,
507 /*RawBuffer=*/false, /*HasCounter=*/false)
508 .addArraySubscriptOperators()
509 .addLoadMethods()
510 .addGetDimensionsMethodForBuffer()
511 .completeDefinition();
512 });
513
514 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "StructuredBuffer")
515 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
516 .finalizeForwardDeclaration();
517 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
518 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
519 /*RawBuffer=*/true, /*HasCounter=*/false)
520 .addArraySubscriptOperators()
521 .addLoadMethods()
522 .addGetDimensionsMethodForBuffer()
523 .completeDefinition();
524 });
525
526 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWStructuredBuffer")
527 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
528 .finalizeForwardDeclaration();
529 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
530 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/false,
531 /*RawBuffer=*/true, /*HasCounter=*/true)
532 .addArraySubscriptOperators()
533 .addLoadMethods()
534 .addIncrementCounterMethod()
535 .addDecrementCounterMethod()
536 .addGetDimensionsMethodForBuffer()
537 .completeDefinition();
538 });
539
540 Decl =
541 BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "AppendStructuredBuffer")
542 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
543 .finalizeForwardDeclaration();
544 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
545 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/false,
546 /*RawBuffer=*/true, /*HasCounter=*/true)
547 .addAppendMethod()
548 .addGetDimensionsMethodForBuffer()
549 .completeDefinition();
550 });
551
552 Decl =
553 BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "ConsumeStructuredBuffer")
554 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
555 .finalizeForwardDeclaration();
556 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
557 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/false,
558 /*RawBuffer=*/true, /*HasCounter=*/true)
559 .addConsumeMethod()
560 .addGetDimensionsMethodForBuffer()
561 .completeDefinition();
562 });
563
564 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace,
565 "RasterizerOrderedStructuredBuffer")
566 .addSimpleTemplateParams(Names: {"element_type"}, CD: StructuredBufferConcept)
567 .finalizeForwardDeclaration();
568 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
569 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/true,
570 /*RawBuffer=*/true, /*HasCounter=*/true)
571 .addArraySubscriptOperators()
572 .addLoadMethods()
573 .addIncrementCounterMethod()
574 .addDecrementCounterMethod()
575 .addGetDimensionsMethodForBuffer()
576 .completeDefinition();
577 });
578
579 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "ByteAddressBuffer")
580 .finalizeForwardDeclaration();
581 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
582 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
583 /*RawBuffer=*/true, /*HasCounter=*/false)
584 .addByteAddressBufferLoadMethods()
585 .addGetDimensionsMethodForBuffer()
586 .completeDefinition();
587 });
588 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWByteAddressBuffer")
589 .finalizeForwardDeclaration();
590 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
591 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/false,
592 /*RawBuffer=*/true, /*HasCounter=*/false)
593 .addByteAddressBufferLoadMethods()
594 .addByteAddressBufferStoreMethods()
595 .addGetDimensionsMethodForBuffer()
596 .completeDefinition();
597 });
598 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace,
599 "RasterizerOrderedByteAddressBuffer")
600 .finalizeForwardDeclaration();
601 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
602 setupBufferType(Decl, S&: *SemaPtr, RC: ResourceClass::UAV, /*IsROV=*/true,
603 /*RawBuffer=*/true, /*HasCounter=*/false)
604 .addGetDimensionsMethodForBuffer()
605 .completeDefinition();
606 });
607
608 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "SamplerState")
609 .finalizeForwardDeclaration();
610 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
611 setupSamplerType(Decl, S&: *SemaPtr).completeDefinition();
612 });
613
614 Decl =
615 BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "SamplerComparisonState")
616 .finalizeForwardDeclaration();
617 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
618 setupSamplerType(Decl, S&: *SemaPtr).completeDefinition();
619 });
620
621 QualType Float4Ty = AST.getExtVectorType(VectorType: AST.FloatTy, NumElts: 4);
622 Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Texture2D")
623 .addSimpleTemplateParams(Names: {"element_type"}, DefaultTypes: {Float4Ty},
624 CD: TypedBufferConcept)
625 .finalizeForwardDeclaration();
626
627 onCompletion(Record: Decl, Fn: [this](CXXRecordDecl *Decl) {
628 setupTextureType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
629 Dim: ResourceDimension::Dim2D)
630 .completeDefinition();
631 });
632
633 auto *PartialSpec = addVectorTexturePartialSpecialization(
634 S&: *SemaPtr, HLSLNamespace, TextureTemplate: Decl->getDescribedClassTemplate());
635 onCompletion(Record: PartialSpec, Fn: [this](CXXRecordDecl *Decl) {
636 setupTextureType(Decl, S&: *SemaPtr, RC: ResourceClass::SRV, /*IsROV=*/false,
637 Dim: ResourceDimension::Dim2D)
638 .completeDefinition();
639 });
640}
641
642void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
643 CompletionFunction Fn) {
644 if (!Record->isCompleteDefinition())
645 Completions.insert(KV: std::make_pair(x: Record->getCanonicalDecl(), y&: Fn));
646}
647
648void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
649 if (!isa<CXXRecordDecl>(Val: Tag))
650 return;
651 auto Record = cast<CXXRecordDecl>(Val: Tag);
652
653 // If this is a specialization, we need to get the underlying templated
654 // declaration and complete that.
655 if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Val: Record)) {
656 if (!isa<ClassTemplatePartialSpecializationDecl>(Val: TDecl)) {
657 ClassTemplateDecl *Template = TDecl->getSpecializedTemplate();
658 llvm::SmallVector<ClassTemplatePartialSpecializationDecl *, 4> Partials;
659 Template->getPartialSpecializations(PS&: Partials);
660 ClassTemplatePartialSpecializationDecl *MatchedPartial = nullptr;
661 for (auto *Partial : Partials) {
662 sema::TemplateDeductionInfo Info(TDecl->getLocation());
663 if (SemaPtr->DeduceTemplateArguments(Partial, TemplateArgs: TDecl->getTemplateArgs(),
664 Info) ==
665 TemplateDeductionResult::Success) {
666 MatchedPartial = Partial;
667 break;
668 }
669 }
670 if (MatchedPartial)
671 Record = MatchedPartial;
672 else
673 Record = Template->getTemplatedDecl();
674 }
675 }
676 Record = Record->getCanonicalDecl();
677 auto It = Completions.find(Val: Record);
678 if (It == Completions.end())
679 return;
680 It->second(Record);
681 Completions.erase(I: It);
682}
683