MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
protobufSerDes.cpp
Go to the documentation of this file.
1//===- protobufSerDes.cpp - Protobuf Serializer for gRPC -------*- C++ -*-===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
15//===----------------------------------------------------------------------===//
16
18#include "google/protobuf/descriptor.h"
19#include "google/protobuf/message.h"
20
21#include <cassert>
22#include <cstdint>
23#include <type_traits>
24#include <vector>
25
26namespace MLBridge {
27inline void ProtobufSerDes::setFeature(const std::string &name,
28 const int value) {
29 Request->GetReflection()->SetInt32(
30 Request, Request->GetDescriptor()->FindFieldByName(name), value);
31}
32
33inline void ProtobufSerDes::setFeature(const std::string &name,
34 const long value) {
35 Request->GetReflection()->SetInt64(
36 Request, Request->GetDescriptor()->FindFieldByName(name), value);
37}
38
39inline void ProtobufSerDes::setFeature(const std::string &name,
40 const float value) {
41 Request->GetReflection()->SetFloat(
42 Request, Request->GetDescriptor()->FindFieldByName(name), value);
43}
44
45inline void ProtobufSerDes::setFeature(const std::string &name,
46 const double value) {
47 Request->GetReflection()->SetDouble(
48 Request, Request->GetDescriptor()->FindFieldByName(name), value);
49}
50
51inline void ProtobufSerDes::setFeature(const std::string &name,
52 const std::string value) {
53 Request->GetReflection()->SetString(
54 Request, Request->GetDescriptor()->FindFieldByName(name), value);
55}
56
57inline void ProtobufSerDes::setFeature(const std::string &name,
58 const bool value) {
59 Request->GetReflection()->SetBool(
60 Request, Request->GetDescriptor()->FindFieldByName(name), value);
61}
62
63inline void ProtobufSerDes::setFeature(const std::string &name,
64 const std::vector<int> &value) {
65 auto ref = Request->GetReflection()->MutableRepeatedField<int>(
66 Request, Request->GetDescriptor()->FindFieldByName(name));
67 ref->Add(value.begin(), value.end());
68}
69
70inline void ProtobufSerDes::setFeature(const std::string &name,
71 const std::vector<long> &value) {
72 auto ref = Request->GetReflection()->MutableRepeatedField<long>(
73 Request, Request->GetDescriptor()->FindFieldByName(name));
74 ref->Add(value.begin(), value.end());
75}
76
77inline void ProtobufSerDes::setFeature(const std::string &name,
78 const std::vector<float> &value) {
79 auto ref = Request->GetReflection()->MutableRepeatedField<float>(
80 Request, Request->GetDescriptor()->FindFieldByName(name));
81 ref->Add(value.begin(), value.end());
82}
83
84inline void ProtobufSerDes::setFeature(const std::string &name,
85 const std::vector<double> &value) {
86 auto ref = Request->GetReflection()->MutableRepeatedField<double>(
87 Request, Request->GetDescriptor()->FindFieldByName(name));
88 ref->Add(value.begin(), value.end());
89}
90
91void ProtobufSerDes::setFeature(const std::string &name,
92 const std::vector<std::string> &value) {
93 auto reflection = Request->GetReflection();
94 auto descriptor = Request->GetDescriptor();
95 auto field = descriptor->FindFieldByName(name);
96 for (auto &v : value) {
97 reflection->AddString(Request, field, v);
98 }
99}
100
101inline void ProtobufSerDes::setFeature(const std::string &name,
102 const std::vector<bool> &value) {
103 auto ref = Request->GetReflection()->MutableRepeatedField<bool>(
104 Request, Request->GetDescriptor()->FindFieldByName(name));
105 ref->Add(value.begin(), value.end());
106}
107
109 std::string *data = new std::string();
110 Request->SerializeToString(data);
112 return data;
113}
114
115void ProtobufSerDes::setFeature(const std::string &name,
116 const google::protobuf::Message *value) {
117 auto reflection = Request->GetReflection();
118 auto descriptor = Request->GetDescriptor();
119 auto field = descriptor->FindFieldByName(name);
120 reflection->MutableMessage(Request, field)->CopyFrom(*value);
121}
122
124 const std::string &name,
125 const std::vector<google::protobuf::Message *> &value) {
126 // set repeated field of messages in this->Request
127 auto reflection = Request->GetReflection();
128 auto descriptor = Request->GetDescriptor();
129 auto field = descriptor->FindFieldByName(name);
130 for (auto &v : value) {
131 reflection->AddMessage(Request, field)->CopyFrom(*v);
132 }
133}
134
135inline void ProtobufSerDes::setRequest(void *Request) {
136 this->Request = reinterpret_cast<Message *>(Request);
137}
138
139inline void ProtobufSerDes::setResponse(void *Response) {
140 this->Response = reinterpret_cast<Message *>(Response);
141}
142
144 Request->Clear(); // todo: find correct place to clear request for protobuf
145 // serdes
146 Response = reinterpret_cast<Message *>(data);
147
148 const Descriptor *descriptor = Response->GetDescriptor();
149 const Reflection *reflection = Response->GetReflection();
150 const FieldDescriptor *field = descriptor->field(0);
151
152 if (field->label() == FieldDescriptor::LABEL_REPEATED) {
153 if (field->type() == FieldDescriptor::Type::TYPE_INT32) {
154 auto &ref = reflection->GetRepeatedField<int32_t>(*Response, field);
155 std::vector<int> *ret = new std::vector<int>(ref.begin(), ref.end());
156 this->MessageLength = ref.size() * sizeof(int32_t);
157 return ret->data();
158 }
159 if (field->type() == FieldDescriptor::Type::TYPE_INT64) {
160 auto &ref = reflection->GetRepeatedField<int64_t>(*Response, field);
161 std::vector<int64_t> *ret =
162 new std::vector<int64_t>(ref.begin(), ref.end());
163 this->MessageLength = ref.size() * sizeof(int64_t);
164 return ret->data();
165 }
166 if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) {
167 auto ref = reflection->GetRepeatedField<float>(*Response, field);
168 std::vector<float> *ret = new std::vector<float>(ref.begin(), ref.end());
169 this->MessageLength = ref.size() * sizeof(float);
170 return ret->data();
171 }
172 if (field->type() == FieldDescriptor::Type::TYPE_DOUBLE) {
173 auto ref = reflection->GetRepeatedField<double>(*Response, field);
174 std::vector<double> *ret =
175 new std::vector<double>(ref.begin(), ref.end());
176 this->MessageLength = ref.size() * sizeof(double);
177 return ret->data();
178 }
179 if (field->type() == FieldDescriptor::Type::TYPE_STRING) {
180 // yet to be tested
181 std::cerr << "vector<string> deserialization yet to be done\n";
182 exit(1);
183 std::vector<std::string> *ptr = new std::vector<std::string>();
184
185 /*
186 ISSUE: error: static assertion failed: We only support non-string scalars
187 in RepeatedField. FIX: ??
188 */
189 // auto ref = reflection->GetRepeatedField<std::string>(*Response, field);
190 // for (auto &v : ref) {
191 // ptr->push_back(v);
192 // }
193 return ptr;
194 }
195 if (field->type() == FieldDescriptor::Type::TYPE_BOOL) {
196 auto ref = reflection->GetRepeatedField<bool>(*Response, field);
197 std::vector<bool> *ptr = new std::vector<bool>(
198 ref.mutable_data(), ref.mutable_data() + ref.size());
199 return ptr;
200 }
201 }
202
203 if (field->type() == FieldDescriptor::Type::TYPE_INT32) {
204 int32_t value = reflection->GetInt32(*Response, field);
205 int32_t *ptr = new int32_t(value);
206 this->MessageLength = sizeof(int32_t);
207 return ptr;
208 }
209 if (field->type() == FieldDescriptor::Type::TYPE_INT64) {
210 int64_t value = reflection->GetInt64(*Response, field);
211 int64_t *ptr = new int64_t(value);
212 this->MessageLength = sizeof(int64_t);
213 return ptr;
214 }
215 if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) {
216 float value = reflection->GetFloat(*Response, field);
217 float *ptr = new float(value);
218 this->MessageLength = sizeof(float);
219 return ptr;
220 }
221 if (field->type() == FieldDescriptor::Type::TYPE_DOUBLE) {
222 double value = reflection->GetDouble(*Response, field);
223 double *ptr = new double(value);
224 this->MessageLength = sizeof(double);
225 return ptr;
226 }
227 if (field->type() == FieldDescriptor::Type::TYPE_STRING) {
228 std::string value = reflection->GetString(*Response, field);
229 std::string *ptr = new std::string(value);
230 this->MessageLength = value.length();
231 return ptr;
232 }
233 if (field->type() == FieldDescriptor::Type::TYPE_BOOL) {
234 bool value = reflection->GetBool(*Response, field);
235 bool *ptr = new bool(value);
236 this->MessageLength = sizeof(bool);
237 return ptr;
238 }
239
240 std::cerr << "Unknown type in protobuf serializer\n";
241 exit(1);
242}
243
245 Request->Clear();
246 Response->Clear();
247}
248} // namespace MLBridge
void setFeature(const std::string &name, const google::protobuf::Message *value) override
void setResponse(void *Response) override
void setRequest(void *Request) override
void cleanDataStructures() override
void * deserializeUntyped(void *data) override
void * getSerializedData() override
Protobuf Serialization/Deserialization to support gRPC communication.