1//===--- USRFinder.cpp - Clang refactoring library ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file Implements a recursive AST visitor that finds the USR of a symbol at a
10/// point.
11///
12//===----------------------------------------------------------------------===//
13
14#include "clang/Tooling/Refactoring/Rename/USRFinder.h"
15#include "clang/AST/ASTContext.h"
16#include "clang/AST/RecursiveASTVisitor.h"
17#include "clang/Basic/SourceManager.h"
18#include "clang/Index/USRGeneration.h"
19#include "clang/Lex/Lexer.h"
20#include "clang/Tooling/Refactoring/RecursiveSymbolVisitor.h"
21
22using namespace llvm;
23
24namespace clang {
25namespace tooling {
26
27namespace {
28
29/// Recursively visits each AST node to find the symbol underneath the cursor.
30class NamedDeclOccurrenceFindingVisitor
31 : public RecursiveSymbolVisitor<NamedDeclOccurrenceFindingVisitor> {
32public:
33 // Finds the NamedDecl at a point in the source.
34 // \param Point the location in the source to search for the NamedDecl.
35 explicit NamedDeclOccurrenceFindingVisitor(const SourceLocation Point,
36 const ASTContext &Context)
37 : RecursiveSymbolVisitor(Context.getSourceManager(),
38 Context.getLangOpts()),
39 Point(Point), Context(Context) {}
40
41 bool visitSymbolOccurrence(const NamedDecl *ND,
42 ArrayRef<SourceRange> NameRanges) {
43 if (!ND)
44 return true;
45 for (const auto &Range : NameRanges) {
46 SourceLocation Start = Range.getBegin();
47 SourceLocation End = Range.getEnd();
48 if (!Start.isValid() || !Start.isFileID() || !End.isValid() ||
49 !End.isFileID() || !isPointWithin(Start, End))
50 return true;
51 }
52 Result = ND;
53 return false;
54 }
55
56 const NamedDecl *getNamedDecl() const { return Result; }
57
58private:
59 // Determines if the Point is within Start and End.
60 bool isPointWithin(const SourceLocation Start, const SourceLocation End) {
61 // FIXME: Add tests for Point == End.
62 return Point == Start || Point == End ||
63 (Context.getSourceManager().isBeforeInTranslationUnit(LHS: Start,
64 RHS: Point) &&
65 Context.getSourceManager().isBeforeInTranslationUnit(LHS: Point, RHS: End));
66 }
67
68 const NamedDecl *Result = nullptr;
69 const SourceLocation Point; // The location to find the NamedDecl.
70 const ASTContext &Context;
71};
72
73} // end anonymous namespace
74
75const NamedDecl *getNamedDeclAt(const ASTContext &Context,
76 const SourceLocation Point) {
77 const SourceManager &SM = Context.getSourceManager();
78 NamedDeclOccurrenceFindingVisitor Visitor(Point, Context);
79
80 // Try to be clever about pruning down the number of top-level declarations we
81 // see. If both start and end is either before or after the point we're
82 // looking for the point cannot be inside of this decl. Don't even look at it.
83 for (auto *CurrDecl : Context.getTranslationUnitDecl()->decls()) {
84 SourceLocation StartLoc = CurrDecl->getBeginLoc();
85 SourceLocation EndLoc = CurrDecl->getEndLoc();
86 if (StartLoc.isValid() && EndLoc.isValid() &&
87 SM.isBeforeInTranslationUnit(LHS: StartLoc, RHS: Point) !=
88 SM.isBeforeInTranslationUnit(LHS: EndLoc, RHS: Point))
89 Visitor.TraverseDecl(D: CurrDecl);
90 }
91
92 return Visitor.getNamedDecl();
93}
94
95namespace {
96
97/// Recursively visits each NamedDecl node to find the declaration with a
98/// specific name.
99class NamedDeclFindingVisitor
100 : public RecursiveASTVisitor<NamedDeclFindingVisitor> {
101public:
102 explicit NamedDeclFindingVisitor(StringRef Name) : Name(Name) {}
103
104 // We don't have to traverse the uses to find some declaration with a
105 // specific name, so just visit the named declarations.
106 bool VisitNamedDecl(const NamedDecl *ND) {
107 if (!ND)
108 return true;
109 // Fully qualified name is used to find the declaration.
110 if (Name != ND->getQualifiedNameAsString() &&
111 Name != "::" + ND->getQualifiedNameAsString())
112 return true;
113 Result = ND;
114 return false;
115 }
116
117 const NamedDecl *getNamedDecl() const { return Result; }
118
119private:
120 const NamedDecl *Result = nullptr;
121 StringRef Name;
122};
123
124} // end anonymous namespace
125
126const NamedDecl *getNamedDeclFor(const ASTContext &Context,
127 const std::string &Name) {
128 NamedDeclFindingVisitor Visitor(Name);
129 Visitor.TraverseDecl(D: Context.getTranslationUnitDecl());
130 return Visitor.getNamedDecl();
131}
132
133std::string getUSRForDecl(const Decl *Decl) {
134 llvm::SmallString<128> Buff;
135
136 // FIXME: Add test for the nullptr case.
137 if (Decl == nullptr || index::generateUSRForDecl(D: Decl, Buf&: Buff))
138 return "";
139
140 return std::string(Buff);
141}
142
143} // end namespace tooling
144} // end namespace clang
145