1//===- StdVariantChecker.cpp -------------------------------------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "clang/AST/Type.h"
10#include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
11#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
12#include "clang/StaticAnalyzer/Core/Checker.h"
13#include "clang/StaticAnalyzer/Core/CheckerManager.h"
14#include "clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h"
15#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
16#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
17#include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h"
18#include "llvm/ADT/FoldingSet.h"
19#include "llvm/ADT/StringRef.h"
20#include "llvm/Support/Casting.h"
21#include <optional>
22#include <string_view>
23
24#include "TaggedUnionModeling.h"
25
26using namespace clang;
27using namespace ento;
28using namespace tagged_union_modeling;
29
30REGISTER_MAP_WITH_PROGRAMSTATE(VariantHeldTypeMap, const MemRegion *, QualType)
31
32namespace clang::ento::tagged_union_modeling {
33
34const CXXConstructorDecl *
35getConstructorDeclarationForCall(const CallEvent &Call) {
36 const auto *ConstructorCall = dyn_cast<CXXConstructorCall>(Val: &Call);
37 if (!ConstructorCall)
38 return nullptr;
39
40 return ConstructorCall->getDecl();
41}
42
43bool isCopyConstructorCall(const CallEvent &Call) {
44 if (const CXXConstructorDecl *ConstructorDecl =
45 getConstructorDeclarationForCall(Call))
46 return ConstructorDecl->isCopyConstructor();
47 return false;
48}
49
50bool isCopyAssignmentCall(const CallEvent &Call) {
51 const Decl *CopyAssignmentDecl = Call.getDecl();
52
53 if (const auto *AsMethodDecl =
54 dyn_cast_or_null<CXXMethodDecl>(Val: CopyAssignmentDecl))
55 return AsMethodDecl->isCopyAssignmentOperator();
56 return false;
57}
58
59bool isMoveConstructorCall(const CallEvent &Call) {
60 const CXXConstructorDecl *ConstructorDecl =
61 getConstructorDeclarationForCall(Call);
62 if (!ConstructorDecl)
63 return false;
64
65 return ConstructorDecl->isMoveConstructor();
66}
67
68bool isMoveAssignmentCall(const CallEvent &Call) {
69 const Decl *CopyAssignmentDecl = Call.getDecl();
70
71 const auto *AsMethodDecl =
72 dyn_cast_or_null<CXXMethodDecl>(Val: CopyAssignmentDecl);
73 if (!AsMethodDecl)
74 return false;
75
76 return AsMethodDecl->isMoveAssignmentOperator();
77}
78
79bool isStdType(const Type *Type, llvm::StringRef TypeName) {
80 auto *Decl = Type->getAsRecordDecl();
81 if (!Decl)
82 return false;
83 return (Decl->getName() == TypeName) && Decl->isInStdNamespace();
84}
85
86bool isStdVariant(const Type *Type) {
87 return isStdType(Type, TypeName: llvm::StringLiteral("variant"));
88}
89
90} // end of namespace clang::ento::tagged_union_modeling
91
92static std::optional<ArrayRef<TemplateArgument>>
93getTemplateArgsFromVariant(const Type *VariantType) {
94 const auto *TempSpecType = VariantType->getAs<TemplateSpecializationType>();
95 if (!TempSpecType)
96 return {};
97
98 return TempSpecType->template_arguments();
99}
100
101static std::optional<QualType>
102getNthTemplateTypeArgFromVariant(const Type *varType, unsigned i) {
103 std::optional<ArrayRef<TemplateArgument>> VariantTemplates =
104 getTemplateArgsFromVariant(VariantType: varType);
105 if (!VariantTemplates)
106 return {};
107
108 return (*VariantTemplates)[i].getAsType();
109}
110
111static bool isVowel(char a) {
112 switch (a) {
113 case 'a':
114 case 'e':
115 case 'i':
116 case 'o':
117 case 'u':
118 return true;
119 default:
120 return false;
121 }
122}
123
124static llvm::StringRef indefiniteArticleBasedOnVowel(char a) {
125 if (isVowel(a))
126 return "an";
127 return "a";
128}
129
130class StdVariantChecker : public Checker<eval::Call, check::RegionChanges> {
131 // Call descriptors to find relevant calls
132 CallDescription VariantConstructor{CDM::CXXMethod,
133 {"std", "variant", "variant"}};
134 CallDescription VariantAssignmentOperator{CDM::CXXMethod,
135 {"std", "variant", "operator="}};
136 CallDescription StdGet{CDM::SimpleFunc, {"std", "get"}, 1, 1};
137
138 BugType BadVariantType{this, "BadVariantType", "BadVariantType"};
139
140public:
141 ProgramStateRef checkRegionChanges(ProgramStateRef State,
142 const InvalidatedSymbols *,
143 ArrayRef<const MemRegion *>,
144 ArrayRef<const MemRegion *> Regions,
145 const LocationContext *,
146 const CallEvent *Call) const {
147 if (!Call)
148 return State;
149
150 return removeInformationStoredForDeadInstances<VariantHeldTypeMap>(
151 Call: *Call, State, Regions);
152 }
153
154 bool evalCall(const CallEvent &Call, CheckerContext &C) const {
155 // Check if the call was not made from a system header. If it was then
156 // we do an early return because it is part of the implementation.
157 if (Call.isCalledFromSystemHeader())
158 return false;
159
160 if (StdGet.matches(Call))
161 return handleStdGetCall(Call, C);
162
163 // First check if a constructor call is happening. If it is a
164 // constructor call, check if it is an std::variant constructor call.
165 bool IsVariantConstructor =
166 isa<CXXConstructorCall>(Val: Call) && VariantConstructor.matches(Call);
167 bool IsVariantAssignmentOperatorCall =
168 isa<CXXMemberOperatorCall>(Val: Call) &&
169 VariantAssignmentOperator.matches(Call);
170
171 if (IsVariantConstructor || IsVariantAssignmentOperatorCall) {
172 if (Call.getNumArgs() == 0 && IsVariantConstructor) {
173 handleDefaultConstructor(ConstructorCall: cast<CXXConstructorCall>(Val: &Call), C);
174 return true;
175 }
176
177 // FIXME Later this checker should be extended to handle constructors
178 // with multiple arguments.
179 if (Call.getNumArgs() != 1)
180 return false;
181
182 SVal ThisSVal;
183 if (IsVariantConstructor) {
184 const auto &AsConstructorCall = cast<CXXConstructorCall>(Val: Call);
185 ThisSVal = AsConstructorCall.getCXXThisVal();
186 } else if (IsVariantAssignmentOperatorCall) {
187 const auto &AsMemberOpCall = cast<CXXMemberOperatorCall>(Val: Call);
188 ThisSVal = AsMemberOpCall.getCXXThisVal();
189 } else {
190 return false;
191 }
192
193 handleConstructorAndAssignment<VariantHeldTypeMap>(Call, C, ThisSVal);
194 return true;
195 }
196 return false;
197 }
198
199private:
200 // The default constructed std::variant must be handled separately
201 // by default the std::variant is going to hold a default constructed instance
202 // of the first type of the possible types
203 void handleDefaultConstructor(const CXXConstructorCall *ConstructorCall,
204 CheckerContext &C) const {
205 SVal ThisSVal = ConstructorCall->getCXXThisVal();
206
207 const auto *const ThisMemRegion = ThisSVal.getAsRegion();
208 if (!ThisMemRegion)
209 return;
210
211 std::optional<QualType> DefaultType = getNthTemplateTypeArgFromVariant(
212 varType: ThisSVal.getType(C.getASTContext())->getPointeeType().getTypePtr(), i: 0);
213 if (!DefaultType)
214 return;
215
216 ProgramStateRef State = ConstructorCall->getState();
217 State = State->set<VariantHeldTypeMap>(K: ThisMemRegion, E: *DefaultType);
218 C.addTransition(State);
219 }
220
221 bool handleStdGetCall(const CallEvent &Call, CheckerContext &C) const {
222 ProgramStateRef State = Call.getState();
223
224 const auto &ArgType = Call.getArgSVal(Index: 0)
225 .getType(C.getASTContext())
226 ->getPointeeType()
227 .getTypePtr();
228 // We have to make sure that the argument is an std::variant.
229 // There is another std::get with std::pair argument
230 if (!isStdVariant(Type: ArgType))
231 return false;
232
233 // Get the mem region of the argument std::variant and look up the type
234 // information that we know about it.
235 const MemRegion *ArgMemRegion = Call.getArgSVal(Index: 0).getAsRegion();
236 const QualType *StoredType = State->get<VariantHeldTypeMap>(key: ArgMemRegion);
237 if (!StoredType)
238 return false;
239
240 const CallExpr *CE = cast<CallExpr>(Val: Call.getOriginExpr());
241 const FunctionDecl *FD = CE->getDirectCallee();
242 if (FD->getTemplateSpecializationArgs()->size() < 1)
243 return false;
244
245 const auto &TypeOut = FD->getTemplateSpecializationArgs()->asArray()[0];
246 // std::get's first template parameter can be the type we want to get
247 // out of the std::variant or a natural number which is the position of
248 // the requested type in the argument type list of the std::variant's
249 // argument.
250 QualType RetrievedType;
251 switch (TypeOut.getKind()) {
252 case TemplateArgument::ArgKind::Type:
253 RetrievedType = TypeOut.getAsType();
254 break;
255 case TemplateArgument::ArgKind::Integral:
256 // In the natural number case we look up which type corresponds to the
257 // number.
258 if (std::optional<QualType> NthTemplate =
259 getNthTemplateTypeArgFromVariant(
260 varType: ArgType, i: TypeOut.getAsIntegral().getSExtValue())) {
261 RetrievedType = *NthTemplate;
262 break;
263 }
264 [[fallthrough]];
265 default:
266 return false;
267 }
268
269 QualType RetrievedCanonicalType = RetrievedType.getCanonicalType();
270 QualType StoredCanonicalType = StoredType->getCanonicalType();
271 if (RetrievedCanonicalType == StoredCanonicalType)
272 return true;
273
274 ExplodedNode *ErrNode = C.generateNonFatalErrorNode();
275 if (!ErrNode)
276 return false;
277 llvm::SmallString<128> Str;
278 llvm::raw_svector_ostream OS(Str);
279 std::string StoredTypeName = StoredType->getAsString();
280 std::string RetrievedTypeName = RetrievedType.getAsString();
281 OS << "std::variant " << ArgMemRegion->getDescriptiveName() << " held "
282 << indefiniteArticleBasedOnVowel(a: StoredTypeName[0]) << " \'"
283 << StoredTypeName << "\', not "
284 << indefiniteArticleBasedOnVowel(a: RetrievedTypeName[0]) << " \'"
285 << RetrievedTypeName << "\'";
286 auto R = std::make_unique<PathSensitiveBugReport>(args: BadVariantType, args: OS.str(),
287 args&: ErrNode);
288 C.emitReport(R: std::move(R));
289 return true;
290 }
291};
292
293bool clang::ento::shouldRegisterStdVariantChecker(
294 clang::ento::CheckerManager const &mgr) {
295 return true;
296}
297
298void clang::ento::registerStdVariantChecker(clang::ento::CheckerManager &mgr) {
299 mgr.registerChecker<StdVariantChecker>();
300}
301