1 | //===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // This implements Semantic Analysis for HLSL constructs. |
9 | //===----------------------------------------------------------------------===// |
10 | |
11 | #include "clang/Sema/SemaHLSL.h" |
12 | #include "clang/AST/Decl.h" |
13 | #include "clang/AST/Expr.h" |
14 | #include "clang/AST/RecursiveASTVisitor.h" |
15 | #include "clang/Basic/DiagnosticSema.h" |
16 | #include "clang/Basic/LLVM.h" |
17 | #include "clang/Basic/TargetInfo.h" |
18 | #include "clang/Sema/ParsedAttr.h" |
19 | #include "clang/Sema/Sema.h" |
20 | #include "llvm/ADT/STLExtras.h" |
21 | #include "llvm/ADT/StringExtras.h" |
22 | #include "llvm/ADT/StringRef.h" |
23 | #include "llvm/Support/Casting.h" |
24 | #include "llvm/Support/ErrorHandling.h" |
25 | #include "llvm/TargetParser/Triple.h" |
26 | #include <iterator> |
27 | |
28 | using namespace clang; |
29 | |
30 | SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {} |
31 | |
32 | Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer, |
33 | SourceLocation KwLoc, IdentifierInfo *Ident, |
34 | SourceLocation IdentLoc, |
35 | SourceLocation LBrace) { |
36 | // For anonymous namespace, take the location of the left brace. |
37 | DeclContext *LexicalParent = SemaRef.getCurLexicalContext(); |
38 | HLSLBufferDecl *Result = HLSLBufferDecl::Create( |
39 | C&: getASTContext(), LexicalParent, CBuffer, KwLoc, ID: Ident, IDLoc: IdentLoc, LBrace); |
40 | |
41 | SemaRef.PushOnScopeChains(D: Result, S: BufferScope); |
42 | SemaRef.PushDeclContext(S: BufferScope, DC: Result); |
43 | |
44 | return Result; |
45 | } |
46 | |
47 | // Calculate the size of a legacy cbuffer type based on |
48 | // https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules |
49 | static unsigned calculateLegacyCbufferSize(const ASTContext &Context, |
50 | QualType T) { |
51 | unsigned Size = 0; |
52 | constexpr unsigned CBufferAlign = 128; |
53 | if (const RecordType *RT = T->getAs<RecordType>()) { |
54 | const RecordDecl *RD = RT->getDecl(); |
55 | for (const FieldDecl *Field : RD->fields()) { |
56 | QualType Ty = Field->getType(); |
57 | unsigned FieldSize = calculateLegacyCbufferSize(Context, T: Ty); |
58 | unsigned FieldAlign = 32; |
59 | if (Ty->isAggregateType()) |
60 | FieldAlign = CBufferAlign; |
61 | Size = llvm::alignTo(Value: Size, Align: FieldAlign); |
62 | Size += FieldSize; |
63 | } |
64 | } else if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) { |
65 | if (unsigned ElementCount = AT->getSize().getZExtValue()) { |
66 | unsigned ElementSize = |
67 | calculateLegacyCbufferSize(Context, T: AT->getElementType()); |
68 | unsigned AlignedElementSize = llvm::alignTo(Value: ElementSize, Align: CBufferAlign); |
69 | Size = AlignedElementSize * (ElementCount - 1) + ElementSize; |
70 | } |
71 | } else if (const VectorType *VT = T->getAs<VectorType>()) { |
72 | unsigned ElementCount = VT->getNumElements(); |
73 | unsigned ElementSize = |
74 | calculateLegacyCbufferSize(Context, T: VT->getElementType()); |
75 | Size = ElementSize * ElementCount; |
76 | } else { |
77 | Size = Context.getTypeSize(T); |
78 | } |
79 | return Size; |
80 | } |
81 | |
82 | void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) { |
83 | auto *BufDecl = cast<HLSLBufferDecl>(Val: Dcl); |
84 | BufDecl->setRBraceLoc(RBrace); |
85 | |
86 | // Validate packoffset. |
87 | llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec; |
88 | bool HasPackOffset = false; |
89 | bool HasNonPackOffset = false; |
90 | for (auto *Field : BufDecl->decls()) { |
91 | VarDecl *Var = dyn_cast<VarDecl>(Val: Field); |
92 | if (!Var) |
93 | continue; |
94 | if (Field->hasAttr<HLSLPackOffsetAttr>()) { |
95 | PackOffsetVec.emplace_back(Args&: Var, Args: Field->getAttr<HLSLPackOffsetAttr>()); |
96 | HasPackOffset = true; |
97 | } else { |
98 | HasNonPackOffset = true; |
99 | } |
100 | } |
101 | |
102 | if (HasPackOffset && HasNonPackOffset) |
103 | Diag(Loc: BufDecl->getLocation(), DiagID: diag::warn_hlsl_packoffset_mix); |
104 | |
105 | if (HasPackOffset) { |
106 | ASTContext &Context = getASTContext(); |
107 | // Make sure no overlap in packoffset. |
108 | // Sort PackOffsetVec by offset. |
109 | std::sort(first: PackOffsetVec.begin(), last: PackOffsetVec.end(), |
110 | comp: [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS, |
111 | const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) { |
112 | return LHS.second->getOffset() < RHS.second->getOffset(); |
113 | }); |
114 | |
115 | for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) { |
116 | VarDecl *Var = PackOffsetVec[i].first; |
117 | HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second; |
118 | unsigned Size = calculateLegacyCbufferSize(Context, T: Var->getType()); |
119 | unsigned Begin = Attr->getOffset() * 32; |
120 | unsigned End = Begin + Size; |
121 | unsigned NextBegin = PackOffsetVec[i + 1].second->getOffset() * 32; |
122 | if (End > NextBegin) { |
123 | VarDecl *NextVar = PackOffsetVec[i + 1].first; |
124 | Diag(Loc: NextVar->getLocation(), DiagID: diag::err_hlsl_packoffset_overlap) |
125 | << NextVar << Var; |
126 | } |
127 | } |
128 | } |
129 | |
130 | SemaRef.PopDeclContext(); |
131 | } |
132 | |
133 | HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D, |
134 | const AttributeCommonInfo &AL, |
135 | int X, int Y, int Z) { |
136 | if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) { |
137 | if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) { |
138 | Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL; |
139 | Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute); |
140 | } |
141 | return nullptr; |
142 | } |
143 | return ::new (getASTContext()) |
144 | HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z); |
145 | } |
146 | |
147 | HLSLShaderAttr * |
148 | SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, |
149 | llvm::Triple::EnvironmentType ShaderType) { |
150 | if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) { |
151 | if (NT->getType() != ShaderType) { |
152 | Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL; |
153 | Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute); |
154 | } |
155 | return nullptr; |
156 | } |
157 | return HLSLShaderAttr::Create(Ctx&: getASTContext(), Type: ShaderType, CommonInfo: AL); |
158 | } |
159 | |
160 | HLSLParamModifierAttr * |
161 | SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, |
162 | HLSLParamModifierAttr::Spelling Spelling) { |
163 | // We can only merge an `in` attribute with an `out` attribute. All other |
164 | // combinations of duplicated attributes are ill-formed. |
165 | if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) { |
166 | if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) || |
167 | (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) { |
168 | D->dropAttr<HLSLParamModifierAttr>(); |
169 | SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()}; |
170 | return HLSLParamModifierAttr::Create( |
171 | Ctx&: getASTContext(), /*MergedSpelling=*/true, Range: AdjustedRange, |
172 | S: HLSLParamModifierAttr::Keyword_inout); |
173 | } |
174 | Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_duplicate_parameter_modifier) << AL; |
175 | Diag(Loc: PA->getLocation(), DiagID: diag::note_conflicting_attribute); |
176 | return nullptr; |
177 | } |
178 | return HLSLParamModifierAttr::Create(Ctx&: getASTContext(), CommonInfo: AL); |
179 | } |
180 | |
181 | void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { |
182 | auto &TargetInfo = getASTContext().getTargetInfo(); |
183 | |
184 | if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry) |
185 | return; |
186 | |
187 | llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment(); |
188 | if (HLSLShaderAttr::isValidShaderType(ShaderType: Env) && Env != llvm::Triple::Library) { |
189 | if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) { |
190 | // The entry point is already annotated - check that it matches the |
191 | // triple. |
192 | if (Shader->getType() != Env) { |
193 | Diag(Loc: Shader->getLocation(), DiagID: diag::err_hlsl_entry_shader_attr_mismatch) |
194 | << Shader; |
195 | FD->setInvalidDecl(); |
196 | } |
197 | } else { |
198 | // Implicitly add the shader attribute if the entry function isn't |
199 | // explicitly annotated. |
200 | FD->addAttr(A: HLSLShaderAttr::CreateImplicit(Ctx&: getASTContext(), Type: Env, |
201 | Range: FD->getBeginLoc())); |
202 | } |
203 | } else { |
204 | switch (Env) { |
205 | case llvm::Triple::UnknownEnvironment: |
206 | case llvm::Triple::Library: |
207 | break; |
208 | default: |
209 | llvm_unreachable("Unhandled environment in triple" ); |
210 | } |
211 | } |
212 | } |
213 | |
214 | void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { |
215 | const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); |
216 | assert(ShaderAttr && "Entry point has no shader attribute" ); |
217 | llvm::Triple::EnvironmentType ST = ShaderAttr->getType(); |
218 | |
219 | switch (ST) { |
220 | case llvm::Triple::Pixel: |
221 | case llvm::Triple::Vertex: |
222 | case llvm::Triple::Geometry: |
223 | case llvm::Triple::Hull: |
224 | case llvm::Triple::Domain: |
225 | case llvm::Triple::RayGeneration: |
226 | case llvm::Triple::Intersection: |
227 | case llvm::Triple::AnyHit: |
228 | case llvm::Triple::ClosestHit: |
229 | case llvm::Triple::Miss: |
230 | case llvm::Triple::Callable: |
231 | if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) { |
232 | DiagnoseAttrStageMismatch(A: NT, Stage: ST, |
233 | AllowedStages: {llvm::Triple::Compute, |
234 | llvm::Triple::Amplification, |
235 | llvm::Triple::Mesh}); |
236 | FD->setInvalidDecl(); |
237 | } |
238 | break; |
239 | |
240 | case llvm::Triple::Compute: |
241 | case llvm::Triple::Amplification: |
242 | case llvm::Triple::Mesh: |
243 | if (!FD->hasAttr<HLSLNumThreadsAttr>()) { |
244 | Diag(Loc: FD->getLocation(), DiagID: diag::err_hlsl_missing_numthreads) |
245 | << llvm::Triple::getEnvironmentTypeName(Kind: ST); |
246 | FD->setInvalidDecl(); |
247 | } |
248 | break; |
249 | default: |
250 | llvm_unreachable("Unhandled environment in triple" ); |
251 | } |
252 | |
253 | for (ParmVarDecl *Param : FD->parameters()) { |
254 | if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) { |
255 | CheckSemanticAnnotation(EntryPoint: FD, Param, AnnotationAttr); |
256 | } else { |
257 | // FIXME: Handle struct parameters where annotations are on struct fields. |
258 | // See: https://github.com/llvm/llvm-project/issues/57875 |
259 | Diag(Loc: FD->getLocation(), DiagID: diag::err_hlsl_missing_semantic_annotation); |
260 | Diag(Loc: Param->getLocation(), DiagID: diag::note_previous_decl) << Param; |
261 | FD->setInvalidDecl(); |
262 | } |
263 | } |
264 | // FIXME: Verify return type semantic annotation. |
265 | } |
266 | |
267 | void SemaHLSL::CheckSemanticAnnotation( |
268 | FunctionDecl *EntryPoint, const Decl *Param, |
269 | const HLSLAnnotationAttr *AnnotationAttr) { |
270 | auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>(); |
271 | assert(ShaderAttr && "Entry point has no shader attribute" ); |
272 | llvm::Triple::EnvironmentType ST = ShaderAttr->getType(); |
273 | |
274 | switch (AnnotationAttr->getKind()) { |
275 | case attr::HLSLSV_DispatchThreadID: |
276 | case attr::HLSLSV_GroupIndex: |
277 | if (ST == llvm::Triple::Compute) |
278 | return; |
279 | DiagnoseAttrStageMismatch(A: AnnotationAttr, Stage: ST, AllowedStages: {llvm::Triple::Compute}); |
280 | break; |
281 | default: |
282 | llvm_unreachable("Unknown HLSLAnnotationAttr" ); |
283 | } |
284 | } |
285 | |
286 | void SemaHLSL::DiagnoseAttrStageMismatch( |
287 | const Attr *A, llvm::Triple::EnvironmentType Stage, |
288 | std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) { |
289 | SmallVector<StringRef, 8> StageStrings; |
290 | llvm::transform(Range&: AllowedStages, d_first: std::back_inserter(x&: StageStrings), |
291 | F: [](llvm::Triple::EnvironmentType ST) { |
292 | return StringRef( |
293 | HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: ST)); |
294 | }); |
295 | Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage) |
296 | << A << llvm::Triple::getEnvironmentTypeName(Kind: Stage) |
297 | << (AllowedStages.size() != 1) << join(R&: StageStrings, Separator: ", " ); |
298 | } |
299 | |
300 | void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) { |
301 | llvm::VersionTuple SMVersion = |
302 | getASTContext().getTargetInfo().getTriple().getOSVersion(); |
303 | uint32_t ZMax = 1024; |
304 | uint32_t ThreadMax = 1024; |
305 | if (SMVersion.getMajor() <= 4) { |
306 | ZMax = 1; |
307 | ThreadMax = 768; |
308 | } else if (SMVersion.getMajor() == 5) { |
309 | ZMax = 64; |
310 | ThreadMax = 1024; |
311 | } |
312 | |
313 | uint32_t X; |
314 | if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: X)) |
315 | return; |
316 | if (X > 1024) { |
317 | Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(), |
318 | DiagID: diag::err_hlsl_numthreads_argument_oor) |
319 | << 0 << 1024; |
320 | return; |
321 | } |
322 | uint32_t Y; |
323 | if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Y)) |
324 | return; |
325 | if (Y > 1024) { |
326 | Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(), |
327 | DiagID: diag::err_hlsl_numthreads_argument_oor) |
328 | << 1 << 1024; |
329 | return; |
330 | } |
331 | uint32_t Z; |
332 | if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Z)) |
333 | return; |
334 | if (Z > ZMax) { |
335 | SemaRef.Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(), |
336 | DiagID: diag::err_hlsl_numthreads_argument_oor) |
337 | << 2 << ZMax; |
338 | return; |
339 | } |
340 | |
341 | if (X * Y * Z > ThreadMax) { |
342 | Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_numthreads_invalid) << ThreadMax; |
343 | return; |
344 | } |
345 | |
346 | HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z); |
347 | if (NewAttr) |
348 | D->addAttr(A: NewAttr); |
349 | } |
350 | |
351 | static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) { |
352 | if (!T->hasUnsignedIntegerRepresentation()) |
353 | return false; |
354 | if (const auto *VT = T->getAs<VectorType>()) |
355 | return VT->getNumElements() <= 3; |
356 | return true; |
357 | } |
358 | |
359 | void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) { |
360 | auto *VD = cast<ValueDecl>(Val: D); |
361 | if (!isLegalTypeForHLSLSV_DispatchThreadID(T: VD->getType())) { |
362 | Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type) |
363 | << AL << "uint/uint2/uint3" ; |
364 | return; |
365 | } |
366 | |
367 | D->addAttr(A: ::new (getASTContext()) |
368 | HLSLSV_DispatchThreadIDAttr(getASTContext(), AL)); |
369 | } |
370 | |
371 | void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) { |
372 | if (!isa<VarDecl>(Val: D) || !isa<HLSLBufferDecl>(Val: D->getDeclContext())) { |
373 | Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_ast_node) |
374 | << AL << "shader constant in a constant buffer" ; |
375 | return; |
376 | } |
377 | |
378 | uint32_t SubComponent; |
379 | if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: SubComponent)) |
380 | return; |
381 | uint32_t Component; |
382 | if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Component)) |
383 | return; |
384 | |
385 | QualType T = cast<VarDecl>(Val: D)->getType().getCanonicalType(); |
386 | // Check if T is an array or struct type. |
387 | // TODO: mark matrix type as aggregate type. |
388 | bool IsAggregateTy = (T->isArrayType() || T->isStructureType()); |
389 | |
390 | // Check Component is valid for T. |
391 | if (Component) { |
392 | unsigned Size = getASTContext().getTypeSize(T); |
393 | if (IsAggregateTy || Size > 128) { |
394 | Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_cross_reg_boundary); |
395 | return; |
396 | } else { |
397 | // Make sure Component + sizeof(T) <= 4. |
398 | if ((Component * 32 + Size) > 128) { |
399 | Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_cross_reg_boundary); |
400 | return; |
401 | } |
402 | QualType EltTy = T; |
403 | if (const auto *VT = T->getAs<VectorType>()) |
404 | EltTy = VT->getElementType(); |
405 | unsigned Align = getASTContext().getTypeAlign(T: EltTy); |
406 | if (Align > 32 && Component == 1) { |
407 | // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary. |
408 | // So we only need to check Component 1 here. |
409 | Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_alignment_mismatch) |
410 | << Align << EltTy; |
411 | return; |
412 | } |
413 | } |
414 | } |
415 | |
416 | D->addAttr(A: ::new (getASTContext()) HLSLPackOffsetAttr( |
417 | getASTContext(), AL, SubComponent, Component)); |
418 | } |
419 | |
420 | void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) { |
421 | StringRef Str; |
422 | SourceLocation ArgLoc; |
423 | if (!SemaRef.checkStringLiteralArgumentAttr(Attr: AL, ArgNum: 0, Str, ArgLocation: &ArgLoc)) |
424 | return; |
425 | |
426 | llvm::Triple::EnvironmentType ShaderType; |
427 | if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Val: Str, Out&: ShaderType)) { |
428 | Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_type_not_supported) |
429 | << AL << Str << ArgLoc; |
430 | return; |
431 | } |
432 | |
433 | // FIXME: check function match the shader stage. |
434 | |
435 | HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType); |
436 | if (NewAttr) |
437 | D->addAttr(A: NewAttr); |
438 | } |
439 | |
440 | void SemaHLSL::handleResourceClassAttr(Decl *D, const ParsedAttr &AL) { |
441 | if (!AL.isArgIdent(Arg: 0)) { |
442 | Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type) |
443 | << AL << AANT_ArgumentIdentifier; |
444 | return; |
445 | } |
446 | |
447 | IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0); |
448 | StringRef Identifier = Loc->Ident->getName(); |
449 | SourceLocation ArgLoc = Loc->Loc; |
450 | |
451 | // Validate. |
452 | llvm::dxil::ResourceClass RC; |
453 | if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Val: Identifier, Out&: RC)) { |
454 | Diag(Loc: ArgLoc, DiagID: diag::warn_attribute_type_not_supported) |
455 | << "ResourceClass" << Identifier; |
456 | return; |
457 | } |
458 | |
459 | D->addAttr(A: HLSLResourceClassAttr::Create(Ctx&: getASTContext(), ResourceClass: RC, Range: ArgLoc)); |
460 | } |
461 | |
462 | void SemaHLSL::handleResourceBindingAttr(Decl *D, const ParsedAttr &AL) { |
463 | StringRef Space = "space0" ; |
464 | StringRef Slot = "" ; |
465 | |
466 | if (!AL.isArgIdent(Arg: 0)) { |
467 | Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type) |
468 | << AL << AANT_ArgumentIdentifier; |
469 | return; |
470 | } |
471 | |
472 | IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0); |
473 | StringRef Str = Loc->Ident->getName(); |
474 | SourceLocation ArgLoc = Loc->Loc; |
475 | |
476 | SourceLocation SpaceArgLoc; |
477 | if (AL.getNumArgs() == 2) { |
478 | Slot = Str; |
479 | if (!AL.isArgIdent(Arg: 1)) { |
480 | Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type) |
481 | << AL << AANT_ArgumentIdentifier; |
482 | return; |
483 | } |
484 | |
485 | IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 1); |
486 | Space = Loc->Ident->getName(); |
487 | SpaceArgLoc = Loc->Loc; |
488 | } else { |
489 | Slot = Str; |
490 | } |
491 | |
492 | // Validate. |
493 | if (!Slot.empty()) { |
494 | switch (Slot[0]) { |
495 | case 'u': |
496 | case 'b': |
497 | case 's': |
498 | case 't': |
499 | break; |
500 | default: |
501 | Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_unsupported_register_type) |
502 | << Slot.substr(Start: 0, N: 1); |
503 | return; |
504 | } |
505 | |
506 | StringRef SlotNum = Slot.substr(Start: 1); |
507 | unsigned Num = 0; |
508 | if (SlotNum.getAsInteger(Radix: 10, Result&: Num)) { |
509 | Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_unsupported_register_number); |
510 | return; |
511 | } |
512 | } |
513 | |
514 | if (!Space.starts_with(Prefix: "space" )) { |
515 | Diag(Loc: SpaceArgLoc, DiagID: diag::err_hlsl_expected_space) << Space; |
516 | return; |
517 | } |
518 | StringRef SpaceNum = Space.substr(Start: 5); |
519 | unsigned Num = 0; |
520 | if (SpaceNum.getAsInteger(Radix: 10, Result&: Num)) { |
521 | Diag(Loc: SpaceArgLoc, DiagID: diag::err_hlsl_expected_space) << Space; |
522 | return; |
523 | } |
524 | |
525 | // FIXME: check reg type match decl. Issue |
526 | // https://github.com/llvm/llvm-project/issues/57886. |
527 | HLSLResourceBindingAttr *NewAttr = |
528 | HLSLResourceBindingAttr::Create(Ctx&: getASTContext(), Slot, Space, CommonInfo: AL); |
529 | if (NewAttr) |
530 | D->addAttr(A: NewAttr); |
531 | } |
532 | |
533 | void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) { |
534 | HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr( |
535 | D, AL, |
536 | Spelling: static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling())); |
537 | if (NewAttr) |
538 | D->addAttr(A: NewAttr); |
539 | } |
540 | |
541 | namespace { |
542 | |
543 | /// This class implements HLSL availability diagnostics for default |
544 | /// and relaxed mode |
545 | /// |
546 | /// The goal of this diagnostic is to emit an error or warning when an |
547 | /// unavailable API is found in code that is reachable from the shader |
548 | /// entry function or from an exported function (when compiling a shader |
549 | /// library). |
550 | /// |
551 | /// This is done by traversing the AST of all shader entry point functions |
552 | /// and of all exported functions, and any functions that are referenced |
553 | /// from this AST. In other words, any functions that are reachable from |
554 | /// the entry points. |
555 | class DiagnoseHLSLAvailability |
556 | : public RecursiveASTVisitor<DiagnoseHLSLAvailability> { |
557 | |
558 | Sema &SemaRef; |
559 | |
560 | // Stack of functions to be scaned |
561 | llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan; |
562 | |
563 | // Tracks which environments functions have been scanned in. |
564 | // |
565 | // Maps FunctionDecl to an unsigned number that represents the set of shader |
566 | // environments the function has been scanned for. |
567 | // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed |
568 | // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification |
569 | // (verified by static_asserts in Triple.cpp), we can use it to index |
570 | // individual bits in the set, as long as we shift the values to start with 0 |
571 | // by subtracting the value of llvm::Triple::Pixel first. |
572 | // |
573 | // The N'th bit in the set will be set if the function has been scanned |
574 | // in shader environment whose llvm::Triple::EnvironmentType integer value |
575 | // equals (llvm::Triple::Pixel + N). |
576 | // |
577 | // For example, if a function has been scanned in compute and pixel stage |
578 | // environment, the value will be 0x21 (100001 binary) because: |
579 | // |
580 | // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0 |
581 | // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5 |
582 | // |
583 | // A FunctionDecl is mapped to 0 (or not included in the map) if it has not |
584 | // been scanned in any environment. |
585 | llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls; |
586 | |
587 | // Do not access these directly, use the get/set methods below to make |
588 | // sure the values are in sync |
589 | llvm::Triple::EnvironmentType CurrentShaderEnvironment; |
590 | unsigned CurrentShaderStageBit; |
591 | |
592 | // True if scanning a function that was already scanned in a different |
593 | // shader stage context, and therefore we should not report issues that |
594 | // depend only on shader model version because they would be duplicate. |
595 | bool ReportOnlyShaderStageIssues; |
596 | |
597 | // Helper methods for dealing with current stage context / environment |
598 | void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) { |
599 | static_assert(sizeof(unsigned) >= 4); |
600 | assert(HLSLShaderAttr::isValidShaderType(ShaderType)); |
601 | assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 && |
602 | "ShaderType is too big for this bitmap" ); // 31 is reserved for |
603 | // "unknown" |
604 | |
605 | unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel; |
606 | CurrentShaderEnvironment = ShaderType; |
607 | CurrentShaderStageBit = (1 << bitmapIndex); |
608 | } |
609 | |
610 | void SetUnknownShaderStageContext() { |
611 | CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment; |
612 | CurrentShaderStageBit = (1 << 31); |
613 | } |
614 | |
615 | llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const { |
616 | return CurrentShaderEnvironment; |
617 | } |
618 | |
619 | bool InUnknownShaderStageContext() const { |
620 | return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment; |
621 | } |
622 | |
623 | // Helper methods for dealing with shader stage bitmap |
624 | void AddToScannedFunctions(const FunctionDecl *FD) { |
625 | unsigned &ScannedStages = ScannedDecls.getOrInsertDefault(Key: FD); |
626 | ScannedStages |= CurrentShaderStageBit; |
627 | } |
628 | |
629 | unsigned GetScannedStages(const FunctionDecl *FD) { |
630 | return ScannedDecls.getOrInsertDefault(Key: FD); |
631 | } |
632 | |
633 | bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) { |
634 | return WasAlreadyScannedInCurrentStage(ScannerStages: GetScannedStages(FD)); |
635 | } |
636 | |
637 | bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) { |
638 | return ScannerStages & CurrentShaderStageBit; |
639 | } |
640 | |
641 | static bool NeverBeenScanned(unsigned ScannedStages) { |
642 | return ScannedStages == 0; |
643 | } |
644 | |
645 | // Scanning methods |
646 | void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr); |
647 | void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA, |
648 | SourceRange Range); |
649 | const AvailabilityAttr *FindAvailabilityAttr(const Decl *D); |
650 | bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA); |
651 | |
652 | public: |
653 | DiagnoseHLSLAvailability(Sema &SemaRef) : SemaRef(SemaRef) {} |
654 | |
655 | // AST traversal methods |
656 | void RunOnTranslationUnit(const TranslationUnitDecl *TU); |
657 | void RunOnFunction(const FunctionDecl *FD); |
658 | |
659 | bool VisitDeclRefExpr(DeclRefExpr *DRE) { |
660 | FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: DRE->getDecl()); |
661 | if (FD) |
662 | HandleFunctionOrMethodRef(FD, RefExpr: DRE); |
663 | return true; |
664 | } |
665 | |
666 | bool VisitMemberExpr(MemberExpr *ME) { |
667 | FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: ME->getMemberDecl()); |
668 | if (FD) |
669 | HandleFunctionOrMethodRef(FD, RefExpr: ME); |
670 | return true; |
671 | } |
672 | }; |
673 | |
674 | void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD, |
675 | Expr *RefExpr) { |
676 | assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) && |
677 | "expected DeclRefExpr or MemberExpr" ); |
678 | |
679 | // has a definition -> add to stack to be scanned |
680 | const FunctionDecl *FDWithBody = nullptr; |
681 | if (FD->hasBody(Definition&: FDWithBody)) { |
682 | if (!WasAlreadyScannedInCurrentStage(FD: FDWithBody)) |
683 | DeclsToScan.push_back(Elt: FDWithBody); |
684 | return; |
685 | } |
686 | |
687 | // no body -> diagnose availability |
688 | const AvailabilityAttr *AA = FindAvailabilityAttr(D: FD); |
689 | if (AA) |
690 | CheckDeclAvailability( |
691 | D: FD, AA, Range: SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc())); |
692 | } |
693 | |
694 | void DiagnoseHLSLAvailability::RunOnTranslationUnit( |
695 | const TranslationUnitDecl *TU) { |
696 | |
697 | // Iterate over all shader entry functions and library exports, and for those |
698 | // that have a body (definiton), run diag scan on each, setting appropriate |
699 | // shader environment context based on whether it is a shader entry function |
700 | // or an exported function. Exported functions can be in namespaces and in |
701 | // export declarations so we need to scan those declaration contexts as well. |
702 | llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan; |
703 | DeclContextsToScan.push_back(Elt: TU); |
704 | |
705 | while (!DeclContextsToScan.empty()) { |
706 | const DeclContext *DC = DeclContextsToScan.pop_back_val(); |
707 | for (auto &D : DC->decls()) { |
708 | // do not scan implicit declaration generated by the implementation |
709 | if (D->isImplicit()) |
710 | continue; |
711 | |
712 | // for namespace or export declaration add the context to the list to be |
713 | // scanned later |
714 | if (llvm::dyn_cast<NamespaceDecl>(Val: D) || llvm::dyn_cast<ExportDecl>(Val: D)) { |
715 | DeclContextsToScan.push_back(Elt: llvm::dyn_cast<DeclContext>(Val: D)); |
716 | continue; |
717 | } |
718 | |
719 | // skip over other decls or function decls without body |
720 | const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: D); |
721 | if (!FD || !FD->isThisDeclarationADefinition()) |
722 | continue; |
723 | |
724 | // shader entry point |
725 | if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) { |
726 | SetShaderStageContext(ShaderAttr->getType()); |
727 | RunOnFunction(FD); |
728 | continue; |
729 | } |
730 | // exported library function |
731 | // FIXME: replace this loop with external linkage check once issue #92071 |
732 | // is resolved |
733 | bool isExport = FD->isInExportDeclContext(); |
734 | if (!isExport) { |
735 | for (const auto *Redecl : FD->redecls()) { |
736 | if (Redecl->isInExportDeclContext()) { |
737 | isExport = true; |
738 | break; |
739 | } |
740 | } |
741 | } |
742 | if (isExport) { |
743 | SetUnknownShaderStageContext(); |
744 | RunOnFunction(FD); |
745 | continue; |
746 | } |
747 | } |
748 | } |
749 | } |
750 | |
751 | void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) { |
752 | assert(DeclsToScan.empty() && "DeclsToScan should be empty" ); |
753 | DeclsToScan.push_back(Elt: FD); |
754 | |
755 | while (!DeclsToScan.empty()) { |
756 | // Take one decl from the stack and check it by traversing its AST. |
757 | // For any CallExpr found during the traversal add it's callee to the top of |
758 | // the stack to be processed next. Functions already processed are stored in |
759 | // ScannedDecls. |
760 | const FunctionDecl *FD = DeclsToScan.pop_back_val(); |
761 | |
762 | // Decl was already scanned |
763 | const unsigned ScannedStages = GetScannedStages(FD); |
764 | if (WasAlreadyScannedInCurrentStage(ScannerStages: ScannedStages)) |
765 | continue; |
766 | |
767 | ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages); |
768 | |
769 | AddToScannedFunctions(FD); |
770 | TraverseStmt(S: FD->getBody()); |
771 | } |
772 | } |
773 | |
774 | bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone( |
775 | const AvailabilityAttr *AA) { |
776 | IdentifierInfo *IIEnvironment = AA->getEnvironment(); |
777 | if (!IIEnvironment) |
778 | return true; |
779 | |
780 | llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment(); |
781 | if (CurrentEnv == llvm::Triple::UnknownEnvironment) |
782 | return false; |
783 | |
784 | llvm::Triple::EnvironmentType AttrEnv = |
785 | AvailabilityAttr::getEnvironmentType(Environment: IIEnvironment->getName()); |
786 | |
787 | return CurrentEnv == AttrEnv; |
788 | } |
789 | |
790 | const AvailabilityAttr * |
791 | DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) { |
792 | AvailabilityAttr const *PartialMatch = nullptr; |
793 | // Check each AvailabilityAttr to find the one for this platform. |
794 | // For multiple attributes with the same platform try to find one for this |
795 | // environment. |
796 | for (const auto *A : D->attrs()) { |
797 | if (const auto *Avail = dyn_cast<AvailabilityAttr>(Val: A)) { |
798 | StringRef AttrPlatform = Avail->getPlatform()->getName(); |
799 | StringRef TargetPlatform = |
800 | SemaRef.getASTContext().getTargetInfo().getPlatformName(); |
801 | |
802 | // Match the platform name. |
803 | if (AttrPlatform == TargetPlatform) { |
804 | // Find the best matching attribute for this environment |
805 | if (HasMatchingEnvironmentOrNone(AA: Avail)) |
806 | return Avail; |
807 | PartialMatch = Avail; |
808 | } |
809 | } |
810 | } |
811 | return PartialMatch; |
812 | } |
813 | |
814 | // Check availability against target shader model version and current shader |
815 | // stage and emit diagnostic |
816 | void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D, |
817 | const AvailabilityAttr *AA, |
818 | SourceRange Range) { |
819 | |
820 | IdentifierInfo *IIEnv = AA->getEnvironment(); |
821 | |
822 | if (!IIEnv) { |
823 | // The availability attribute does not have environment -> it depends only |
824 | // on shader model version and not on specific the shader stage. |
825 | |
826 | // Skip emitting the diagnostics if the diagnostic mode is set to |
827 | // strict (-fhlsl-strict-availability) because all relevant diagnostics |
828 | // were already emitted in the DiagnoseUnguardedAvailability scan |
829 | // (SemaAvailability.cpp). |
830 | if (SemaRef.getLangOpts().HLSLStrictAvailability) |
831 | return; |
832 | |
833 | // Do not report shader-stage-independent issues if scanning a function |
834 | // that was already scanned in a different shader stage context (they would |
835 | // be duplicate) |
836 | if (ReportOnlyShaderStageIssues) |
837 | return; |
838 | |
839 | } else { |
840 | // The availability attribute has environment -> we need to know |
841 | // the current stage context to property diagnose it. |
842 | if (InUnknownShaderStageContext()) |
843 | return; |
844 | } |
845 | |
846 | // Check introduced version and if environment matches |
847 | bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA); |
848 | VersionTuple Introduced = AA->getIntroduced(); |
849 | VersionTuple TargetVersion = |
850 | SemaRef.Context.getTargetInfo().getPlatformMinVersion(); |
851 | |
852 | if (TargetVersion >= Introduced && EnvironmentMatches) |
853 | return; |
854 | |
855 | // Emit diagnostic message |
856 | const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo(); |
857 | llvm::StringRef PlatformName( |
858 | AvailabilityAttr::getPrettyPlatformName(Platform: TI.getPlatformName())); |
859 | |
860 | llvm::StringRef CurrentEnvStr = |
861 | llvm::Triple::getEnvironmentTypeName(Kind: GetCurrentShaderEnvironment()); |
862 | |
863 | llvm::StringRef AttrEnvStr = |
864 | AA->getEnvironment() ? AA->getEnvironment()->getName() : "" ; |
865 | bool UseEnvironment = !AttrEnvStr.empty(); |
866 | |
867 | if (EnvironmentMatches) { |
868 | SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability) |
869 | << Range << D << PlatformName << Introduced.getAsString() |
870 | << UseEnvironment << CurrentEnvStr; |
871 | } else { |
872 | SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability_unavailable) |
873 | << Range << D; |
874 | } |
875 | |
876 | SemaRef.Diag(Loc: D->getLocation(), DiagID: diag::note_partial_availability_specified_here) |
877 | << D << PlatformName << Introduced.getAsString() |
878 | << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString() |
879 | << UseEnvironment << AttrEnvStr << CurrentEnvStr; |
880 | } |
881 | |
882 | } // namespace |
883 | |
884 | void SemaHLSL::DiagnoseAvailabilityViolations(TranslationUnitDecl *TU) { |
885 | // Skip running the diagnostics scan if the diagnostic mode is |
886 | // strict (-fhlsl-strict-availability) and the target shader stage is known |
887 | // because all relevant diagnostics were already emitted in the |
888 | // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp). |
889 | const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo(); |
890 | if (SemaRef.getLangOpts().HLSLStrictAvailability && |
891 | TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library) |
892 | return; |
893 | |
894 | DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU); |
895 | } |
896 | |
897 | // Helper function for CheckHLSLBuiltinFunctionCall |
898 | bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) { |
899 | assert(TheCall->getNumArgs() > 1); |
900 | ExprResult A = TheCall->getArg(Arg: 0); |
901 | |
902 | QualType ArgTyA = A.get()->getType(); |
903 | |
904 | auto *VecTyA = ArgTyA->getAs<VectorType>(); |
905 | SourceLocation BuiltinLoc = TheCall->getBeginLoc(); |
906 | |
907 | for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) { |
908 | ExprResult B = TheCall->getArg(Arg: i); |
909 | QualType ArgTyB = B.get()->getType(); |
910 | auto *VecTyB = ArgTyB->getAs<VectorType>(); |
911 | if (VecTyA == nullptr && VecTyB == nullptr) |
912 | return false; |
913 | |
914 | if (VecTyA && VecTyB) { |
915 | bool retValue = false; |
916 | if (VecTyA->getElementType() != VecTyB->getElementType()) { |
917 | // Note: type promotion is intended to be handeled via the intrinsics |
918 | // and not the builtin itself. |
919 | S->Diag(Loc: TheCall->getBeginLoc(), |
920 | DiagID: diag::err_vec_builtin_incompatible_vector) |
921 | << TheCall->getDirectCallee() << /*useAllTerminology*/ true |
922 | << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc()); |
923 | retValue = true; |
924 | } |
925 | if (VecTyA->getNumElements() != VecTyB->getNumElements()) { |
926 | // You should only be hitting this case if you are calling the builtin |
927 | // directly. HLSL intrinsics should avoid this case via a |
928 | // HLSLVectorTruncation. |
929 | S->Diag(Loc: BuiltinLoc, DiagID: diag::err_vec_builtin_incompatible_vector) |
930 | << TheCall->getDirectCallee() << /*useAllTerminology*/ true |
931 | << SourceRange(TheCall->getArg(Arg: 0)->getBeginLoc(), |
932 | TheCall->getArg(Arg: 1)->getEndLoc()); |
933 | retValue = true; |
934 | } |
935 | return retValue; |
936 | } |
937 | } |
938 | |
939 | // Note: if we get here one of the args is a scalar which |
940 | // requires a VectorSplat on Arg0 or Arg1 |
941 | S->Diag(Loc: BuiltinLoc, DiagID: diag::err_vec_builtin_non_vector) |
942 | << TheCall->getDirectCallee() << /*useAllTerminology*/ true |
943 | << SourceRange(TheCall->getArg(Arg: 0)->getBeginLoc(), |
944 | TheCall->getArg(Arg: 1)->getEndLoc()); |
945 | return true; |
946 | } |
947 | |
948 | bool CheckArgsTypesAreCorrect( |
949 | Sema *S, CallExpr *TheCall, QualType ExpectedType, |
950 | llvm::function_ref<bool(clang::QualType PassedType)> Check) { |
951 | for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) { |
952 | QualType PassedType = TheCall->getArg(Arg: i)->getType(); |
953 | if (Check(PassedType)) { |
954 | if (auto *VecTyA = PassedType->getAs<VectorType>()) |
955 | ExpectedType = S->Context.getVectorType( |
956 | VectorType: ExpectedType, NumElts: VecTyA->getNumElements(), VecKind: VecTyA->getVectorKind()); |
957 | S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(), |
958 | DiagID: diag::err_typecheck_convert_incompatible) |
959 | << PassedType << ExpectedType << 1 << 0 << 0; |
960 | return true; |
961 | } |
962 | } |
963 | return false; |
964 | } |
965 | |
966 | bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) { |
967 | auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool { |
968 | return !PassedType->hasFloatingRepresentation(); |
969 | }; |
970 | return CheckArgsTypesAreCorrect(S, TheCall, ExpectedType: S->Context.FloatTy, |
971 | Check: checkAllFloatTypes); |
972 | } |
973 | |
974 | bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) { |
975 | auto checkFloatorHalf = [](clang::QualType PassedType) -> bool { |
976 | clang::QualType BaseType = |
977 | PassedType->isVectorType() |
978 | ? PassedType->getAs<clang::VectorType>()->getElementType() |
979 | : PassedType; |
980 | return !BaseType->isHalfType() && !BaseType->isFloat32Type(); |
981 | }; |
982 | return CheckArgsTypesAreCorrect(S, TheCall, ExpectedType: S->Context.FloatTy, |
983 | Check: checkFloatorHalf); |
984 | } |
985 | |
986 | bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) { |
987 | auto checkDoubleVector = [](clang::QualType PassedType) -> bool { |
988 | if (const auto *VecTy = PassedType->getAs<VectorType>()) |
989 | return VecTy->getElementType()->isDoubleType(); |
990 | return false; |
991 | }; |
992 | return CheckArgsTypesAreCorrect(S, TheCall, ExpectedType: S->Context.FloatTy, |
993 | Check: checkDoubleVector); |
994 | } |
995 | |
996 | bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) { |
997 | auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool { |
998 | return !PassedType->hasUnsignedIntegerRepresentation(); |
999 | }; |
1000 | return CheckArgsTypesAreCorrect(S, TheCall, ExpectedType: S->Context.UnsignedIntTy, |
1001 | Check: checkAllUnsignedTypes); |
1002 | } |
1003 | |
1004 | void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, |
1005 | QualType ReturnType) { |
1006 | auto *VecTyA = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>(); |
1007 | if (VecTyA) |
1008 | ReturnType = S->Context.getVectorType(VectorType: ReturnType, NumElts: VecTyA->getNumElements(), |
1009 | VecKind: VectorKind::Generic); |
1010 | TheCall->setType(ReturnType); |
1011 | } |
1012 | |
1013 | // Note: returning true in this case results in CheckBuiltinFunctionCall |
1014 | // returning an ExprError |
1015 | bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { |
1016 | switch (BuiltinID) { |
1017 | case Builtin::BI__builtin_hlsl_elementwise_all: |
1018 | case Builtin::BI__builtin_hlsl_elementwise_any: { |
1019 | if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1)) |
1020 | return true; |
1021 | break; |
1022 | } |
1023 | case Builtin::BI__builtin_hlsl_elementwise_clamp: { |
1024 | if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3)) |
1025 | return true; |
1026 | if (CheckVectorElementCallArgs(S: &SemaRef, TheCall)) |
1027 | return true; |
1028 | if (SemaRef.BuiltinElementwiseTernaryMath( |
1029 | TheCall, /*CheckForFloatArgs*/ |
1030 | TheCall->getArg(Arg: 0)->getType()->hasFloatingRepresentation())) |
1031 | return true; |
1032 | break; |
1033 | } |
1034 | case Builtin::BI__builtin_hlsl_dot: { |
1035 | if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2)) |
1036 | return true; |
1037 | if (CheckVectorElementCallArgs(S: &SemaRef, TheCall)) |
1038 | return true; |
1039 | if (SemaRef.BuiltinVectorToScalarMath(TheCall)) |
1040 | return true; |
1041 | if (CheckNoDoubleVectors(S: &SemaRef, TheCall)) |
1042 | return true; |
1043 | break; |
1044 | } |
1045 | case Builtin::BI__builtin_hlsl_elementwise_rcp: { |
1046 | if (CheckAllArgsHaveFloatRepresentation(S: &SemaRef, TheCall)) |
1047 | return true; |
1048 | if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) |
1049 | return true; |
1050 | break; |
1051 | } |
1052 | case Builtin::BI__builtin_hlsl_elementwise_rsqrt: |
1053 | case Builtin::BI__builtin_hlsl_elementwise_frac: { |
1054 | if (CheckFloatOrHalfRepresentations(S: &SemaRef, TheCall)) |
1055 | return true; |
1056 | if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) |
1057 | return true; |
1058 | break; |
1059 | } |
1060 | case Builtin::BI__builtin_hlsl_elementwise_isinf: { |
1061 | if (CheckFloatOrHalfRepresentations(S: &SemaRef, TheCall)) |
1062 | return true; |
1063 | if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) |
1064 | return true; |
1065 | SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().BoolTy); |
1066 | break; |
1067 | } |
1068 | case Builtin::BI__builtin_hlsl_lerp: { |
1069 | if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3)) |
1070 | return true; |
1071 | if (CheckVectorElementCallArgs(S: &SemaRef, TheCall)) |
1072 | return true; |
1073 | if (SemaRef.BuiltinElementwiseTernaryMath(TheCall)) |
1074 | return true; |
1075 | if (CheckFloatOrHalfRepresentations(S: &SemaRef, TheCall)) |
1076 | return true; |
1077 | break; |
1078 | } |
1079 | case Builtin::BI__builtin_hlsl_mad: { |
1080 | if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3)) |
1081 | return true; |
1082 | if (CheckVectorElementCallArgs(S: &SemaRef, TheCall)) |
1083 | return true; |
1084 | if (SemaRef.BuiltinElementwiseTernaryMath( |
1085 | TheCall, /*CheckForFloatArgs*/ |
1086 | TheCall->getArg(Arg: 0)->getType()->hasFloatingRepresentation())) |
1087 | return true; |
1088 | break; |
1089 | } |
1090 | // Note these are llvm builtins that we want to catch invalid intrinsic |
1091 | // generation. Normal handling of these builitns will occur elsewhere. |
1092 | case Builtin::BI__builtin_elementwise_bitreverse: { |
1093 | if (CheckUnsignedIntRepresentation(S: &SemaRef, TheCall)) |
1094 | return true; |
1095 | break; |
1096 | } |
1097 | case Builtin::BI__builtin_elementwise_acos: |
1098 | case Builtin::BI__builtin_elementwise_asin: |
1099 | case Builtin::BI__builtin_elementwise_atan: |
1100 | case Builtin::BI__builtin_elementwise_ceil: |
1101 | case Builtin::BI__builtin_elementwise_cos: |
1102 | case Builtin::BI__builtin_elementwise_cosh: |
1103 | case Builtin::BI__builtin_elementwise_exp: |
1104 | case Builtin::BI__builtin_elementwise_exp2: |
1105 | case Builtin::BI__builtin_elementwise_floor: |
1106 | case Builtin::BI__builtin_elementwise_log: |
1107 | case Builtin::BI__builtin_elementwise_log2: |
1108 | case Builtin::BI__builtin_elementwise_log10: |
1109 | case Builtin::BI__builtin_elementwise_pow: |
1110 | case Builtin::BI__builtin_elementwise_roundeven: |
1111 | case Builtin::BI__builtin_elementwise_sin: |
1112 | case Builtin::BI__builtin_elementwise_sinh: |
1113 | case Builtin::BI__builtin_elementwise_sqrt: |
1114 | case Builtin::BI__builtin_elementwise_tan: |
1115 | case Builtin::BI__builtin_elementwise_tanh: |
1116 | case Builtin::BI__builtin_elementwise_trunc: { |
1117 | if (CheckFloatOrHalfRepresentations(S: &SemaRef, TheCall)) |
1118 | return true; |
1119 | break; |
1120 | } |
1121 | } |
1122 | return false; |
1123 | } |
1124 | |