MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
TFModelRunner.h
Go to the documentation of this file.
1//===- TFModelRunner.h ---- TF precompiled model runner ------*- C++-*----===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (Preliminary version adopted from ReleaseModeModelRunner.h of LLVM 17.X)
8//
9//===----------------------------------------------------------------------===//
15//===----------------------------------------------------------------------===//
16
17#ifndef TFMODELRUNNER_H
18#define TFMODELRUNNER_H
19
21#include "SerDes/TensorSpec.h"
22
23#include <memory>
24#include <vector>
25
26namespace MLBridge {
27
30template <class TGen> class TFModelRunner final : public MLModelRunner {
31public:
34 TFModelRunner(llvm::StringRef DecisionName, llvm::LLVMContext &Ctx,
35 llvm::StringRef FeedPrefix = "feed_",
36 llvm::StringRef FetchPrefix = "fetch_")
38 CompiledModel(std::make_unique<TGen>()) {
39
41
42 assert(CompiledModel && "The CompiledModel should be valid");
43
44 ResultIndex = CompiledModel->LookupResultIndex(FetchPrefix.str() +
45 DecisionName.str());
46 assert(ResultIndex >= 0 && "Cannot find DecisionName in inlining model");
47 }
48 TFModelRunner(llvm::StringRef DecisionName,
49 llvm::StringRef FeedPrefix = "feed_",
50 llvm::StringRef FetchPrefix = "fetch_")
52 CompiledModel(std::make_unique<TGen>()) {
53
55
56 assert(CompiledModel && "The CompiledModel should be valid");
57
58 ResultIndex = CompiledModel->LookupResultIndex(FetchPrefix.str() +
59 DecisionName.str());
60 assert(ResultIndex >= 0 && "Cannot find DecisionName in inlining model");
61 }
62 virtual ~TFModelRunner() = default;
63
64 virtual void requestExit() override {
65 llvm_unreachable("requestExit() is not supported in TFModelRunner");
66 }
67
68 static bool classof(const MLModelRunner *R) {
70 }
71
72private:
73 void *evaluateUntyped() override {
74 CompiledModel->Run();
75 return CompiledModel->result_data(ResultIndex);
76 }
77
78 int32_t ResultIndex = -1;
79 std::unique_ptr<TGen> CompiledModel;
80};
81
85class NoopSavedModelImpl final {
86#define NOOP_MODEL_ERRMSG \
87 "The mock AOT-ed saved model is a compile-time stub and should not be " \
88 "called."
89
90public:
91 NoopSavedModelImpl() = default;
92 int LookupArgIndex(const std::string &) {
93 llvm_unreachable(NOOP_MODEL_ERRMSG);
94 }
95 int LookupResultIndex(const std::string &) {
96 llvm_unreachable(NOOP_MODEL_ERRMSG);
97 }
98 void Run() { llvm_unreachable(NOOP_MODEL_ERRMSG); }
99 void *result_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
100 void *arg_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
101#undef NOOP_MODEL_ERRMSG
102};
103
104template <class T> bool isEmbeddedModelEvaluatorValid() { return true; }
105
107 return false;
108}
109} // namespace MLBridge
110
111#endif // TFMODELRUNNER_H
The MLModelRunner class is the main interface for interacting with the ML models.
#define NOOP_MODEL_ERRMSG
MLModelRunner - The main interface for interacting with the ML models.
std::unique_ptr< BaseSerDes > SerDes
Kind
Type of the MLModelRunner.
void setRequest(void *request)
Mainly used in the case of gRPC where the request object is not known explicitly.
llvm::LLVMContext * Ctx
A mock class satisfying the interface expected by ReleaseModeModelRunner for its TGen parameter.
int LookupResultIndex(const std::string &)
int LookupArgIndex(const std::string &)
TFModelRunner - TF Compiled model implementation of the MLModelRunner.
void * evaluateUntyped() override
Should be implemented by the derived class to call the model and get the result.
TFModelRunner(llvm::StringRef DecisionName, llvm::StringRef FeedPrefix="feed_", llvm::StringRef FetchPrefix="fetch_")
TFModelRunner(llvm::StringRef DecisionName, llvm::LLVMContext &Ctx, llvm::StringRef FeedPrefix="feed_", llvm::StringRef FetchPrefix="fetch_")
FeatureNames' type should be an indexed collection of std::string, like std::array or std::vector,...
virtual void requestExit() override
virtual ~TFModelRunner()=default
static bool classof(const MLModelRunner *R)
std::unique_ptr< TGen > CompiledModel
bool isEmbeddedModelEvaluatorValid()
bool isEmbeddedModelEvaluatorValid< NoopSavedModelImpl >()
SerDesKind
This is the base class for SerDes.
Definition baseSerDes.h:46