1//===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// This file contains functions which are used to decide if a loop worth to be
10/// unrolled. Moreover, these functions manages the stack of loop which is
11/// tracked by the ProgramState.
12///
13//===----------------------------------------------------------------------===//
14
15#include "clang/ASTMatchers/ASTMatchers.h"
16#include "clang/ASTMatchers/ASTMatchFinder.h"
17#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
18#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
19#include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
20#include <optional>
21
22using namespace clang;
23using namespace ento;
24using namespace clang::ast_matchers;
25
26using ast_matchers::internal::Matcher;
27
28static const int MAXIMUM_STEP_UNROLLED = 128;
29
30namespace {
31struct LoopState {
32private:
33 enum Kind { Normal, Unrolled } K;
34 const Stmt *LoopStmt;
35 const LocationContext *LCtx;
36 unsigned maxStep;
37 LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
38 : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
39
40public:
41 static LoopState getNormal(const Stmt *S, const LocationContext *L,
42 unsigned N) {
43 return LoopState(Normal, S, L, N);
44 }
45 static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
46 unsigned N) {
47 return LoopState(Unrolled, S, L, N);
48 }
49 bool isUnrolled() const { return K == Unrolled; }
50 unsigned getMaxStep() const { return maxStep; }
51 const Stmt *getLoopStmt() const { return LoopStmt; }
52 const LocationContext *getLocationContext() const { return LCtx; }
53 bool operator==(const LoopState &X) const {
54 return K == X.K && LoopStmt == X.LoopStmt;
55 }
56 void Profile(llvm::FoldingSetNodeID &ID) const {
57 ID.AddInteger(I: K);
58 ID.AddPointer(Ptr: LoopStmt);
59 ID.AddPointer(Ptr: LCtx);
60 ID.AddInteger(I: maxStep);
61 }
62};
63} // namespace
64
65// The tracked stack of loops. The stack indicates that which loops the
66// simulated element contained by. The loops are marked depending if we decided
67// to unroll them.
68// TODO: The loop stack should not need to be in the program state since it is
69// lexical in nature. Instead, the stack of loops should be tracked in the
70// LocationContext.
71REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
72
73namespace clang {
74namespace {
75AST_MATCHER(QualType, isIntegralOrEnumerationType) {
76 return Node->isIntegralOrEnumerationType();
77}
78} // namespace
79namespace ento {
80
81static bool isLoopStmt(const Stmt *S) {
82 return isa_and_nonnull<ForStmt, WhileStmt, DoStmt>(Val: S);
83}
84
85ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
86 auto LS = State->get<LoopStack>();
87 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
88 State = State->set<LoopStack>(LS.getTail());
89 return State;
90}
91
92static Matcher<Stmt> simpleCondition(StringRef BindName, StringRef RefName) {
93 auto LoopVariable = ignoringParenImpCasts(
94 InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(hasType(InnerMatcher: isInteger())).bind(ID: BindName)))
95 .bind(ID: RefName));
96 auto UpperBound = ignoringParenImpCasts(
97 InnerMatcher: expr(hasType(InnerMatcher: isIntegralOrEnumerationType())).bind(ID: "boundNum"));
98
99 return binaryOperator(
100 anyOf(hasOperatorName(Name: "<"), hasOperatorName(Name: ">"),
101 hasOperatorName(Name: "<="), hasOperatorName(Name: ">="),
102 hasOperatorName(Name: "!=")),
103 anyOf(binaryOperator(hasLHS(InnerMatcher: LoopVariable), hasRHS(InnerMatcher: UpperBound)),
104 binaryOperator(hasRHS(InnerMatcher: LoopVariable), hasLHS(InnerMatcher: UpperBound))))
105 .bind(ID: "conditionOperator");
106}
107
108static Matcher<Stmt> changeIntBoundNode(Matcher<Decl> VarNodeMatcher) {
109 return anyOf(
110 unaryOperator(anyOf(hasOperatorName(Name: "--"), hasOperatorName(Name: "++")),
111 hasUnaryOperand(InnerMatcher: ignoringParenImpCasts(
112 InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher)))))),
113 binaryOperator(isAssignmentOperator(),
114 hasLHS(InnerMatcher: ignoringParenImpCasts(
115 InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher)))))));
116}
117
118static Matcher<Stmt> callByRef(Matcher<Decl> VarNodeMatcher) {
119 return callExpr(forEachArgumentWithParam(
120 ArgMatcher: declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher))),
121 ParamMatcher: parmVarDecl(hasType(InnerMatcher: references(InnerMatcher: qualType(unless(isConstQualified())))))));
122}
123
124static Matcher<Stmt> assignedToRef(Matcher<Decl> VarNodeMatcher) {
125 return declStmt(hasDescendant(varDecl(
126 allOf(hasType(InnerMatcher: referenceType()),
127 hasInitializer(InnerMatcher: anyOf(
128 initListExpr(has(declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher))))),
129 declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher)))))))));
130}
131
132static Matcher<Stmt> getAddrTo(Matcher<Decl> VarNodeMatcher) {
133 return unaryOperator(
134 hasOperatorName(Name: "&"),
135 hasUnaryOperand(InnerMatcher: declRefExpr(hasDeclaration(InnerMatcher: VarNodeMatcher))));
136}
137
138static Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
139 return hasDescendant(stmt(
140 anyOf(gotoStmt(), switchStmt(), returnStmt(),
141 // Escaping and not known mutation of the loop counter is handled
142 // by exclusion of assigning and address-of operators and
143 // pass-by-ref function calls on the loop counter from the body.
144 changeIntBoundNode(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))),
145 callByRef(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))),
146 getAddrTo(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))),
147 assignedToRef(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))))));
148}
149
150static Matcher<Stmt> forLoopMatcher() {
151 return forStmt(
152 hasCondition(InnerMatcher: simpleCondition(BindName: "initVarName", RefName: "initVarRef")),
153 // Initialization should match the form: 'int i = 6' or 'i = 42'.
154 hasLoopInit(
155 InnerMatcher: anyOf(declStmt(hasSingleDecl(
156 InnerMatcher: varDecl(allOf(hasInitializer(InnerMatcher: ignoringParenImpCasts(
157 InnerMatcher: integerLiteral().bind(ID: "initNum"))),
158 equalsBoundNode(ID: "initVarName"))))),
159 binaryOperator(hasLHS(InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(
160 equalsBoundNode(ID: "initVarName"))))),
161 hasRHS(InnerMatcher: ignoringParenImpCasts(
162 InnerMatcher: integerLiteral().bind(ID: "initNum")))))),
163 // Incrementation should be a simple increment or decrement
164 // operator call.
165 hasIncrement(InnerMatcher: unaryOperator(
166 anyOf(hasOperatorName(Name: "++"), hasOperatorName(Name: "--")),
167 hasUnaryOperand(InnerMatcher: declRefExpr(
168 to(InnerMatcher: varDecl(allOf(equalsBoundNode(ID: "initVarName"),
169 hasType(InnerMatcher: isInteger())))))))),
170 unless(hasBody(InnerMatcher: hasSuspiciousStmt(NodeName: "initVarName"))))
171 .bind(ID: "forLoop");
172}
173
174static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR) {
175
176 // Get the lambda CXXRecordDecl
177 assert(DR->refersToEnclosingVariableOrCapture());
178 const LocationContext *LocCtxt = N->getLocationContext();
179 const Decl *D = LocCtxt->getDecl();
180 const auto *MD = cast<CXXMethodDecl>(Val: D);
181 assert(MD && MD->getParent()->isLambda() &&
182 "Captured variable should only be seen while evaluating a lambda");
183 const CXXRecordDecl *LambdaCXXRec = MD->getParent();
184
185 // Lookup the fields of the lambda
186 llvm::DenseMap<const ValueDecl *, FieldDecl *> LambdaCaptureFields;
187 FieldDecl *LambdaThisCaptureField;
188 LambdaCXXRec->getCaptureFields(Captures&: LambdaCaptureFields, ThisCapture&: LambdaThisCaptureField);
189
190 // Check if the counter is captured by reference
191 const VarDecl *VD = cast<VarDecl>(Val: DR->getDecl()->getCanonicalDecl());
192 assert(VD);
193 const FieldDecl *FD = LambdaCaptureFields[VD];
194 assert(FD && "Captured variable without a corresponding field");
195 return FD->getType()->isReferenceType();
196}
197
198static bool isFoundInStmt(const Stmt *S, const VarDecl *VD) {
199 if (const DeclStmt *DS = dyn_cast<DeclStmt>(Val: S)) {
200 for (const Decl *D : DS->decls()) {
201 // Once we reach the declaration of the VD we can return.
202 if (D->getCanonicalDecl() == VD)
203 return true;
204 }
205 }
206 return false;
207}
208
209// A loop counter is considered escaped if:
210// case 1: It is a global variable.
211// case 2: It is a reference parameter or a reference capture.
212// case 3: It is assigned to a non-const reference variable or parameter.
213// case 4: Has its address taken.
214static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR) {
215 const VarDecl *VD = cast<VarDecl>(Val: DR->getDecl()->getCanonicalDecl());
216 assert(VD);
217 // Case 1:
218 if (VD->hasGlobalStorage())
219 return true;
220
221 const bool IsRefParamOrCapture =
222 isa<ParmVarDecl>(Val: VD) || DR->refersToEnclosingVariableOrCapture();
223 // Case 2:
224 if ((DR->refersToEnclosingVariableOrCapture() &&
225 isCapturedByReference(N, DR)) ||
226 (IsRefParamOrCapture && VD->getType()->isReferenceType()))
227 return true;
228
229 while (!N->pred_empty()) {
230 // FIXME: getStmtForDiagnostics() does nasty things in order to provide
231 // a valid statement for body farms, do we need this behavior here?
232 const Stmt *S = N->getStmtForDiagnostics();
233 if (!S) {
234 N = N->getFirstPred();
235 continue;
236 }
237
238 if (isFoundInStmt(S, VD)) {
239 return false;
240 }
241
242 if (const auto *SS = dyn_cast<SwitchStmt>(Val: S)) {
243 if (const auto *CST = dyn_cast<CompoundStmt>(Val: SS->getBody())) {
244 for (const Stmt *CB : CST->body()) {
245 if (isFoundInStmt(S: CB, VD))
246 return false;
247 }
248 }
249 }
250
251 // Check the usage of the pass-by-ref function calls and adress-of operator
252 // on VD and reference initialized by VD.
253 ASTContext &ASTCtx =
254 N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
255 // Case 3 and 4:
256 auto Match =
257 match(Matcher: stmt(anyOf(callByRef(VarNodeMatcher: equalsNode(Other: VD)), getAddrTo(VarNodeMatcher: equalsNode(Other: VD)),
258 assignedToRef(VarNodeMatcher: equalsNode(Other: VD)))),
259 Node: *S, Context&: ASTCtx);
260 if (!Match.empty())
261 return true;
262
263 N = N->getFirstPred();
264 }
265
266 // Reference parameter and reference capture will not be found.
267 if (IsRefParamOrCapture)
268 return false;
269
270 llvm_unreachable("Reached root without finding the declaration of VD");
271}
272
273static bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
274 ExplodedNode *Pred, unsigned &maxStep) {
275
276 if (!isLoopStmt(S: LoopStmt))
277 return false;
278
279 auto Matches = match(Matcher: forLoopMatcher(), Node: *LoopStmt, Context&: ASTCtx);
280 if (Matches.empty())
281 return false;
282
283 const auto *CounterVarRef = Matches[0].getNodeAs<DeclRefExpr>(ID: "initVarRef");
284 const Expr *BoundNumExpr = Matches[0].getNodeAs<Expr>(ID: "boundNum");
285
286 Expr::EvalResult BoundNumResult;
287 if (!BoundNumExpr || !BoundNumExpr->EvaluateAsInt(Result&: BoundNumResult, Ctx: ASTCtx,
288 AllowSideEffects: Expr::SE_NoSideEffects)) {
289 return false;
290 }
291 llvm::APInt InitNum =
292 Matches[0].getNodeAs<IntegerLiteral>(ID: "initNum")->getValue();
293 auto CondOp = Matches[0].getNodeAs<BinaryOperator>(ID: "conditionOperator");
294 unsigned MaxWidth = std::max(a: InitNum.getBitWidth(),
295 b: BoundNumResult.Val.getInt().getBitWidth());
296
297 InitNum = InitNum.zext(width: MaxWidth);
298 llvm::APInt BoundNum = BoundNumResult.Val.getInt().zext(width: MaxWidth);
299 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
300 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
301 else
302 maxStep = (BoundNum - InitNum).abs().getZExtValue();
303
304 // Check if the counter of the loop is not escaped before.
305 return !isPossiblyEscaped(N: Pred, DR: CounterVarRef);
306}
307
308static bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
309 const Stmt *S = nullptr;
310 while (!N->pred_empty()) {
311 if (N->succ_size() > 1)
312 return true;
313
314 ProgramPoint P = N->getLocation();
315 if (std::optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
316 S = BE->getBlock()->getTerminatorStmt();
317
318 if (S == LoopStmt)
319 return false;
320
321 N = N->getFirstPred();
322 }
323
324 llvm_unreachable("Reached root without encountering the previous step");
325}
326
327// updateLoopStack is called on every basic block, therefore it needs to be fast
328ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
329 ExplodedNode *Pred, unsigned maxVisitOnPath) {
330 auto State = Pred->getState();
331 auto LCtx = Pred->getLocationContext();
332
333 if (!isLoopStmt(S: LoopStmt))
334 return State;
335
336 auto LS = State->get<LoopStack>();
337 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
338 LCtx == LS.getHead().getLocationContext()) {
339 if (LS.getHead().isUnrolled() && madeNewBranch(N: Pred, LoopStmt)) {
340 State = State->set<LoopStack>(LS.getTail());
341 State = State->add<LoopStack>(
342 K: LoopState::getNormal(S: LoopStmt, L: LCtx, N: maxVisitOnPath));
343 }
344 return State;
345 }
346 unsigned maxStep;
347 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
348 State = State->add<LoopStack>(
349 K: LoopState::getNormal(S: LoopStmt, L: LCtx, N: maxVisitOnPath));
350 return State;
351 }
352
353 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
354
355 unsigned innerMaxStep = maxStep * outerStep;
356 if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
357 State = State->add<LoopStack>(
358 K: LoopState::getNormal(S: LoopStmt, L: LCtx, N: maxVisitOnPath));
359 else
360 State = State->add<LoopStack>(
361 K: LoopState::getUnrolled(S: LoopStmt, L: LCtx, N: innerMaxStep));
362 return State;
363}
364
365bool isUnrolledState(ProgramStateRef State) {
366 auto LS = State->get<LoopStack>();
367 if (LS.isEmpty() || !LS.getHead().isUnrolled())
368 return false;
369 return true;
370}
371}
372}
373