MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
Public Member Functions | Private Member Functions | Private Attributes | List of all members
MLBridge::ONNXModelRunner Class Reference

ONNXModelRunner is the main user facing class that interfaces with the Environment and Agent classes to support ML/RL model inference via ONNXModel. More...

#include <ONNXModelRunner.h>

Inheritance diagram for MLBridge::ONNXModelRunner:
Inheritance graph
[legend]
Collaboration diagram for MLBridge::ONNXModelRunner:
Collaboration graph
[legend]

Public Member Functions

 ONNXModelRunner (MLBridge::Environment *env, std::map< std::string, Agent * > agents, llvm::LLVMContext *Ctx=nullptr)
 
void setEnvironment (MLBridge::Environment *_env)
 
MLBridge::EnvironmentgetEnvironment ()
 
void addAgent (Agent *agent, std::string name)
 
void requestExit () override
 
- Public Member Functions inherited from MLBridge::MLModelRunner
 MLModelRunner (const MLModelRunner &)=delete
 
MLModelRunneroperator= (const MLModelRunner &)=delete
 
virtual ~MLModelRunner ()=default
 
template<typename T >
std::enable_if< std::is_fundamental< T >::value, T >::type evaluate ()
 Main user-facing method for interacting with the ML models.
 
template<typename T >
std::enable_if< std::is_fundamental< typenamestd::remove_pointer< T >::type >::value, void >::type evaluate (T &data, size_t &dataSize)
 Main user-facing method for interacting with the ML models.
 
Kind getKind () const
 
SerDesKind getSerDesKind () const
 
template<typename U , typename T , typename... Types>
void populateFeatures (const std::pair< U, T > &var1, const std::pair< U, Types > &...var2)
 User-facing interface for setting the features to be sent to the model.
 
template<typename U , typename T , typename... Types>
void populateFeatures (const std::pair< U, T > &&var1, const std::pair< U, Types > &&...var2)
 
void populateFeatures ()
 
void setRequest (void *request)
 Mainly used in the case of gRPC where the request object is not known explicitly.
 
void setResponse (void *response)
 Mainly used in the case of gRPC where the response object is not known explicitly.
 

Private Member Functions

void computeAction (Observation &obs)
 
void * evaluateUntyped () override
 Should be implemented by the derived class to call the model and get the result.
 

Private Attributes

MLBridge::Environmentenv
 
std::map< std::string, Agent * > agents
 

Additional Inherited Members

- Public Types inherited from MLBridge::MLModelRunner
enum class  Kind : int {
  Unknown , Pipe , gRPC , ONNX ,
  TFAOT
}
 Type of the MLModelRunner. More...
 
- Protected Member Functions inherited from MLBridge::MLModelRunner
 MLModelRunner (Kind Type, SerDesKind SerDesType, llvm::LLVMContext *Ctx=nullptr)
 
 MLModelRunner (Kind Type, llvm::LLVMContext *Ctx=nullptr)
 
- Protected Attributes inherited from MLBridge::MLModelRunner
llvm::LLVMContext * Ctx
 
const Kind Type
 
const SerDesKind SerDesType
 
std::unique_ptr< BaseSerDesSerDes
 

Detailed Description

ONNXModelRunner is the main user facing class that interfaces with the Environment and Agent classes to support ML/RL model inference via ONNXModel.

Definition at line 42 of file ONNXModelRunner.h.

Constructor & Destructor Documentation

◆ ONNXModelRunner()

MLBridge::ONNXModelRunner::ONNXModelRunner ( MLBridge::Environment * env,
std::map< std::string, Agent * > agents,
llvm::LLVMContext * Ctx = nullptr )

Member Function Documentation

◆ addAgent()

void ONNXModelRunner::addAgent ( Agent * agent,
std::string name )

Definition at line 26 of file ONNXModelRunner.cpp.

◆ computeAction()

void ONNXModelRunner::computeAction ( Observation & obs)
private

Definition at line 37 of file ONNXModelRunner.cpp.

◆ evaluateUntyped()

void * ONNXModelRunner::evaluateUntyped ( )
overrideprivatevirtual

Should be implemented by the derived class to call the model and get the result.

Implements MLBridge::MLModelRunner.

Definition at line 51 of file ONNXModelRunner.cpp.

◆ getEnvironment()

MLBridge::Environment * MLBridge::ONNXModelRunner::getEnvironment ( )
inline

Definition at line 49 of file ONNXModelRunner.h.

◆ requestExit()

void MLBridge::ONNXModelRunner::requestExit ( )
inlineoverridevirtual

Implements MLBridge::MLModelRunner.

Definition at line 53 of file ONNXModelRunner.h.

◆ setEnvironment()

void MLBridge::ONNXModelRunner::setEnvironment ( MLBridge::Environment * _env)
inline

Definition at line 48 of file ONNXModelRunner.h.

Member Data Documentation

◆ agents

std::map<std::string, Agent *> MLBridge::ONNXModelRunner::agents
private

Definition at line 57 of file ONNXModelRunner.h.

◆ env

MLBridge::Environment* MLBridge::ONNXModelRunner::env
private

Definition at line 56 of file ONNXModelRunner.h.


The documentation for this class was generated from the following files: