MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
tensorflowSerDes.cpp
Go to the documentation of this file.
1//===- tensorflowSerDes.cpp - Serializer support for TF ---------*- C++ -*-===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
16#include "SerDes/baseSerDes.h"
17
18// #define EXCEPT_LONG(M) M(int) M(float) M(double) M(std::string) M(bool)
19namespace MLBridge {
20#define SET_FEATURE(TYPE, _) \
21 void TensorflowSerDes::setFeature(const std::string &Name, \
22 const TYPE Value) { \
23 std::string prefix = "feed_"; \
24 const int Index = CompiledModel->LookupArgIndex(prefix + Name); \
25 if (Index >= 0) \
26 *reinterpret_cast<TYPE *>(CompiledModel->arg_data(Index)) = Value; \
27 }
29#undef SET_FEATURE
30
31void TensorflowSerDes::setFeature(const std::string &Name,
32 const std::vector<int64_t> &Value) {
33 std::string prefix = "feed_";
34 const int Index = CompiledModel->LookupArgIndex(prefix + Name);
35 std::copy(Value.begin(), Value.end(),
36 static_cast<int64_t *>(CompiledModel->arg_data(Index)));
37}
38
39void TensorflowSerDes::setFeature(const std::string &Name,
40 const std::vector<float> &Value) {
41 std::string prefix = "feed_";
42 const int Index = CompiledModel->LookupArgIndex(prefix + Name);
43 std::copy(Value.begin(), Value.end(),
44 static_cast<float *>(CompiledModel->arg_data(Index)));
45}
46
47void TensorflowSerDes::setFeature(const std::string &Name,
48 const std::vector<double> &Value) {
49 std::string prefix = "feed_";
50 const int Index = CompiledModel->LookupArgIndex(prefix + Name);
51 std::copy(Value.begin(), Value.end(),
52 static_cast<double *>(CompiledModel->arg_data(Index)));
53}
54
55void TensorflowSerDes::setFeature(const std::string &Name,
56 const std::vector<std::string> &Value) {
57 std::string prefix = "feed_";
58 const int Index = CompiledModel->LookupArgIndex(prefix + Name);
59 std::copy(Value.begin(), Value.end(),
60 static_cast<std::string *>(CompiledModel->arg_data(Index)));
61}
62
63void TensorflowSerDes::setFeature(const std::string &Name,
64 const std::vector<bool> &Value) {
65 std::string prefix = "feed_";
66 const int Index = CompiledModel->LookupArgIndex(prefix + Name);
67 std::copy(Value.begin(), Value.end(),
68 static_cast<bool *>(CompiledModel->arg_data(Index)));
69}
70
71void TensorflowSerDes::setFeature(const std::string &Name,
72 const std::vector<int> &Value) {
73 std::string prefix = "feed_";
74 const int Index = CompiledModel->LookupArgIndex(prefix + Name);
75 std::copy(Value.begin(), Value.end(),
76 static_cast<int *>(CompiledModel->arg_data(Index)));
77}
78} // namespace MLBridge
Supporting new SerDes:
#define SUPPORTED_TYPES(M)
Definition baseSerDes.h:32
#define SET_FEATURE(TYPE, _)
setFeature() is used to set the features of the data structure used for communication.
Definition baseSerDes.h:54
virtual void setFeature(const std::string &name, const google::protobuf::Message *value)
Definition baseSerDes.h:60
tensorflow::XlaCompiledCpuFunction * CompiledModel
Serialization/Deserialization to support TF AOT models.