1 | //===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Adds brackets in case statements that "contain" initialization of retaining |
10 | // variable, thus emitting the "switch case is in protected scope" error. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "Internals.h" |
15 | #include "Transforms.h" |
16 | #include "clang/AST/ASTContext.h" |
17 | #include "clang/Basic/SourceManager.h" |
18 | #include "clang/Sema/SemaDiagnostic.h" |
19 | |
20 | using namespace clang; |
21 | using namespace arcmt; |
22 | using namespace trans; |
23 | |
24 | namespace { |
25 | |
26 | class LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> { |
27 | SmallVectorImpl<DeclRefExpr *> &Refs; |
28 | |
29 | public: |
30 | LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs) |
31 | : Refs(refs) { } |
32 | |
33 | bool VisitDeclRefExpr(DeclRefExpr *E) { |
34 | if (ValueDecl *D = E->getDecl()) |
35 | if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod()) |
36 | Refs.push_back(Elt: E); |
37 | return true; |
38 | } |
39 | }; |
40 | |
41 | struct CaseInfo { |
42 | SwitchCase *SC; |
43 | SourceRange Range; |
44 | enum { |
45 | St_Unchecked, |
46 | St_CannotFix, |
47 | St_Fixed |
48 | } State; |
49 | |
50 | CaseInfo() : SC(nullptr), State(St_Unchecked) {} |
51 | CaseInfo(SwitchCase *S, SourceRange Range) |
52 | : SC(S), Range(Range), State(St_Unchecked) {} |
53 | }; |
54 | |
55 | class CaseCollector : public RecursiveASTVisitor<CaseCollector> { |
56 | ParentMap &PMap; |
57 | SmallVectorImpl<CaseInfo> &Cases; |
58 | |
59 | public: |
60 | CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases) |
61 | : PMap(PMap), Cases(Cases) { } |
62 | |
63 | bool VisitSwitchStmt(SwitchStmt *S) { |
64 | SwitchCase *Curr = S->getSwitchCaseList(); |
65 | if (!Curr) |
66 | return true; |
67 | Stmt *Parent = getCaseParent(S: Curr); |
68 | Curr = Curr->getNextSwitchCase(); |
69 | // Make sure all case statements are in the same scope. |
70 | while (Curr) { |
71 | if (getCaseParent(S: Curr) != Parent) |
72 | return true; |
73 | Curr = Curr->getNextSwitchCase(); |
74 | } |
75 | |
76 | SourceLocation NextLoc = S->getEndLoc(); |
77 | Curr = S->getSwitchCaseList(); |
78 | // We iterate over case statements in reverse source-order. |
79 | while (Curr) { |
80 | Cases.push_back( |
81 | Elt: CaseInfo(Curr, SourceRange(Curr->getBeginLoc(), NextLoc))); |
82 | NextLoc = Curr->getBeginLoc(); |
83 | Curr = Curr->getNextSwitchCase(); |
84 | } |
85 | return true; |
86 | } |
87 | |
88 | Stmt *getCaseParent(SwitchCase *S) { |
89 | Stmt *Parent = PMap.getParent(S); |
90 | while (Parent && (isa<SwitchCase>(Val: Parent) || isa<LabelStmt>(Val: Parent))) |
91 | Parent = PMap.getParent(Parent); |
92 | return Parent; |
93 | } |
94 | }; |
95 | |
96 | class ProtectedScopeFixer { |
97 | MigrationPass &Pass; |
98 | SourceManager &SM; |
99 | SmallVector<CaseInfo, 16> Cases; |
100 | SmallVector<DeclRefExpr *, 16> LocalRefs; |
101 | |
102 | public: |
103 | ProtectedScopeFixer(BodyContext &BodyCtx) |
104 | : Pass(BodyCtx.getMigrationContext().Pass), |
105 | SM(Pass.Ctx.getSourceManager()) { |
106 | |
107 | CaseCollector(BodyCtx.getParentMap(), Cases) |
108 | .TraverseStmt(S: BodyCtx.getTopStmt()); |
109 | LocalRefsCollector(LocalRefs).TraverseStmt(S: BodyCtx.getTopStmt()); |
110 | |
111 | SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange(); |
112 | const CapturedDiagList &DiagList = Pass.getDiags(); |
113 | // Copy the diagnostics so we don't have to worry about invaliding iterators |
114 | // from the diagnostic list. |
115 | SmallVector<StoredDiagnostic, 16> StoredDiags; |
116 | StoredDiags.append(in_start: DiagList.begin(), in_end: DiagList.end()); |
117 | SmallVectorImpl<StoredDiagnostic>::iterator |
118 | I = StoredDiags.begin(), E = StoredDiags.end(); |
119 | while (I != E) { |
120 | if (I->getID() == diag::err_switch_into_protected_scope && |
121 | isInRange(Loc: I->getLocation(), R: BodyRange)) { |
122 | handleProtectedScopeError(DiagI&: I, DiagE: E); |
123 | continue; |
124 | } |
125 | ++I; |
126 | } |
127 | } |
128 | |
129 | void handleProtectedScopeError( |
130 | SmallVectorImpl<StoredDiagnostic>::iterator &DiagI, |
131 | SmallVectorImpl<StoredDiagnostic>::iterator DiagE){ |
132 | Transaction Trans(Pass.TA); |
133 | assert(DiagI->getID() == diag::err_switch_into_protected_scope); |
134 | SourceLocation ErrLoc = DiagI->getLocation(); |
135 | bool handledAllNotes = true; |
136 | ++DiagI; |
137 | for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note; |
138 | ++DiagI) { |
139 | if (!handleProtectedNote(Diag: *DiagI)) |
140 | handledAllNotes = false; |
141 | } |
142 | |
143 | if (handledAllNotes) |
144 | Pass.TA.clearDiagnostic(IDs: diag::err_switch_into_protected_scope, range: ErrLoc); |
145 | } |
146 | |
147 | bool handleProtectedNote(const StoredDiagnostic &Diag) { |
148 | assert(Diag.getLevel() == DiagnosticsEngine::Note); |
149 | |
150 | for (unsigned i = 0; i != Cases.size(); i++) { |
151 | CaseInfo &info = Cases[i]; |
152 | if (isInRange(Loc: Diag.getLocation(), R: info.Range)) { |
153 | |
154 | if (info.State == CaseInfo::St_Unchecked) |
155 | tryFixing(info); |
156 | assert(info.State != CaseInfo::St_Unchecked); |
157 | |
158 | if (info.State == CaseInfo::St_Fixed) { |
159 | Pass.TA.clearDiagnostic(IDs: Diag.getID(), range: Diag.getLocation()); |
160 | return true; |
161 | } |
162 | return false; |
163 | } |
164 | } |
165 | |
166 | return false; |
167 | } |
168 | |
169 | void tryFixing(CaseInfo &info) { |
170 | assert(info.State == CaseInfo::St_Unchecked); |
171 | if (hasVarReferencedOutside(info)) { |
172 | info.State = CaseInfo::St_CannotFix; |
173 | return; |
174 | } |
175 | |
176 | Pass.TA.insertAfterToken(loc: info.SC->getColonLoc(), text: " {" ); |
177 | Pass.TA.insert(loc: info.Range.getEnd(), text: "}\n" ); |
178 | info.State = CaseInfo::St_Fixed; |
179 | } |
180 | |
181 | bool hasVarReferencedOutside(CaseInfo &info) { |
182 | for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) { |
183 | DeclRefExpr *DRE = LocalRefs[i]; |
184 | if (isInRange(Loc: DRE->getDecl()->getLocation(), R: info.Range) && |
185 | !isInRange(Loc: DRE->getLocation(), R: info.Range)) |
186 | return true; |
187 | } |
188 | return false; |
189 | } |
190 | |
191 | bool isInRange(SourceLocation Loc, SourceRange R) { |
192 | if (Loc.isInvalid()) |
193 | return false; |
194 | return !SM.isBeforeInTranslationUnit(LHS: Loc, RHS: R.getBegin()) && |
195 | SM.isBeforeInTranslationUnit(LHS: Loc, RHS: R.getEnd()); |
196 | } |
197 | }; |
198 | |
199 | } // anonymous namespace |
200 | |
201 | void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) { |
202 | ProtectedScopeFixer Fix(BodyCtx); |
203 | } |
204 | |