1//=======- RawPtrRefCallArgsChecker.cpp --------------------------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "ASTUtils.h"
10#include "DiagOutputUtils.h"
11#include "PtrTypesSemantics.h"
12#include "clang/AST/Decl.h"
13#include "clang/AST/DeclCXX.h"
14#include "clang/AST/DynamicRecursiveASTVisitor.h"
15#include "clang/Analysis/DomainSpecific/CocoaConventions.h"
16#include "clang/Basic/SourceLocation.h"
17#include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
18#include "clang/StaticAnalyzer/Core/BugReporter/BugReporter.h"
19#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
20#include "clang/StaticAnalyzer/Core/Checker.h"
21#include "llvm/Support/SaveAndRestore.h"
22#include <optional>
23
24using namespace clang;
25using namespace ento;
26
27namespace {
28
29class RawPtrRefCallArgsChecker
30 : public Checker<check::ASTDecl<TranslationUnitDecl>> {
31 BugType Bug;
32
33 TrivialFunctionAnalysis TFA;
34 EnsureFunctionAnalysis EFA;
35
36protected:
37 mutable BugReporter *BR;
38 mutable std::optional<RetainTypeChecker> RTC;
39
40public:
41 RawPtrRefCallArgsChecker(const char *description)
42 : Bug(this, description, "WebKit coding guidelines") {}
43
44 virtual std::optional<bool> isUnsafeType(QualType) const = 0;
45 virtual std::optional<bool> isUnsafePtr(QualType) const = 0;
46 virtual bool isSafePtr(const CXXRecordDecl *Record) const = 0;
47 virtual bool isSafePtrType(const QualType type) const = 0;
48 virtual bool isSafeExpr(const Expr *) const { return false; }
49 virtual bool isSafeDecl(const Decl *) const { return false; }
50 virtual const char *ptrKind() const = 0;
51
52 void checkASTDecl(const TranslationUnitDecl *TUD, AnalysisManager &MGR,
53 BugReporter &BRArg) const {
54 BR = &BRArg;
55
56 // The calls to checkAST* from AnalysisConsumer don't
57 // visit template instantiations or lambda classes. We
58 // want to visit those, so we make our own RecursiveASTVisitor.
59 struct LocalVisitor : DynamicRecursiveASTVisitor {
60 const RawPtrRefCallArgsChecker *Checker;
61 Decl *DeclWithIssue{nullptr};
62
63 explicit LocalVisitor(const RawPtrRefCallArgsChecker *Checker)
64 : Checker(Checker) {
65 assert(Checker);
66 ShouldVisitTemplateInstantiations = true;
67 ShouldVisitImplicitCode = false;
68 }
69
70 bool TraverseClassTemplateDecl(ClassTemplateDecl *Decl) override {
71 if (isSmartPtrClass(Name: safeGetName(ASTNode: Decl)))
72 return true;
73 return DynamicRecursiveASTVisitor::TraverseClassTemplateDecl(D: Decl);
74 }
75
76 bool TraverseDecl(Decl *D) override {
77 llvm::SaveAndRestore SavedDecl(DeclWithIssue);
78 if (D && (isa<FunctionDecl>(Val: D) || isa<ObjCMethodDecl>(Val: D)))
79 DeclWithIssue = D;
80 return DynamicRecursiveASTVisitor::TraverseDecl(D);
81 }
82
83 bool VisitCallExpr(CallExpr *CE) override {
84 Checker->visitCallExpr(CE, D: DeclWithIssue);
85 return true;
86 }
87
88 bool VisitTypedefDecl(TypedefDecl *TD) override {
89 if (Checker->RTC)
90 Checker->RTC->visitTypedef(TD);
91 return true;
92 }
93
94 bool VisitObjCMessageExpr(ObjCMessageExpr *ObjCMsgExpr) override {
95 Checker->visitObjCMessageExpr(E: ObjCMsgExpr, D: DeclWithIssue);
96 return true;
97 }
98 };
99
100 LocalVisitor visitor(this);
101 if (RTC)
102 RTC->visitTranslationUnitDecl(TUD);
103 visitor.TraverseDecl(D: const_cast<TranslationUnitDecl *>(TUD));
104 }
105
106 void visitCallExpr(const CallExpr *CE, const Decl *D) const {
107 if (shouldSkipCall(CE))
108 return;
109
110 if (auto *F = CE->getDirectCallee()) {
111 // Skip the first argument for overloaded member operators (e. g. lambda
112 // or std::function call operator).
113 unsigned ArgIdx =
114 isa<CXXOperatorCallExpr>(Val: CE) && isa_and_nonnull<CXXMethodDecl>(Val: F);
115
116 if (auto *MemberCallExpr = dyn_cast<CXXMemberCallExpr>(Val: CE)) {
117 if (auto *MD = MemberCallExpr->getMethodDecl()) {
118 auto name = safeGetName(ASTNode: MD);
119 if (name == "ref" || name == "deref")
120 return;
121 if (name == "incrementCheckedPtrCount" ||
122 name == "decrementCheckedPtrCount")
123 return;
124 }
125 auto *E = MemberCallExpr->getImplicitObjectArgument();
126 QualType ArgType = MemberCallExpr->getObjectType().getCanonicalType();
127 std::optional<bool> IsUnsafe = isUnsafeType(ArgType);
128 if (IsUnsafe && *IsUnsafe && !isPtrOriginSafe(Arg: E))
129 reportBugOnThis(CallArg: E, DeclWithIssue: D);
130 }
131
132 for (auto P = F->param_begin();
133 // FIXME: Also check variadic function parameters.
134 // FIXME: Also check default function arguments. Probably a different
135 // checker. In case there are default arguments the call can have
136 // fewer arguments than the callee has parameters.
137 P < F->param_end() && ArgIdx < CE->getNumArgs(); ++P, ++ArgIdx) {
138 // TODO: attributes.
139 // if ((*P)->hasAttr<SafeRefCntblRawPtrAttr>())
140 // continue;
141
142 QualType ArgType = (*P)->getType();
143 // FIXME: more complex types (arrays, references to raw pointers, etc)
144 std::optional<bool> IsUncounted = isUnsafePtr(ArgType);
145 if (!IsUncounted || !(*IsUncounted))
146 continue;
147
148 const auto *Arg = CE->getArg(Arg: ArgIdx);
149
150 if (auto *defaultArg = dyn_cast<CXXDefaultArgExpr>(Val: Arg))
151 Arg = defaultArg->getExpr();
152
153 if (isPtrOriginSafe(Arg))
154 continue;
155
156 reportBug(CallArg: Arg, Param: *P, DeclWithIssue: D);
157 }
158 for (; ArgIdx < CE->getNumArgs(); ++ArgIdx) {
159 const auto *Arg = CE->getArg(Arg: ArgIdx);
160 auto ArgType = Arg->getType();
161 std::optional<bool> IsUncounted = isUnsafePtr(ArgType);
162 if (!IsUncounted || !(*IsUncounted))
163 continue;
164
165 if (auto *defaultArg = dyn_cast<CXXDefaultArgExpr>(Val: Arg))
166 Arg = defaultArg->getExpr();
167
168 if (isPtrOriginSafe(Arg))
169 continue;
170
171 reportBug(CallArg: Arg, Param: nullptr, DeclWithIssue: D);
172 }
173 }
174 }
175
176 void visitObjCMessageExpr(const ObjCMessageExpr *E, const Decl *D) const {
177 if (BR->getSourceManager().isInSystemHeader(Loc: E->getExprLoc()))
178 return;
179
180 if (auto *Receiver = E->getInstanceReceiver()) {
181 std::optional<bool> IsUnsafe = isUnsafePtr(E->getReceiverType());
182 if (IsUnsafe && *IsUnsafe && !isPtrOriginSafe(Arg: Receiver)) {
183 if (isAllocInit(E))
184 return;
185 reportBugOnReceiver(CallArg: Receiver, DeclWithIssue: D);
186 }
187 }
188
189 auto *MethodDecl = E->getMethodDecl();
190 if (!MethodDecl)
191 return;
192
193 auto ArgCount = E->getNumArgs();
194 for (unsigned i = 0; i < ArgCount; ++i) {
195 auto *Arg = E->getArg(Arg: i);
196 bool hasParam = i < MethodDecl->param_size();
197 auto *Param = hasParam ? MethodDecl->getParamDecl(Idx: i) : nullptr;
198 auto ArgType = Arg->getType();
199 std::optional<bool> IsUnsafe = isUnsafePtr(ArgType);
200 if (!IsUnsafe || !(*IsUnsafe))
201 continue;
202 if (isPtrOriginSafe(Arg))
203 continue;
204 reportBug(CallArg: Arg, Param, DeclWithIssue: D);
205 }
206 }
207
208 bool isPtrOriginSafe(const Expr *Arg) const {
209 return tryToFindPtrOrigin(
210 E: Arg, /*StopAtFirstRefCountedObj=*/true,
211 isSafePtr: [&](const clang::CXXRecordDecl *Record) { return isSafePtr(Record); },
212 isSafePtrType: [&](const clang::QualType T) { return isSafePtrType(type: T); },
213 isSafeGlobalDecl: [&](const clang::Decl *D) { return isSafeDecl(D); },
214 callback: [&](const clang::Expr *ArgOrigin, bool IsSafe) {
215 if (IsSafe)
216 return true;
217 if (isNullPtr(E: ArgOrigin))
218 return true;
219 if (isa<IntegerLiteral>(Val: ArgOrigin)) {
220 // FIXME: Check the value.
221 // foo(123)
222 return true;
223 }
224 if (isa<CXXBoolLiteralExpr>(Val: ArgOrigin))
225 return true;
226 if (isa<ObjCStringLiteral>(Val: ArgOrigin))
227 return true;
228 if (isASafeCallArg(E: ArgOrigin))
229 return true;
230 if (EFA.isACallToEnsureFn(E: ArgOrigin))
231 return true;
232 if (isSafeExpr(ArgOrigin))
233 return true;
234 return false;
235 });
236 }
237
238 bool shouldSkipCall(const CallExpr *CE) const {
239 const auto *Callee = CE->getDirectCallee();
240
241 if (BR->getSourceManager().isInSystemHeader(Loc: CE->getExprLoc()))
242 return true;
243
244 if (Callee && TFA.isTrivial(D: Callee))
245 return true;
246
247 if (isTrivialBuiltinFunction(F: Callee))
248 return true;
249
250 if (CE->getNumArgs() == 0)
251 return false;
252
253 // If an assignment is problematic we should warn about the sole existence
254 // of object on LHS.
255 if (auto *MemberOp = dyn_cast<CXXOperatorCallExpr>(Val: CE)) {
256 // Note: assignemnt to built-in type isn't derived from CallExpr.
257 if (MemberOp->getOperator() ==
258 OO_Equal) { // Ignore assignment to Ref/RefPtr.
259 auto *callee = MemberOp->getDirectCallee();
260 if (auto *calleeDecl = dyn_cast<CXXMethodDecl>(Val: callee)) {
261 if (const CXXRecordDecl *classDecl = calleeDecl->getParent()) {
262 if (isSafePtr(Record: classDecl))
263 return true;
264 }
265 }
266 }
267 if (MemberOp->isAssignmentOp())
268 return false;
269 }
270
271 if (!Callee)
272 return false;
273
274 if (isMethodOnWTFContainerType(Decl: Callee))
275 return true;
276
277 auto overloadedOperatorType = Callee->getOverloadedOperator();
278 if (overloadedOperatorType == OO_EqualEqual ||
279 overloadedOperatorType == OO_ExclaimEqual ||
280 overloadedOperatorType == OO_LessEqual ||
281 overloadedOperatorType == OO_GreaterEqual ||
282 overloadedOperatorType == OO_Spaceship ||
283 overloadedOperatorType == OO_AmpAmp ||
284 overloadedOperatorType == OO_PipePipe)
285 return true;
286
287 if (isCtorOfSafePtr(F: Callee) || isPtrConversion(F: Callee))
288 return true;
289
290 auto name = safeGetName(ASTNode: Callee);
291 if (name == "adoptRef" || name == "getPtr" || name == "WeakPtr" ||
292 name == "is" || name == "equal" || name == "hash" || name == "isType" ||
293 // FIXME: Most/all of these should be implemented via attributes.
294 name == "CFEqual" || name == "equalIgnoringASCIICase" ||
295 name == "equalIgnoringASCIICaseCommon" ||
296 name == "equalIgnoringNullity" || name == "toString")
297 return true;
298
299 return false;
300 }
301
302 bool isMethodOnWTFContainerType(const FunctionDecl *Decl) const {
303 if (!isa<CXXMethodDecl>(Val: Decl))
304 return false;
305 auto *ClassDecl = Decl->getParent();
306 if (!ClassDecl || !isa<CXXRecordDecl>(Val: ClassDecl))
307 return false;
308
309 auto *NsDecl = ClassDecl->getParent();
310 if (!NsDecl || !isa<NamespaceDecl>(Val: NsDecl))
311 return false;
312
313 auto MethodName = safeGetName(ASTNode: Decl);
314 auto ClsNameStr = safeGetName(ASTNode: ClassDecl);
315 StringRef ClsName = ClsNameStr; // FIXME: Make safeGetName return StringRef.
316 auto NamespaceName = safeGetName(ASTNode: NsDecl);
317 // FIXME: These should be implemented via attributes.
318 return NamespaceName == "WTF" &&
319 (MethodName == "find" || MethodName == "findIf" ||
320 MethodName == "reverseFind" || MethodName == "reverseFindIf" ||
321 MethodName == "findIgnoringASCIICase" || MethodName == "get" ||
322 MethodName == "inlineGet" || MethodName == "contains" ||
323 MethodName == "containsIf" ||
324 MethodName == "containsIgnoringASCIICase" ||
325 MethodName == "startsWith" || MethodName == "endsWith" ||
326 MethodName == "startsWithIgnoringASCIICase" ||
327 MethodName == "endsWithIgnoringASCIICase" ||
328 MethodName == "substring") &&
329 (ClsName.ends_with(Suffix: "Vector") || ClsName.ends_with(Suffix: "Set") ||
330 ClsName.ends_with(Suffix: "Map") || ClsName == "StringImpl" ||
331 ClsName.ends_with(Suffix: "String"));
332 }
333
334 void reportBug(const Expr *CallArg, const ParmVarDecl *Param,
335 const Decl *DeclWithIssue) const {
336 assert(CallArg);
337
338 SmallString<100> Buf;
339 llvm::raw_svector_ostream Os(Buf);
340
341 const std::string paramName = safeGetName(ASTNode: Param);
342 Os << "Call argument";
343 if (!paramName.empty()) {
344 Os << " for parameter ";
345 printQuotedQualifiedName(Os, D: Param);
346 }
347 Os << " is " << ptrKind() << " and unsafe.";
348
349 bool usesDefaultArgValue = isa<CXXDefaultArgExpr>(Val: CallArg) && Param;
350 const SourceLocation SrcLocToReport =
351 usesDefaultArgValue ? Param->getDefaultArg()->getExprLoc()
352 : CallArg->getSourceRange().getBegin();
353
354 PathDiagnosticLocation BSLoc(SrcLocToReport, BR->getSourceManager());
355 auto Report = std::make_unique<BasicBugReport>(args: Bug, args: Os.str(), args&: BSLoc);
356 Report->addRange(R: CallArg->getSourceRange());
357 Report->setDeclWithIssue(DeclWithIssue);
358 BR->emitReport(R: std::move(Report));
359 }
360
361 void reportBugOnThis(const Expr *CallArg, const Decl *DeclWithIssue) const {
362 assert(CallArg);
363
364 const SourceLocation SrcLocToReport = CallArg->getSourceRange().getBegin();
365
366 SmallString<100> Buf;
367 llvm::raw_svector_ostream Os(Buf);
368 Os << "Call argument for 'this' parameter is " << ptrKind();
369 Os << " and unsafe.";
370
371 PathDiagnosticLocation BSLoc(SrcLocToReport, BR->getSourceManager());
372 auto Report = std::make_unique<BasicBugReport>(args: Bug, args: Os.str(), args&: BSLoc);
373 Report->addRange(R: CallArg->getSourceRange());
374 Report->setDeclWithIssue(DeclWithIssue);
375 BR->emitReport(R: std::move(Report));
376 }
377
378 void reportBugOnReceiver(const Expr *CallArg,
379 const Decl *DeclWithIssue) const {
380 assert(CallArg);
381
382 const SourceLocation SrcLocToReport = CallArg->getSourceRange().getBegin();
383
384 SmallString<100> Buf;
385 llvm::raw_svector_ostream Os(Buf);
386 Os << "Receiver is " << ptrKind() << " and unsafe.";
387
388 PathDiagnosticLocation BSLoc(SrcLocToReport, BR->getSourceManager());
389 auto Report = std::make_unique<BasicBugReport>(args: Bug, args: Os.str(), args&: BSLoc);
390 Report->addRange(R: CallArg->getSourceRange());
391 Report->setDeclWithIssue(DeclWithIssue);
392 BR->emitReport(R: std::move(Report));
393 }
394};
395
396class UncountedCallArgsChecker final : public RawPtrRefCallArgsChecker {
397public:
398 UncountedCallArgsChecker()
399 : RawPtrRefCallArgsChecker("Uncounted call argument for a raw "
400 "pointer/reference parameter") {}
401
402 std::optional<bool> isUnsafeType(QualType QT) const final {
403 return isUncounted(T: QT);
404 }
405
406 std::optional<bool> isUnsafePtr(QualType QT) const final {
407 return isUncountedPtr(T: QT.getCanonicalType());
408 }
409
410 bool isSafePtr(const CXXRecordDecl *Record) const final {
411 return isRefCounted(Class: Record) || isCheckedPtr(Class: Record);
412 }
413
414 bool isSafePtrType(const QualType type) const final {
415 return isRefOrCheckedPtrType(T: type);
416 }
417
418 const char *ptrKind() const final { return "uncounted"; }
419};
420
421class UncheckedCallArgsChecker final : public RawPtrRefCallArgsChecker {
422public:
423 UncheckedCallArgsChecker()
424 : RawPtrRefCallArgsChecker("Unchecked call argument for a raw "
425 "pointer/reference parameter") {}
426
427 std::optional<bool> isUnsafeType(QualType QT) const final {
428 return isUnchecked(T: QT);
429 }
430
431 std::optional<bool> isUnsafePtr(QualType QT) const final {
432 return isUncheckedPtr(T: QT.getCanonicalType());
433 }
434
435 bool isSafePtr(const CXXRecordDecl *Record) const final {
436 return isRefCounted(Class: Record) || isCheckedPtr(Class: Record);
437 }
438
439 bool isSafePtrType(const QualType type) const final {
440 return isRefOrCheckedPtrType(T: type);
441 }
442
443 bool isSafeExpr(const Expr *E) const final {
444 return isExprToGetCheckedPtrCapableMember(E);
445 }
446
447 const char *ptrKind() const final { return "unchecked"; }
448};
449
450class UnretainedCallArgsChecker final : public RawPtrRefCallArgsChecker {
451public:
452 UnretainedCallArgsChecker()
453 : RawPtrRefCallArgsChecker("Unretained call argument for a raw "
454 "pointer/reference parameter") {
455 RTC = RetainTypeChecker();
456 }
457
458 std::optional<bool> isUnsafeType(QualType QT) const final {
459 return RTC->isUnretained(QT);
460 }
461
462 std::optional<bool> isUnsafePtr(QualType QT) const final {
463 return RTC->isUnretained(QT);
464 }
465
466 bool isSafePtr(const CXXRecordDecl *Record) const final {
467 return isRetainPtrOrOSPtr(Class: Record);
468 }
469
470 bool isSafePtrType(const QualType type) const final {
471 return isRetainPtrOrOSPtrType(T: type);
472 }
473
474 bool isSafeExpr(const Expr *E) const final {
475 return ento::cocoa::isCocoaObjectRef(T: E->getType()) &&
476 isa<ObjCMessageExpr>(Val: E);
477 }
478
479 bool isSafeDecl(const Decl *D) const final {
480 // Treat NS/CF globals in system header as immortal.
481 return BR->getSourceManager().isInSystemHeader(Loc: D->getLocation());
482 }
483
484 const char *ptrKind() const final { return "unretained"; }
485};
486
487} // namespace
488
489void ento::registerUncountedCallArgsChecker(CheckerManager &Mgr) {
490 Mgr.registerChecker<UncountedCallArgsChecker>();
491}
492
493bool ento::shouldRegisterUncountedCallArgsChecker(const CheckerManager &) {
494 return true;
495}
496
497void ento::registerUncheckedCallArgsChecker(CheckerManager &Mgr) {
498 Mgr.registerChecker<UncheckedCallArgsChecker>();
499}
500
501bool ento::shouldRegisterUncheckedCallArgsChecker(const CheckerManager &) {
502 return true;
503}
504
505void ento::registerUnretainedCallArgsChecker(CheckerManager &Mgr) {
506 Mgr.registerChecker<UnretainedCallArgsChecker>();
507}
508
509bool ento::shouldRegisterUnretainedCallArgsChecker(const CheckerManager &) {
510 return true;
511}
512