MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
MLModelRunner.h
Go to the documentation of this file.
1//===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (Preliminary version adopted from MLModelRunner.h of LLVM 17.X)
8//
9//===----------------------------------------------------------------------===//
34//===----------------------------------------------------------------------===//
35
36#ifndef ML_MODEL_RUNNER_H
37#define ML_MODEL_RUNNER_H
38
39#include "SerDes/baseSerDes.h"
41#include "SerDes/jsonSerDes.h"
42
43#include <cstdlib>
44#include <future>
45#include <memory>
46#include <string>
47#include <type_traits>
48
49#ifndef C_LIBRARY
52#endif
53namespace MLBridge {
54
57public:
58 // Disallows copy and assign.
59 MLModelRunner(const MLModelRunner &) = delete;
61 virtual ~MLModelRunner() = default;
62
64 template <typename T>
65 typename std::enable_if<std::is_fundamental<T>::value, T>::type evaluate() {
66 return *reinterpret_cast<T *>(evaluateUntyped());
67 }
68
70 template <typename T>
71 typename std::enable_if<
72 std::is_fundamental<typename std::remove_pointer<T>::type>::value,
73 void>::type
74 evaluate(T &data, size_t &dataSize) {
75 using BaseType = typename std::remove_pointer<T>::type;
76 void *res = evaluateUntyped();
77 T ret = static_cast<T>(malloc(SerDes->getMessageLength()));
78 memcpy(ret, res, SerDes->getMessageLength());
79 dataSize = SerDes->getMessageLength() / sizeof(BaseType);
80 data = ret;
81 }
82
84 enum class Kind : int { Unknown, Pipe, gRPC, ONNX, TFAOT };
85
86 Kind getKind() const { return Type; }
88
89 virtual void requestExit() = 0;
90
95
96 template <typename U, typename T, typename... Types>
97 void populateFeatures(const std::pair<U, T> &var1,
98 const std::pair<U, Types> &...var2) {
99 SerDes->setFeature(var1.first, var1.second);
100 populateFeatures(var2...);
101 }
102
103 template <typename U, typename T, typename... Types>
104 void populateFeatures(const std::pair<U, T> &&var1,
105 const std::pair<U, Types> &&...var2) {
106 SerDes->setFeature(var1.first, var1.second);
107 populateFeatures(var2...);
108 }
109
111
114 void setRequest(void *request) { SerDes->setRequest(request); }
115
118 void setResponse(void *response) { SerDes->setResponse(response); }
119
120protected:
122 llvm::LLVMContext *Ctx = nullptr)
124 assert(Type != Kind::Unknown);
125 initSerDes();
126 }
127
128 MLModelRunner(Kind Type, llvm::LLVMContext *Ctx = nullptr)
130 SerDes = nullptr;
131 };
132
135 virtual void *evaluateUntyped() = 0;
136
137 llvm::LLVMContext *Ctx;
138 const Kind Type;
140
141protected:
142 std::unique_ptr<BaseSerDes> SerDes;
143
144private:
145 void initSerDes() {
146 switch (SerDesType) {
147 case SerDesKind::Json:
148 SerDes = std::make_unique<JsonSerDes>();
149 break;
151 SerDes = std::make_unique<BitstreamSerDes>();
152 break;
153#ifndef C_LIBRARY
155 SerDes = std::make_unique<ProtobufSerDes>();
156 break;
158 SerDes = std::make_unique<TensorflowSerDes>();
159 break;
160#endif
162 SerDes = nullptr;
163 break;
164 }
165 }
166};
167} // namespace MLBridge
168
169#endif // LLVM_MLMODELRUNNER_H
Supporting new SerDes:
Bitstream Serialization/Deserialization which sends header information followed by the raw data.
MLModelRunner - The main interface for interacting with the ML models.
const SerDesKind SerDesType
std::unique_ptr< BaseSerDes > SerDes
SerDesKind getSerDesKind() const
Kind
Type of the MLModelRunner.
void setRequest(void *request)
Mainly used in the case of gRPC where the request object is not known explicitly.
void populateFeatures(const std::pair< U, T > &var1, const std::pair< U, Types > &...var2)
User-facing interface for setting the features to be sent to the model.
MLModelRunner(Kind Type, llvm::LLVMContext *Ctx=nullptr)
std::enable_if< std::is_fundamental< T >::value, T >::type evaluate()
Main user-facing method for interacting with the ML models.
MLModelRunner(Kind Type, SerDesKind SerDesType, llvm::LLVMContext *Ctx=nullptr)
MLModelRunner(const MLModelRunner &)=delete
void populateFeatures(const std::pair< U, T > &&var1, const std::pair< U, Types > &&...var2)
llvm::LLVMContext * Ctx
virtual ~MLModelRunner()=default
std::enable_if< std::is_fundamental< typenamestd::remove_pointer< T >::type >::value, void >::type evaluate(T &data, size_t &dataSize)
Main user-facing method for interacting with the ML models.
MLModelRunner & operator=(const MLModelRunner &)=delete
virtual void * evaluateUntyped()=0
Should be implemented by the derived class to call the model and get the result.
void setResponse(void *response)
Mainly used in the case of gRPC where the response object is not known explicitly.
virtual void requestExit()=0
Json Serialization/Deserialization using LLVM's json library.
SerDesKind
This is the base class for SerDes.
Definition baseSerDes.h:46
Protobuf Serialization/Deserialization to support gRPC communication.
Serialization/Deserialization to support TF AOT models.