1 | //===- OffloadBundle.cpp - Utilities for offload bundles---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------===// |
8 | |
9 | #include "llvm/Object/OffloadBundle.h" |
10 | #include "llvm/BinaryFormat/Magic.h" |
11 | #include "llvm/IR/Module.h" |
12 | #include "llvm/IRReader/IRReader.h" |
13 | #include "llvm/MC/StringTableBuilder.h" |
14 | #include "llvm/Object/Archive.h" |
15 | #include "llvm/Object/Binary.h" |
16 | #include "llvm/Object/COFF.h" |
17 | #include "llvm/Object/ELFObjectFile.h" |
18 | #include "llvm/Object/Error.h" |
19 | #include "llvm/Object/IRObjectFile.h" |
20 | #include "llvm/Object/ObjectFile.h" |
21 | #include "llvm/Support/BinaryStreamReader.h" |
22 | #include "llvm/Support/SourceMgr.h" |
23 | #include "llvm/Support/Timer.h" |
24 | |
25 | using namespace llvm; |
26 | using namespace llvm::object; |
27 | |
28 | static llvm::TimerGroup |
29 | OffloadBundlerTimerGroup("Offload Bundler Timer Group" , |
30 | "Timer group for offload bundler" ); |
31 | |
32 | // Extract an Offload bundle (usually a Offload Bundle) from a fat_bin |
33 | // section |
34 | Error (MemoryBufferRef Contents, uint64_t SectionOffset, |
35 | StringRef FileName, |
36 | SmallVectorImpl<OffloadBundleFatBin> &Bundles) { |
37 | |
38 | size_t Offset = 0; |
39 | size_t NextbundleStart = 0; |
40 | |
41 | // There could be multiple offloading bundles stored at this section. |
42 | while (NextbundleStart != StringRef::npos) { |
43 | std::unique_ptr<MemoryBuffer> Buffer = |
44 | MemoryBuffer::getMemBuffer(InputData: Contents.getBuffer().drop_front(N: Offset), BufferName: "" , |
45 | /*RequiresNullTerminator=*/false); |
46 | |
47 | // Create the FatBinBindle object. This will also create the Bundle Entry |
48 | // list info. |
49 | auto FatBundleOrErr = |
50 | OffloadBundleFatBin::create(*Buffer, SectionOffset: SectionOffset + Offset, FileName); |
51 | if (!FatBundleOrErr) |
52 | return FatBundleOrErr.takeError(); |
53 | |
54 | // Add current Bundle to list. |
55 | Bundles.emplace_back(Args: std::move(**FatBundleOrErr)); |
56 | |
57 | // Find the next bundle by searching for the magic string |
58 | StringRef Str = Buffer->getBuffer(); |
59 | NextbundleStart = Str.find(Str: StringRef("__CLANG_OFFLOAD_BUNDLE__" ), From: 24); |
60 | |
61 | if (NextbundleStart != StringRef::npos) |
62 | Offset += NextbundleStart; |
63 | } |
64 | |
65 | return Error::success(); |
66 | } |
67 | |
68 | Error OffloadBundleFatBin::readEntries(StringRef Buffer, |
69 | uint64_t SectionOffset) { |
70 | uint64_t NumOfEntries = 0; |
71 | |
72 | BinaryStreamReader Reader(Buffer, llvm::endianness::little); |
73 | |
74 | // Read the Magic String first. |
75 | StringRef Magic; |
76 | if (auto EC = Reader.readFixedString(Dest&: Magic, Length: 24)) |
77 | return errorCodeToError(EC: object_error::parse_failed); |
78 | |
79 | // Read the number of Code Objects (Entries) in the current Bundle. |
80 | if (auto EC = Reader.readInteger(Dest&: NumOfEntries)) |
81 | return errorCodeToError(EC: object_error::parse_failed); |
82 | |
83 | NumberOfEntries = NumOfEntries; |
84 | |
85 | // For each Bundle Entry (code object) |
86 | for (uint64_t I = 0; I < NumOfEntries; I++) { |
87 | uint64_t EntrySize; |
88 | uint64_t EntryOffset; |
89 | uint64_t EntryIDSize; |
90 | StringRef EntryID; |
91 | |
92 | if (auto EC = Reader.readInteger(Dest&: EntryOffset)) |
93 | return errorCodeToError(EC: object_error::parse_failed); |
94 | |
95 | if (auto EC = Reader.readInteger(Dest&: EntrySize)) |
96 | return errorCodeToError(EC: object_error::parse_failed); |
97 | |
98 | if (auto EC = Reader.readInteger(Dest&: EntryIDSize)) |
99 | return errorCodeToError(EC: object_error::parse_failed); |
100 | |
101 | if (auto EC = Reader.readFixedString(Dest&: EntryID, Length: EntryIDSize)) |
102 | return errorCodeToError(EC: object_error::parse_failed); |
103 | |
104 | auto Entry = std::make_unique<OffloadBundleEntry>( |
105 | args: EntryOffset + SectionOffset, args&: EntrySize, args&: EntryIDSize, args&: EntryID); |
106 | |
107 | Entries.push_back(Elt: *Entry); |
108 | } |
109 | |
110 | return Error::success(); |
111 | } |
112 | |
113 | Expected<std::unique_ptr<OffloadBundleFatBin>> |
114 | OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset, |
115 | StringRef FileName) { |
116 | if (Buf.getBufferSize() < 24) |
117 | return errorCodeToError(EC: object_error::parse_failed); |
118 | |
119 | // Check for magic bytes. |
120 | if (identify_magic(magic: Buf.getBuffer()) != file_magic::offload_bundle) |
121 | return errorCodeToError(EC: object_error::parse_failed); |
122 | |
123 | OffloadBundleFatBin *TheBundle = new OffloadBundleFatBin(Buf, FileName); |
124 | |
125 | // Read the Bundle Entries |
126 | Error Err = TheBundle->readEntries(Buffer: Buf.getBuffer(), SectionOffset); |
127 | if (Err) |
128 | return errorCodeToError(EC: object_error::parse_failed); |
129 | |
130 | return std::unique_ptr<OffloadBundleFatBin>(TheBundle); |
131 | } |
132 | |
133 | Error OffloadBundleFatBin::(const ObjectFile &Source) { |
134 | // This will extract all entries in the Bundle |
135 | for (OffloadBundleEntry &Entry : Entries) { |
136 | |
137 | if (Entry.Size == 0) |
138 | continue; |
139 | |
140 | // create output file name. Which should be |
141 | // <fileName>-offset<Offset>-size<Size>.co" |
142 | std::string Str = getFileName().str() + "-offset" + itostr(X: Entry.Offset) + |
143 | "-size" + itostr(X: Entry.Size) + ".co" ; |
144 | if (Error Err = object::extractCodeObject(Source, Offset: Entry.Offset, Size: Entry.Size, |
145 | OutputFileName: StringRef(Str))) |
146 | return Err; |
147 | } |
148 | |
149 | return Error::success(); |
150 | } |
151 | |
152 | Error object::( |
153 | const ObjectFile &Obj, SmallVectorImpl<OffloadBundleFatBin> &Bundles) { |
154 | assert((Obj.isELF() || Obj.isCOFF()) && "Invalid file type" ); |
155 | |
156 | // Iterate through Sections until we find an offload_bundle section. |
157 | for (SectionRef Sec : Obj.sections()) { |
158 | Expected<StringRef> Buffer = Sec.getContents(); |
159 | if (!Buffer) |
160 | return Buffer.takeError(); |
161 | |
162 | // If it does not start with the reserved suffix, just skip this section. |
163 | if ((llvm::identify_magic(magic: *Buffer) == llvm::file_magic::offload_bundle) || |
164 | (llvm::identify_magic(magic: *Buffer) == |
165 | llvm::file_magic::offload_bundle_compressed)) { |
166 | |
167 | uint64_t SectionOffset = 0; |
168 | if (Obj.isELF()) { |
169 | SectionOffset = ELFSectionRef(Sec).getOffset(); |
170 | } else if (Obj.isCOFF()) // TODO: add COFF Support |
171 | return createStringError(EC: object_error::parse_failed, |
172 | S: "COFF object files not supported.\n" ); |
173 | |
174 | MemoryBufferRef Contents(*Buffer, Obj.getFileName()); |
175 | |
176 | if (llvm::identify_magic(magic: *Buffer) == |
177 | llvm::file_magic::offload_bundle_compressed) { |
178 | // Decompress the input if necessary. |
179 | Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr = |
180 | CompressedOffloadBundle::decompress(Input&: Contents, Verbose: false); |
181 | |
182 | if (!DecompressedBufferOrErr) |
183 | return createStringError( |
184 | EC: inconvertibleErrorCode(), |
185 | S: "Failed to decompress input: " + |
186 | llvm::toString(E: DecompressedBufferOrErr.takeError())); |
187 | |
188 | MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr; |
189 | if (Error Err = extractOffloadBundle(Contents: DecompressedInput, SectionOffset, |
190 | FileName: Obj.getFileName(), Bundles)) |
191 | return Err; |
192 | } else { |
193 | if (Error Err = extractOffloadBundle(Contents, SectionOffset, |
194 | FileName: Obj.getFileName(), Bundles)) |
195 | return Err; |
196 | } |
197 | } |
198 | } |
199 | return Error::success(); |
200 | } |
201 | |
202 | Error object::(const ObjectFile &Source, int64_t Offset, |
203 | int64_t Size, StringRef OutputFileName) { |
204 | Expected<std::unique_ptr<FileOutputBuffer>> BufferOrErr = |
205 | FileOutputBuffer::create(FilePath: OutputFileName, Size); |
206 | |
207 | if (!BufferOrErr) |
208 | return BufferOrErr.takeError(); |
209 | |
210 | Expected<MemoryBufferRef> InputBuffOrErr = Source.getMemoryBufferRef(); |
211 | if (Error Err = InputBuffOrErr.takeError()) |
212 | return Err; |
213 | |
214 | std::unique_ptr<FileOutputBuffer> Buf = std::move(*BufferOrErr); |
215 | std::copy(first: InputBuffOrErr->getBufferStart() + Offset, |
216 | last: InputBuffOrErr->getBufferStart() + Offset + Size, |
217 | result: Buf->getBufferStart()); |
218 | if (Error E = Buf->commit()) |
219 | return E; |
220 | |
221 | return Error::success(); |
222 | } |
223 | |
224 | // given a file name, offset, and size, extract data into a code object file, |
225 | // into file <SourceFile>-offset<Offset>-size<Size>.co |
226 | Error object::(StringRef URIstr) { |
227 | // create a URI object |
228 | Expected<std::unique_ptr<OffloadBundleURI>> UriOrErr( |
229 | OffloadBundleURI::createOffloadBundleURI(Str: URIstr, Type: FILE_URI)); |
230 | if (!UriOrErr) |
231 | return UriOrErr.takeError(); |
232 | |
233 | OffloadBundleURI &Uri = **UriOrErr; |
234 | std::string OutputFile = Uri.FileName.str(); |
235 | OutputFile += |
236 | "-offset" + itostr(X: Uri.Offset) + "-size" + itostr(X: Uri.Size) + ".co" ; |
237 | |
238 | // Create an ObjectFile object from uri.file_uri |
239 | auto ObjOrErr = ObjectFile::createObjectFile(ObjectPath: Uri.FileName); |
240 | if (!ObjOrErr) |
241 | return ObjOrErr.takeError(); |
242 | |
243 | auto Obj = ObjOrErr->getBinary(); |
244 | if (Error Err = |
245 | object::extractCodeObject(Source: *Obj, Offset: Uri.Offset, Size: Uri.Size, OutputFileName: OutputFile)) |
246 | return Err; |
247 | |
248 | return Error::success(); |
249 | } |
250 | |
251 | // Utility function to format numbers with commas |
252 | static std::string formatWithCommas(unsigned long long Value) { |
253 | std::string Num = std::to_string(val: Value); |
254 | int InsertPosition = Num.length() - 3; |
255 | while (InsertPosition > 0) { |
256 | Num.insert(pos: InsertPosition, s: "," ); |
257 | InsertPosition -= 3; |
258 | } |
259 | return Num; |
260 | } |
261 | |
262 | llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> |
263 | CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input, |
264 | bool Verbose) { |
265 | StringRef Blob = Input.getBuffer(); |
266 | |
267 | if (Blob.size() < V1HeaderSize) |
268 | return llvm::MemoryBuffer::getMemBufferCopy(InputData: Blob); |
269 | |
270 | if (llvm::identify_magic(magic: Blob) != |
271 | llvm::file_magic::offload_bundle_compressed) { |
272 | if (Verbose) |
273 | llvm::errs() << "Uncompressed bundle.\n" ; |
274 | return llvm::MemoryBuffer::getMemBufferCopy(InputData: Blob); |
275 | } |
276 | |
277 | size_t CurrentOffset = MagicSize; |
278 | |
279 | uint16_t ThisVersion; |
280 | memcpy(dest: &ThisVersion, src: Blob.data() + CurrentOffset, n: sizeof(uint16_t)); |
281 | CurrentOffset += VersionFieldSize; |
282 | |
283 | uint16_t CompressionMethod; |
284 | memcpy(dest: &CompressionMethod, src: Blob.data() + CurrentOffset, n: sizeof(uint16_t)); |
285 | CurrentOffset += MethodFieldSize; |
286 | |
287 | uint32_t TotalFileSize; |
288 | if (ThisVersion >= 2) { |
289 | if (Blob.size() < V2HeaderSize) |
290 | return createStringError(EC: inconvertibleErrorCode(), |
291 | S: "Compressed bundle header size too small" ); |
292 | memcpy(dest: &TotalFileSize, src: Blob.data() + CurrentOffset, n: sizeof(uint32_t)); |
293 | CurrentOffset += FileSizeFieldSize; |
294 | } |
295 | |
296 | uint32_t UncompressedSize; |
297 | memcpy(dest: &UncompressedSize, src: Blob.data() + CurrentOffset, n: sizeof(uint32_t)); |
298 | CurrentOffset += UncompressedSizeFieldSize; |
299 | |
300 | uint64_t StoredHash; |
301 | memcpy(dest: &StoredHash, src: Blob.data() + CurrentOffset, n: sizeof(uint64_t)); |
302 | CurrentOffset += HashFieldSize; |
303 | |
304 | llvm::compression::Format CompressionFormat; |
305 | if (CompressionMethod == |
306 | static_cast<uint16_t>(llvm::compression::Format::Zlib)) |
307 | CompressionFormat = llvm::compression::Format::Zlib; |
308 | else if (CompressionMethod == |
309 | static_cast<uint16_t>(llvm::compression::Format::Zstd)) |
310 | CompressionFormat = llvm::compression::Format::Zstd; |
311 | else |
312 | return createStringError(EC: inconvertibleErrorCode(), |
313 | S: "Unknown compressing method" ); |
314 | |
315 | llvm::Timer DecompressTimer("Decompression Timer" , "Decompression time" , |
316 | OffloadBundlerTimerGroup); |
317 | if (Verbose) |
318 | DecompressTimer.startTimer(); |
319 | |
320 | SmallVector<uint8_t, 0> DecompressedData; |
321 | StringRef CompressedData = Blob.substr(Start: CurrentOffset); |
322 | if (llvm::Error DecompressionError = llvm::compression::decompress( |
323 | F: CompressionFormat, Input: llvm::arrayRefFromStringRef(Input: CompressedData), |
324 | Output&: DecompressedData, UncompressedSize)) |
325 | return createStringError(EC: inconvertibleErrorCode(), |
326 | S: "Could not decompress embedded file contents: " + |
327 | llvm::toString(E: std::move(DecompressionError))); |
328 | |
329 | if (Verbose) { |
330 | DecompressTimer.stopTimer(); |
331 | |
332 | double DecompressionTimeSeconds = |
333 | DecompressTimer.getTotalTime().getWallTime(); |
334 | |
335 | // Recalculate MD5 hash for integrity check. |
336 | llvm::Timer HashRecalcTimer("Hash Recalculation Timer" , |
337 | "Hash recalculation time" , |
338 | OffloadBundlerTimerGroup); |
339 | HashRecalcTimer.startTimer(); |
340 | llvm::MD5 Hash; |
341 | llvm::MD5::MD5Result Result; |
342 | Hash.update(Data: llvm::ArrayRef<uint8_t>(DecompressedData)); |
343 | Hash.final(Result); |
344 | uint64_t RecalculatedHash = Result.low(); |
345 | HashRecalcTimer.stopTimer(); |
346 | bool HashMatch = (StoredHash == RecalculatedHash); |
347 | |
348 | double CompressionRate = |
349 | static_cast<double>(UncompressedSize) / CompressedData.size(); |
350 | double DecompressionSpeedMBs = |
351 | (UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds; |
352 | |
353 | llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n" ; |
354 | if (ThisVersion >= 2) |
355 | llvm::errs() << "Total file size (from header): " |
356 | << formatWithCommas(Value: TotalFileSize) << " bytes\n" ; |
357 | llvm::errs() << "Decompression method: " |
358 | << (CompressionFormat == llvm::compression::Format::Zlib |
359 | ? "zlib" |
360 | : "zstd" ) |
361 | << "\n" |
362 | << "Size before decompression: " |
363 | << formatWithCommas(Value: CompressedData.size()) << " bytes\n" |
364 | << "Size after decompression: " |
365 | << formatWithCommas(Value: UncompressedSize) << " bytes\n" |
366 | << "Compression rate: " |
367 | << llvm::format(Fmt: "%.2lf" , Vals: CompressionRate) << "\n" |
368 | << "Compression ratio: " |
369 | << llvm::format(Fmt: "%.2lf%%" , Vals: 100.0 / CompressionRate) << "\n" |
370 | << "Decompression speed: " |
371 | << llvm::format(Fmt: "%.2lf MB/s" , Vals: DecompressionSpeedMBs) << "\n" |
372 | << "Stored hash: " << llvm::format_hex(N: StoredHash, Width: 16) << "\n" |
373 | << "Recalculated hash: " |
374 | << llvm::format_hex(N: RecalculatedHash, Width: 16) << "\n" |
375 | << "Hashes match: " << (HashMatch ? "Yes" : "No" ) << "\n" ; |
376 | } |
377 | |
378 | return llvm::MemoryBuffer::getMemBufferCopy( |
379 | InputData: llvm::toStringRef(Input: DecompressedData)); |
380 | } |
381 | |
382 | llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> |
383 | CompressedOffloadBundle::compress(llvm::compression::Params P, |
384 | const llvm::MemoryBuffer &Input, |
385 | bool Verbose) { |
386 | if (!llvm::compression::zstd::isAvailable() && |
387 | !llvm::compression::zlib::isAvailable()) |
388 | return createStringError(EC: llvm::inconvertibleErrorCode(), |
389 | S: "Compression not supported" ); |
390 | |
391 | llvm::Timer HashTimer("Hash Calculation Timer" , "Hash calculation time" , |
392 | OffloadBundlerTimerGroup); |
393 | if (Verbose) |
394 | HashTimer.startTimer(); |
395 | llvm::MD5 Hash; |
396 | llvm::MD5::MD5Result Result; |
397 | Hash.update(Str: Input.getBuffer()); |
398 | Hash.final(Result); |
399 | uint64_t TruncatedHash = Result.low(); |
400 | if (Verbose) |
401 | HashTimer.stopTimer(); |
402 | |
403 | SmallVector<uint8_t, 0> CompressedBuffer; |
404 | auto BufferUint8 = llvm::ArrayRef<uint8_t>( |
405 | reinterpret_cast<const uint8_t *>(Input.getBuffer().data()), |
406 | Input.getBuffer().size()); |
407 | |
408 | llvm::Timer CompressTimer("Compression Timer" , "Compression time" , |
409 | OffloadBundlerTimerGroup); |
410 | if (Verbose) |
411 | CompressTimer.startTimer(); |
412 | llvm::compression::compress(P, Input: BufferUint8, Output&: CompressedBuffer); |
413 | if (Verbose) |
414 | CompressTimer.stopTimer(); |
415 | |
416 | uint16_t CompressionMethod = static_cast<uint16_t>(P.format); |
417 | uint32_t UncompressedSize = Input.getBuffer().size(); |
418 | uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) + |
419 | sizeof(Version) + sizeof(CompressionMethod) + |
420 | sizeof(UncompressedSize) + sizeof(TruncatedHash) + |
421 | CompressedBuffer.size(); |
422 | |
423 | SmallVector<char, 0> FinalBuffer; |
424 | llvm::raw_svector_ostream OS(FinalBuffer); |
425 | OS << MagicNumber; |
426 | OS.write(Ptr: reinterpret_cast<const char *>(&Version), Size: sizeof(Version)); |
427 | OS.write(Ptr: reinterpret_cast<const char *>(&CompressionMethod), |
428 | Size: sizeof(CompressionMethod)); |
429 | OS.write(Ptr: reinterpret_cast<const char *>(&TotalFileSize), |
430 | Size: sizeof(TotalFileSize)); |
431 | OS.write(Ptr: reinterpret_cast<const char *>(&UncompressedSize), |
432 | Size: sizeof(UncompressedSize)); |
433 | OS.write(Ptr: reinterpret_cast<const char *>(&TruncatedHash), |
434 | Size: sizeof(TruncatedHash)); |
435 | OS.write(Ptr: reinterpret_cast<const char *>(CompressedBuffer.data()), |
436 | Size: CompressedBuffer.size()); |
437 | |
438 | if (Verbose) { |
439 | auto MethodUsed = |
440 | P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib" ; |
441 | double CompressionRate = |
442 | static_cast<double>(UncompressedSize) / CompressedBuffer.size(); |
443 | double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime(); |
444 | double CompressionSpeedMBs = |
445 | (UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds; |
446 | |
447 | llvm::errs() << "Compressed bundle format version: " << Version << "\n" |
448 | << "Total file size (including headers): " |
449 | << formatWithCommas(Value: TotalFileSize) << " bytes\n" |
450 | << "Compression method used: " << MethodUsed << "\n" |
451 | << "Compression level: " << P.level << "\n" |
452 | << "Binary size before compression: " |
453 | << formatWithCommas(Value: UncompressedSize) << " bytes\n" |
454 | << "Binary size after compression: " |
455 | << formatWithCommas(Value: CompressedBuffer.size()) << " bytes\n" |
456 | << "Compression rate: " |
457 | << llvm::format(Fmt: "%.2lf" , Vals: CompressionRate) << "\n" |
458 | << "Compression ratio: " |
459 | << llvm::format(Fmt: "%.2lf%%" , Vals: 100.0 / CompressionRate) << "\n" |
460 | << "Compression speed: " |
461 | << llvm::format(Fmt: "%.2lf MB/s" , Vals: CompressionSpeedMBs) << "\n" |
462 | << "Truncated MD5 hash: " |
463 | << llvm::format_hex(N: TruncatedHash, Width: 16) << "\n" ; |
464 | } |
465 | return llvm::MemoryBuffer::getMemBufferCopy( |
466 | InputData: llvm::StringRef(FinalBuffer.data(), FinalBuffer.size())); |
467 | } |
468 | |