1//===- OffloadBundle.cpp - Utilities for offload bundles---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------===//
8
9#include "llvm/Object/OffloadBundle.h"
10#include "llvm/BinaryFormat/Magic.h"
11#include "llvm/IR/Module.h"
12#include "llvm/IRReader/IRReader.h"
13#include "llvm/MC/StringTableBuilder.h"
14#include "llvm/Object/Archive.h"
15#include "llvm/Object/Binary.h"
16#include "llvm/Object/COFF.h"
17#include "llvm/Object/ELFObjectFile.h"
18#include "llvm/Object/Error.h"
19#include "llvm/Object/IRObjectFile.h"
20#include "llvm/Object/ObjectFile.h"
21#include "llvm/Support/BinaryStreamReader.h"
22#include "llvm/Support/SourceMgr.h"
23#include "llvm/Support/Timer.h"
24
25using namespace llvm;
26using namespace llvm::object;
27
28static llvm::TimerGroup
29 OffloadBundlerTimerGroup("Offload Bundler Timer Group",
30 "Timer group for offload bundler");
31
32// Extract an Offload bundle (usually a Offload Bundle) from a fat_bin
33// section
34Error extractOffloadBundle(MemoryBufferRef Contents, uint64_t SectionOffset,
35 StringRef FileName,
36 SmallVectorImpl<OffloadBundleFatBin> &Bundles) {
37
38 size_t Offset = 0;
39 size_t NextbundleStart = 0;
40
41 // There could be multiple offloading bundles stored at this section.
42 while (NextbundleStart != StringRef::npos) {
43 std::unique_ptr<MemoryBuffer> Buffer =
44 MemoryBuffer::getMemBuffer(InputData: Contents.getBuffer().drop_front(N: Offset), BufferName: "",
45 /*RequiresNullTerminator=*/false);
46
47 // Create the FatBinBindle object. This will also create the Bundle Entry
48 // list info.
49 auto FatBundleOrErr =
50 OffloadBundleFatBin::create(*Buffer, SectionOffset: SectionOffset + Offset, FileName);
51 if (!FatBundleOrErr)
52 return FatBundleOrErr.takeError();
53
54 // Add current Bundle to list.
55 Bundles.emplace_back(Args: std::move(**FatBundleOrErr));
56
57 // Find the next bundle by searching for the magic string
58 StringRef Str = Buffer->getBuffer();
59 NextbundleStart = Str.find(Str: StringRef("__CLANG_OFFLOAD_BUNDLE__"), From: 24);
60
61 if (NextbundleStart != StringRef::npos)
62 Offset += NextbundleStart;
63 }
64
65 return Error::success();
66}
67
68Error OffloadBundleFatBin::readEntries(StringRef Buffer,
69 uint64_t SectionOffset) {
70 uint64_t NumOfEntries = 0;
71
72 BinaryStreamReader Reader(Buffer, llvm::endianness::little);
73
74 // Read the Magic String first.
75 StringRef Magic;
76 if (auto EC = Reader.readFixedString(Dest&: Magic, Length: 24))
77 return errorCodeToError(EC: object_error::parse_failed);
78
79 // Read the number of Code Objects (Entries) in the current Bundle.
80 if (auto EC = Reader.readInteger(Dest&: NumOfEntries))
81 return errorCodeToError(EC: object_error::parse_failed);
82
83 NumberOfEntries = NumOfEntries;
84
85 // For each Bundle Entry (code object)
86 for (uint64_t I = 0; I < NumOfEntries; I++) {
87 uint64_t EntrySize;
88 uint64_t EntryOffset;
89 uint64_t EntryIDSize;
90 StringRef EntryID;
91
92 if (auto EC = Reader.readInteger(Dest&: EntryOffset))
93 return errorCodeToError(EC: object_error::parse_failed);
94
95 if (auto EC = Reader.readInteger(Dest&: EntrySize))
96 return errorCodeToError(EC: object_error::parse_failed);
97
98 if (auto EC = Reader.readInteger(Dest&: EntryIDSize))
99 return errorCodeToError(EC: object_error::parse_failed);
100
101 if (auto EC = Reader.readFixedString(Dest&: EntryID, Length: EntryIDSize))
102 return errorCodeToError(EC: object_error::parse_failed);
103
104 auto Entry = std::make_unique<OffloadBundleEntry>(
105 args: EntryOffset + SectionOffset, args&: EntrySize, args&: EntryIDSize, args&: EntryID);
106
107 Entries.push_back(Elt: *Entry);
108 }
109
110 return Error::success();
111}
112
113Expected<std::unique_ptr<OffloadBundleFatBin>>
114OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset,
115 StringRef FileName) {
116 if (Buf.getBufferSize() < 24)
117 return errorCodeToError(EC: object_error::parse_failed);
118
119 // Check for magic bytes.
120 if (identify_magic(magic: Buf.getBuffer()) != file_magic::offload_bundle)
121 return errorCodeToError(EC: object_error::parse_failed);
122
123 OffloadBundleFatBin *TheBundle = new OffloadBundleFatBin(Buf, FileName);
124
125 // Read the Bundle Entries
126 Error Err = TheBundle->readEntries(Buffer: Buf.getBuffer(), SectionOffset);
127 if (Err)
128 return errorCodeToError(EC: object_error::parse_failed);
129
130 return std::unique_ptr<OffloadBundleFatBin>(TheBundle);
131}
132
133Error OffloadBundleFatBin::extractBundle(const ObjectFile &Source) {
134 // This will extract all entries in the Bundle
135 for (OffloadBundleEntry &Entry : Entries) {
136
137 if (Entry.Size == 0)
138 continue;
139
140 // create output file name. Which should be
141 // <fileName>-offset<Offset>-size<Size>.co"
142 std::string Str = getFileName().str() + "-offset" + itostr(X: Entry.Offset) +
143 "-size" + itostr(X: Entry.Size) + ".co";
144 if (Error Err = object::extractCodeObject(Source, Offset: Entry.Offset, Size: Entry.Size,
145 OutputFileName: StringRef(Str)))
146 return Err;
147 }
148
149 return Error::success();
150}
151
152Error object::extractOffloadBundleFatBinary(
153 const ObjectFile &Obj, SmallVectorImpl<OffloadBundleFatBin> &Bundles) {
154 assert((Obj.isELF() || Obj.isCOFF()) && "Invalid file type");
155
156 // Iterate through Sections until we find an offload_bundle section.
157 for (SectionRef Sec : Obj.sections()) {
158 Expected<StringRef> Buffer = Sec.getContents();
159 if (!Buffer)
160 return Buffer.takeError();
161
162 // If it does not start with the reserved suffix, just skip this section.
163 if ((llvm::identify_magic(magic: *Buffer) == llvm::file_magic::offload_bundle) ||
164 (llvm::identify_magic(magic: *Buffer) ==
165 llvm::file_magic::offload_bundle_compressed)) {
166
167 uint64_t SectionOffset = 0;
168 if (Obj.isELF()) {
169 SectionOffset = ELFSectionRef(Sec).getOffset();
170 } else if (Obj.isCOFF()) // TODO: add COFF Support
171 return createStringError(EC: object_error::parse_failed,
172 S: "COFF object files not supported.\n");
173
174 MemoryBufferRef Contents(*Buffer, Obj.getFileName());
175
176 if (llvm::identify_magic(magic: *Buffer) ==
177 llvm::file_magic::offload_bundle_compressed) {
178 // Decompress the input if necessary.
179 Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
180 CompressedOffloadBundle::decompress(Input&: Contents, Verbose: false);
181
182 if (!DecompressedBufferOrErr)
183 return createStringError(
184 EC: inconvertibleErrorCode(),
185 S: "Failed to decompress input: " +
186 llvm::toString(E: DecompressedBufferOrErr.takeError()));
187
188 MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr;
189 if (Error Err = extractOffloadBundle(Contents: DecompressedInput, SectionOffset,
190 FileName: Obj.getFileName(), Bundles))
191 return Err;
192 } else {
193 if (Error Err = extractOffloadBundle(Contents, SectionOffset,
194 FileName: Obj.getFileName(), Bundles))
195 return Err;
196 }
197 }
198 }
199 return Error::success();
200}
201
202Error object::extractCodeObject(const ObjectFile &Source, int64_t Offset,
203 int64_t Size, StringRef OutputFileName) {
204 Expected<std::unique_ptr<FileOutputBuffer>> BufferOrErr =
205 FileOutputBuffer::create(FilePath: OutputFileName, Size);
206
207 if (!BufferOrErr)
208 return BufferOrErr.takeError();
209
210 Expected<MemoryBufferRef> InputBuffOrErr = Source.getMemoryBufferRef();
211 if (Error Err = InputBuffOrErr.takeError())
212 return Err;
213
214 std::unique_ptr<FileOutputBuffer> Buf = std::move(*BufferOrErr);
215 std::copy(first: InputBuffOrErr->getBufferStart() + Offset,
216 last: InputBuffOrErr->getBufferStart() + Offset + Size,
217 result: Buf->getBufferStart());
218 if (Error E = Buf->commit())
219 return E;
220
221 return Error::success();
222}
223
224// given a file name, offset, and size, extract data into a code object file,
225// into file <SourceFile>-offset<Offset>-size<Size>.co
226Error object::extractOffloadBundleByURI(StringRef URIstr) {
227 // create a URI object
228 Expected<std::unique_ptr<OffloadBundleURI>> UriOrErr(
229 OffloadBundleURI::createOffloadBundleURI(Str: URIstr, Type: FILE_URI));
230 if (!UriOrErr)
231 return UriOrErr.takeError();
232
233 OffloadBundleURI &Uri = **UriOrErr;
234 std::string OutputFile = Uri.FileName.str();
235 OutputFile +=
236 "-offset" + itostr(X: Uri.Offset) + "-size" + itostr(X: Uri.Size) + ".co";
237
238 // Create an ObjectFile object from uri.file_uri
239 auto ObjOrErr = ObjectFile::createObjectFile(ObjectPath: Uri.FileName);
240 if (!ObjOrErr)
241 return ObjOrErr.takeError();
242
243 auto Obj = ObjOrErr->getBinary();
244 if (Error Err =
245 object::extractCodeObject(Source: *Obj, Offset: Uri.Offset, Size: Uri.Size, OutputFileName: OutputFile))
246 return Err;
247
248 return Error::success();
249}
250
251// Utility function to format numbers with commas
252static std::string formatWithCommas(unsigned long long Value) {
253 std::string Num = std::to_string(val: Value);
254 int InsertPosition = Num.length() - 3;
255 while (InsertPosition > 0) {
256 Num.insert(pos: InsertPosition, s: ",");
257 InsertPosition -= 3;
258 }
259 return Num;
260}
261
262llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
263CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input,
264 bool Verbose) {
265 StringRef Blob = Input.getBuffer();
266
267 if (Blob.size() < V1HeaderSize)
268 return llvm::MemoryBuffer::getMemBufferCopy(InputData: Blob);
269
270 if (llvm::identify_magic(magic: Blob) !=
271 llvm::file_magic::offload_bundle_compressed) {
272 if (Verbose)
273 llvm::errs() << "Uncompressed bundle.\n";
274 return llvm::MemoryBuffer::getMemBufferCopy(InputData: Blob);
275 }
276
277 size_t CurrentOffset = MagicSize;
278
279 uint16_t ThisVersion;
280 memcpy(dest: &ThisVersion, src: Blob.data() + CurrentOffset, n: sizeof(uint16_t));
281 CurrentOffset += VersionFieldSize;
282
283 uint16_t CompressionMethod;
284 memcpy(dest: &CompressionMethod, src: Blob.data() + CurrentOffset, n: sizeof(uint16_t));
285 CurrentOffset += MethodFieldSize;
286
287 uint32_t TotalFileSize;
288 if (ThisVersion >= 2) {
289 if (Blob.size() < V2HeaderSize)
290 return createStringError(EC: inconvertibleErrorCode(),
291 S: "Compressed bundle header size too small");
292 memcpy(dest: &TotalFileSize, src: Blob.data() + CurrentOffset, n: sizeof(uint32_t));
293 CurrentOffset += FileSizeFieldSize;
294 }
295
296 uint32_t UncompressedSize;
297 memcpy(dest: &UncompressedSize, src: Blob.data() + CurrentOffset, n: sizeof(uint32_t));
298 CurrentOffset += UncompressedSizeFieldSize;
299
300 uint64_t StoredHash;
301 memcpy(dest: &StoredHash, src: Blob.data() + CurrentOffset, n: sizeof(uint64_t));
302 CurrentOffset += HashFieldSize;
303
304 llvm::compression::Format CompressionFormat;
305 if (CompressionMethod ==
306 static_cast<uint16_t>(llvm::compression::Format::Zlib))
307 CompressionFormat = llvm::compression::Format::Zlib;
308 else if (CompressionMethod ==
309 static_cast<uint16_t>(llvm::compression::Format::Zstd))
310 CompressionFormat = llvm::compression::Format::Zstd;
311 else
312 return createStringError(EC: inconvertibleErrorCode(),
313 S: "Unknown compressing method");
314
315 llvm::Timer DecompressTimer("Decompression Timer", "Decompression time",
316 OffloadBundlerTimerGroup);
317 if (Verbose)
318 DecompressTimer.startTimer();
319
320 SmallVector<uint8_t, 0> DecompressedData;
321 StringRef CompressedData = Blob.substr(Start: CurrentOffset);
322 if (llvm::Error DecompressionError = llvm::compression::decompress(
323 F: CompressionFormat, Input: llvm::arrayRefFromStringRef(Input: CompressedData),
324 Output&: DecompressedData, UncompressedSize))
325 return createStringError(EC: inconvertibleErrorCode(),
326 S: "Could not decompress embedded file contents: " +
327 llvm::toString(E: std::move(DecompressionError)));
328
329 if (Verbose) {
330 DecompressTimer.stopTimer();
331
332 double DecompressionTimeSeconds =
333 DecompressTimer.getTotalTime().getWallTime();
334
335 // Recalculate MD5 hash for integrity check.
336 llvm::Timer HashRecalcTimer("Hash Recalculation Timer",
337 "Hash recalculation time",
338 OffloadBundlerTimerGroup);
339 HashRecalcTimer.startTimer();
340 llvm::MD5 Hash;
341 llvm::MD5::MD5Result Result;
342 Hash.update(Data: llvm::ArrayRef<uint8_t>(DecompressedData));
343 Hash.final(Result);
344 uint64_t RecalculatedHash = Result.low();
345 HashRecalcTimer.stopTimer();
346 bool HashMatch = (StoredHash == RecalculatedHash);
347
348 double CompressionRate =
349 static_cast<double>(UncompressedSize) / CompressedData.size();
350 double DecompressionSpeedMBs =
351 (UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds;
352
353 llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n";
354 if (ThisVersion >= 2)
355 llvm::errs() << "Total file size (from header): "
356 << formatWithCommas(Value: TotalFileSize) << " bytes\n";
357 llvm::errs() << "Decompression method: "
358 << (CompressionFormat == llvm::compression::Format::Zlib
359 ? "zlib"
360 : "zstd")
361 << "\n"
362 << "Size before decompression: "
363 << formatWithCommas(Value: CompressedData.size()) << " bytes\n"
364 << "Size after decompression: "
365 << formatWithCommas(Value: UncompressedSize) << " bytes\n"
366 << "Compression rate: "
367 << llvm::format(Fmt: "%.2lf", Vals: CompressionRate) << "\n"
368 << "Compression ratio: "
369 << llvm::format(Fmt: "%.2lf%%", Vals: 100.0 / CompressionRate) << "\n"
370 << "Decompression speed: "
371 << llvm::format(Fmt: "%.2lf MB/s", Vals: DecompressionSpeedMBs) << "\n"
372 << "Stored hash: " << llvm::format_hex(N: StoredHash, Width: 16) << "\n"
373 << "Recalculated hash: "
374 << llvm::format_hex(N: RecalculatedHash, Width: 16) << "\n"
375 << "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n";
376 }
377
378 return llvm::MemoryBuffer::getMemBufferCopy(
379 InputData: llvm::toStringRef(Input: DecompressedData));
380}
381
382llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
383CompressedOffloadBundle::compress(llvm::compression::Params P,
384 const llvm::MemoryBuffer &Input,
385 bool Verbose) {
386 if (!llvm::compression::zstd::isAvailable() &&
387 !llvm::compression::zlib::isAvailable())
388 return createStringError(EC: llvm::inconvertibleErrorCode(),
389 S: "Compression not supported");
390
391 llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
392 OffloadBundlerTimerGroup);
393 if (Verbose)
394 HashTimer.startTimer();
395 llvm::MD5 Hash;
396 llvm::MD5::MD5Result Result;
397 Hash.update(Str: Input.getBuffer());
398 Hash.final(Result);
399 uint64_t TruncatedHash = Result.low();
400 if (Verbose)
401 HashTimer.stopTimer();
402
403 SmallVector<uint8_t, 0> CompressedBuffer;
404 auto BufferUint8 = llvm::ArrayRef<uint8_t>(
405 reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
406 Input.getBuffer().size());
407
408 llvm::Timer CompressTimer("Compression Timer", "Compression time",
409 OffloadBundlerTimerGroup);
410 if (Verbose)
411 CompressTimer.startTimer();
412 llvm::compression::compress(P, Input: BufferUint8, Output&: CompressedBuffer);
413 if (Verbose)
414 CompressTimer.stopTimer();
415
416 uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
417 uint32_t UncompressedSize = Input.getBuffer().size();
418 uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) +
419 sizeof(Version) + sizeof(CompressionMethod) +
420 sizeof(UncompressedSize) + sizeof(TruncatedHash) +
421 CompressedBuffer.size();
422
423 SmallVector<char, 0> FinalBuffer;
424 llvm::raw_svector_ostream OS(FinalBuffer);
425 OS << MagicNumber;
426 OS.write(Ptr: reinterpret_cast<const char *>(&Version), Size: sizeof(Version));
427 OS.write(Ptr: reinterpret_cast<const char *>(&CompressionMethod),
428 Size: sizeof(CompressionMethod));
429 OS.write(Ptr: reinterpret_cast<const char *>(&TotalFileSize),
430 Size: sizeof(TotalFileSize));
431 OS.write(Ptr: reinterpret_cast<const char *>(&UncompressedSize),
432 Size: sizeof(UncompressedSize));
433 OS.write(Ptr: reinterpret_cast<const char *>(&TruncatedHash),
434 Size: sizeof(TruncatedHash));
435 OS.write(Ptr: reinterpret_cast<const char *>(CompressedBuffer.data()),
436 Size: CompressedBuffer.size());
437
438 if (Verbose) {
439 auto MethodUsed =
440 P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib";
441 double CompressionRate =
442 static_cast<double>(UncompressedSize) / CompressedBuffer.size();
443 double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();
444 double CompressionSpeedMBs =
445 (UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds;
446
447 llvm::errs() << "Compressed bundle format version: " << Version << "\n"
448 << "Total file size (including headers): "
449 << formatWithCommas(Value: TotalFileSize) << " bytes\n"
450 << "Compression method used: " << MethodUsed << "\n"
451 << "Compression level: " << P.level << "\n"
452 << "Binary size before compression: "
453 << formatWithCommas(Value: UncompressedSize) << " bytes\n"
454 << "Binary size after compression: "
455 << formatWithCommas(Value: CompressedBuffer.size()) << " bytes\n"
456 << "Compression rate: "
457 << llvm::format(Fmt: "%.2lf", Vals: CompressionRate) << "\n"
458 << "Compression ratio: "
459 << llvm::format(Fmt: "%.2lf%%", Vals: 100.0 / CompressionRate) << "\n"
460 << "Compression speed: "
461 << llvm::format(Fmt: "%.2lf MB/s", Vals: CompressionSpeedMBs) << "\n"
462 << "Truncated MD5 hash: "
463 << llvm::format_hex(N: TruncatedHash, Width: 16) << "\n";
464 }
465 return llvm::MemoryBuffer::getMemBufferCopy(
466 InputData: llvm::StringRef(FinalBuffer.data(), FinalBuffer.size()));
467}
468