1//===- SemaSPIRV.cpp - Semantic Analysis for SPIRV constructs--------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for SPIRV constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaSPIRV.h"
12#include "clang/Basic/TargetBuiltins.h"
13#include "clang/Basic/TargetInfo.h"
14#include "clang/Sema/Sema.h"
15
16// SPIR-V enumerants. Enums have only the required entries, see SPIR-V specs for
17// values.
18// FIXME: either use the SPIRV-Headers or generate a custom header using the
19// grammar (like done with MLIR).
20namespace spirv {
21enum class StorageClass : int {
22 Workgroup = 4,
23 CrossWorkgroup = 5,
24 Function = 7
25};
26}
27
28namespace clang {
29
30SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {}
31
32static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
33 assert(TheCall->getNumArgs() > 1);
34 QualType ArgTy0 = TheCall->getArg(Arg: 0)->getType();
35
36 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
37 if (!S->getASTContext().hasSameUnqualifiedType(
38 T1: ArgTy0, T2: TheCall->getArg(Arg: I)->getType())) {
39 S->Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vec_builtin_incompatible_vector)
40 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
41 << SourceRange(TheCall->getArg(Arg: 0)->getBeginLoc(),
42 TheCall->getArg(Arg: N - 1)->getEndLoc());
43 return true;
44 }
45 }
46 return false;
47}
48
49static bool CheckAllArgTypesAreCorrect(
50 Sema *S, CallExpr *TheCall,
51 llvm::ArrayRef<
52 llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
53 Checks) {
54 unsigned NumArgs = TheCall->getNumArgs();
55 assert(Checks.size() == NumArgs &&
56 "Wrong number of checks for Number of args.");
57 // Apply each check to the corresponding argument
58 for (unsigned I = 0; I < NumArgs; ++I) {
59 Expr *Arg = TheCall->getArg(Arg: I);
60 if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
61 return true;
62 }
63 return false;
64}
65
66static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
67 int ArgOrdinal,
68 clang::QualType PassedType) {
69 clang::QualType BaseType =
70 PassedType->isVectorType()
71 ? PassedType->castAs<clang::VectorType>()->getElementType()
72 : PassedType;
73 if (!BaseType->isHalfType() && !BaseType->isFloat16Type() &&
74 !BaseType->isFloat32Type())
75 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
76 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
77 << /* half or float */ 2 << PassedType;
78 return false;
79}
80
81static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
82 int ArgOrdinal,
83 clang::QualType PassedType) {
84 if (!PassedType->isHalfType() && !PassedType->isFloat16Type() &&
85 !PassedType->isFloat32Type())
86 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
87 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0
88 << /* half or float */ 2 << PassedType;
89 return false;
90}
91
92static std::optional<int>
93processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) {
94 ExprResult Arg =
95 SemaRef.DefaultFunctionArrayLvalueConversion(E: Call->getArg(Arg: Argument));
96 if (Arg.isInvalid())
97 return true;
98 Call->setArg(Arg: Argument, ArgExpr: Arg.get());
99
100 const Expr *IntArg = Arg.get();
101 SmallVector<PartialDiagnosticAt, 8> Notes;
102 Expr::EvalResult Eval;
103 Eval.Diag = &Notes;
104 if ((!IntArg->EvaluateAsConstantExpr(Result&: Eval, Ctx: SemaRef.getASTContext())) ||
105 !Eval.Val.isInt() || Eval.Val.getInt().getBitWidth() > 32) {
106 SemaRef.Diag(Loc: IntArg->getBeginLoc(), DiagID: diag::err_spirv_enum_not_int)
107 << 0 << IntArg->getSourceRange();
108 for (const PartialDiagnosticAt &PDiag : Notes)
109 SemaRef.Diag(Loc: PDiag.first, PD: PDiag.second);
110 return true;
111 }
112 return {Eval.Val.getInt().getZExtValue()};
113}
114
115static bool checkGenericCastToPtr(Sema &SemaRef, CallExpr *Call) {
116 if (SemaRef.checkArgCount(Call, DesiredArgCount: 2))
117 return true;
118
119 {
120 ExprResult Arg =
121 SemaRef.DefaultFunctionArrayLvalueConversion(E: Call->getArg(Arg: 0));
122 if (Arg.isInvalid())
123 return true;
124 Call->setArg(Arg: 0, ArgExpr: Arg.get());
125
126 QualType Ty = Arg.get()->getType();
127 const auto *PtrTy = Ty->getAs<PointerType>();
128 auto AddressSpaceNotInGeneric = [&](LangAS AS) {
129 if (SemaRef.LangOpts.OpenCL)
130 return AS != LangAS::opencl_generic;
131 return AS != LangAS::Default;
132 };
133 if (!PtrTy ||
134 AddressSpaceNotInGeneric(PtrTy->getPointeeType().getAddressSpace())) {
135 SemaRef.Diag(Loc: Arg.get()->getBeginLoc(),
136 DiagID: diag::err_spirv_builtin_generic_cast_invalid_arg)
137 << Call->getSourceRange();
138 return true;
139 }
140 }
141
142 spirv::StorageClass StorageClass;
143 if (std::optional<int> SCInt =
144 processConstant32BitIntArgument(SemaRef, Call, Argument: 1);
145 SCInt.has_value()) {
146 StorageClass = static_cast<spirv::StorageClass>(SCInt.value());
147 if (StorageClass != spirv::StorageClass::CrossWorkgroup &&
148 StorageClass != spirv::StorageClass::Workgroup &&
149 StorageClass != spirv::StorageClass::Function) {
150 SemaRef.Diag(Loc: Call->getArg(Arg: 1)->getBeginLoc(),
151 DiagID: diag::err_spirv_enum_not_valid)
152 << 0 << Call->getArg(Arg: 1)->getSourceRange();
153 return true;
154 }
155 } else {
156 return true;
157 }
158 auto RT = Call->getArg(Arg: 0)->getType();
159 RT = RT->getPointeeType();
160 auto Qual = RT.getQualifiers();
161 LangAS AddrSpace;
162 switch (StorageClass) {
163 case spirv::StorageClass::CrossWorkgroup:
164 AddrSpace =
165 SemaRef.LangOpts.isSYCL() ? LangAS::sycl_global : LangAS::opencl_global;
166 break;
167 case spirv::StorageClass::Workgroup:
168 AddrSpace =
169 SemaRef.LangOpts.isSYCL() ? LangAS::sycl_local : LangAS::opencl_local;
170 break;
171 case spirv::StorageClass::Function:
172 AddrSpace = SemaRef.LangOpts.isSYCL() ? LangAS::sycl_private
173 : LangAS::opencl_private;
174 break;
175 }
176 Qual.setAddressSpace(AddrSpace);
177 Call->setType(SemaRef.getASTContext().getPointerType(
178 T: SemaRef.getASTContext().getQualifiedType(T: RT.getUnqualifiedType(), Qs: Qual)));
179
180 return false;
181}
182
183bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
184 unsigned BuiltinID,
185 CallExpr *TheCall) {
186 if (BuiltinID >= SPIRV::FirstVKBuiltin && BuiltinID <= SPIRV::LastVKBuiltin &&
187 TI.getTriple().getArch() != llvm::Triple::spirv) {
188 SemaRef.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_spirv_invalid_target) << 0;
189 return true;
190 }
191 if (BuiltinID >= SPIRV::FirstCLBuiltin && BuiltinID <= SPIRV::LastTSBuiltin &&
192 TI.getTriple().getArch() != llvm::Triple::spirv32 &&
193 TI.getTriple().getArch() != llvm::Triple::spirv64) {
194 SemaRef.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_spirv_invalid_target) << 1;
195 return true;
196 }
197
198 switch (BuiltinID) {
199 case SPIRV::BI__builtin_spirv_distance: {
200 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
201 return true;
202
203 ExprResult A = TheCall->getArg(Arg: 0);
204 QualType ArgTyA = A.get()->getType();
205 auto *VTyA = ArgTyA->getAs<VectorType>();
206 if (VTyA == nullptr) {
207 SemaRef.Diag(Loc: A.get()->getBeginLoc(),
208 DiagID: diag::err_typecheck_convert_incompatible)
209 << ArgTyA
210 << SemaRef.Context.getVectorType(VectorType: ArgTyA, NumElts: 2, VecKind: VectorKind::Generic) << 1
211 << 0 << 0;
212 return true;
213 }
214
215 ExprResult B = TheCall->getArg(Arg: 1);
216 QualType ArgTyB = B.get()->getType();
217 auto *VTyB = ArgTyB->getAs<VectorType>();
218 if (VTyB == nullptr) {
219 SemaRef.Diag(Loc: A.get()->getBeginLoc(),
220 DiagID: diag::err_typecheck_convert_incompatible)
221 << ArgTyB
222 << SemaRef.Context.getVectorType(VectorType: ArgTyB, NumElts: 2, VecKind: VectorKind::Generic) << 1
223 << 0 << 0;
224 return true;
225 }
226
227 QualType RetTy = VTyA->getElementType();
228 TheCall->setType(RetTy);
229 break;
230 }
231 case SPIRV::BI__builtin_spirv_length: {
232 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
233 return true;
234 ExprResult A = TheCall->getArg(Arg: 0);
235 QualType ArgTyA = A.get()->getType();
236 auto *VTy = ArgTyA->getAs<VectorType>();
237 if (VTy == nullptr) {
238 SemaRef.Diag(Loc: A.get()->getBeginLoc(),
239 DiagID: diag::err_typecheck_convert_incompatible)
240 << ArgTyA
241 << SemaRef.Context.getVectorType(VectorType: ArgTyA, NumElts: 2, VecKind: VectorKind::Generic) << 1
242 << 0 << 0;
243 return true;
244 }
245 QualType RetTy = VTy->getElementType();
246 TheCall->setType(RetTy);
247 break;
248 }
249 case SPIRV::BI__builtin_spirv_reflect: {
250 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
251 return true;
252
253 ExprResult A = TheCall->getArg(Arg: 0);
254 QualType ArgTyA = A.get()->getType();
255 auto *VTyA = ArgTyA->getAs<VectorType>();
256 if (VTyA == nullptr) {
257 SemaRef.Diag(Loc: A.get()->getBeginLoc(),
258 DiagID: diag::err_typecheck_convert_incompatible)
259 << ArgTyA
260 << SemaRef.Context.getVectorType(VectorType: ArgTyA, NumElts: 2, VecKind: VectorKind::Generic) << 1
261 << 0 << 0;
262 return true;
263 }
264
265 ExprResult B = TheCall->getArg(Arg: 1);
266 QualType ArgTyB = B.get()->getType();
267 auto *VTyB = ArgTyB->getAs<VectorType>();
268 if (VTyB == nullptr) {
269 SemaRef.Diag(Loc: A.get()->getBeginLoc(),
270 DiagID: diag::err_typecheck_convert_incompatible)
271 << ArgTyB
272 << SemaRef.Context.getVectorType(VectorType: ArgTyB, NumElts: 2, VecKind: VectorKind::Generic) << 1
273 << 0 << 0;
274 return true;
275 }
276
277 QualType RetTy = ArgTyA;
278 TheCall->setType(RetTy);
279 break;
280 }
281 case SPIRV::BI__builtin_spirv_refract: {
282 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
283 return true;
284
285 llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
286 ChecksArr[] = {CheckFloatOrHalfRepresentation,
287 CheckFloatOrHalfRepresentation,
288 CheckFloatOrHalfScalarRepresentation};
289 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
290 Checks: llvm::ArrayRef(ChecksArr)))
291 return true;
292 // Check that first two arguments are vectors/scalars of the same type
293 QualType Arg0Type = TheCall->getArg(Arg: 0)->getType();
294 if (!SemaRef.getASTContext().hasSameUnqualifiedType(
295 T1: Arg0Type, T2: TheCall->getArg(Arg: 1)->getType()))
296 return SemaRef.Diag(Loc: TheCall->getBeginLoc(),
297 DiagID: diag::err_vec_builtin_incompatible_vector)
298 << TheCall->getDirectCallee() << /* first two */ 0
299 << SourceRange(TheCall->getArg(Arg: 0)->getBeginLoc(),
300 TheCall->getArg(Arg: 1)->getEndLoc());
301
302 // Check that scalar type of 3rd arg is same as base type of first two args
303 clang::QualType BaseType =
304 Arg0Type->isVectorType()
305 ? Arg0Type->castAs<clang::VectorType>()->getElementType()
306 : Arg0Type;
307 if (!SemaRef.getASTContext().hasSameUnqualifiedType(
308 T1: BaseType, T2: TheCall->getArg(Arg: 2)->getType()))
309 return SemaRef.Diag(Loc: TheCall->getBeginLoc(),
310 DiagID: diag::err_hlsl_builtin_scalar_vector_mismatch)
311 << /* all */ 0 << TheCall->getDirectCallee() << Arg0Type
312 << TheCall->getArg(Arg: 2)->getType();
313
314 QualType RetTy = TheCall->getArg(Arg: 0)->getType();
315 TheCall->setType(RetTy);
316 break;
317 }
318 case SPIRV::BI__builtin_spirv_smoothstep: {
319 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
320 return true;
321
322 // Check if first argument has floating representation
323 ExprResult A = TheCall->getArg(Arg: 0);
324 QualType ArgTyA = A.get()->getType();
325 if (!ArgTyA->hasFloatingRepresentation()) {
326 SemaRef.Diag(Loc: A.get()->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
327 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
328 << /* fp */ 1 << ArgTyA;
329 return true;
330 }
331
332 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
333 return true;
334
335 QualType RetTy = ArgTyA;
336 TheCall->setType(RetTy);
337 break;
338 }
339 case SPIRV::BI__builtin_spirv_faceforward: {
340 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
341 return true;
342
343 // Check if first argument has floating representation
344 ExprResult A = TheCall->getArg(Arg: 0);
345 QualType ArgTyA = A.get()->getType();
346 if (!ArgTyA->hasFloatingRepresentation()) {
347 SemaRef.Diag(Loc: A.get()->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
348 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
349 << /* fp */ 1 << ArgTyA;
350 return true;
351 }
352
353 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
354 return true;
355
356 QualType RetTy = ArgTyA;
357 TheCall->setType(RetTy);
358 break;
359 }
360 case SPIRV::BI__builtin_spirv_generic_cast_to_ptr_explicit: {
361 return checkGenericCastToPtr(SemaRef, Call: TheCall);
362 }
363 case SPIRV::BI__builtin_spirv_ddx:
364 case SPIRV::BI__builtin_spirv_ddy:
365 case SPIRV::BI__builtin_spirv_fwidth: {
366 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
367 return true;
368
369 // Check if first argument has floating representation
370 ExprResult A = TheCall->getArg(Arg: 0);
371 QualType ArgTyA = A.get()->getType();
372 if (!ArgTyA->hasFloatingRepresentation()) {
373 SemaRef.Diag(Loc: A.get()->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
374 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
375 << /* fp */ 1 << ArgTyA;
376 return true;
377 }
378
379 QualType RetTy = ArgTyA;
380 TheCall->setType(RetTy);
381 break;
382 }
383 case SPIRV::BI__builtin_spirv_subgroup_shuffle: {
384 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
385 return true;
386
387 ExprResult A =
388 SemaRef.DefaultFunctionArrayLvalueConversion(E: TheCall->getArg(Arg: 0));
389 if (A.isInvalid())
390 return true;
391 TheCall->setArg(Arg: 0, ArgExpr: A.get());
392
393 QualType ArgTyA = A.get()->getType();
394 if (!ArgTyA->isIntegerType() && !ArgTyA->isFloatingType()) {
395 SemaRef.Diag(Loc: A.get()->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
396 << /* ordinal */ 1 << /* scalar */ 1 << /* no int */ 0
397 << /* no fp */ 0 << ArgTyA;
398 return true;
399 }
400
401 ExprResult B =
402 SemaRef.DefaultFunctionArrayLvalueConversion(E: TheCall->getArg(Arg: 1));
403 if (B.isInvalid())
404 return true;
405
406 QualType Uint32Ty =
407 SemaRef.getASTContext().getIntTypeForBitwidth(DestWidth: 32,
408 /*Signed=*/false);
409 ExprResult ResB = SemaRef.PerformImplicitConversion(
410 From: B.get(), ToType: Uint32Ty, Action: AssignmentAction::Passing);
411 if (ResB.isInvalid())
412 return true;
413 TheCall->setArg(Arg: 1, ArgExpr: ResB.get());
414
415 QualType RetTy = ArgTyA;
416 TheCall->setType(RetTy);
417 break;
418 }
419 }
420 return false;
421}
422} // namespace clang
423