1 | //===--- RefactoringCallbacks.cpp - Structural query framework ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // |
10 | //===----------------------------------------------------------------------===// |
11 | #include "clang/Tooling/RefactoringCallbacks.h" |
12 | #include "clang/ASTMatchers/ASTMatchFinder.h" |
13 | #include "clang/Basic/SourceLocation.h" |
14 | #include "clang/Lex/Lexer.h" |
15 | |
16 | using llvm::StringError; |
17 | using llvm::make_error; |
18 | |
19 | namespace clang { |
20 | namespace tooling { |
21 | |
22 | RefactoringCallback::RefactoringCallback() {} |
23 | tooling::Replacements &RefactoringCallback::getReplacements() { |
24 | return Replace; |
25 | } |
26 | |
27 | ASTMatchRefactorer::ASTMatchRefactorer( |
28 | std::map<std::string, Replacements> &FileToReplaces) |
29 | : FileToReplaces(FileToReplaces) {} |
30 | |
31 | void ASTMatchRefactorer::addDynamicMatcher( |
32 | const ast_matchers::internal::DynTypedMatcher &Matcher, |
33 | RefactoringCallback *Callback) { |
34 | MatchFinder.addDynamicMatcher(NodeMatch: Matcher, Action: Callback); |
35 | Callbacks.push_back(x: Callback); |
36 | } |
37 | |
38 | class RefactoringASTConsumer : public ASTConsumer { |
39 | public: |
40 | explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring) |
41 | : Refactoring(Refactoring) {} |
42 | |
43 | void HandleTranslationUnit(ASTContext &Context) override { |
44 | // The ASTMatchRefactorer is re-used between translation units. |
45 | // Clear the matchers so that each Replacement is only emitted once. |
46 | for (const auto &Callback : Refactoring.Callbacks) { |
47 | Callback->getReplacements().clear(); |
48 | } |
49 | Refactoring.MatchFinder.matchAST(Context); |
50 | for (const auto &Callback : Refactoring.Callbacks) { |
51 | for (const auto &Replacement : Callback->getReplacements()) { |
52 | llvm::Error Err = |
53 | Refactoring.FileToReplaces[std::string(Replacement.getFilePath())] |
54 | .add(R: Replacement); |
55 | if (Err) { |
56 | llvm::errs() << "Skipping replacement " << Replacement.toString() |
57 | << " due to this error:\n" |
58 | << toString(E: std::move(Err)) << "\n" ; |
59 | } |
60 | } |
61 | } |
62 | } |
63 | |
64 | private: |
65 | ASTMatchRefactorer &Refactoring; |
66 | }; |
67 | |
68 | std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() { |
69 | return std::make_unique<RefactoringASTConsumer>(args&: *this); |
70 | } |
71 | |
72 | static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From, |
73 | StringRef Text) { |
74 | return tooling::Replacement( |
75 | Sources, CharSourceRange::getTokenRange(R: From.getSourceRange()), Text); |
76 | } |
77 | static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From, |
78 | const Stmt &To) { |
79 | return replaceStmtWithText( |
80 | Sources, From, |
81 | Text: Lexer::getSourceText(Range: CharSourceRange::getTokenRange(R: To.getSourceRange()), |
82 | SM: Sources, LangOpts: LangOptions())); |
83 | } |
84 | |
85 | ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText) |
86 | : FromId(std::string(FromId)), ToText(std::string(ToText)) {} |
87 | |
88 | void ReplaceStmtWithText::run( |
89 | const ast_matchers::MatchFinder::MatchResult &Result) { |
90 | if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(ID: FromId)) { |
91 | auto Err = Replace.add(R: tooling::Replacement( |
92 | *Result.SourceManager, |
93 | CharSourceRange::getTokenRange(R: FromMatch->getSourceRange()), ToText)); |
94 | // FIXME: better error handling. For now, just print error message in the |
95 | // release version. |
96 | if (Err) { |
97 | llvm::errs() << llvm::toString(E: std::move(Err)) << "\n" ; |
98 | assert(false); |
99 | } |
100 | } |
101 | } |
102 | |
103 | ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId) |
104 | : FromId(std::string(FromId)), ToId(std::string(ToId)) {} |
105 | |
106 | void ReplaceStmtWithStmt::run( |
107 | const ast_matchers::MatchFinder::MatchResult &Result) { |
108 | const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(ID: FromId); |
109 | const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ID: ToId); |
110 | if (FromMatch && ToMatch) { |
111 | auto Err = Replace.add( |
112 | R: replaceStmtWithStmt(Sources&: *Result.SourceManager, From: *FromMatch, To: *ToMatch)); |
113 | // FIXME: better error handling. For now, just print error message in the |
114 | // release version. |
115 | if (Err) { |
116 | llvm::errs() << llvm::toString(E: std::move(Err)) << "\n" ; |
117 | assert(false); |
118 | } |
119 | } |
120 | } |
121 | |
122 | ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id, |
123 | bool PickTrueBranch) |
124 | : Id(std::string(Id)), PickTrueBranch(PickTrueBranch) {} |
125 | |
126 | void ReplaceIfStmtWithItsBody::run( |
127 | const ast_matchers::MatchFinder::MatchResult &Result) { |
128 | if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(ID: Id)) { |
129 | const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse(); |
130 | if (Body) { |
131 | auto Err = |
132 | Replace.add(R: replaceStmtWithStmt(Sources&: *Result.SourceManager, From: *Node, To: *Body)); |
133 | // FIXME: better error handling. For now, just print error message in the |
134 | // release version. |
135 | if (Err) { |
136 | llvm::errs() << llvm::toString(E: std::move(Err)) << "\n" ; |
137 | assert(false); |
138 | } |
139 | } else if (!PickTrueBranch) { |
140 | // If we want to use the 'else'-branch, but it doesn't exist, delete |
141 | // the whole 'if'. |
142 | auto Err = |
143 | Replace.add(R: replaceStmtWithText(Sources&: *Result.SourceManager, From: *Node, Text: "" )); |
144 | // FIXME: better error handling. For now, just print error message in the |
145 | // release version. |
146 | if (Err) { |
147 | llvm::errs() << llvm::toString(E: std::move(Err)) << "\n" ; |
148 | assert(false); |
149 | } |
150 | } |
151 | } |
152 | } |
153 | |
154 | ReplaceNodeWithTemplate::ReplaceNodeWithTemplate( |
155 | llvm::StringRef FromId, std::vector<TemplateElement> Template) |
156 | : FromId(std::string(FromId)), Template(std::move(Template)) {} |
157 | |
158 | llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>> |
159 | ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) { |
160 | std::vector<TemplateElement> ParsedTemplate; |
161 | for (size_t Index = 0; Index < ToTemplate.size();) { |
162 | if (ToTemplate[Index] == '$') { |
163 | if (ToTemplate.substr(Start: Index, N: 2) == "$$" ) { |
164 | Index += 2; |
165 | ParsedTemplate.push_back( |
166 | x: TemplateElement{.Type: TemplateElement::Literal, .Value: "$" }); |
167 | } else if (ToTemplate.substr(Start: Index, N: 2) == "${" ) { |
168 | size_t EndOfIdentifier = ToTemplate.find(Str: "}" , From: Index); |
169 | if (EndOfIdentifier == std::string::npos) { |
170 | return make_error<StringError>( |
171 | Args: "Unterminated ${...} in replacement template near " + |
172 | ToTemplate.substr(Start: Index), |
173 | Args: llvm::inconvertibleErrorCode()); |
174 | } |
175 | std::string SourceNodeName = std::string( |
176 | ToTemplate.substr(Start: Index + 2, N: EndOfIdentifier - Index - 2)); |
177 | ParsedTemplate.push_back( |
178 | x: TemplateElement{.Type: TemplateElement::Identifier, .Value: SourceNodeName}); |
179 | Index = EndOfIdentifier + 1; |
180 | } else { |
181 | return make_error<StringError>( |
182 | Args: "Invalid $ in replacement template near " + |
183 | ToTemplate.substr(Start: Index), |
184 | Args: llvm::inconvertibleErrorCode()); |
185 | } |
186 | } else { |
187 | size_t NextIndex = ToTemplate.find(C: '$', From: Index + 1); |
188 | ParsedTemplate.push_back(x: TemplateElement{ |
189 | .Type: TemplateElement::Literal, |
190 | .Value: std::string(ToTemplate.substr(Start: Index, N: NextIndex - Index))}); |
191 | Index = NextIndex; |
192 | } |
193 | } |
194 | return std::unique_ptr<ReplaceNodeWithTemplate>( |
195 | new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate))); |
196 | } |
197 | |
198 | void ReplaceNodeWithTemplate::run( |
199 | const ast_matchers::MatchFinder::MatchResult &Result) { |
200 | const auto &NodeMap = Result.Nodes.getMap(); |
201 | |
202 | std::string ToText; |
203 | for (const auto &Element : Template) { |
204 | switch (Element.Type) { |
205 | case TemplateElement::Literal: |
206 | ToText += Element.Value; |
207 | break; |
208 | case TemplateElement::Identifier: { |
209 | auto NodeIter = NodeMap.find(x: Element.Value); |
210 | if (NodeIter == NodeMap.end()) { |
211 | llvm::errs() << "Node " << Element.Value |
212 | << " used in replacement template not bound in Matcher \n" ; |
213 | llvm::report_fatal_error(reason: "Unbound node in replacement template." ); |
214 | } |
215 | CharSourceRange Source = |
216 | CharSourceRange::getTokenRange(R: NodeIter->second.getSourceRange()); |
217 | ToText += Lexer::getSourceText(Range: Source, SM: *Result.SourceManager, |
218 | LangOpts: Result.Context->getLangOpts()); |
219 | break; |
220 | } |
221 | } |
222 | } |
223 | if (NodeMap.count(x: FromId) == 0) { |
224 | llvm::errs() << "Node to be replaced " << FromId |
225 | << " not bound in query.\n" ; |
226 | llvm::report_fatal_error(reason: "FromId node not bound in MatchResult" ); |
227 | } |
228 | auto Replacement = |
229 | tooling::Replacement(*Result.SourceManager, &NodeMap.at(k: FromId), ToText, |
230 | Result.Context->getLangOpts()); |
231 | llvm::Error Err = Replace.add(R: Replacement); |
232 | if (Err) { |
233 | llvm::errs() << "Query and replace failed in " << Replacement.getFilePath() |
234 | << "! " << llvm::toString(E: std::move(Err)) << "\n" ; |
235 | llvm::report_fatal_error(reason: "Replacement failed" ); |
236 | } |
237 | } |
238 | |
239 | } // end namespace tooling |
240 | } // end namespace clang |
241 | |