1 | //===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements logging infrastructure for extracting features and |
10 | // rewards for mlgo policy training. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | #include "llvm/Analysis/TensorSpec.h" |
14 | #include "llvm/Config/config.h" |
15 | |
16 | #include "llvm/ADT/Twine.h" |
17 | #include "llvm/Analysis/Utils/TrainingLogger.h" |
18 | #include "llvm/Support/CommandLine.h" |
19 | #include "llvm/Support/Debug.h" |
20 | #include "llvm/Support/JSON.h" |
21 | #include "llvm/Support/MemoryBuffer.h" |
22 | #include "llvm/Support/Path.h" |
23 | #include "llvm/Support/raw_ostream.h" |
24 | |
25 | #include <cassert> |
26 | #include <numeric> |
27 | |
28 | using namespace llvm; |
29 | |
30 | void Logger::(std::optional<TensorSpec> AdviceSpec) { |
31 | json::OStream JOS(*OS); |
32 | JOS.object(Contents: [&]() { |
33 | JOS.attributeArray(Key: "features" , Contents: [&]() { |
34 | for (const auto &TS : FeatureSpecs) |
35 | TS.toJSON(OS&: JOS); |
36 | }); |
37 | if (IncludeReward) { |
38 | JOS.attributeBegin(Key: "score" ); |
39 | RewardSpec.toJSON(OS&: JOS); |
40 | JOS.attributeEnd(); |
41 | } |
42 | if (AdviceSpec.has_value()) { |
43 | JOS.attributeBegin(Key: "advice" ); |
44 | AdviceSpec->toJSON(OS&: JOS); |
45 | JOS.attributeEnd(); |
46 | } |
47 | }); |
48 | *OS << "\n" ; |
49 | } |
50 | |
51 | void Logger::switchContext(StringRef Name) { |
52 | CurrentContext = Name.str(); |
53 | json::OStream JOS(*OS); |
54 | JOS.object(Contents: [&]() { JOS.attribute(Key: "context" , Contents: Name); }); |
55 | *OS << "\n" ; |
56 | } |
57 | |
58 | void Logger::startObservation() { |
59 | auto I = ObservationIDs.insert(KV: {CurrentContext, 0}); |
60 | size_t NewObservationID = I.second ? 0 : ++I.first->second; |
61 | json::OStream JOS(*OS); |
62 | JOS.object(Contents: [&]() { |
63 | JOS.attribute(Key: "observation" , Contents: static_cast<int64_t>(NewObservationID)); |
64 | }); |
65 | *OS << "\n" ; |
66 | } |
67 | |
68 | void Logger::endObservation() { *OS << "\n" ; } |
69 | |
70 | void Logger::logRewardImpl(const char *RawData) { |
71 | assert(IncludeReward); |
72 | json::OStream JOS(*OS); |
73 | JOS.object(Contents: [&]() { |
74 | JOS.attribute(Key: "outcome" , Contents: static_cast<int64_t>( |
75 | ObservationIDs.find(Key: CurrentContext)->second)); |
76 | }); |
77 | *OS << "\n" ; |
78 | writeTensor(Spec: RewardSpec, RawData); |
79 | *OS << "\n" ; |
80 | } |
81 | |
82 | Logger::Logger(std::unique_ptr<raw_ostream> OS, |
83 | const std::vector<TensorSpec> &FeatureSpecs, |
84 | const TensorSpec &RewardSpec, bool IncludeReward, |
85 | std::optional<TensorSpec> AdviceSpec) |
86 | : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), |
87 | IncludeReward(IncludeReward) { |
88 | writeHeader(AdviceSpec); |
89 | } |
90 | |