1#include "llvm/ProfileData/DataAccessProf.h"
2#include "llvm/ADT/STLExtras.h"
3#include "llvm/ProfileData/InstrProf.h"
4#include "llvm/Support/Compression.h"
5#include "llvm/Support/Endian.h"
6#include "llvm/Support/Errc.h"
7#include "llvm/Support/Error.h"
8#include "llvm/Support/StringSaver.h"
9#include "llvm/Support/raw_ostream.h"
10
11namespace llvm {
12namespace memprof {
13
14// If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise,
15// creates an owned copy of `Str`, adds a map entry for it and returns the
16// iterator.
17static std::pair<StringRef, uint64_t>
18saveStringToMap(DataAccessProfData::StringToIndexMap &Map,
19 llvm::UniqueStringSaver &Saver, StringRef Str) {
20 auto [Iter, Inserted] = Map.try_emplace(Key: Saver.save(S: Str), Args: Map.size());
21 return *Iter;
22}
23
24// Returns the canonical name or error.
25static Expected<StringRef> getCanonicalName(StringRef Name) {
26 if (Name.empty())
27 return make_error<StringError>(Args: "Empty symbol name",
28 Args: llvm::errc::invalid_argument);
29 return InstrProfSymtab::getCanonicalName(PGOName: Name);
30}
31
32std::optional<DataAccessProfRecord>
33DataAccessProfData::getProfileRecord(const SymbolHandleRef SymbolID) const {
34 auto Key = SymbolID;
35 if (std::holds_alternative<StringRef>(v: SymbolID)) {
36 auto NameOrErr = getCanonicalName(Name: std::get<StringRef>(v: SymbolID));
37 // If name canonicalization fails, suppress the error inside.
38 if (!NameOrErr) {
39 assert(
40 std::get<StringRef>(SymbolID).empty() &&
41 "Name canonicalization only fails when stringified string is empty.");
42 return std::nullopt;
43 }
44 Key = *NameOrErr;
45 }
46
47 auto It = Records.find(Key);
48 if (It != Records.end()) {
49 return DataAccessProfRecord(Key, It->second.AccessCount,
50 It->second.Locations);
51 }
52
53 return std::nullopt;
54}
55
56bool DataAccessProfData::isKnownColdSymbol(const SymbolHandleRef SymID) const {
57 if (std::holds_alternative<uint64_t>(v: SymID))
58 return KnownColdHashes.contains(key: std::get<uint64_t>(v: SymID));
59 return KnownColdSymbols.contains(key: std::get<StringRef>(v: SymID));
60}
61
62Error DataAccessProfData::setDataAccessProfile(SymbolHandleRef Symbol,
63 uint64_t AccessCount) {
64 uint64_t RecordID = -1;
65 const bool IsStringLiteral = std::holds_alternative<uint64_t>(v: Symbol);
66 SymbolHandleRef Key;
67 if (IsStringLiteral) {
68 RecordID = std::get<uint64_t>(v&: Symbol);
69 Key = RecordID;
70 } else {
71 auto CanonicalName = getCanonicalName(Name: std::get<StringRef>(v&: Symbol));
72 if (!CanonicalName)
73 return CanonicalName.takeError();
74 std::tie(args&: Key, args&: RecordID) =
75 saveStringToMap(Map&: StrToIndexMap, Saver, Str: *CanonicalName);
76 }
77
78 auto [Iter, Inserted] =
79 Records.try_emplace(Key, Args&: RecordID, Args&: AccessCount, Args: IsStringLiteral);
80 if (!Inserted)
81 return make_error<StringError>(Args: "Duplicate symbol or string literal added. "
82 "User of DataAccessProfData should "
83 "aggregate count for the same symbol. ",
84 Args: llvm::errc::invalid_argument);
85
86 return Error::success();
87}
88
89Error DataAccessProfData::setDataAccessProfile(
90 SymbolHandleRef SymbolID, uint64_t AccessCount,
91 ArrayRef<SourceLocation> Locations) {
92 if (Error E = setDataAccessProfile(Symbol: SymbolID, AccessCount))
93 return E;
94
95 auto &Record = Records.back().second;
96 for (const auto &Location : Locations)
97 Record.Locations.push_back(
98 Elt: {saveStringToMap(Map&: StrToIndexMap, Saver, Str: Location.FileName).first,
99 Location.Line});
100
101 return Error::success();
102}
103
104Error DataAccessProfData::addKnownSymbolWithoutSamples(
105 SymbolHandleRef SymbolID) {
106 if (std::holds_alternative<uint64_t>(v: SymbolID)) {
107 KnownColdHashes.insert(X: std::get<uint64_t>(v&: SymbolID));
108 return Error::success();
109 }
110 auto CanonicalName = getCanonicalName(Name: std::get<StringRef>(v&: SymbolID));
111 if (!CanonicalName)
112 return CanonicalName.takeError();
113 KnownColdSymbols.insert(
114 X: saveStringToMap(Map&: StrToIndexMap, Saver, Str: *CanonicalName).first);
115 return Error::success();
116}
117
118Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
119 uint64_t NumSampledSymbols =
120 support::endian::readNext<uint64_t, llvm::endianness::little>(memory&: Ptr);
121 uint64_t NumColdKnownSymbols =
122 support::endian::readNext<uint64_t, llvm::endianness::little>(memory&: Ptr);
123 if (Error E = deserializeSymbolsAndFilenames(Ptr, NumSampledSymbols,
124 NumColdKnownSymbols))
125 return E;
126
127 uint64_t Num =
128 support::endian::readNext<uint64_t, llvm::endianness::little>(memory&: Ptr);
129 for (uint64_t I = 0; I < Num; ++I)
130 KnownColdHashes.insert(
131 X: support::endian::readNext<uint64_t, llvm::endianness::little>(memory&: Ptr));
132
133 return deserializeRecords(Ptr);
134}
135
136Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
137 OS.write(V: StrToIndexMap.size());
138 OS.write(V: KnownColdSymbols.size());
139
140 std::vector<std::string> Strs;
141 Strs.reserve(n: StrToIndexMap.size() + KnownColdSymbols.size());
142 for (const auto &Str : StrToIndexMap)
143 Strs.push_back(x: Str.first.str());
144 for (const auto &Str : KnownColdSymbols)
145 Strs.push_back(x: Str.str());
146
147 std::string CompressedStrings;
148 if (!Strs.empty())
149 if (Error E = collectGlobalObjectNameStrings(
150 NameStrs: Strs, doCompression: compression::zlib::isAvailable(), Result&: CompressedStrings))
151 return E;
152 const uint64_t CompressedStringLen = CompressedStrings.length();
153 // Record the length of compressed string.
154 OS.write(V: CompressedStringLen);
155 // Write the chars in compressed strings.
156 for (char C : CompressedStrings)
157 OS.writeByte(V: static_cast<uint8_t>(C));
158 // Pad up to a multiple of 8.
159 // InstrProfReader could read bytes according to 'CompressedStringLen'.
160 const uint64_t PaddedLength = alignTo(Value: CompressedStringLen, Align: 8);
161 for (uint64_t K = CompressedStringLen; K < PaddedLength; K++)
162 OS.writeByte(V: 0);
163 return Error::success();
164}
165
166uint64_t
167DataAccessProfData::getEncodedIndex(const SymbolHandleRef SymbolID) const {
168 if (std::holds_alternative<uint64_t>(v: SymbolID))
169 return std::get<uint64_t>(v: SymbolID);
170
171 auto Iter = StrToIndexMap.find(Key: std::get<StringRef>(v: SymbolID));
172 assert(Iter != StrToIndexMap.end() &&
173 "String literals not found in StrToIndexMap");
174 return Iter->second;
175}
176
177Error DataAccessProfData::serialize(ProfOStream &OS) const {
178 if (Error E = serializeSymbolsAndFilenames(OS))
179 return E;
180 OS.write(V: KnownColdHashes.size());
181 for (const auto &Hash : KnownColdHashes)
182 OS.write(V: Hash);
183 OS.write(V: (uint64_t)(Records.size()));
184 for (const auto &[Key, Rec] : Records) {
185 OS.write(V: getEncodedIndex(SymbolID: Rec.SymbolID));
186 OS.writeByte(V: Rec.IsStringLiteral);
187 OS.write(V: Rec.AccessCount);
188 OS.write(V: Rec.Locations.size());
189 for (const auto &Loc : Rec.Locations) {
190 OS.write(V: getEncodedIndex(SymbolID: Loc.FileName));
191 OS.write32(V: Loc.Line);
192 }
193 }
194 return Error::success();
195}
196
197Error DataAccessProfData::deserializeSymbolsAndFilenames(
198 const unsigned char *&Ptr, const uint64_t NumSampledSymbols,
199 const uint64_t NumColdKnownSymbols) {
200 uint64_t Len =
201 support::endian::readNext<uint64_t, llvm::endianness::little>(memory&: Ptr);
202
203 // The first NumSampledSymbols strings are symbols with samples, and next
204 // NumColdKnownSymbols strings are known cold symbols.
205 uint64_t StringCnt = 0;
206 std::function<Error(StringRef)> addName = [&](StringRef Name) {
207 if (StringCnt < NumSampledSymbols)
208 saveStringToMap(Map&: StrToIndexMap, Saver, Str: Name);
209 else
210 KnownColdSymbols.insert(X: Saver.save(S: Name));
211 ++StringCnt;
212 return Error::success();
213 };
214 if (Error E =
215 readAndDecodeStrings(NameStrings: StringRef((const char *)Ptr, Len), NameCallback: addName))
216 return E;
217
218 Ptr += alignTo(Value: Len, Align: 8);
219 return Error::success();
220}
221
222Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
223 SmallVector<StringRef> Strings =
224 llvm::to_vector(Range: llvm::make_first_range(c: getStrToIndexMapRef()));
225
226 uint64_t NumRecords =
227 support::endian::readNext<uint64_t, llvm::endianness::little>(memory&: Ptr);
228
229 for (uint64_t I = 0; I < NumRecords; ++I) {
230 uint64_t ID =
231 support::endian::readNext<uint64_t, llvm::endianness::little>(memory&: Ptr);
232
233 bool IsStringLiteral =
234 support::endian::readNext<uint8_t, llvm::endianness::little>(memory&: Ptr);
235
236 uint64_t AccessCount =
237 support::endian::readNext<uint64_t, llvm::endianness::little>(memory&: Ptr);
238
239 SymbolHandleRef SymbolID;
240 if (IsStringLiteral)
241 SymbolID = ID;
242 else
243 SymbolID = Strings[ID];
244 if (Error E = setDataAccessProfile(Symbol: SymbolID, AccessCount))
245 return E;
246
247 auto &Record = Records.back().second;
248
249 uint64_t NumLocations =
250 support::endian::readNext<uint64_t, llvm::endianness::little>(memory&: Ptr);
251
252 Record.Locations.reserve(N: NumLocations);
253 for (uint64_t J = 0; J < NumLocations; ++J) {
254 uint64_t FileNameIndex =
255 support::endian::readNext<uint64_t, llvm::endianness::little>(memory&: Ptr);
256 uint32_t Line =
257 support::endian::readNext<uint32_t, llvm::endianness::little>(memory&: Ptr);
258 Record.Locations.push_back(Elt: {Strings[FileNameIndex], Line});
259 }
260 }
261 return Error::success();
262}
263} // namespace memprof
264} // namespace llvm
265