1//===- UnsafeBufferUsageExtractor.cpp -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "clang/ScalableStaticAnalysisFramework/Analyses/UnsafeBufferUsage/UnsafeBufferUsageExtractor.h"
10#include "clang/AST/ASTConsumer.h"
11#include "clang/AST/ASTContext.h"
12#include "clang/AST/Decl.h"
13#include "clang/AST/DynamicRecursiveASTVisitor.h"
14#include "clang/AST/StmtVisitor.h"
15#include "clang/Analysis/Analyses/UnsafeBufferUsage.h"
16#include "clang/ScalableStaticAnalysisFramework/Analyses/UnsafeBufferUsage/UnsafeBufferUsage.h"
17#include "clang/ScalableStaticAnalysisFramework/Core/ASTEntityMapping.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/iterator_range.h"
20#include "llvm/Support/Error.h"
21#include <memory>
22
23namespace {
24using namespace clang;
25using namespace ssaf;
26
27static bool hasPointerType(const Expr *E) {
28 auto Ty = E->getType();
29 return !Ty.isNull() && !Ty->isFunctionPointerType() &&
30 (Ty->isPointerType() || Ty->isArrayType());
31}
32
33constexpr inline auto buildEntityPointerLevel =
34 UnsafeBufferUsageTUSummaryExtractor::buildEntityPointerLevel;
35
36static llvm::Error makeUnsupportedStmtKindError(const Stmt *Unsupported) {
37 return llvm::createStringError(
38 Fmt: "unsupported expression kind for translation to "
39 "EntityPointerLevel: %s",
40 Vals: Unsupported->getStmtClassName());
41}
42
43static llvm::Error makeCreateEntityNameError(const NamedDecl *FailedDecl,
44 ASTContext &Ctx) {
45 std::string LocStr = FailedDecl->getSourceRange().getBegin().printToString(
46 SM: Ctx.getSourceManager());
47 return llvm::createStringError(
48 Fmt: "failed to create entity name for %s declared at %s",
49 Vals: FailedDecl->getNameAsString().c_str(), Vals: LocStr.c_str());
50}
51
52static llvm::Error makeAddEntitySummaryError(const NamedDecl *FailedContributor,
53 ASTContext &Ctx) {
54 std::string LocStr =
55 FailedContributor->getSourceRange().getBegin().printToString(
56 SM: Ctx.getSourceManager());
57 return llvm::createStringError(
58 Fmt: "failed to add entity summary for contributor %s declared at %s",
59 Vals: FailedContributor->getNameAsString().c_str(), Vals: LocStr.c_str());
60}
61
62// Translate a pointer type expression 'E' to a (set of) EntityPointerLevel(s)
63// associated with the declared type of the base address of `E`. If the base
64// address of `E` is not associated with an entity, the translation result is an
65// empty set.
66//
67// The translation is a process of traversing into the pointer 'E' until its
68// base address can be represented by an entity, with the number of dereferences
69// tracked by incrementing the pointer level. Naturally, taking address of, as
70// the inverse operation of dereference, is tracked by decrementing the pointer
71// level.
72//
73// For example, suppose there are pointers and arrays declared as
74// int *ptr, **p1, **p2;
75// int arr[10][10];
76// , the translation of expressions involving these base addresses will be:
77// Translate(ptr + 5) -> {(ptr, 1)}
78// Translate(arr[5]) -> {(arr, 2)}
79// Translate(cond ? p1[5] : p2) -> {(p1, 2), (p2, 1)}
80// Translate(&arr[5]) -> {(arr, 1)}
81class EntityPointerLevelTranslator
82 : ConstStmtVisitor<EntityPointerLevelTranslator,
83 Expected<EntityPointerLevelSet>> {
84 friend class StmtVisitorBase;
85
86 // Fallback method for all unsupported expression kind:
87 llvm::Error fallback(const Stmt *E) {
88 return makeUnsupportedStmtKindError(Unsupported: E);
89 }
90
91 static EntityPointerLevel incrementPointerLevel(const EntityPointerLevel &E) {
92 return buildEntityPointerLevel(E.getEntity(), E.getPointerLevel() + 1);
93 }
94
95 static EntityPointerLevel decrementPointerLevel(const EntityPointerLevel &E) {
96 assert(E.getPointerLevel() > 0);
97 return buildEntityPointerLevel(E.getEntity(), E.getPointerLevel() - 1);
98 }
99
100 EntityPointerLevel createEntityPointerLevelFor(const EntityName &Name) {
101 return buildEntityPointerLevel(Extractor.addEntity(EN: Name), 1);
102 }
103
104 // The common helper function for Translate(*base):
105 // Translate(*base) -> Translate(base) with .pointerLevel + 1
106 Expected<EntityPointerLevelSet> translateDereferencePointer(const Expr *Ptr) {
107 assert(hasPointerType(Ptr));
108
109 Expected<EntityPointerLevelSet> SubResult = Visit(S: Ptr);
110 if (!SubResult)
111 return SubResult.takeError();
112
113 auto Incremented = llvm::map_range(C&: *SubResult, F: incrementPointerLevel);
114 return EntityPointerLevelSet{Incremented.begin(), Incremented.end()};
115 }
116
117 UnsafeBufferUsageTUSummaryExtractor &Extractor;
118 ASTContext &Ctx;
119
120public:
121 EntityPointerLevelTranslator(UnsafeBufferUsageTUSummaryExtractor &Extractor,
122 ASTContext &Ctx)
123 : Extractor(Extractor), Ctx(Ctx) {}
124
125 Expected<EntityPointerLevelSet> translate(const Expr *E) { return Visit(S: E); }
126
127private:
128 Expected<EntityPointerLevelSet> VisitStmt(const Stmt *E) {
129 return fallback(E);
130 }
131
132 // Translate(base + x) -> Translate(base)
133 // Translate(x + base) -> Translate(base)
134 // Translate(base - x) -> Translate(base)
135 // Translate(base {+=, -=, =} x) -> Translate(base)
136 // Translate(x, base) -> Translate(base)
137 Expected<EntityPointerLevelSet> VisitBinaryOperator(const BinaryOperator *E) {
138 switch (E->getOpcode()) {
139 case clang::BO_Add:
140 if (hasPointerType(E: E->getLHS()))
141 return Visit(S: E->getLHS());
142 return Visit(S: E->getRHS());
143 case clang::BO_Sub:
144 case clang::BO_AddAssign:
145 case clang::BO_SubAssign:
146 case clang::BO_Assign:
147 return Visit(S: E->getLHS());
148 case clang::BO_Comma:
149 return Visit(S: E->getRHS());
150 default:
151 return fallback(E);
152 }
153 }
154
155 // Translate({++, --}base) -> Translate(base)
156 // Translate(base{++, --}) -> Translate(base)
157 // Translate(*base) -> Translate(base) with .pointerLevel += 1
158 // Translate(&base) -> {}, if Translate(base) is {}
159 // -> Translate(base) with .pointerLevel -= 1
160 Expected<EntityPointerLevelSet> VisitUnaryOperator(const UnaryOperator *E) {
161 switch (E->getOpcode()) {
162 case clang::UO_PostInc:
163 case clang::UO_PostDec:
164 case clang::UO_PreInc:
165 case clang::UO_PreDec:
166 return Visit(S: E->getSubExpr());
167 case clang::UO_AddrOf: {
168 Expected<EntityPointerLevelSet> SubResult = Visit(S: E->getSubExpr());
169 if (!SubResult)
170 return SubResult.takeError();
171
172 auto Decremented = llvm::map_range(C&: *SubResult, F: decrementPointerLevel);
173 return EntityPointerLevelSet{Decremented.begin(), Decremented.end()};
174 }
175 case clang::UO_Deref:
176 return translateDereferencePointer(Ptr: E->getSubExpr());
177 default:
178 return fallback(E);
179 }
180 }
181
182 // Translate((T*)base) -> Translate(p) if p has pointer type
183 // -> {} otherwise
184 Expected<EntityPointerLevelSet> VisitCastExpr(const CastExpr *E) {
185 if (hasPointerType(E: E->getSubExpr()))
186 return Visit(S: E->getSubExpr());
187 return EntityPointerLevelSet{};
188 }
189
190 // Translate(f(...)) -> {} if it is an indirect call
191 // -> {(f_return, 1)}, otherwise
192 Expected<EntityPointerLevelSet> VisitCallExpr(const CallExpr *E) {
193 if (auto *FD = E->getDirectCallee())
194 if (auto FDEntityName = getEntityNameForReturn(FD))
195 return EntityPointerLevelSet{
196 createEntityPointerLevelFor(Name: *FDEntityName)};
197 return EntityPointerLevelSet{};
198 }
199
200 // Translate(base[x]) -> Translate(*base)
201 Expected<EntityPointerLevelSet>
202 VisitArraySubscriptExpr(const ArraySubscriptExpr *E) {
203 return translateDereferencePointer(Ptr: E->getBase());
204 }
205
206 // Translate(cond ? base1 : base2) := Translate(base1) U Translate(base2)
207 Expected<EntityPointerLevelSet>
208 VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
209 Expected<EntityPointerLevelSet> ReT = Visit(S: E->getTrueExpr());
210 Expected<EntityPointerLevelSet> ReF = Visit(S: E->getFalseExpr());
211
212 if (ReT && ReF) {
213 ReT->insert(first: ReF->begin(), last: ReF->end());
214 return ReT;
215 }
216 if (!ReF && !ReT)
217 return llvm::joinErrors(E1: ReT.takeError(), E2: ReF.takeError());
218 if (!ReF)
219 return ReF.takeError();
220 return ReT.takeError();
221 }
222
223 Expected<EntityPointerLevelSet> VisitParenExpr(const ParenExpr *E) {
224 return Visit(S: E->getSubExpr());
225 }
226
227 // Translate("string-literal") -> {}
228 // Buffer accesses on string literals are unsafe, but string literals are not
229 // entities so there is no EntityPointerLevel associated with it.
230 Expected<EntityPointerLevelSet> VisitStringLiteral(const StringLiteral *E) {
231 return EntityPointerLevelSet{};
232 }
233
234 // Translate(DRE) -> {(Decl, 1)}
235 Expected<EntityPointerLevelSet> VisitDeclRefExpr(const DeclRefExpr *E) {
236 if (auto EntityName = getEntityName(D: E->getDecl()))
237 return EntityPointerLevelSet{createEntityPointerLevelFor(Name: *EntityName)};
238 return makeCreateEntityNameError(FailedDecl: E->getDecl(), Ctx);
239 }
240
241 // Translate({., ->}f) -> {(MemberDecl, 1)}
242 Expected<EntityPointerLevelSet> VisitMemberExpr(const MemberExpr *E) {
243 if (auto EntityName = getEntityName(D: E->getMemberDecl()))
244 return EntityPointerLevelSet{createEntityPointerLevelFor(Name: *EntityName)};
245 return makeCreateEntityNameError(FailedDecl: E->getMemberDecl(), Ctx);
246 }
247
248 Expected<EntityPointerLevelSet>
249 VisitOpaqueValueExpr(const OpaqueValueExpr *S) {
250 return Visit(S: S->getSourceExpr());
251 }
252};
253
254Expected<EntityPointerLevelSet>
255buildEntityPointerLevels(std::set<const Expr *> &&UnsafePointers,
256 UnsafeBufferUsageTUSummaryExtractor &Extractor,
257 ASTContext &Ctx) {
258 EntityPointerLevelSet Result{};
259 EntityPointerLevelTranslator Translator{Extractor, Ctx};
260 llvm::Error AllErrors = llvm::ErrorSuccess();
261
262 for (const Expr *Ptr : UnsafePointers) {
263 Expected<EntityPointerLevelSet> Translation = Translator.translate(E: Ptr);
264
265 if (Translation) {
266 // Filter out those temporary invalid EntityPointerLevels associated with
267 // `&E` pointers:
268 auto FilteredTranslation = llvm::make_filter_range(
269 Range&: *Translation, Pred: [](const EntityPointerLevel &E) -> bool {
270 return E.getPointerLevel() > 0;
271 });
272 Result.insert(first: FilteredTranslation.begin(), last: FilteredTranslation.end());
273 continue;
274 }
275 AllErrors = llvm::joinErrors(E1: std::move(AllErrors), E2: Translation.takeError());
276 }
277 if (AllErrors)
278 return AllErrors;
279 return Result;
280}
281} // namespace
282
283static std::set<const Expr *> findUnsafePointersInContributor(const Decl *D) {
284 if (isa<FunctionDecl>(Val: D) || isa<VarDecl>(Val: D))
285 return findUnsafePointers(D);
286 if (auto *RD = dyn_cast<RecordDecl>(Val: D)) {
287 std::set<const Expr *> Result;
288
289 for (const FieldDecl *FD : RD->fields()) {
290 Result.merge(source: findUnsafePointers(D: FD));
291 }
292 return Result;
293 }
294 return {};
295}
296
297std::unique_ptr<UnsafeBufferUsageEntitySummary>
298UnsafeBufferUsageTUSummaryExtractor::extractEntitySummary(
299 const Decl *Contributor, ASTContext &Ctx, llvm::Error &Error) {
300 Expected<EntityPointerLevelSet> EPLs = buildEntityPointerLevels(
301 UnsafePointers: findUnsafePointersInContributor(D: Contributor), Extractor&: *this, Ctx);
302
303 if (EPLs)
304 return std::make_unique<UnsafeBufferUsageEntitySummary>(
305 args: UnsafeBufferUsageEntitySummary(std::move(*EPLs)));
306 Error = EPLs.takeError();
307 return nullptr;
308}
309
310void UnsafeBufferUsageTUSummaryExtractor::HandleTranslationUnit(
311 ASTContext &Ctx) {
312
313 // FIXME: I suppose finding contributor Decls is commonly needed by all/many
314 // extractors
315 class ContributorFinder : public DynamicRecursiveASTVisitor {
316 public:
317 std::vector<const NamedDecl *> Contributors;
318
319 bool VisitFunctionDecl(FunctionDecl *D) override {
320 Contributors.push_back(x: D);
321 return true;
322 }
323
324 bool VisitRecordDecl(RecordDecl *D) override {
325 Contributors.push_back(x: D);
326 return true;
327 }
328
329 bool VisitVarDecl(VarDecl *D) override {
330 DeclContext *DC = D->getDeclContext();
331
332 if (DC->isFileContext() || DC->isNamespace())
333 Contributors.push_back(x: D);
334 return false;
335 }
336 } ContributorFinder;
337
338 ContributorFinder.VisitTranslationUnitDecl(D: Ctx.getTranslationUnitDecl());
339
340 llvm::Error Errors = llvm::ErrorSuccess();
341 auto addError = [&Errors](llvm::Error Err) {
342 Errors = llvm::joinErrors(E1: std::move(Errors), E2: std::move(Err));
343 };
344
345 for (auto *CD : ContributorFinder.Contributors) {
346 llvm::Error Error = llvm::ErrorSuccess();
347 auto EntitySummary = extractEntitySummary(Contributor: CD, Ctx, Error);
348
349 if (Error)
350 addError(std::move(Error));
351 if (EntitySummary->empty())
352 continue;
353
354 auto ContributorName = getEntityName(D: CD);
355
356 if (!ContributorName) {
357 addError(makeCreateEntityNameError(FailedDecl: CD, Ctx));
358 continue;
359 }
360
361 auto [EntitySummaryPtr, Success] = SummaryBuilder.addSummary(
362 Entity: addEntity(EN: *ContributorName), Data: std::move(EntitySummary));
363
364 if (!Success)
365 addError(makeAddEntitySummaryError(FailedContributor: CD, Ctx));
366 }
367 // FIXME: handle errors!
368 llvm::consumeError(Err: std::move(Errors));
369}
370