1//===--- TestSupport.cpp - Clang-based refactoring tool -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements routines that provide refactoring testing
11/// utilities.
12///
13//===----------------------------------------------------------------------===//
14
15#include "TestSupport.h"
16#include "clang/Basic/DiagnosticError.h"
17#include "clang/Basic/FileManager.h"
18#include "clang/Basic/SourceManager.h"
19#include "clang/Lex/Lexer.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/Error.h"
22#include "llvm/Support/ErrorOr.h"
23#include "llvm/Support/LineIterator.h"
24#include "llvm/Support/MemoryBuffer.h"
25#include "llvm/Support/Regex.h"
26#include "llvm/Support/raw_ostream.h"
27#include <optional>
28
29using namespace llvm;
30
31namespace clang {
32namespace refactor {
33
34void TestSelectionRangesInFile::dump(raw_ostream &OS) const {
35 for (const auto &Group : GroupedRanges) {
36 OS << "Test selection group '" << Group.Name << "':\n";
37 for (const auto &Range : Group.Ranges) {
38 OS << " " << Range.Begin << "-" << Range.End << "\n";
39 }
40 }
41}
42
43bool TestSelectionRangesInFile::foreachRange(
44 const SourceManager &SM,
45 llvm::function_ref<void(SourceRange)> Callback) const {
46 auto FE = SM.getFileManager().getFile(Filename);
47 FileID FID = FE ? SM.translateFile(SourceFile: *FE) : FileID();
48 if (!FE || FID.isInvalid()) {
49 llvm::errs() << "error: -selection=test:" << Filename
50 << " : given file is not in the target TU";
51 return true;
52 }
53 SourceLocation FileLoc = SM.getLocForStartOfFile(FID);
54 for (const auto &Group : GroupedRanges) {
55 for (const TestSelectionRange &Range : Group.Ranges) {
56 // Translate the offset pair to a true source range.
57 SourceLocation Start =
58 SM.getMacroArgExpandedLocation(Loc: FileLoc.getLocWithOffset(Offset: Range.Begin));
59 SourceLocation End =
60 SM.getMacroArgExpandedLocation(Loc: FileLoc.getLocWithOffset(Offset: Range.End));
61 assert(Start.isValid() && End.isValid() && "unexpected invalid range");
62 Callback(SourceRange(Start, End));
63 }
64 }
65 return false;
66}
67
68namespace {
69
70void dumpChanges(const tooling::AtomicChanges &Changes, raw_ostream &OS) {
71 for (const auto &Change : Changes)
72 OS << const_cast<tooling::AtomicChange &>(Change).toYAMLString() << "\n";
73}
74
75bool areChangesSame(const tooling::AtomicChanges &LHS,
76 const tooling::AtomicChanges &RHS) {
77 if (LHS.size() != RHS.size())
78 return false;
79 for (auto I : llvm::zip(t: LHS, u: RHS)) {
80 if (!(std::get<0>(t&: I) == std::get<1>(t&: I)))
81 return false;
82 }
83 return true;
84}
85
86bool printRewrittenSources(const tooling::AtomicChanges &Changes,
87 raw_ostream &OS) {
88 std::set<std::string> Files;
89 for (const auto &Change : Changes)
90 Files.insert(x: Change.getFilePath());
91 tooling::ApplyChangesSpec Spec;
92 Spec.Cleanup = false;
93 for (const auto &File : Files) {
94 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> BufferErr =
95 llvm::MemoryBuffer::getFile(Filename: File);
96 if (!BufferErr) {
97 llvm::errs() << "failed to open" << File << "\n";
98 return true;
99 }
100 auto Result = tooling::applyAtomicChanges(FilePath: File, Code: (*BufferErr)->getBuffer(),
101 Changes, Spec);
102 if (!Result) {
103 llvm::errs() << toString(E: Result.takeError());
104 return true;
105 }
106 OS << *Result;
107 }
108 return false;
109}
110
111class TestRefactoringResultConsumer final
112 : public ClangRefactorToolConsumerInterface {
113public:
114 TestRefactoringResultConsumer(const TestSelectionRangesInFile &TestRanges)
115 : TestRanges(TestRanges) {
116 Results.push_back(x: {});
117 }
118
119 ~TestRefactoringResultConsumer() {
120 // Ensure all results are checked.
121 for (auto &Group : Results) {
122 for (auto &Result : Group) {
123 if (!Result) {
124 (void)llvm::toString(E: Result.takeError());
125 }
126 }
127 }
128 }
129
130 void handleError(llvm::Error Err) override { handleResult(Result: std::move(Err)); }
131
132 void handle(tooling::AtomicChanges Changes) override {
133 handleResult(Result: std::move(Changes));
134 }
135
136 void handle(tooling::SymbolOccurrences Occurrences) override {
137 tooling::RefactoringResultConsumer::handle(Occurrences: std::move(Occurrences));
138 }
139
140private:
141 bool handleAllResults();
142
143 void handleResult(Expected<tooling::AtomicChanges> Result) {
144 Results.back().push_back(x: std::move(Result));
145 size_t GroupIndex = Results.size() - 1;
146 if (Results.back().size() >=
147 TestRanges.GroupedRanges[GroupIndex].Ranges.size()) {
148 ++GroupIndex;
149 if (GroupIndex >= TestRanges.GroupedRanges.size()) {
150 if (handleAllResults())
151 exit(status: 1); // error has occurred.
152 return;
153 }
154 Results.push_back(x: {});
155 }
156 }
157
158 const TestSelectionRangesInFile &TestRanges;
159 std::vector<std::vector<Expected<tooling::AtomicChanges>>> Results;
160};
161
162std::pair<unsigned, unsigned> getLineColumn(StringRef Filename,
163 unsigned Offset) {
164 ErrorOr<std::unique_ptr<MemoryBuffer>> ErrOrFile =
165 MemoryBuffer::getFile(Filename);
166 if (!ErrOrFile)
167 return {0, 0};
168 StringRef Source = ErrOrFile.get()->getBuffer();
169 Source = Source.take_front(N: Offset);
170 size_t LastLine = Source.find_last_of(Chars: "\r\n");
171 return {Source.count(C: '\n') + 1,
172 (LastLine == StringRef::npos ? Offset : Offset - LastLine) + 1};
173}
174
175} // end anonymous namespace
176
177bool TestRefactoringResultConsumer::handleAllResults() {
178 bool Failed = false;
179 for (const auto &Group : llvm::enumerate(First&: Results)) {
180 // All ranges in the group must produce the same result.
181 std::optional<tooling::AtomicChanges> CanonicalResult;
182 std::optional<std::string> CanonicalErrorMessage;
183 for (const auto &I : llvm::enumerate(First&: Group.value())) {
184 Expected<tooling::AtomicChanges> &Result = I.value();
185 std::string ErrorMessage;
186 bool HasResult = !!Result;
187 if (!HasResult) {
188 handleAllErrors(
189 E: Result.takeError(),
190 Handlers: [&](StringError &Err) { ErrorMessage = Err.getMessage(); },
191 Handlers: [&](DiagnosticError &Err) {
192 const PartialDiagnosticAt &Diag = Err.getDiagnostic();
193 llvm::SmallString<100> DiagText;
194 Diag.second.EmitToString(Diags&: getDiags(), Buf&: DiagText);
195 ErrorMessage = std::string(DiagText);
196 });
197 }
198 if (!CanonicalResult && !CanonicalErrorMessage) {
199 if (HasResult)
200 CanonicalResult = std::move(*Result);
201 else
202 CanonicalErrorMessage = std::move(ErrorMessage);
203 continue;
204 }
205
206 // Verify that this result corresponds to the canonical result.
207 if (CanonicalErrorMessage) {
208 // The error messages must match.
209 if (!HasResult && ErrorMessage == *CanonicalErrorMessage)
210 continue;
211 } else {
212 assert(CanonicalResult && "missing canonical result");
213 // The results must match.
214 if (HasResult && areChangesSame(LHS: *Result, RHS: *CanonicalResult))
215 continue;
216 }
217 Failed = true;
218 // Report the mismatch.
219 std::pair<unsigned, unsigned> LineColumn = getLineColumn(
220 Filename: TestRanges.Filename,
221 Offset: TestRanges.GroupedRanges[Group.index()].Ranges[I.index()].Begin);
222 llvm::errs()
223 << "error: unexpected refactoring result for range starting at "
224 << LineColumn.first << ':' << LineColumn.second << " in group '"
225 << TestRanges.GroupedRanges[Group.index()].Name << "':\n ";
226 if (HasResult)
227 llvm::errs() << "valid result";
228 else
229 llvm::errs() << "error '" << ErrorMessage << "'";
230 llvm::errs() << " does not match initial ";
231 if (CanonicalErrorMessage)
232 llvm::errs() << "error '" << *CanonicalErrorMessage << "'\n";
233 else
234 llvm::errs() << "valid result\n";
235 if (HasResult && !CanonicalErrorMessage) {
236 llvm::errs() << " Expected to Produce:\n";
237 dumpChanges(Changes: *CanonicalResult, OS&: llvm::errs());
238 llvm::errs() << " Produced:\n";
239 dumpChanges(Changes: *Result, OS&: llvm::errs());
240 }
241 }
242
243 // Dump the results:
244 const auto &TestGroup = TestRanges.GroupedRanges[Group.index()];
245 if (!CanonicalResult) {
246 llvm::outs() << TestGroup.Ranges.size() << " '" << TestGroup.Name
247 << "' results:\n";
248 llvm::outs() << *CanonicalErrorMessage << "\n";
249 } else {
250 llvm::outs() << TestGroup.Ranges.size() << " '" << TestGroup.Name
251 << "' results:\n";
252 if (printRewrittenSources(Changes: *CanonicalResult, OS&: llvm::outs()))
253 return true;
254 }
255 }
256 return Failed;
257}
258
259std::unique_ptr<ClangRefactorToolConsumerInterface>
260TestSelectionRangesInFile::createConsumer() const {
261 return std::make_unique<TestRefactoringResultConsumer>(args: *this);
262}
263
264/// Adds the \p ColumnOffset to file offset \p Offset, without going past a
265/// newline.
266static unsigned addColumnOffset(StringRef Source, unsigned Offset,
267 unsigned ColumnOffset) {
268 if (!ColumnOffset)
269 return Offset;
270 StringRef Substr = Source.drop_front(N: Offset).take_front(N: ColumnOffset);
271 size_t NewlinePos = Substr.find_first_of(Chars: "\r\n");
272 return Offset +
273 (NewlinePos == StringRef::npos ? ColumnOffset : (unsigned)NewlinePos);
274}
275
276static unsigned addEndLineOffsetAndEndColumn(StringRef Source, unsigned Offset,
277 unsigned LineNumberOffset,
278 unsigned Column) {
279 StringRef Line = Source.drop_front(N: Offset);
280 unsigned LineOffset = 0;
281 for (; LineNumberOffset != 0; --LineNumberOffset) {
282 size_t NewlinePos = Line.find_first_of(Chars: "\r\n");
283 // Line offset goes out of bounds.
284 if (NewlinePos == StringRef::npos)
285 break;
286 LineOffset += NewlinePos + 1;
287 Line = Line.drop_front(N: NewlinePos + 1);
288 }
289 // Source now points to the line at +lineOffset;
290 size_t LineStart = Source.find_last_of(Chars: "\r\n", /*From=*/Offset + LineOffset);
291 return addColumnOffset(
292 Source, Offset: LineStart == StringRef::npos ? 0 : LineStart + 1, ColumnOffset: Column - 1);
293}
294
295std::optional<TestSelectionRangesInFile>
296findTestSelectionRanges(StringRef Filename) {
297 ErrorOr<std::unique_ptr<MemoryBuffer>> ErrOrFile =
298 MemoryBuffer::getFile(Filename);
299 if (!ErrOrFile) {
300 llvm::errs() << "error: -selection=test:" << Filename
301 << " : could not open the given file";
302 return std::nullopt;
303 }
304 StringRef Source = ErrOrFile.get()->getBuffer();
305
306 // See the doc comment for this function for the explanation of this
307 // syntax.
308 static const Regex RangeRegex(
309 "range[[:blank:]]*([[:alpha:]_]*)?[[:blank:]]*=[[:"
310 "blank:]]*(\\+[[:digit:]]+)?[[:blank:]]*(->[[:blank:]"
311 "]*[\\+\\:[:digit:]]+)?");
312
313 std::map<std::string, SmallVector<TestSelectionRange, 8>> GroupedRanges;
314
315 LangOptions LangOpts;
316 LangOpts.CPlusPlus = 1;
317 LangOpts.CPlusPlus11 = 1;
318 Lexer Lex(SourceLocation::getFromRawEncoding(Encoding: 0), LangOpts, Source.begin(),
319 Source.begin(), Source.end());
320 Lex.SetCommentRetentionState(true);
321 Token Tok;
322 for (Lex.LexFromRawLexer(Result&: Tok); Tok.isNot(K: tok::eof);
323 Lex.LexFromRawLexer(Result&: Tok)) {
324 if (Tok.isNot(K: tok::comment))
325 continue;
326 StringRef Comment =
327 Source.substr(Start: Tok.getLocation().getRawEncoding(), N: Tok.getLength());
328 SmallVector<StringRef, 4> Matches;
329 // Try to detect mistyped 'range:' comments to ensure tests don't miss
330 // anything.
331 auto DetectMistypedCommand = [&]() -> bool {
332 if (Comment.contains_insensitive(Other: "range") && Comment.contains(Other: "=") &&
333 !Comment.contains_insensitive(Other: "run") && !Comment.contains(Other: "CHECK")) {
334 llvm::errs() << "error: suspicious comment '" << Comment
335 << "' that "
336 "resembles the range command found\n";
337 llvm::errs() << "note: please reword if this isn't a range command\n";
338 }
339 return false;
340 };
341 // Allow CHECK: comments to contain range= commands.
342 if (!RangeRegex.match(String: Comment, Matches: &Matches) || Comment.contains(Other: "CHECK")) {
343 if (DetectMistypedCommand())
344 return std::nullopt;
345 continue;
346 }
347 unsigned Offset = Tok.getEndLoc().getRawEncoding();
348 unsigned ColumnOffset = 0;
349 if (!Matches[2].empty()) {
350 // Don't forget to drop the '+'!
351 if (Matches[2].drop_front().getAsInteger(Radix: 10, Result&: ColumnOffset))
352 assert(false && "regex should have produced a number");
353 }
354 Offset = addColumnOffset(Source, Offset, ColumnOffset);
355 unsigned EndOffset;
356
357 if (!Matches[3].empty()) {
358 static const Regex EndLocRegex(
359 "->[[:blank:]]*(\\+[[:digit:]]+):([[:digit:]]+)");
360 SmallVector<StringRef, 4> EndLocMatches;
361 if (!EndLocRegex.match(String: Matches[3], Matches: &EndLocMatches)) {
362 if (DetectMistypedCommand())
363 return std::nullopt;
364 continue;
365 }
366 unsigned EndLineOffset = 0, EndColumn = 0;
367 if (EndLocMatches[1].drop_front().getAsInteger(Radix: 10, Result&: EndLineOffset) ||
368 EndLocMatches[2].getAsInteger(Radix: 10, Result&: EndColumn))
369 assert(false && "regex should have produced a number");
370 EndOffset = addEndLineOffsetAndEndColumn(Source, Offset, LineNumberOffset: EndLineOffset,
371 Column: EndColumn);
372 } else {
373 EndOffset = Offset;
374 }
375 TestSelectionRange Range = {.Begin: Offset, .End: EndOffset};
376 auto It = GroupedRanges.insert(x: std::make_pair(
377 x: Matches[1].str(), y: SmallVector<TestSelectionRange, 8>{Range}));
378 if (!It.second)
379 It.first->second.push_back(Elt: Range);
380 }
381 if (GroupedRanges.empty()) {
382 llvm::errs() << "error: -selection=test:" << Filename
383 << ": no 'range' commands";
384 return std::nullopt;
385 }
386
387 TestSelectionRangesInFile TestRanges = {.Filename: Filename.str(), .GroupedRanges: {}};
388 for (auto &Group : GroupedRanges)
389 TestRanges.GroupedRanges.push_back(x: {.Name: Group.first, .Ranges: std::move(Group.second)});
390 return std::move(TestRanges);
391}
392
393} // end namespace refactor
394} // end namespace clang
395