11#ifndef MLBRIDGE_TENSORSPEC_H
12#define MLBRIDGE_TENSORSPEC_H
14#include "llvm/ADT/Optional.h"
15#include "llvm/IR/LLVMContext.h"
16#include "llvm/Support/JSON.h"
33#define SUPPORTED_TENSOR_TYPES(M) \
47#define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name,
49#undef _TENSOR_TYPE_ENUM_MEMBERS
57 const std::vector<int64_t> &
Shape,
62 const std::string &
name()
const {
return Name; }
65 const std::vector<int64_t> &
shape()
const {
return Shape; }
69 std::multiplies<int64_t>());
93 void toJSON(llvm::json::OStream &
OS)
const;
119llvm::Optional<TensorSpec>
122#define TFUTILS_GETDATATYPE_DEF(T, Name) \
123 template <> TensorType TensorSpec::getDataType<T>();
126#undef TFUTILS_GETDATATYPE_DEF
#define TFUTILS_GETDATATYPE_DEF(T, Name)
#define SUPPORTED_TENSOR_TYPES(M)
TensorSpec encapsulates the specification of a tensor: its dimensions, or "shape" (row-major),...
bool isElementType() const
void toJSON(llvm::json::OStream &OS) const
bool operator==(const TensorSpec &Other) const
TensorSpec(const std::string &NewName, const TensorSpec &Other)
size_t getElementCount() const
Get the number of elements in a tensor with this shape.
static TensorSpec createSpec(const std::string &Name, const std::vector< int64_t > &Shape, int Port=6020)
size_t getElementByteSize() const
Get the size, in bytes, of one element.
size_t getTotalTensorBufferSize() const
Get the total size of a memory buffer needed to store the whole tensor.
const std::string & name() const
const std::vector< int64_t > & shape() const
static TensorType getDataType()
void setShape(const std::vector< int64_t > &NewShape)
bool operator!=(const TensorSpec &Other) const
std::vector< int64_t > Shape
std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec)
For debugging.
llvm::Optional< TensorSpec > getTensorSpecFromJSON(llvm::LLVMContext &Ctx, const llvm::json::Value &Value)
Construct a TensorSpec from a JSON dictionary of the form: { "name": <string>, "port": <int>,...
_TENSOR_TYPE_ENUM_MEMBERS(_, Name)