1//===-- TestModuleFileExtension.cpp - Module Extension Tester -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "TestModuleFileExtension.h"
9#include "clang/Frontend/FrontendDiagnostic.h"
10#include "clang/Serialization/ASTReader.h"
11#include "llvm/ADT/Hashing.h"
12#include "llvm/Bitstream/BitstreamWriter.h"
13#include "llvm/Support/raw_ostream.h"
14#include <cstdio>
15using namespace clang;
16using namespace clang::serialization;
17
18char TestModuleFileExtension::ID = 0;
19
20TestModuleFileExtension::Writer::~Writer() { }
21
22void TestModuleFileExtension::Writer::writeExtensionContents(
23 Sema &SemaRef,
24 llvm::BitstreamWriter &Stream) {
25 using namespace llvm;
26
27 // Write an abbreviation for this record.
28 auto Abv = std::make_shared<llvm::BitCodeAbbrev>();
29 Abv->Add(OpInfo: BitCodeAbbrevOp(FIRST_EXTENSION_RECORD_ID));
30 Abv->Add(OpInfo: BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6)); // # of characters
31 Abv->Add(OpInfo: BitCodeAbbrevOp(BitCodeAbbrevOp::Blob)); // message
32 auto Abbrev = Stream.EmitAbbrev(Abbv: std::move(Abv));
33
34 // Write a message into the extension block.
35 SmallString<64> Message;
36 {
37 auto Ext = static_cast<TestModuleFileExtension *>(getExtension());
38 raw_svector_ostream OS(Message);
39 OS << "Hello from " << Ext->BlockName << " v" << Ext->MajorVersion << "."
40 << Ext->MinorVersion;
41 }
42 uint64_t Record[] = {FIRST_EXTENSION_RECORD_ID, Message.size()};
43 Stream.EmitRecordWithBlob(Abbrev, Vals: Record, Blob: Message);
44}
45
46TestModuleFileExtension::Reader::Reader(ModuleFileExtension *Ext,
47 const llvm::BitstreamCursor &InStream)
48 : ModuleFileExtensionReader(Ext), Stream(InStream)
49{
50 // Read the extension block.
51 SmallVector<uint64_t, 4> Record;
52 while (true) {
53 llvm::Expected<llvm::BitstreamEntry> MaybeEntry =
54 Stream.advanceSkippingSubblocks();
55 if (!MaybeEntry)
56 (void)MaybeEntry.takeError();
57 llvm::BitstreamEntry Entry = MaybeEntry.get();
58
59 switch (Entry.Kind) {
60 case llvm::BitstreamEntry::SubBlock:
61 case llvm::BitstreamEntry::EndBlock:
62 case llvm::BitstreamEntry::Error:
63 return;
64
65 case llvm::BitstreamEntry::Record:
66 break;
67 }
68
69 Record.clear();
70 StringRef Blob;
71 Expected<unsigned> MaybeRecCode =
72 Stream.readRecord(AbbrevID: Entry.ID, Vals&: Record, Blob: &Blob);
73 if (!MaybeRecCode)
74 fprintf(stderr, format: "Failed reading rec code: %s\n",
75 toString(E: MaybeRecCode.takeError()).c_str());
76 switch (MaybeRecCode.get()) {
77 case FIRST_EXTENSION_RECORD_ID: {
78 StringRef Message = Blob.substr(Start: 0, N: Record[0]);
79 fprintf(stderr, format: "Read extension block message: %s\n",
80 Message.str().c_str());
81 break;
82 }
83 }
84 }
85}
86
87TestModuleFileExtension::Reader::~Reader() { }
88
89TestModuleFileExtension::~TestModuleFileExtension() { }
90
91ModuleFileExtensionMetadata
92TestModuleFileExtension::getExtensionMetadata() const {
93 return { .BlockName: BlockName, .MajorVersion: MajorVersion, .MinorVersion: MinorVersion, .UserInfo: UserInfo };
94}
95
96void TestModuleFileExtension::hashExtension(
97 ExtensionHashBuilder &HBuilder) const {
98 if (Hashed) {
99 HBuilder.add(Value: BlockName);
100 HBuilder.add(Value: MajorVersion);
101 HBuilder.add(Value: MinorVersion);
102 HBuilder.add(Value: UserInfo);
103 }
104}
105
106std::unique_ptr<ModuleFileExtensionWriter>
107TestModuleFileExtension::createExtensionWriter(ASTWriter &) {
108 return std::unique_ptr<ModuleFileExtensionWriter>(new Writer(this));
109}
110
111std::unique_ptr<ModuleFileExtensionReader>
112TestModuleFileExtension::createExtensionReader(
113 const ModuleFileExtensionMetadata &Metadata,
114 ASTReader &Reader, serialization::ModuleFile &Mod,
115 const llvm::BitstreamCursor &Stream)
116{
117 assert(Metadata.BlockName == BlockName && "Wrong block name");
118 if (std::make_pair(x: Metadata.MajorVersion, y: Metadata.MinorVersion) !=
119 std::make_pair(x&: MajorVersion, y&: MinorVersion)) {
120 Reader.getDiags().Report(Loc: Mod.ImportLoc,
121 DiagID: diag::err_test_module_file_extension_version)
122 << BlockName << Metadata.MajorVersion << Metadata.MinorVersion
123 << MajorVersion << MinorVersion;
124 return nullptr;
125 }
126
127 return std::unique_ptr<ModuleFileExtensionReader>(
128 new TestModuleFileExtension::Reader(this, Stream));
129}
130
131std::string TestModuleFileExtension::str() const {
132 std::string Buffer;
133 llvm::raw_string_ostream OS(Buffer);
134 OS << BlockName << ":" << MajorVersion << ":" << MinorVersion << ":" << Hashed
135 << ":" << UserInfo;
136 return Buffer;
137}
138