MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
tensorflowSerDes.h
Go to the documentation of this file.
1//=== SerDes/tensorflowSerDes.h - SerDes for TF support ---*- C++ ---------===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
12//===----------------------------------------------------------------------===//
13
14#ifndef TENSORFLOW_SERIALIZER_H
15#define TENSORFLOW_SERIALIZER_H
16
17#include "SerDes/baseSerDes.h"
18#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
19
20namespace MLBridge {
23public:
25
26 static bool classof(const BaseSerDes *S) {
27 return S->getKind() == SerDesKind::Tensorflow;
28 }
29
30#define SET_FEATURE(TYPE, _) \
31 void setFeature(const std::string &, const TYPE) override; \
32 void setFeature(const std::string &, const std::vector<TYPE> &) override;
34#undef SET_FEATURE
35
36 void setRequest(void *request) override {
38 reinterpret_cast<tensorflow::XlaCompiledCpuFunction *>(request);
39 }
40
41 void *getSerializedData() override { return nullptr; };
42 void cleanDataStructures() override{};
43
44private:
45 void *deserializeUntyped(void *data) override { return nullptr; };
46 tensorflow::XlaCompiledCpuFunction *CompiledModel;
47};
48} // namespace MLBridge
49
50#endif
Supporting new SerDes:
#define SUPPORTED_TYPES(M)
Definition baseSerDes.h:32
#define SET_FEATURE(TYPE, _)
setFeature() is used to set the features of the data structure used for communication.
Definition baseSerDes.h:54
SerDesKind getKind() const
Definition baseSerDes.h:49
TensorflowSerDes - Serialization/Deserialization to support TF AOT models.
void setRequest(void *request) override
static bool classof(const BaseSerDes *S)
void cleanDataStructures() override
void * deserializeUntyped(void *data) override
tensorflow::XlaCompiledCpuFunction * CompiledModel
void * getSerializedData() override
SerDesKind
This is the base class for SerDes.
Definition baseSerDes.h:46