1//===- SemaSYCL.cpp - Semantic Analysis for SYCL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for SYCL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaSYCL.h"
12#include "TreeTransform.h"
13#include "clang/AST/Mangle.h"
14#include "clang/AST/SYCLKernelInfo.h"
15#include "clang/AST/StmtSYCL.h"
16#include "clang/AST/TypeOrdering.h"
17#include "clang/Basic/Diagnostic.h"
18#include "clang/Sema/Attr.h"
19#include "clang/Sema/ParsedAttr.h"
20#include "clang/Sema/Sema.h"
21
22using namespace clang;
23
24// -----------------------------------------------------------------------------
25// SYCL device specific diagnostics implementation
26// -----------------------------------------------------------------------------
27
28SemaSYCL::SemaSYCL(Sema &S) : SemaBase(S) {}
29
30Sema::SemaDiagnosticBuilder SemaSYCL::DiagIfDeviceCode(SourceLocation Loc,
31 unsigned DiagID) {
32 assert(getLangOpts().SYCLIsDevice &&
33 "Should only be called during SYCL compilation");
34 FunctionDecl *FD = dyn_cast<FunctionDecl>(Val: SemaRef.getCurLexicalContext());
35 SemaDiagnosticBuilder::Kind DiagKind = [this, FD] {
36 if (!FD)
37 return SemaDiagnosticBuilder::K_Nop;
38 if (SemaRef.getEmissionStatus(Decl: FD) == Sema::FunctionEmissionStatus::Emitted)
39 return SemaDiagnosticBuilder::K_ImmediateWithCallStack;
40 return SemaDiagnosticBuilder::K_Deferred;
41 }();
42 return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, FD, SemaRef);
43}
44
45static bool isZeroSizedArray(SemaSYCL &S, QualType Ty) {
46 if (const auto *CAT = S.getASTContext().getAsConstantArrayType(T: Ty))
47 return CAT->isZeroSize();
48 return false;
49}
50
51void SemaSYCL::deepTypeCheckForDevice(SourceLocation UsedAt,
52 llvm::DenseSet<QualType> Visited,
53 ValueDecl *DeclToCheck) {
54 assert(getLangOpts().SYCLIsDevice &&
55 "Should only be called during SYCL compilation");
56 // Emit notes only for the first discovered declaration of unsupported type
57 // to avoid mess of notes. This flag is to track that error already happened.
58 bool NeedToEmitNotes = true;
59
60 auto Check = [&](QualType TypeToCheck, const ValueDecl *D) {
61 bool ErrorFound = false;
62 if (isZeroSizedArray(S&: *this, Ty: TypeToCheck)) {
63 DiagIfDeviceCode(Loc: UsedAt, DiagID: diag::err_typecheck_zero_array_size) << 1;
64 ErrorFound = true;
65 }
66 // Checks for other types can also be done here.
67 if (ErrorFound) {
68 if (NeedToEmitNotes) {
69 if (auto *FD = dyn_cast<FieldDecl>(Val: D))
70 DiagIfDeviceCode(Loc: FD->getLocation(),
71 DiagID: diag::note_illegal_field_declared_here)
72 << FD->getType()->isPointerType() << FD->getType();
73 else
74 DiagIfDeviceCode(Loc: D->getLocation(), DiagID: diag::note_declared_at);
75 }
76 }
77
78 return ErrorFound;
79 };
80
81 // In case we have a Record used do the DFS for a bad field.
82 SmallVector<const ValueDecl *, 4> StackForRecursion;
83 StackForRecursion.push_back(Elt: DeclToCheck);
84
85 // While doing DFS save how we get there to emit a nice set of notes.
86 SmallVector<const FieldDecl *, 4> History;
87 History.push_back(Elt: nullptr);
88
89 do {
90 const ValueDecl *Next = StackForRecursion.pop_back_val();
91 if (!Next) {
92 assert(!History.empty());
93 // Found a marker, we have gone up a level.
94 History.pop_back();
95 continue;
96 }
97 QualType NextTy = Next->getType();
98
99 if (!Visited.insert(V: NextTy).second)
100 continue;
101
102 auto EmitHistory = [&]() {
103 // The first element is always nullptr.
104 for (uint64_t Index = 1; Index < History.size(); ++Index) {
105 DiagIfDeviceCode(Loc: History[Index]->getLocation(),
106 DiagID: diag::note_within_field_of_type)
107 << History[Index]->getType();
108 }
109 };
110
111 if (Check(NextTy, Next)) {
112 if (NeedToEmitNotes)
113 EmitHistory();
114 NeedToEmitNotes = false;
115 }
116
117 // In case pointer/array/reference type is met get pointee type, then
118 // proceed with that type.
119 while (NextTy->isAnyPointerType() || NextTy->isArrayType() ||
120 NextTy->isReferenceType()) {
121 if (NextTy->isArrayType())
122 NextTy = QualType{NextTy->getArrayElementTypeNoTypeQual(), 0};
123 else
124 NextTy = NextTy->getPointeeType();
125 if (Check(NextTy, Next)) {
126 if (NeedToEmitNotes)
127 EmitHistory();
128 NeedToEmitNotes = false;
129 }
130 }
131
132 if (const auto *RecDecl = NextTy->getAsRecordDecl()) {
133 if (auto *NextFD = dyn_cast<FieldDecl>(Val: Next))
134 History.push_back(Elt: NextFD);
135 // When nullptr is discovered, this means we've gone back up a level, so
136 // the history should be cleaned.
137 StackForRecursion.push_back(Elt: nullptr);
138 llvm::append_range(C&: StackForRecursion, R: RecDecl->fields());
139 }
140 } while (!StackForRecursion.empty());
141}
142
143ExprResult SemaSYCL::BuildUniqueStableNameExpr(SourceLocation OpLoc,
144 SourceLocation LParen,
145 SourceLocation RParen,
146 TypeSourceInfo *TSI) {
147 return SYCLUniqueStableNameExpr::Create(Ctx: getASTContext(), OpLoc, LParen,
148 RParen, TSI);
149}
150
151ExprResult SemaSYCL::ActOnUniqueStableNameExpr(SourceLocation OpLoc,
152 SourceLocation LParen,
153 SourceLocation RParen,
154 ParsedType ParsedTy) {
155 TypeSourceInfo *TSI = nullptr;
156 QualType Ty = SemaRef.GetTypeFromParser(Ty: ParsedTy, TInfo: &TSI);
157
158 if (Ty.isNull())
159 return ExprError();
160 if (!TSI)
161 TSI = getASTContext().getTrivialTypeSourceInfo(T: Ty, Loc: LParen);
162
163 return BuildUniqueStableNameExpr(OpLoc, LParen, RParen, TSI);
164}
165
166void SemaSYCL::handleKernelAttr(Decl *D, const ParsedAttr &AL) {
167 // The 'sycl_kernel' attribute applies only to function templates.
168 const auto *FD = cast<FunctionDecl>(Val: D);
169 const FunctionTemplateDecl *FT = FD->getDescribedFunctionTemplate();
170 assert(FT && "Function template is expected");
171
172 // Function template must have at least two template parameters.
173 const TemplateParameterList *TL = FT->getTemplateParameters();
174 if (TL->size() < 2) {
175 Diag(Loc: FT->getLocation(), DiagID: diag::warn_sycl_kernel_num_of_template_params);
176 return;
177 }
178
179 // Template parameters must be typenames.
180 for (unsigned I = 0; I < 2; ++I) {
181 const NamedDecl *TParam = TL->getParam(Idx: I);
182 if (isa<NonTypeTemplateParmDecl>(Val: TParam)) {
183 Diag(Loc: FT->getLocation(),
184 DiagID: diag::warn_sycl_kernel_invalid_template_param_type);
185 return;
186 }
187 }
188
189 // Function must have at least one argument.
190 if (getFunctionOrMethodNumParams(D) != 1) {
191 Diag(Loc: FT->getLocation(), DiagID: diag::warn_sycl_kernel_num_of_function_params);
192 return;
193 }
194
195 // Function must return void.
196 QualType RetTy = getFunctionOrMethodResultType(D);
197 if (!RetTy->isVoidType()) {
198 Diag(Loc: FT->getLocation(), DiagID: diag::warn_sycl_kernel_return_type);
199 return;
200 }
201
202 handleSimpleAttribute<SYCLKernelAttr>(S&: *this, D, CI: AL);
203}
204
205void SemaSYCL::handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL) {
206 ParsedType PT = AL.getTypeArg();
207 TypeSourceInfo *TSI = nullptr;
208 (void)SemaRef.GetTypeFromParser(Ty: PT, TInfo: &TSI);
209 assert(TSI && "no type source info for attribute argument");
210 D->addAttr(A: ::new (SemaRef.Context)
211 SYCLKernelEntryPointAttr(SemaRef.Context, AL, TSI));
212}
213
214// Given a potentially qualified type, SourceLocationForUserDeclaredType()
215// returns the source location of the canonical declaration of the unqualified
216// desugared user declared type, if any. For non-user declared types, an
217// invalid source location is returned. The intended usage of this function
218// is to identify an appropriate source location, if any, for a
219// "entity declared here" diagnostic note.
220static SourceLocation SourceLocationForUserDeclaredType(QualType QT) {
221 SourceLocation Loc;
222 const Type *T = QT->getUnqualifiedDesugaredType();
223 if (const TagType *TT = dyn_cast<TagType>(Val: T))
224 Loc = TT->getDecl()->getLocation();
225 else if (const auto *ObjCIT = dyn_cast<ObjCInterfaceType>(Val: T))
226 Loc = ObjCIT->getDecl()->getLocation();
227 return Loc;
228}
229
230static bool CheckSYCLKernelName(Sema &S, SourceLocation Loc,
231 QualType KernelName) {
232 assert(!KernelName->isDependentType());
233
234 if (!KernelName->isStructureOrClassType()) {
235 // SYCL 2020 section 5.2, "Naming of kernels", only requires that the
236 // kernel name be a C++ typename. However, the definition of "kernel name"
237 // in the glossary states that a kernel name is a class type. Neither
238 // section explicitly states whether the kernel name type can be
239 // cv-qualified. For now, kernel name types are required to be class types
240 // and that they may be cv-qualified. The following issue requests
241 // clarification from the SYCL WG.
242 // https://github.com/KhronosGroup/SYCL-Docs/issues/568
243 S.Diag(Loc, DiagID: diag::warn_sycl_kernel_name_not_a_class_type) << KernelName;
244 SourceLocation DeclTypeLoc = SourceLocationForUserDeclaredType(QT: KernelName);
245 if (DeclTypeLoc.isValid())
246 S.Diag(Loc: DeclTypeLoc, DiagID: diag::note_entity_declared_at) << KernelName;
247 return true;
248 }
249
250 return false;
251}
252
253void SemaSYCL::CheckSYCLExternalFunctionDecl(FunctionDecl *FD) {
254 const auto *SEAttr = FD->getAttr<SYCLExternalAttr>();
255 assert(SEAttr && "Missing sycl_external attribute");
256 if (!FD->isInvalidDecl() && !FD->isTemplated()) {
257 if (!FD->isExternallyVisible())
258 if (!FD->isFunctionTemplateSpecialization() ||
259 FD->getTemplateSpecializationInfo()->isExplicitSpecialization())
260 Diag(Loc: SEAttr->getLocation(), DiagID: diag::err_sycl_external_invalid_linkage)
261 << SEAttr;
262 }
263 if (FD->isDeletedAsWritten()) {
264 Diag(Loc: SEAttr->getLocation(),
265 DiagID: diag::err_sycl_external_invalid_deleted_function)
266 << SEAttr;
267 }
268}
269
270void SemaSYCL::CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD) {
271 // Ensure that all attributes present on the declaration are consistent
272 // and warn about any redundant ones.
273 SYCLKernelEntryPointAttr *SKEPAttr = nullptr;
274 for (auto *SAI : FD->specific_attrs<SYCLKernelEntryPointAttr>()) {
275 if (!SKEPAttr) {
276 SKEPAttr = SAI;
277 continue;
278 }
279 if (!getASTContext().hasSameType(T1: SAI->getKernelName(),
280 T2: SKEPAttr->getKernelName())) {
281 Diag(Loc: SAI->getLocation(), DiagID: diag::err_sycl_entry_point_invalid_redeclaration)
282 << SKEPAttr << SAI->getKernelName() << SKEPAttr->getKernelName();
283 Diag(Loc: SKEPAttr->getLocation(), DiagID: diag::note_previous_attribute);
284 SAI->setInvalidAttr();
285 } else {
286 Diag(Loc: SAI->getLocation(),
287 DiagID: diag::warn_sycl_entry_point_redundant_declaration)
288 << SAI;
289 Diag(Loc: SKEPAttr->getLocation(), DiagID: diag::note_previous_attribute);
290 }
291 }
292 assert(SKEPAttr && "Missing sycl_kernel_entry_point attribute");
293
294 // Ensure the kernel name type is valid.
295 if (!SKEPAttr->getKernelName()->isDependentType() &&
296 CheckSYCLKernelName(S&: SemaRef, Loc: SKEPAttr->getLocation(),
297 KernelName: SKEPAttr->getKernelName()))
298 SKEPAttr->setInvalidAttr();
299
300 // Ensure that an attribute present on the previous declaration
301 // matches the one on this declaration.
302 FunctionDecl *PrevFD = FD->getPreviousDecl();
303 if (PrevFD && !PrevFD->isInvalidDecl()) {
304 const auto *PrevSKEPAttr = PrevFD->getAttr<SYCLKernelEntryPointAttr>();
305 if (PrevSKEPAttr && !PrevSKEPAttr->isInvalidAttr()) {
306 if (!getASTContext().hasSameType(T1: SKEPAttr->getKernelName(),
307 T2: PrevSKEPAttr->getKernelName())) {
308 Diag(Loc: SKEPAttr->getLocation(),
309 DiagID: diag::err_sycl_entry_point_invalid_redeclaration)
310 << SKEPAttr << SKEPAttr->getKernelName()
311 << PrevSKEPAttr->getKernelName();
312 Diag(Loc: PrevSKEPAttr->getLocation(), DiagID: diag::note_previous_decl) << PrevFD;
313 SKEPAttr->setInvalidAttr();
314 }
315 }
316 }
317
318 if (const auto *MD = dyn_cast<CXXMethodDecl>(Val: FD)) {
319 if (!MD->isStatic()) {
320 Diag(Loc: SKEPAttr->getLocation(), DiagID: diag::err_sycl_entry_point_invalid)
321 << SKEPAttr << diag::InvalidSKEPReason::NonStaticMemberFn;
322 SKEPAttr->setInvalidAttr();
323 }
324 }
325
326 if (FD->isVariadic()) {
327 Diag(Loc: SKEPAttr->getLocation(), DiagID: diag::err_sycl_entry_point_invalid)
328 << SKEPAttr << diag::InvalidSKEPReason::VariadicFn;
329 SKEPAttr->setInvalidAttr();
330 }
331
332 if (FD->isDefaulted()) {
333 Diag(Loc: SKEPAttr->getLocation(), DiagID: diag::err_sycl_entry_point_invalid)
334 << SKEPAttr << diag::InvalidSKEPReason::DefaultedFn;
335 SKEPAttr->setInvalidAttr();
336 } else if (FD->isDeleted()) {
337 Diag(Loc: SKEPAttr->getLocation(), DiagID: diag::err_sycl_entry_point_invalid)
338 << SKEPAttr << diag::InvalidSKEPReason::DeletedFn;
339 SKEPAttr->setInvalidAttr();
340 }
341
342 if (FD->isConsteval()) {
343 Diag(Loc: SKEPAttr->getLocation(), DiagID: diag::err_sycl_entry_point_invalid)
344 << SKEPAttr << diag::InvalidSKEPReason::ConstevalFn;
345 SKEPAttr->setInvalidAttr();
346 } else if (FD->isConstexpr()) {
347 Diag(Loc: SKEPAttr->getLocation(), DiagID: diag::err_sycl_entry_point_invalid)
348 << SKEPAttr << diag::InvalidSKEPReason::ConstexprFn;
349 SKEPAttr->setInvalidAttr();
350 }
351
352 if (FD->isNoReturn()) {
353 Diag(Loc: SKEPAttr->getLocation(), DiagID: diag::err_sycl_entry_point_invalid)
354 << SKEPAttr << diag::InvalidSKEPReason::NoreturnFn;
355 SKEPAttr->setInvalidAttr();
356 }
357
358 if (FD->getReturnType()->isUndeducedType()) {
359 Diag(Loc: SKEPAttr->getLocation(),
360 DiagID: diag::err_sycl_entry_point_deduced_return_type)
361 << SKEPAttr;
362 SKEPAttr->setInvalidAttr();
363 } else if (!FD->getReturnType()->isDependentType() &&
364 !FD->getReturnType()->isVoidType()) {
365 Diag(Loc: SKEPAttr->getLocation(), DiagID: diag::err_sycl_entry_point_return_type)
366 << SKEPAttr;
367 SKEPAttr->setInvalidAttr();
368 }
369
370 if (!FD->isInvalidDecl() && !FD->isTemplated() &&
371 !SKEPAttr->isInvalidAttr()) {
372 const SYCLKernelInfo *SKI =
373 getASTContext().findSYCLKernelInfo(T: SKEPAttr->getKernelName());
374 if (SKI) {
375 if (!declaresSameEntity(D1: FD, D2: SKI->getKernelEntryPointDecl())) {
376 // FIXME: This diagnostic should include the origin of the kernel
377 // FIXME: names; not just the locations of the conflicting declarations.
378 Diag(Loc: FD->getLocation(), DiagID: diag::err_sycl_kernel_name_conflict)
379 << SKEPAttr;
380 Diag(Loc: SKI->getKernelEntryPointDecl()->getLocation(),
381 DiagID: diag::note_previous_declaration);
382 SKEPAttr->setInvalidAttr();
383 }
384 } else {
385 getASTContext().registerSYCLEntryPointFunction(FD);
386 }
387 }
388}
389
390namespace {
391
392// The body of a function declared with the [[sycl_kernel_entry_point]]
393// attribute is cloned and transformed to substitute references to the original
394// function parameters with references to replacement variables that stand in
395// for SYCL kernel parameters or local variables that reconstitute a decomposed
396// SYCL kernel argument.
397class OutlinedFunctionDeclBodyInstantiator
398 : public TreeTransform<OutlinedFunctionDeclBodyInstantiator> {
399public:
400 using ParmDeclMap = llvm::DenseMap<ParmVarDecl *, VarDecl *>;
401
402 OutlinedFunctionDeclBodyInstantiator(Sema &S, ParmDeclMap &M)
403 : TreeTransform<OutlinedFunctionDeclBodyInstantiator>(S), SemaRef(S),
404 MapRef(M) {}
405
406 // A new set of AST nodes is always required.
407 bool AlwaysRebuild() { return true; }
408
409 // Transform ParmVarDecl references to the supplied replacement variables.
410 ExprResult TransformDeclRefExpr(DeclRefExpr *DRE) {
411 const ParmVarDecl *PVD = dyn_cast<ParmVarDecl>(Val: DRE->getDecl());
412 if (PVD) {
413 ParmDeclMap::iterator I = MapRef.find(Val: PVD);
414 if (I != MapRef.end()) {
415 VarDecl *VD = I->second;
416 assert(SemaRef.getASTContext().hasSameUnqualifiedType(PVD->getType(),
417 VD->getType()));
418 assert(!VD->getType().isMoreQualifiedThan(PVD->getType(),
419 SemaRef.getASTContext()));
420 VD->setIsUsed();
421 return DeclRefExpr::Create(
422 Context: SemaRef.getASTContext(), QualifierLoc: DRE->getQualifierLoc(),
423 TemplateKWLoc: DRE->getTemplateKeywordLoc(), D: VD, RefersToEnclosingVariableOrCapture: false, NameInfo: DRE->getNameInfo(),
424 T: DRE->getType(), VK: DRE->getValueKind());
425 }
426 }
427 return DRE;
428 }
429
430private:
431 Sema &SemaRef;
432 ParmDeclMap &MapRef;
433};
434
435} // unnamed namespace
436
437StmtResult SemaSYCL::BuildSYCLKernelCallStmt(FunctionDecl *FD,
438 CompoundStmt *Body) {
439 assert(!FD->isInvalidDecl());
440 assert(!FD->isTemplated());
441 assert(FD->hasPrototype());
442
443 const auto *SKEPAttr = FD->getAttr<SYCLKernelEntryPointAttr>();
444 assert(SKEPAttr && "Missing sycl_kernel_entry_point attribute");
445 assert(!SKEPAttr->isInvalidAttr() &&
446 "sycl_kernel_entry_point attribute is invalid");
447
448 // Ensure that the kernel name was previously registered and that the
449 // stored declaration matches.
450 const SYCLKernelInfo &SKI =
451 getASTContext().getSYCLKernelInfo(T: SKEPAttr->getKernelName());
452 assert(declaresSameEntity(SKI.getKernelEntryPointDecl(), FD) &&
453 "SYCL kernel name conflict");
454 (void)SKI;
455
456 using ParmDeclMap = OutlinedFunctionDeclBodyInstantiator::ParmDeclMap;
457 ParmDeclMap ParmMap;
458
459 assert(SemaRef.CurContext == FD);
460 OutlinedFunctionDecl *OFD =
461 OutlinedFunctionDecl::Create(C&: getASTContext(), DC: FD, NumParams: FD->getNumParams());
462 unsigned i = 0;
463 for (ParmVarDecl *PVD : FD->parameters()) {
464 ImplicitParamDecl *IPD = ImplicitParamDecl::Create(
465 C&: getASTContext(), DC: OFD, IdLoc: SourceLocation(), Id: PVD->getIdentifier(),
466 T: PVD->getType(), ParamKind: ImplicitParamKind::Other);
467 OFD->setParam(i, P: IPD);
468 ParmMap[PVD] = IPD;
469 ++i;
470 }
471
472 OutlinedFunctionDeclBodyInstantiator OFDBodyInstantiator(SemaRef, ParmMap);
473 Stmt *OFDBody = OFDBodyInstantiator.TransformStmt(S: Body).get();
474 OFD->setBody(OFDBody);
475 OFD->setNothrow();
476 Stmt *NewBody = new (getASTContext()) SYCLKernelCallStmt(Body, OFD);
477
478 return NewBody;
479}
480