1//===--- SourceExtraction.cpp - Clang refactoring library -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
10#include "clang/AST/Stmt.h"
11#include "clang/AST/StmtCXX.h"
12#include "clang/AST/StmtObjC.h"
13#include "clang/Basic/SourceManager.h"
14#include "clang/Lex/Lexer.h"
15#include <optional>
16
17using namespace clang;
18
19namespace {
20
21/// Returns true if the token at the given location is a semicolon.
22bool isSemicolonAtLocation(SourceLocation TokenLoc, const SourceManager &SM,
23 const LangOptions &LangOpts) {
24 return Lexer::getSourceText(
25 Range: CharSourceRange::getTokenRange(B: TokenLoc, E: TokenLoc), SM,
26 LangOpts) == ";";
27}
28
29/// Returns true if there should be a semicolon after the given statement.
30bool isSemicolonRequiredAfter(const Stmt *S) {
31 if (isa<CompoundStmt>(Val: S))
32 return false;
33 if (const auto *If = dyn_cast<IfStmt>(Val: S))
34 return isSemicolonRequiredAfter(S: If->getElse() ? If->getElse()
35 : If->getThen());
36 if (const auto *While = dyn_cast<WhileStmt>(Val: S))
37 return isSemicolonRequiredAfter(S: While->getBody());
38 if (const auto *For = dyn_cast<ForStmt>(Val: S))
39 return isSemicolonRequiredAfter(S: For->getBody());
40 if (const auto *CXXFor = dyn_cast<CXXForRangeStmt>(Val: S))
41 return isSemicolonRequiredAfter(S: CXXFor->getBody());
42 if (const auto *ObjCFor = dyn_cast<ObjCForCollectionStmt>(Val: S))
43 return isSemicolonRequiredAfter(S: ObjCFor->getBody());
44 if(const auto *Switch = dyn_cast<SwitchStmt>(Val: S))
45 return isSemicolonRequiredAfter(S: Switch->getBody());
46 if(const auto *Case = dyn_cast<SwitchCase>(Val: S))
47 return isSemicolonRequiredAfter(S: Case->getSubStmt());
48 switch (S->getStmtClass()) {
49 case Stmt::DeclStmtClass:
50 case Stmt::CXXTryStmtClass:
51 case Stmt::ObjCAtSynchronizedStmtClass:
52 case Stmt::ObjCAutoreleasePoolStmtClass:
53 case Stmt::ObjCAtTryStmtClass:
54 return false;
55 default:
56 return true;
57 }
58}
59
60/// Returns true if the two source locations are on the same line.
61bool areOnSameLine(SourceLocation Loc1, SourceLocation Loc2,
62 const SourceManager &SM) {
63 return !Loc1.isMacroID() && !Loc2.isMacroID() &&
64 SM.getSpellingLineNumber(Loc: Loc1) == SM.getSpellingLineNumber(Loc: Loc2);
65}
66
67} // end anonymous namespace
68
69namespace clang {
70namespace tooling {
71
72ExtractionSemicolonPolicy
73ExtractionSemicolonPolicy::compute(const Stmt *S, SourceRange &ExtractedRange,
74 const SourceManager &SM,
75 const LangOptions &LangOpts) {
76 auto neededInExtractedFunction = []() {
77 return ExtractionSemicolonPolicy(true, false);
78 };
79 auto neededInOriginalFunction = []() {
80 return ExtractionSemicolonPolicy(false, true);
81 };
82
83 /// The extracted expression should be terminated with a ';'. The call to
84 /// the extracted function will replace this expression, so it won't need
85 /// a terminating ';'.
86 if (isa<Expr>(Val: S))
87 return neededInExtractedFunction();
88
89 /// Some statements don't need to be terminated with ';'. The call to the
90 /// extracted function will be a standalone statement, so it should be
91 /// terminated with a ';'.
92 bool NeedsSemi = isSemicolonRequiredAfter(S);
93 if (!NeedsSemi)
94 return neededInOriginalFunction();
95
96 /// Some statements might end at ';'. The extraction will move that ';', so
97 /// the call to the extracted function should be terminated with a ';'.
98 SourceLocation End = ExtractedRange.getEnd();
99 if (isSemicolonAtLocation(TokenLoc: End, SM, LangOpts))
100 return neededInOriginalFunction();
101
102 /// Other statements should generally have a trailing ';'. We can try to find
103 /// it and move it together it with the extracted code.
104 std::optional<Token> NextToken = Lexer::findNextToken(Loc: End, SM, LangOpts);
105 if (NextToken && NextToken->is(K: tok::semi) &&
106 areOnSameLine(Loc1: NextToken->getLocation(), Loc2: End, SM)) {
107 ExtractedRange.setEnd(NextToken->getLocation());
108 return neededInOriginalFunction();
109 }
110
111 /// Otherwise insert semicolons in both places.
112 return ExtractionSemicolonPolicy(true, true);
113}
114
115} // end namespace tooling
116} // end namespace clang
117