1//===--- InferAlloc.cpp - Allocation type inference -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements allocation-related type inference.
10//
11//===----------------------------------------------------------------------===//
12
13#include "clang/AST/InferAlloc.h"
14#include "clang/AST/ASTContext.h"
15#include "clang/AST/Decl.h"
16#include "clang/AST/DeclCXX.h"
17#include "clang/AST/Expr.h"
18#include "clang/AST/Type.h"
19#include "clang/Basic/IdentifierTable.h"
20#include "llvm/ADT/SmallPtrSet.h"
21
22using namespace clang;
23using namespace infer_alloc;
24
25static bool
26typeContainsPointer(QualType T,
27 llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
28 bool &IncompleteType) {
29 QualType CanonicalType = T.getCanonicalType();
30 if (CanonicalType->isPointerType())
31 return true; // base case
32
33 // Look through typedef chain to check for special types.
34 for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>();
35 CurrentT = TT->getDecl()->getUnderlyingType()) {
36 const IdentifierInfo *II = TT->getDecl()->getIdentifier();
37 // Special Case: Syntactically uintptr_t is not a pointer; semantically,
38 // however, very likely used as such. Therefore, classify uintptr_t as a
39 // pointer, too.
40 if (II && II->isStr(Str: "uintptr_t"))
41 return true;
42 }
43
44 // The type is an array; check the element type.
45 if (const ArrayType *AT = dyn_cast<ArrayType>(Val&: CanonicalType))
46 return typeContainsPointer(T: AT->getElementType(), VisitedRD, IncompleteType);
47
48 // The type is an atomic type.
49 if (const AtomicType *AT = dyn_cast<AtomicType>(Val&: CanonicalType))
50 return typeContainsPointer(T: AT->getValueType(), VisitedRD, IncompleteType);
51
52 // The type is a struct, class, or union.
53 if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) {
54 if (!RD->isCompleteDefinition()) {
55 IncompleteType = true;
56 return false;
57 }
58 if (!VisitedRD.insert(Ptr: RD).second)
59 return false; // already visited
60 // Check all fields.
61 for (const FieldDecl *Field : RD->fields()) {
62 if (typeContainsPointer(T: Field->getType(), VisitedRD, IncompleteType))
63 return true;
64 }
65 // For C++ classes, also check base classes.
66 if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(Val: RD)) {
67 // Polymorphic types require a vptr.
68 if (CXXRD->isDynamicClass())
69 return true;
70 for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
71 if (typeContainsPointer(T: Base.getType(), VisitedRD, IncompleteType))
72 return true;
73 }
74 }
75 }
76 return false;
77}
78
79/// Infer type from a simple sizeof expression.
80static QualType inferTypeFromSizeofExpr(const Expr *E) {
81 const Expr *Arg = E->IgnoreParenImpCasts();
82 if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Val: Arg)) {
83 if (UET->getKind() == UETT_SizeOf) {
84 if (UET->isArgumentType())
85 return UET->getArgumentTypeInfo()->getType();
86 else
87 return UET->getArgumentExpr()->getType();
88 }
89 }
90 return QualType();
91}
92
93/// Infer type from an arithmetic expression involving a sizeof. For example:
94///
95/// malloc(sizeof(MyType) + padding); // infers 'MyType'
96/// malloc(sizeof(MyType) * 32); // infers 'MyType'
97/// malloc(32 * sizeof(MyType)); // infers 'MyType'
98/// malloc(sizeof(MyType) << 1); // infers 'MyType'
99/// ...
100///
101/// More complex arithmetic expressions are supported, but are a heuristic, e.g.
102/// when considering allocations for structs with flexible array members:
103///
104/// malloc(sizeof(HasFlexArray) + sizeof(int) * 32); // infers 'HasFlexArray'
105///
106static QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
107 const Expr *Arg = E->IgnoreParenImpCasts();
108 // The argument is a lone sizeof expression.
109 if (QualType T = inferTypeFromSizeofExpr(E: Arg); !T.isNull())
110 return T;
111 if (const auto *BO = dyn_cast<BinaryOperator>(Val: Arg)) {
112 // Argument is an arithmetic expression. Cover common arithmetic patterns
113 // involving sizeof.
114 switch (BO->getOpcode()) {
115 case BO_Add:
116 case BO_Div:
117 case BO_Mul:
118 case BO_Shl:
119 case BO_Shr:
120 case BO_Sub:
121 if (QualType T = inferPossibleTypeFromArithSizeofExpr(E: BO->getLHS());
122 !T.isNull())
123 return T;
124 if (QualType T = inferPossibleTypeFromArithSizeofExpr(E: BO->getRHS());
125 !T.isNull())
126 return T;
127 break;
128 default:
129 break;
130 }
131 }
132 return QualType();
133}
134
135/// If the expression E is a reference to a variable, infer the type from a
136/// variable's initializer if it contains a sizeof. Beware, this is a heuristic
137/// and ignores if a variable is later reassigned. For example:
138///
139/// size_t my_size = sizeof(MyType);
140/// void *x = malloc(my_size); // infers 'MyType'
141///
142static QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
143 const Expr *Arg = E->IgnoreParenImpCasts();
144 if (const auto *DRE = dyn_cast<DeclRefExpr>(Val: Arg)) {
145 if (const auto *VD = dyn_cast<VarDecl>(Val: DRE->getDecl())) {
146 if (const Expr *Init = VD->getInit())
147 return inferPossibleTypeFromArithSizeofExpr(E: Init);
148 }
149 }
150 return QualType();
151}
152
153/// Deduces the allocated type by checking if the allocation call's result
154/// is immediately used in a cast expression. For example:
155///
156/// MyType *x = (MyType *)malloc(4096); // infers 'MyType'
157///
158static QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
159 const CastExpr *CastE) {
160 if (!CastE)
161 return QualType();
162 QualType PtrType = CastE->getType();
163 if (PtrType->isPointerType())
164 return PtrType->getPointeeType();
165 return QualType();
166}
167
168QualType infer_alloc::inferPossibleType(const CallExpr *E,
169 const ASTContext &Ctx,
170 const CastExpr *CastE) {
171 QualType AllocType;
172 // First check arguments.
173 for (const Expr *Arg : E->arguments()) {
174 AllocType = inferPossibleTypeFromArithSizeofExpr(E: Arg);
175 if (AllocType.isNull())
176 AllocType = inferPossibleTypeFromVarInitSizeofExpr(E: Arg);
177 if (!AllocType.isNull())
178 break;
179 }
180 // Then check later casts.
181 if (AllocType.isNull())
182 AllocType = inferPossibleTypeFromCastExpr(CallE: E, CastE);
183 return AllocType;
184}
185
186std::optional<llvm::AllocTokenMetadata>
187infer_alloc::getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
188 llvm::AllocTokenMetadata ATMD;
189
190 // Get unique type name.
191 PrintingPolicy Policy(Ctx.getLangOpts());
192 Policy.SuppressTagKeyword = true;
193 Policy.FullyQualifiedName = true;
194 llvm::raw_svector_ostream TypeNameOS(ATMD.TypeName);
195 T.getCanonicalType().print(OS&: TypeNameOS, Policy);
196
197 // Check if QualType contains a pointer. Implements a simple DFS to
198 // recursively check if a type contains a pointer type.
199 llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD;
200 bool IncompleteType = false;
201 ATMD.ContainsPointer = typeContainsPointer(T, VisitedRD, IncompleteType);
202 if (!ATMD.ContainsPointer && IncompleteType)
203 return std::nullopt;
204
205 return ATMD;
206}
207