MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
TensorSpec.cpp
Go to the documentation of this file.
1//===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (Preliminary version adopted from TensorSpec.cpp of LLVM 12.X)
8//
9//===----------------------------------------------------------------------===//
15//===----------------------------------------------------------------------===//
16
17#include "SerDes/TensorSpec.h"
18#include "llvm/ADT/None.h"
19#include "llvm/ADT/StringExtras.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/JSON.h"
22
23#include <array>
24#include <cassert>
25#include <numeric>
26
27using namespace llvm;
28
29namespace MLBridge {
30
31#define TFUTILS_GETDATATYPE_IMPL(T, E) \
32 template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
33
35
36#undef TFUTILS_GETDATATYPE_IMPL
37
38static std::array<std::string, static_cast<size_t>(TensorType::Total)>
39 TensorTypeNames{"INVALID",
40#define TFUTILS_GETNAME_IMPL(T, _) #T,
42#undef TFUTILS_GETNAME_IMPL
43 };
44
45StringRef toString(TensorType TT) {
46 return TensorTypeNames[static_cast<size_t>(TT)];
47}
48
49void TensorSpec::toJSON(json::OStream &OS) const {
50 OS.object([&]() {
51 OS.attribute("name", name());
52 OS.attribute("type", toString(type()));
53 OS.attribute("port", port());
54 OS.attributeArray("shape", [&]() {
55 for (size_t D : shape())
56 OS.value(static_cast<int64_t>(D));
57 });
58 });
59}
60
61TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
62 size_t ElementSize, const std::vector<int64_t> &Shape)
63 : Name(Name), Port(Port), Type(Type), Shape(Shape),
64 ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
65 std::multiplies<int64_t>())),
66 ElementSize(ElementSize) {}
67
68llvm::Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
69 const json::Value &Value) {
70 auto EmitError =
71 [&](const llvm::Twine &Message) -> llvm::Optional<TensorSpec> {
72 std::string S;
73 llvm::raw_string_ostream OS(S);
74 OS << Value;
75 Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
76 return None;
77 };
78// FIXME: accept a Path as a parameter, and use it for error reporting.
79#ifdef LLVM_MLBRIDGE
80 json::Path::Root Root("tensor_spec");
81 json::ObjectMapper Mapper(Value, Root);
82#else
83 json::ObjectMapper Mapper(Value);
84#endif
85 if (!Mapper)
86 return EmitError("Value is not a dict");
87
88 std::string TensorName;
89 int TensorPort = -1;
90 std::string TensorType;
91 std::vector<int64_t> TensorShape;
92
93 if (!Mapper.map<std::string>("name", TensorName))
94 return EmitError("'name' property not present or not a string");
95 if (!Mapper.map<std::string>("type", TensorType))
96 return EmitError("'type' property not present or not a string");
97 if (!Mapper.map<int>("port", TensorPort))
98 return EmitError("'port' property not present or not an int");
99 if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
100 return EmitError("'shape' property not present or not an int array");
101
102#define PARSE_TYPE(T, E) \
103 if (TensorType == #T) \
104 return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
106#undef PARSE_TYPE
107 return None;
108}
109
110std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) {
111 switch (Spec.type()) {
112#define _IMR_DBG_PRINTER(T, N) \
113 case TensorType::N: { \
114 const T *TypedBuff = reinterpret_cast<const T *>(Buffer); \
115 auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount()); \
116 return llvm::join( \
117 llvm::map_range(R, [](T V) { return std::to_string(V); }), ","); \
118 }
120#undef _IMR_DBG_PRINTER
123 llvm_unreachable("invalid tensor type");
124 }
125 // To appease warnings about not all control paths returning a value.
126 return "";
127}
128
129} // namespace MLBridge
#define PARSE_TYPE(T, E)
#define TFUTILS_GETDATATYPE_IMPL(T, E)
#define TFUTILS_GETNAME_IMPL(T, _)
#define _IMR_DBG_PRINTER(T, N)
#define SUPPORTED_TENSOR_TYPES(M)
TensorSpec encapsulates the specification of a tensor: its dimensions, or "shape" (row-major),...
Definition TensorSpec.h:33
void toJSON(llvm::json::OStream &OS) const
TensorType type() const
Definition TensorSpec.h:64
TensorSpec(const std::string &NewName, const TensorSpec &Other)
Definition TensorSpec.h:89
const std::string & name() const
Definition TensorSpec.h:62
const std::vector< int64_t > & shape() const
Definition TensorSpec.h:65
static TensorType getDataType()
std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec)
For debugging.
llvm::Optional< TensorSpec > getTensorSpecFromJSON(llvm::LLVMContext &Ctx, const llvm::json::Value &Value)
Construct a TensorSpec from a JSON dictionary of the form: { "name": <string>, "port": <int>,...
StringRef toString(TensorType TT)