| 1 | //===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements logging infrastructure for extracting features and |
| 10 | // rewards for mlgo policy training. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | #include "llvm/Analysis/TensorSpec.h" |
| 14 | #include "llvm/Config/config.h" |
| 15 | |
| 16 | #include "llvm/ADT/Twine.h" |
| 17 | #include "llvm/Analysis/Utils/TrainingLogger.h" |
| 18 | #include "llvm/Support/CommandLine.h" |
| 19 | #include "llvm/Support/Debug.h" |
| 20 | #include "llvm/Support/JSON.h" |
| 21 | #include "llvm/Support/MemoryBuffer.h" |
| 22 | #include "llvm/Support/Path.h" |
| 23 | #include "llvm/Support/raw_ostream.h" |
| 24 | |
| 25 | #include <cassert> |
| 26 | #include <numeric> |
| 27 | |
| 28 | using namespace llvm; |
| 29 | |
| 30 | void Logger::(std::optional<TensorSpec> AdviceSpec) { |
| 31 | json::OStream JOS(*OS); |
| 32 | JOS.object(Contents: [&]() { |
| 33 | JOS.attributeArray(Key: "features" , Contents: [&]() { |
| 34 | for (const auto &TS : FeatureSpecs) |
| 35 | TS.toJSON(OS&: JOS); |
| 36 | }); |
| 37 | if (IncludeReward) { |
| 38 | JOS.attributeBegin(Key: "score" ); |
| 39 | RewardSpec.toJSON(OS&: JOS); |
| 40 | JOS.attributeEnd(); |
| 41 | } |
| 42 | if (AdviceSpec.has_value()) { |
| 43 | JOS.attributeBegin(Key: "advice" ); |
| 44 | AdviceSpec->toJSON(OS&: JOS); |
| 45 | JOS.attributeEnd(); |
| 46 | } |
| 47 | }); |
| 48 | *OS << "\n" ; |
| 49 | } |
| 50 | |
| 51 | void Logger::switchContext(StringRef Name) { |
| 52 | CurrentContext = Name.str(); |
| 53 | json::OStream JOS(*OS); |
| 54 | JOS.object(Contents: [&]() { JOS.attribute(Key: "context" , Contents: Name); }); |
| 55 | *OS << "\n" ; |
| 56 | } |
| 57 | |
| 58 | void Logger::startObservation() { |
| 59 | auto I = ObservationIDs.insert(KV: {CurrentContext, 0}); |
| 60 | size_t NewObservationID = I.second ? 0 : ++I.first->second; |
| 61 | json::OStream JOS(*OS); |
| 62 | JOS.object(Contents: [&]() { |
| 63 | JOS.attribute(Key: "observation" , Contents: static_cast<int64_t>(NewObservationID)); |
| 64 | }); |
| 65 | *OS << "\n" ; |
| 66 | } |
| 67 | |
| 68 | void Logger::endObservation() { *OS << "\n" ; } |
| 69 | |
| 70 | void Logger::logRewardImpl(const char *RawData) { |
| 71 | assert(IncludeReward); |
| 72 | json::OStream JOS(*OS); |
| 73 | JOS.object(Contents: [&]() { |
| 74 | JOS.attribute(Key: "outcome" , Contents: static_cast<int64_t>( |
| 75 | ObservationIDs.find(Key: CurrentContext)->second)); |
| 76 | }); |
| 77 | *OS << "\n" ; |
| 78 | writeTensor(Spec: RewardSpec, RawData); |
| 79 | *OS << "\n" ; |
| 80 | } |
| 81 | |
| 82 | Logger::Logger(std::unique_ptr<raw_ostream> OS, |
| 83 | const std::vector<TensorSpec> &FeatureSpecs, |
| 84 | const TensorSpec &RewardSpec, bool IncludeReward, |
| 85 | std::optional<TensorSpec> AdviceSpec) |
| 86 | : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), |
| 87 | IncludeReward(IncludeReward) { |
| 88 | writeHeader(AdviceSpec); |
| 89 | } |
| 90 | |