1 | //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation |
10 | // happens off a model that's provided from the command line and is interpreted. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/ADT/STLExtras.h" |
15 | #include "llvm/Config/config.h" |
16 | #if defined(LLVM_HAVE_TFLITE) |
17 | #include "llvm/Analysis/ModelUnderTrainingRunner.h" |
18 | #include "llvm/Support/MemoryBuffer.h" |
19 | #include "llvm/Support/Path.h" |
20 | #include <optional> |
21 | |
22 | using namespace llvm; |
23 | namespace { |
24 | struct LoggedFeatureSpec { |
25 | TensorSpec Spec; |
26 | std::optional<std::string> LoggingName; |
27 | }; |
28 | |
29 | std::optional<std::vector<LoggedFeatureSpec>> |
30 | loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName, |
31 | StringRef ModelPath, StringRef SpecFileOverride) { |
32 | SmallVector<char, 128> OutputSpecsPath; |
33 | StringRef FileName = SpecFileOverride; |
34 | if (FileName.empty()) { |
35 | llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json" ); |
36 | FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()}; |
37 | } |
38 | |
39 | auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName); |
40 | if (!BufferOrError) { |
41 | Ctx.emitError("Error opening output specs file: " + FileName + " : " + |
42 | BufferOrError.getError().message()); |
43 | return std::nullopt; |
44 | } |
45 | auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer()); |
46 | if (!ParsedJSONValues) { |
47 | Ctx.emitError("Could not parse specs file: " + FileName); |
48 | return std::nullopt; |
49 | } |
50 | auto ValuesArray = ParsedJSONValues->getAsArray(); |
51 | if (!ValuesArray) { |
52 | Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, " |
53 | "logging_name:<name>} dictionaries" ); |
54 | return std::nullopt; |
55 | } |
56 | std::vector<LoggedFeatureSpec> Ret; |
57 | for (const auto &Value : *ValuesArray) |
58 | if (const auto *Obj = Value.getAsObject()) |
59 | if (const auto *SpecPart = Obj->get("tensor_spec" )) |
60 | if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart)) |
61 | if (auto LoggingName = Obj->getString("logging_name" )) { |
62 | if (!TensorSpec->isElementType<int64_t>() && |
63 | !TensorSpec->isElementType<int32_t>() && |
64 | !TensorSpec->isElementType<float>()) { |
65 | Ctx.emitError( |
66 | "Only int64, int32, and float tensors are supported. " |
67 | "Found unsupported type for tensor named " + |
68 | TensorSpec->name()); |
69 | return std::nullopt; |
70 | } |
71 | Ret.push_back({*TensorSpec, LoggingName->str()}); |
72 | } |
73 | |
74 | if (ValuesArray->size() != Ret.size()) { |
75 | Ctx.emitError( |
76 | "Unable to parse output spec. It should be a json file containing an " |
77 | "array of dictionaries. Each dictionary must have a 'tensor_spec' key, " |
78 | "with a json object describing a TensorSpec; and a 'logging_name' key, " |
79 | "which is a string to use as name when logging this tensor in the " |
80 | "training log." ); |
81 | return std::nullopt; |
82 | } |
83 | if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) { |
84 | Ctx.emitError("The first output spec must describe the decision tensor, " |
85 | "and must have the logging_name " + |
86 | StringRef(ExpectedDecisionName)); |
87 | return std::nullopt; |
88 | } |
89 | return Ret; |
90 | } |
91 | } // namespace |
92 | |
93 | ModelUnderTrainingRunner::ModelUnderTrainingRunner( |
94 | LLVMContext &Ctx, const std::string &ModelPath, |
95 | const std::vector<TensorSpec> &InputSpecs, |
96 | const std::vector<TensorSpec> &OutputSpecs, |
97 | const std::vector<TensorSpec> &ExtraOutputsForLogging) |
98 | : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()), |
99 | OutputSpecs(OutputSpecs), ExtraOutputsForLogging(ExtraOutputsForLogging) { |
100 | Evaluator = |
101 | std::make_unique<TFModelEvaluator>(ModelPath, InputSpecs, OutputSpecs); |
102 | if (!Evaluator || !Evaluator->isValid()) { |
103 | Ctx.emitError("Failed to create saved model evaluator" ); |
104 | Evaluator.reset(); |
105 | return; |
106 | } |
107 | |
108 | for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) { |
109 | setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I)); |
110 | } |
111 | } |
112 | |
113 | void *ModelUnderTrainingRunner::evaluateUntyped() { |
114 | LastEvaluationResult = Evaluator->evaluate(); |
115 | if (!LastEvaluationResult.has_value()) { |
116 | Ctx.emitError("Error evaluating model." ); |
117 | return nullptr; |
118 | } |
119 | return LastEvaluationResult->getUntypedTensorValue(0); |
120 | } |
121 | |
122 | std::unique_ptr<ModelUnderTrainingRunner> |
123 | ModelUnderTrainingRunner::createAndEnsureValid( |
124 | LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, |
125 | const std::vector<TensorSpec> &InputSpecs, |
126 | StringRef OutputSpecsPathOverride) { |
127 | if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath, |
128 | OutputSpecsPathOverride)) { |
129 | std::unique_ptr<ModelUnderTrainingRunner> MUTR; |
130 | std::vector<TensorSpec> OutputSpecs; |
131 | std::vector<TensorSpec> ExtraOutputsForLogging; |
132 | append_range(OutputSpecs, |
133 | map_range(*MaybeOutputSpecs, [](const LoggedFeatureSpec &LFS) { |
134 | return LFS.Spec; |
135 | })); |
136 | append_range(ExtraOutputsForLogging, |
137 | map_range(drop_begin(*MaybeOutputSpecs), |
138 | [](const LoggedFeatureSpec &LFS) { |
139 | return TensorSpec(LFS.LoggingName |
140 | ? *LFS.LoggingName |
141 | : LFS.Spec.name(), |
142 | LFS.Spec); |
143 | })); |
144 | |
145 | MUTR.reset(new ModelUnderTrainingRunner( |
146 | Ctx, ModelPath, InputSpecs, OutputSpecs, ExtraOutputsForLogging)); |
147 | if (MUTR && MUTR->isValid()) |
148 | return MUTR; |
149 | |
150 | Ctx.emitError("Could not load or create model evaluator." ); |
151 | return nullptr; |
152 | } |
153 | Ctx.emitError("Could not load the policy model from the provided path" ); |
154 | return nullptr; |
155 | } |
156 | |
157 | #endif // defined(LLVM_HAVE_TFLITE) |
158 | |