1//===--- InferAlloc.cpp - Allocation type inference -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements allocation-related type inference.
10//
11//===----------------------------------------------------------------------===//
12
13#include "clang/AST/InferAlloc.h"
14#include "clang/AST/ASTContext.h"
15#include "clang/AST/Decl.h"
16#include "clang/AST/DeclCXX.h"
17#include "clang/AST/Expr.h"
18#include "clang/AST/Type.h"
19#include "clang/Basic/IdentifierTable.h"
20#include "llvm/ADT/SmallPtrSet.h"
21
22using namespace clang;
23using namespace infer_alloc;
24
25static bool
26typeContainsPointer(QualType T,
27 llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
28 bool &IncompleteType) {
29 QualType CanonicalType = T.getCanonicalType();
30 if (CanonicalType->isPointerType())
31 return true; // base case
32
33 // Look through typedef chain to check for special types.
34 for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>();
35 CurrentT = TT->getDecl()->getUnderlyingType()) {
36 const IdentifierInfo *II = TT->getDecl()->getIdentifier();
37 // Special Case: Syntactically uintptr_t is not a pointer; semantically,
38 // however, very likely used as such. Therefore, classify uintptr_t as a
39 // pointer, too.
40 if (II && II->isStr(Str: "uintptr_t"))
41 return true;
42 }
43
44 // The type is an array; check the element type.
45 if (const ArrayType *AT = dyn_cast<ArrayType>(Val&: CanonicalType))
46 return typeContainsPointer(T: AT->getElementType(), VisitedRD, IncompleteType);
47 // The type is a struct, class, or union.
48 if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) {
49 if (!RD->isCompleteDefinition()) {
50 IncompleteType = true;
51 return false;
52 }
53 if (!VisitedRD.insert(Ptr: RD).second)
54 return false; // already visited
55 // Check all fields.
56 for (const FieldDecl *Field : RD->fields()) {
57 if (typeContainsPointer(T: Field->getType(), VisitedRD, IncompleteType))
58 return true;
59 }
60 // For C++ classes, also check base classes.
61 if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
62 // Polymorphic types require a vptr.
63 if (CXXRD->isDynamicClass())
64 return true;
65 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
66 if (typeContainsPointer(T: Base.getType(), VisitedRD, IncompleteType))
67 return true;
68 }
69 }
70 }
71 return false;
72}
73
74/// Infer type from a simple sizeof expression.
75static QualType inferTypeFromSizeofExpr(const Expr *E) {
76 const Expr *Arg = E->IgnoreParenImpCasts();
77 if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Val: Arg)) {
78 if (UET->getKind() == UETT_SizeOf) {
79 if (UET->isArgumentType())
80 return UET->getArgumentTypeInfo()->getType();
81 else
82 return UET->getArgumentExpr()->getType();
83 }
84 }
85 return QualType();
86}
87
88/// Infer type from an arithmetic expression involving a sizeof. For example:
89///
90/// malloc(sizeof(MyType) + padding); // infers 'MyType'
91/// malloc(sizeof(MyType) * 32); // infers 'MyType'
92/// malloc(32 * sizeof(MyType)); // infers 'MyType'
93/// malloc(sizeof(MyType) << 1); // infers 'MyType'
94/// ...
95///
96/// More complex arithmetic expressions are supported, but are a heuristic, e.g.
97/// when considering allocations for structs with flexible array members:
98///
99/// malloc(sizeof(HasFlexArray) + sizeof(int) * 32); // infers 'HasFlexArray'
100///
101static QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
102 const Expr *Arg = E->IgnoreParenImpCasts();
103 // The argument is a lone sizeof expression.
104 if (QualType T = inferTypeFromSizeofExpr(E: Arg); !T.isNull())
105 return T;
106 if (const auto *BO = dyn_cast<BinaryOperator>(Val: Arg)) {
107 // Argument is an arithmetic expression. Cover common arithmetic patterns
108 // involving sizeof.
109 switch (BO->getOpcode()) {
110 case BO_Add:
111 case BO_Div:
112 case BO_Mul:
113 case BO_Shl:
114 case BO_Shr:
115 case BO_Sub:
116 if (QualType T = inferPossibleTypeFromArithSizeofExpr(E: BO->getLHS());
117 !T.isNull())
118 return T;
119 if (QualType T = inferPossibleTypeFromArithSizeofExpr(E: BO->getRHS());
120 !T.isNull())
121 return T;
122 break;
123 default:
124 break;
125 }
126 }
127 return QualType();
128}
129
130/// If the expression E is a reference to a variable, infer the type from a
131/// variable's initializer if it contains a sizeof. Beware, this is a heuristic
132/// and ignores if a variable is later reassigned. For example:
133///
134/// size_t my_size = sizeof(MyType);
135/// void *x = malloc(my_size); // infers 'MyType'
136///
137static QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
138 const Expr *Arg = E->IgnoreParenImpCasts();
139 if (const auto *DRE = dyn_cast<DeclRefExpr>(Val: Arg)) {
140 if (const auto *VD = dyn_cast<VarDecl>(Val: DRE->getDecl())) {
141 if (const Expr *Init = VD->getInit())
142 return inferPossibleTypeFromArithSizeofExpr(E: Init);
143 }
144 }
145 return QualType();
146}
147
148/// Deduces the allocated type by checking if the allocation call's result
149/// is immediately used in a cast expression. For example:
150///
151/// MyType *x = (MyType *)malloc(4096); // infers 'MyType'
152///
153static QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
154 const CastExpr *CastE) {
155 if (!CastE)
156 return QualType();
157 QualType PtrType = CastE->getType();
158 if (PtrType->isPointerType())
159 return PtrType->getPointeeType();
160 return QualType();
161}
162
163QualType infer_alloc::inferPossibleType(const CallExpr *E,
164 const ASTContext &Ctx,
165 const CastExpr *CastE) {
166 QualType AllocType;
167 // First check arguments.
168 for (const Expr *Arg : E->arguments()) {
169 AllocType = inferPossibleTypeFromArithSizeofExpr(E: Arg);
170 if (AllocType.isNull())
171 AllocType = inferPossibleTypeFromVarInitSizeofExpr(E: Arg);
172 if (!AllocType.isNull())
173 break;
174 }
175 // Then check later casts.
176 if (AllocType.isNull())
177 AllocType = inferPossibleTypeFromCastExpr(CallE: E, CastE);
178 return AllocType;
179}
180
181std::optional<llvm::AllocTokenMetadata>
182infer_alloc::getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
183 llvm::AllocTokenMetadata ATMD;
184
185 // Get unique type name.
186 PrintingPolicy Policy(Ctx.getLangOpts());
187 Policy.SuppressTagKeyword = true;
188 Policy.FullyQualifiedName = true;
189 llvm::raw_svector_ostream TypeNameOS(ATMD.TypeName);
190 T.getCanonicalType().print(OS&: TypeNameOS, Policy);
191
192 // Check if QualType contains a pointer. Implements a simple DFS to
193 // recursively check if a type contains a pointer type.
194 llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD;
195 bool IncompleteType = false;
196 ATMD.ContainsPointer = typeContainsPointer(T, VisitedRD, IncompleteType);
197 if (!ATMD.ContainsPointer && IncompleteType)
198 return std::nullopt;
199
200 return ATMD;
201}
202