MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
ONNXModelRunner.h
Go to the documentation of this file.
1//=== MLModelRunner/ONNXModelRunner/ONNXModelRunner.h - C++ --------------===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
28//===----------------------------------------------------------------------===//
29
30#ifndef ONNX_MODELRUNNER_H
31#define ONNX_MODELRUNNER_H
32
36
37namespace MLBridge {
38
43public:
45 std::map<std::string, Agent *> agents,
46 llvm::LLVMContext *Ctx = nullptr);
47
50
51 void addAgent(Agent *agent, std::string name);
52
53 void requestExit() override {}
54
55private:
57 std::map<std::string, Agent *> agents;
58 void computeAction(Observation &obs);
59 void *evaluateUntyped() override;
60};
61} // namespace MLBridge
62#endif // ONNX_MODELRUNNER_H
The MLModelRunner class is the main interface for interacting with the ML models.
Agent class to support ML/RL model inference via ONNX.
Agent is a wrapper around the ONNXModel class, interfaces with the Environment class to support ML/RL...
Definition agent.h:30
MLModelRunner - The main interface for interacting with the ML models.
llvm::LLVMContext * Ctx
ONNXModelRunner is the main user facing class that interfaces with the Environment and Agent classes ...
void * evaluateUntyped() override
Should be implemented by the derived class to call the model and get the result.
void setEnvironment(MLBridge::Environment *_env)
MLBridge::Environment * env
ONNXModelRunner(MLBridge::Environment *env, std::map< std::string, Agent * > agents, llvm::LLVMContext *Ctx=nullptr)
MLBridge::Environment * getEnvironment()
void addAgent(Agent *agent, std::string name)
std::map< std::string, Agent * > agents
void computeAction(Observation &obs)
Base Environment class to support ONNX based inference of RL models.
llvm::SmallVector< float, 300 > Observation
Definition utils.h:15