| 1 | //===- TensorSpec.cpp - tensor type abstraction ---------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Implementation file for the abstraction of a tensor type, and JSON loading |
| 10 | // utils. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | #include "llvm/ADT/STLExtras.h" |
| 14 | #include "llvm/Config/config.h" |
| 15 | |
| 16 | #include "llvm/ADT/StringExtras.h" |
| 17 | #include "llvm/ADT/Twine.h" |
| 18 | #include "llvm/Analysis/TensorSpec.h" |
| 19 | #include "llvm/Support/CommandLine.h" |
| 20 | #include "llvm/Support/Debug.h" |
| 21 | #include "llvm/Support/JSON.h" |
| 22 | #include "llvm/Support/ManagedStatic.h" |
| 23 | #include "llvm/Support/raw_ostream.h" |
| 24 | #include <array> |
| 25 | #include <cassert> |
| 26 | #include <numeric> |
| 27 | |
| 28 | using namespace llvm; |
| 29 | |
| 30 | namespace llvm { |
| 31 | |
| 32 | #define TFUTILS_GETDATATYPE_IMPL(T, E) \ |
| 33 | template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; } |
| 34 | |
| 35 | SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL) |
| 36 | |
| 37 | #undef TFUTILS_GETDATATYPE_IMPL |
| 38 | |
| 39 | static std::array<std::string, static_cast<size_t>(TensorType::Total)> |
| 40 | TensorTypeNames{"INVALID" , |
| 41 | #define TFUTILS_GETNAME_IMPL(T, _) #T, |
| 42 | SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL) |
| 43 | #undef TFUTILS_GETNAME_IMPL |
| 44 | }; |
| 45 | |
| 46 | StringRef toString(TensorType TT) { |
| 47 | return TensorTypeNames[static_cast<size_t>(TT)]; |
| 48 | } |
| 49 | |
| 50 | void TensorSpec::toJSON(json::OStream &OS) const { |
| 51 | OS.object(Contents: [&]() { |
| 52 | OS.attribute(Key: "name" , Contents: name()); |
| 53 | OS.attribute(Key: "type" , Contents: toString(TT: type())); |
| 54 | OS.attribute(Key: "port" , Contents: port()); |
| 55 | OS.attributeArray(Key: "shape" , Contents: [&]() { |
| 56 | for (size_t D : shape()) |
| 57 | OS.value(V: static_cast<int64_t>(D)); |
| 58 | }); |
| 59 | }); |
| 60 | } |
| 61 | |
| 62 | TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type, |
| 63 | size_t ElementSize, const std::vector<int64_t> &Shape) |
| 64 | : Name(Name), Port(Port), Type(Type), Shape(Shape), |
| 65 | ElementCount(std::accumulate(first: Shape.begin(), last: Shape.end(), init: 1, |
| 66 | binary_op: std::multiplies<int64_t>())), |
| 67 | ElementSize(ElementSize) {} |
| 68 | |
| 69 | std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, |
| 70 | const json::Value &Value) { |
| 71 | auto EmitError = |
| 72 | [&](const llvm::Twine &Message) -> std::optional<TensorSpec> { |
| 73 | std::string S; |
| 74 | llvm::raw_string_ostream OS(S); |
| 75 | OS << Value; |
| 76 | Ctx.emitError(ErrorStr: "Unable to parse JSON Value as spec (" + Message + "): " + S); |
| 77 | return std::nullopt; |
| 78 | }; |
| 79 | // FIXME: accept a Path as a parameter, and use it for error reporting. |
| 80 | json::Path::Root Root("tensor_spec" ); |
| 81 | json::ObjectMapper Mapper(Value, Root); |
| 82 | if (!Mapper) |
| 83 | return EmitError("Value is not a dict" ); |
| 84 | |
| 85 | std::string TensorName; |
| 86 | int TensorPort = -1; |
| 87 | std::string TensorType; |
| 88 | std::vector<int64_t> TensorShape; |
| 89 | |
| 90 | if (!Mapper.map<std::string>(Prop: "name" , Out&: TensorName)) |
| 91 | return EmitError("'name' property not present or not a string" ); |
| 92 | if (!Mapper.map<std::string>(Prop: "type" , Out&: TensorType)) |
| 93 | return EmitError("'type' property not present or not a string" ); |
| 94 | if (!Mapper.map<int>(Prop: "port" , Out&: TensorPort)) |
| 95 | return EmitError("'port' property not present or not an int" ); |
| 96 | if (!Mapper.map<std::vector<int64_t>>(Prop: "shape" , Out&: TensorShape)) |
| 97 | return EmitError("'shape' property not present or not an int array" ); |
| 98 | |
| 99 | #define PARSE_TYPE(T, E) \ |
| 100 | if (TensorType == #T) \ |
| 101 | return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort); |
| 102 | SUPPORTED_TENSOR_TYPES(PARSE_TYPE) |
| 103 | #undef PARSE_TYPE |
| 104 | return std::nullopt; |
| 105 | } |
| 106 | |
| 107 | std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) { |
| 108 | switch (Spec.type()) { |
| 109 | #define _IMR_DBG_PRINTER(T, N) \ |
| 110 | case TensorType::N: { \ |
| 111 | const T *TypedBuff = reinterpret_cast<const T *>(Buffer); \ |
| 112 | auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount()); \ |
| 113 | return llvm::join( \ |
| 114 | llvm::map_range(R, [](T V) { return std::to_string(V); }), ","); \ |
| 115 | } |
| 116 | SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER) |
| 117 | #undef _IMR_DBG_PRINTER |
| 118 | case TensorType::Total: |
| 119 | case TensorType::Invalid: |
| 120 | llvm_unreachable("invalid tensor type" ); |
| 121 | } |
| 122 | // To appease warnings about not all control paths returning a value. |
| 123 | return "" ; |
| 124 | } |
| 125 | |
| 126 | } // namespace llvm |
| 127 | |