18#include "llvm/ADT/SmallVector.h"
29 std::map<std::string, Agent *>
agents;
40 float *(*resetFunc)();
64 "Step function is null! Define step function on env");
66 return llvm::SmallVector<float, 100>(stepRes, stepRes + numFeatures);
71 "Reset function is null! Define reset function on env");
73 return llvm::SmallVector<float, 100>(resetRes, resetRes + numFeatures);
90 env->setNumFeatures(numFeatures);
94 env->setStepFunc(stepFunc);
98 env->setResetFunc(resetFunc);
106 assert(env !=
nullptr &&
"Environment is null!");
109 va_start(args, numAgents);
110 std::map<std::string, Agent *> agents;
112 for (
int i = 0; i < numAgents; i += 2) {
113 char *agentName = va_arg(args,
char *);
114 char *agentPath = va_arg(args,
char *);
115 agents[agentName] =
new Agent(agentPath);
137 Agent *current_agent = omr->agent;
This file defines the C APIs for ONNXModelRunner.
struct Environment Environment
struct ONNXModelRunner ONNXModelRunner
This file defines the debug macros for the MLCompilerBridge.
#define MLBRIDGE_DEBUG(X)
ONNXModelRunner * createSingleAgentOMR(char *agent_path)
ONNXModelRunner * createONNXModelRunner(Environment *env, int numAgents,...)
void env_setStepFunc(Environment *env, float *(*stepFunc)(Action action))
void destroyONNXModelRunner(ONNXModelRunner *omr)
Environment * createEnvironment()
void evaluate(ONNXModelRunner *omr)
void env_setResetFunc(Environment *env, float *(*resetFunc)())
void env_setNextAgent(Environment *env, char *agentName)
bool env_checkDone(Environment *env)
void env_setDone(Environment *env)
void env_resetDone(Environment *env)
void env_setNumFeatures(Environment *env, int numFeatures)
int singleAgentEvaluate(ONNXModelRunner *obj, float *inp, int inp_size)
void destroyEnvironment(Environment *env)
Agent class to support ML/RL model inference via ONNX.
Agent is a wrapper around the ONNXModel class, interfaces with the Environment class to support ML/RL...
unsigned computeAction(Observation &obs)
Runs the ONNX model on the input Observation and returns the output.
void setDone()
SetDone sets the termination condition to true.
bool checkDone()
CheckDone returns true if the termination condition is met at the end of the episode.
void setNextAgent(std::string name)
SetNextAgent sets the name of the next agent to interact with.
virtual Observation & step(Action action)=0
Step function takes an action as input and returns the observation corresponding to the next state.
virtual Observation & reset()=0
Reset function returns the initial observation.
ONNXModelRunner is the main user facing class that interfaces with the Environment and Agent classes ...
MLBridge::Environment * env
std::map< std::string, Agent * > agents
void computeAction(Observation &obs)
llvm::SmallVector< float, 300 > Observation
float *(* stepFunc)(Action action)
Observation step(Action action)
void setNextAgent(char *agentName)
void setStepFunc(float *(*stepFunc)(Action action))
void setResetFunc(float *(*resetFunc)())
void setNumFeatures(int numFeatures)
std::string getNextAgent()
std::map< std::string, Agent * > agents
ONNXModelRunner(Agent *agent)
ONNXModelRunner(Environment *env, std::map< std::string, Agent * > &&agents)