1//===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// This file contains functions which are used to decide if a loop worth to be
10/// unrolled. Moreover, these functions manages the stack of loop which is
11/// tracked by the ProgramState.
12///
13//===----------------------------------------------------------------------===//
14
15#include "clang/ASTMatchers/ASTMatchers.h"
16#include "clang/ASTMatchers/ASTMatchFinder.h"
17#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
18#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
19#include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
20#include <optional>
21
22using namespace clang;
23using namespace ento;
24using namespace clang::ast_matchers;
25
26using ast_matchers::internal::Matcher;
27
28static const int MAXIMUM_STEP_UNROLLED = 128;
29
30namespace {
31struct LoopState {
32private:
33 enum Kind { Normal, Unrolled } K;
34 const Stmt *LoopStmt;
35 const StackFrame *SF;
36 unsigned maxStep;
37 LoopState(Kind InK, const Stmt *S, const StackFrame *SF, unsigned N)
38 : K(InK), LoopStmt(S), SF(SF), maxStep(N) {}
39
40public:
41 static LoopState getNormal(const Stmt *S, const StackFrame *SF, unsigned N) {
42 return LoopState(Normal, S, SF, N);
43 }
44 static LoopState getUnrolled(const Stmt *S, const StackFrame *SF,
45 unsigned N) {
46 return LoopState(Unrolled, S, SF, N);
47 }
48 bool isUnrolled() const { return K == Unrolled; }
49 unsigned getMaxStep() const { return maxStep; }
50 const Stmt *getLoopStmt() const { return LoopStmt; }
51 const StackFrame *getStackFrame() const { return SF; }
52 bool operator==(const LoopState &X) const {
53 return K == X.K && LoopStmt == X.LoopStmt;
54 }
55 void Profile(llvm::FoldingSetNodeID &ID) const {
56 ID.AddInteger(I: K);
57 ID.AddPointer(Ptr: LoopStmt);
58 ID.AddPointer(Ptr: SF);
59 ID.AddInteger(I: maxStep);
60 }
61};
62} // namespace
63
64// The tracked stack of loops. The stack indicates that which loops the
65// simulated element contained by. The loops are marked depending if we decided
66// to unroll them.
67// TODO: The loop stack should not need to be in the program state since it is
68// lexical in nature. Instead, the stack of loops should be tracked in the
69// StackFrame.
70REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
71
72namespace clang {
73namespace {
74AST_MATCHER(QualType, isIntegralOrEnumerationType) {
75 return Node->isIntegralOrEnumerationType();
76}
77} // namespace
78namespace ento {
79
80static bool isLoopStmt(const Stmt *S) {
81 return isa_and_nonnull<ForStmt, WhileStmt, DoStmt>(Val: S);
82}
83
84ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
85 auto LS = State->get<LoopStack>();
86 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
87 State = State->set<LoopStack>(LS.getTail());
88 return State;
89}
90
91static Matcher<Stmt> simpleCondition(StringRef BindName, StringRef RefName) {
92 auto LoopVariable = ignoringParenImpCasts(
93 InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(hasType(InnerMatcher: isInteger())).bind(ID: BindName)))
94 .bind(ID: RefName));
95 auto UpperBound = ignoringParenImpCasts(
96 InnerMatcher: expr(hasType(InnerMatcher: isIntegralOrEnumerationType())).bind(ID: "boundNum"));
97
98 return binaryOperator(
99 anyOf(hasOperatorName(Name: "<"), hasOperatorName(Name: ">"),
100 hasOperatorName(Name: "<="), hasOperatorName(Name: ">="),
101 hasOperatorName(Name: "!=")),
102 anyOf(binaryOperator(hasLHS(InnerMatcher: LoopVariable), hasRHS(InnerMatcher: UpperBound)),
103 binaryOperator(hasRHS(InnerMatcher: LoopVariable), hasLHS(InnerMatcher: UpperBound))))
104 .bind(ID: "conditionOperator");
105}
106
107static Matcher<Stmt> changeIntBoundNode(Matcher<Decl> VarNodeMatcher) {
108 return anyOf(
109 unaryOperator(anyOf(hasOperatorName(Name: "--"), hasOperatorName(Name: "++")),
110 hasUnaryOperand(InnerMatcher: ignoringParenImpCasts(
111 InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher)))))),
112 binaryOperator(isAssignmentOperator(),
113 hasLHS(InnerMatcher: ignoringParenImpCasts(
114 InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher)))))));
115}
116
117static Matcher<Stmt> callByRef(Matcher<Decl> VarNodeMatcher) {
118 return callExpr(forEachArgumentWithParam(
119 ArgMatcher: declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher))),
120 ParamMatcher: parmVarDecl(hasType(InnerMatcher: references(InnerMatcher: qualType(unless(isConstQualified())))))));
121}
122
123static Matcher<Stmt> assignedToRef(Matcher<Decl> VarNodeMatcher) {
124 return declStmt(hasDescendant(varDecl(
125 allOf(hasType(InnerMatcher: referenceType()),
126 hasInitializer(InnerMatcher: anyOf(
127 initListExpr(has(declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher))))),
128 declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher)))))))));
129}
130
131static Matcher<Stmt> getAddrTo(Matcher<Decl> VarNodeMatcher) {
132 return unaryOperator(
133 hasOperatorName(Name: "&"),
134 hasUnaryOperand(InnerMatcher: declRefExpr(hasDeclaration(InnerMatcher: VarNodeMatcher))));
135}
136
137static Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
138 return hasDescendant(stmt(
139 anyOf(gotoStmt(), switchStmt(), returnStmt(),
140 // Escaping and not known mutation of the loop counter is handled
141 // by exclusion of assigning and address-of operators and
142 // pass-by-ref function calls on the loop counter from the body.
143 changeIntBoundNode(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))),
144 callByRef(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))),
145 getAddrTo(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))),
146 assignedToRef(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))))));
147}
148
149static Matcher<Stmt> forLoopMatcher() {
150 return forStmt(
151 hasCondition(InnerMatcher: simpleCondition(BindName: "initVarName", RefName: "initVarRef")),
152 // Initialization should match the form: 'int i = 6' or 'i = 42'.
153 hasLoopInit(
154 InnerMatcher: anyOf(declStmt(hasSingleDecl(
155 InnerMatcher: varDecl(allOf(hasInitializer(InnerMatcher: ignoringParenImpCasts(
156 InnerMatcher: integerLiteral().bind(ID: "initNum"))),
157 equalsBoundNode(ID: "initVarName"))))),
158 binaryOperator(hasLHS(InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(
159 equalsBoundNode(ID: "initVarName"))))),
160 hasRHS(InnerMatcher: ignoringParenImpCasts(
161 InnerMatcher: integerLiteral().bind(ID: "initNum")))))),
162 // Incrementation should be a simple increment or decrement
163 // operator call.
164 hasIncrement(InnerMatcher: unaryOperator(
165 anyOf(hasOperatorName(Name: "++"), hasOperatorName(Name: "--")),
166 hasUnaryOperand(InnerMatcher: declRefExpr(
167 to(InnerMatcher: varDecl(allOf(equalsBoundNode(ID: "initVarName"),
168 hasType(InnerMatcher: isInteger())))))))),
169 unless(hasBody(InnerMatcher: hasSuspiciousStmt(NodeName: "initVarName"))))
170 .bind(ID: "forLoop");
171}
172
173static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR) {
174
175 // Get the lambda CXXRecordDecl
176 assert(DR->refersToEnclosingVariableOrCapture());
177 const Decl *D = N->getStackFrame()->getDecl();
178 const auto *MD = cast<CXXMethodDecl>(Val: D);
179 assert(MD && MD->getParent()->isLambda() &&
180 "Captured variable should only be seen while evaluating a lambda");
181 const CXXRecordDecl *LambdaCXXRec = MD->getParent();
182
183 // Lookup the fields of the lambda
184 llvm::DenseMap<const ValueDecl *, FieldDecl *> LambdaCaptureFields;
185 FieldDecl *LambdaThisCaptureField;
186 LambdaCXXRec->getCaptureFields(Captures&: LambdaCaptureFields, ThisCapture&: LambdaThisCaptureField);
187
188 // Check if the counter is captured by reference
189 const VarDecl *VD = cast<VarDecl>(Val: DR->getDecl()->getCanonicalDecl());
190 assert(VD);
191 const FieldDecl *FD = LambdaCaptureFields[VD];
192 assert(FD && "Captured variable without a corresponding field");
193 return FD->getType()->isReferenceType();
194}
195
196static bool isFoundInStmt(const Stmt *S, const VarDecl *VD) {
197 if (const DeclStmt *DS = dyn_cast<DeclStmt>(Val: S)) {
198 for (const Decl *D : DS->decls()) {
199 // Once we reach the declaration of the VD we can return.
200 if (D->getCanonicalDecl() == VD)
201 return true;
202 }
203 }
204 return false;
205}
206
207// A loop counter is considered escaped if:
208// case 1: It is a global variable.
209// case 2: It is a reference parameter or a reference capture.
210// case 3: It is assigned to a non-const reference variable or parameter.
211// case 4: Has its address taken.
212static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR) {
213 const VarDecl *VD = cast<VarDecl>(Val: DR->getDecl()->getCanonicalDecl());
214 assert(VD);
215 // Case 1:
216 if (VD->hasGlobalStorage())
217 return true;
218
219 const bool IsRefParamOrCapture =
220 isa<ParmVarDecl>(Val: VD) || DR->refersToEnclosingVariableOrCapture();
221 // Case 2:
222 if ((DR->refersToEnclosingVariableOrCapture() &&
223 isCapturedByReference(N, DR)) ||
224 (IsRefParamOrCapture && VD->getType()->isReferenceType()))
225 return true;
226
227 while (!N->pred_empty()) {
228 // FIXME: getStmtForDiagnostics() does nasty things in order to provide
229 // a valid statement for body farms, do we need this behavior here?
230 const Stmt *S = N->getStmtForDiagnostics();
231 if (!S) {
232 N = N->getFirstPred();
233 continue;
234 }
235
236 if (isFoundInStmt(S, VD)) {
237 return false;
238 }
239
240 if (const auto *SS = dyn_cast<SwitchStmt>(Val: S)) {
241 if (const auto *CST = dyn_cast<CompoundStmt>(Val: SS->getBody())) {
242 for (const Stmt *CB : CST->body()) {
243 if (isFoundInStmt(S: CB, VD))
244 return false;
245 }
246 }
247 }
248
249 // Check the usage of the pass-by-ref function calls and adress-of operator
250 // on VD and reference initialized by VD.
251 ASTContext &ASTCtx =
252 N->getStackFrame()->getAnalysisDeclContext()->getASTContext();
253 // Case 3 and 4:
254 auto Match =
255 match(Matcher: stmt(anyOf(callByRef(VarNodeMatcher: equalsNode(Other: VD)), getAddrTo(VarNodeMatcher: equalsNode(Other: VD)),
256 assignedToRef(VarNodeMatcher: equalsNode(Other: VD)))),
257 Node: *S, Context&: ASTCtx);
258 if (!Match.empty())
259 return true;
260
261 N = N->getFirstPred();
262 }
263
264 // Reference parameter and reference capture will not be found.
265 if (IsRefParamOrCapture)
266 return false;
267
268 llvm_unreachable("Reached root without finding the declaration of VD");
269}
270
271static bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
272 ExplodedNode *Pred, unsigned &maxStep) {
273
274 if (!isLoopStmt(S: LoopStmt))
275 return false;
276
277 auto Matches = match(Matcher: forLoopMatcher(), Node: *LoopStmt, Context&: ASTCtx);
278 if (Matches.empty())
279 return false;
280
281 const auto *CounterVarRef = Matches[0].getNodeAs<DeclRefExpr>(ID: "initVarRef");
282 const Expr *BoundNumExpr = Matches[0].getNodeAs<Expr>(ID: "boundNum");
283
284 Expr::EvalResult BoundNumResult;
285 if (!BoundNumExpr || !BoundNumExpr->EvaluateAsInt(Result&: BoundNumResult, Ctx: ASTCtx,
286 AllowSideEffects: Expr::SE_NoSideEffects)) {
287 return false;
288 }
289 llvm::APInt InitNum =
290 Matches[0].getNodeAs<IntegerLiteral>(ID: "initNum")->getValue();
291 auto CondOp = Matches[0].getNodeAs<BinaryOperator>(ID: "conditionOperator");
292 unsigned MaxWidth = std::max(a: InitNum.getBitWidth(),
293 b: BoundNumResult.Val.getInt().getBitWidth());
294
295 InitNum = InitNum.zext(width: MaxWidth);
296 llvm::APInt BoundNum = BoundNumResult.Val.getInt().zext(width: MaxWidth);
297 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
298 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
299 else
300 maxStep = (BoundNum - InitNum).abs().getZExtValue();
301
302 // Check if the counter of the loop is not escaped before.
303 return !isPossiblyEscaped(N: Pred, DR: CounterVarRef);
304}
305
306static bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
307 const Stmt *S = nullptr;
308 while (!N->pred_empty()) {
309 if (N->succ_size() > 1)
310 return true;
311
312 ProgramPoint P = N->getLocation();
313 if (std::optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
314 S = BE->getBlock()->getTerminatorStmt();
315
316 if (S == LoopStmt)
317 return false;
318
319 N = N->getFirstPred();
320 }
321
322 llvm_unreachable("Reached root without encountering the previous step");
323}
324
325// updateLoopStack is called on every basic block, therefore it needs to be fast
326ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
327 ExplodedNode *Pred, unsigned maxVisitOnPath) {
328 auto State = Pred->getState();
329 auto SF = Pred->getStackFrame();
330
331 if (!isLoopStmt(S: LoopStmt))
332 return State;
333
334 auto LS = State->get<LoopStack>();
335 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
336 SF == LS.getHead().getStackFrame()) {
337 if (LS.getHead().isUnrolled() && madeNewBranch(N: Pred, LoopStmt)) {
338 State = State->set<LoopStack>(LS.getTail());
339 State = State->add<LoopStack>(
340 K: LoopState::getNormal(S: LoopStmt, SF, N: maxVisitOnPath));
341 }
342 return State;
343 }
344 unsigned maxStep;
345 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
346 State = State->add<LoopStack>(
347 K: LoopState::getNormal(S: LoopStmt, SF, N: maxVisitOnPath));
348 return State;
349 }
350
351 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
352
353 unsigned innerMaxStep = maxStep * outerStep;
354 if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
355 State = State->add<LoopStack>(
356 K: LoopState::getNormal(S: LoopStmt, SF, N: maxVisitOnPath));
357 else
358 State = State->add<LoopStack>(
359 K: LoopState::getUnrolled(S: LoopStmt, SF, N: innerMaxStep));
360 return State;
361}
362
363bool isUnrolledState(ProgramStateRef State) {
364 auto LS = State->get<LoopStack>();
365 if (LS.isEmpty() || !LS.getHead().isUnrolled())
366 return false;
367 return true;
368}
369}
370}
371