1//===- SemaSYCL.cpp - Semantic Analysis for SYCL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for SYCL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaSYCL.h"
12#include "clang/AST/Mangle.h"
13#include "clang/Sema/Attr.h"
14#include "clang/Sema/ParsedAttr.h"
15#include "clang/Sema/Sema.h"
16#include "clang/Sema/SemaDiagnostic.h"
17
18using namespace clang;
19
20// -----------------------------------------------------------------------------
21// SYCL device specific diagnostics implementation
22// -----------------------------------------------------------------------------
23
24SemaSYCL::SemaSYCL(Sema &S) : SemaBase(S) {}
25
26Sema::SemaDiagnosticBuilder SemaSYCL::DiagIfDeviceCode(SourceLocation Loc,
27 unsigned DiagID) {
28 assert(getLangOpts().SYCLIsDevice &&
29 "Should only be called during SYCL compilation");
30 FunctionDecl *FD = dyn_cast<FunctionDecl>(Val: SemaRef.getCurLexicalContext());
31 SemaDiagnosticBuilder::Kind DiagKind = [this, FD] {
32 if (!FD)
33 return SemaDiagnosticBuilder::K_Nop;
34 if (SemaRef.getEmissionStatus(Decl: FD) == Sema::FunctionEmissionStatus::Emitted)
35 return SemaDiagnosticBuilder::K_ImmediateWithCallStack;
36 return SemaDiagnosticBuilder::K_Deferred;
37 }();
38 return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, FD, SemaRef);
39}
40
41static bool isZeroSizedArray(SemaSYCL &S, QualType Ty) {
42 if (const auto *CAT = S.getASTContext().getAsConstantArrayType(T: Ty))
43 return CAT->isZeroSize();
44 return false;
45}
46
47void SemaSYCL::deepTypeCheckForDevice(SourceLocation UsedAt,
48 llvm::DenseSet<QualType> Visited,
49 ValueDecl *DeclToCheck) {
50 assert(getLangOpts().SYCLIsDevice &&
51 "Should only be called during SYCL compilation");
52 // Emit notes only for the first discovered declaration of unsupported type
53 // to avoid mess of notes. This flag is to track that error already happened.
54 bool NeedToEmitNotes = true;
55
56 auto Check = [&](QualType TypeToCheck, const ValueDecl *D) {
57 bool ErrorFound = false;
58 if (isZeroSizedArray(S&: *this, Ty: TypeToCheck)) {
59 DiagIfDeviceCode(Loc: UsedAt, DiagID: diag::err_typecheck_zero_array_size) << 1;
60 ErrorFound = true;
61 }
62 // Checks for other types can also be done here.
63 if (ErrorFound) {
64 if (NeedToEmitNotes) {
65 if (auto *FD = dyn_cast<FieldDecl>(Val: D))
66 DiagIfDeviceCode(Loc: FD->getLocation(),
67 DiagID: diag::note_illegal_field_declared_here)
68 << FD->getType()->isPointerType() << FD->getType();
69 else
70 DiagIfDeviceCode(Loc: D->getLocation(), DiagID: diag::note_declared_at);
71 }
72 }
73
74 return ErrorFound;
75 };
76
77 // In case we have a Record used do the DFS for a bad field.
78 SmallVector<const ValueDecl *, 4> StackForRecursion;
79 StackForRecursion.push_back(Elt: DeclToCheck);
80
81 // While doing DFS save how we get there to emit a nice set of notes.
82 SmallVector<const FieldDecl *, 4> History;
83 History.push_back(Elt: nullptr);
84
85 do {
86 const ValueDecl *Next = StackForRecursion.pop_back_val();
87 if (!Next) {
88 assert(!History.empty());
89 // Found a marker, we have gone up a level.
90 History.pop_back();
91 continue;
92 }
93 QualType NextTy = Next->getType();
94
95 if (!Visited.insert(V: NextTy).second)
96 continue;
97
98 auto EmitHistory = [&]() {
99 // The first element is always nullptr.
100 for (uint64_t Index = 1; Index < History.size(); ++Index) {
101 DiagIfDeviceCode(Loc: History[Index]->getLocation(),
102 DiagID: diag::note_within_field_of_type)
103 << History[Index]->getType();
104 }
105 };
106
107 if (Check(NextTy, Next)) {
108 if (NeedToEmitNotes)
109 EmitHistory();
110 NeedToEmitNotes = false;
111 }
112
113 // In case pointer/array/reference type is met get pointee type, then
114 // proceed with that type.
115 while (NextTy->isAnyPointerType() || NextTy->isArrayType() ||
116 NextTy->isReferenceType()) {
117 if (NextTy->isArrayType())
118 NextTy = QualType{NextTy->getArrayElementTypeNoTypeQual(), 0};
119 else
120 NextTy = NextTy->getPointeeType();
121 if (Check(NextTy, Next)) {
122 if (NeedToEmitNotes)
123 EmitHistory();
124 NeedToEmitNotes = false;
125 }
126 }
127
128 if (const auto *RecDecl = NextTy->getAsRecordDecl()) {
129 if (auto *NextFD = dyn_cast<FieldDecl>(Val: Next))
130 History.push_back(Elt: NextFD);
131 // When nullptr is discovered, this means we've gone back up a level, so
132 // the history should be cleaned.
133 StackForRecursion.push_back(Elt: nullptr);
134 llvm::copy(Range: RecDecl->fields(), Out: std::back_inserter(x&: StackForRecursion));
135 }
136 } while (!StackForRecursion.empty());
137}
138
139ExprResult SemaSYCL::BuildUniqueStableNameExpr(SourceLocation OpLoc,
140 SourceLocation LParen,
141 SourceLocation RParen,
142 TypeSourceInfo *TSI) {
143 return SYCLUniqueStableNameExpr::Create(Ctx: getASTContext(), OpLoc, LParen,
144 RParen, TSI);
145}
146
147ExprResult SemaSYCL::ActOnUniqueStableNameExpr(SourceLocation OpLoc,
148 SourceLocation LParen,
149 SourceLocation RParen,
150 ParsedType ParsedTy) {
151 TypeSourceInfo *TSI = nullptr;
152 QualType Ty = SemaRef.GetTypeFromParser(Ty: ParsedTy, TInfo: &TSI);
153
154 if (Ty.isNull())
155 return ExprError();
156 if (!TSI)
157 TSI = getASTContext().getTrivialTypeSourceInfo(T: Ty, Loc: LParen);
158
159 return BuildUniqueStableNameExpr(OpLoc, LParen, RParen, TSI);
160}
161
162void SemaSYCL::handleKernelAttr(Decl *D, const ParsedAttr &AL) {
163 // The 'sycl_kernel' attribute applies only to function templates.
164 const auto *FD = cast<FunctionDecl>(Val: D);
165 const FunctionTemplateDecl *FT = FD->getDescribedFunctionTemplate();
166 assert(FT && "Function template is expected");
167
168 // Function template must have at least two template parameters.
169 const TemplateParameterList *TL = FT->getTemplateParameters();
170 if (TL->size() < 2) {
171 Diag(Loc: FT->getLocation(), DiagID: diag::warn_sycl_kernel_num_of_template_params);
172 return;
173 }
174
175 // Template parameters must be typenames.
176 for (unsigned I = 0; I < 2; ++I) {
177 const NamedDecl *TParam = TL->getParam(Idx: I);
178 if (isa<NonTypeTemplateParmDecl>(Val: TParam)) {
179 Diag(Loc: FT->getLocation(),
180 DiagID: diag::warn_sycl_kernel_invalid_template_param_type);
181 return;
182 }
183 }
184
185 // Function must have at least one argument.
186 if (getFunctionOrMethodNumParams(D) != 1) {
187 Diag(Loc: FT->getLocation(), DiagID: diag::warn_sycl_kernel_num_of_function_params);
188 return;
189 }
190
191 // Function must return void.
192 QualType RetTy = getFunctionOrMethodResultType(D);
193 if (!RetTy->isVoidType()) {
194 Diag(Loc: FT->getLocation(), DiagID: diag::warn_sycl_kernel_return_type);
195 return;
196 }
197
198 handleSimpleAttribute<SYCLKernelAttr>(S&: *this, D, CI: AL);
199}
200