MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
ONNXModelRunnerCWrapper.cpp
Go to the documentation of this file.
1//=== MLModelRunner/C/ONNXModelRunner.cpp - C API for ONNXModelRunner -----===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
12//===----------------------------------------------------------------------===//
13
18#include "llvm/ADT/SmallVector.h"
19#include <cassert>
20#include <iostream>
21#include <map>
22#include <stdarg.h>
23#include <vector>
24
25using namespace MLBridge;
29 std::map<std::string, Agent *> agents;
30 ONNXModelRunner(Environment *env, std::map<std::string, Agent *> &&agents)
31 : env(env), agents(agents) {}
32
33 ONNXModelRunner(Agent *agent) : agent(agent) {}
34};
35
37private:
38 // Function pointer to the step and reset functions
39 float *(*stepFunc)(Action action);
40 float *(*resetFunc)();
42 std::string nextAgent;
43 bool done;
44
45public:
46 Environment() : stepFunc(nullptr), resetFunc(nullptr) {}
47 // EnvironmentImpl(float *(*stepFunc)(Action action), float *(*resetFunc)())
48 // : stepFunc(stepFunc), resetFunc(resetFunc) {}
49
50 void setNumFeatures(int numFeatures) { this->numFeatures = numFeatures; }
51
52 void setStepFunc(float *(*stepFunc)(Action action)) {
53 this->stepFunc = stepFunc;
54 }
55
56 void setResetFunc(float *(*resetFunc)()) { this->resetFunc = resetFunc; }
57
58 void setNextAgent(char *agentName) { nextAgent = agentName; }
59
60 std::string getNextAgent() { return nextAgent; }
61
63 assert(stepFunc != nullptr &&
64 "Step function is null! Define step function on env");
65 float *stepRes = stepFunc(action);
66 return llvm::SmallVector<float, 100>(stepRes, stepRes + numFeatures);
67 }
68
70 assert(resetFunc != nullptr &&
71 "Reset function is null! Define reset function on env");
72 float *resetRes = resetFunc();
73 return llvm::SmallVector<float, 100>(resetRes, resetRes + numFeatures);
74 }
75
76 bool checkDone() { return done; }
77 void setDone() { done = true; }
78 void resetDone() { done = false; }
79};
80
82
83void env_setDone(Environment *env) { env->setDone(); }
84
85void env_resetDone(Environment *env) { env->resetDone(); }
86
87bool env_checkDone(Environment *env) { return env->checkDone(); }
88
89void env_setNumFeatures(Environment *env, int numFeatures) {
90 env->setNumFeatures(numFeatures);
91}
92
93void env_setStepFunc(Environment *env, float *(*stepFunc)(Action action)) {
94 env->setStepFunc(stepFunc);
95}
96
97void env_setResetFunc(Environment *env, float *(*resetFunc)()) {
98 env->setResetFunc(resetFunc);
99}
100
101void env_setNextAgent(Environment *env, char *agentName) {
102 env->setNextAgent(agentName);
103}
104
106 assert(env != nullptr && "Environment is null!");
107
108 va_list args;
109 va_start(args, numAgents);
110 std::map<std::string, Agent *> agents;
111
112 for (int i = 0; i < numAgents; i += 2) {
113 char *agentName = va_arg(args, char *);
114 char *agentPath = va_arg(args, char *);
115 agents[agentName] = new Agent(agentPath);
116 }
117
118 va_end(args);
119
120 ONNXModelRunner *obj = new ONNXModelRunner(env, std::move(agents));
121 return obj;
122}
123
125 Agent *agent = new Agent(agent_path);
126 ONNXModelRunner *obj = new ONNXModelRunner(agent);
127 return obj;
128}
129
131 auto x = omr->env->reset();
132
133 while (true) {
134 Action action;
135 // current agent
136 // auto current_agent = omr->agents[omr->env->getNextAgent()];
137 Agent *current_agent = omr->agent;
138 action = current_agent->computeAction(x);
139 MLBRIDGE_DEBUG(std::cout << "Action: " << action << "\n");
140 x = omr->env->step(action);
141 if (omr->env->checkDone()) {
142 MLBRIDGE_DEBUG(std::cout << "Done🎉\n");
143 break;
144 }
145 }
146}
147
148int singleAgentEvaluate(ONNXModelRunner *obj, float *inp, int inp_size) {
149 Observation obs(inp, inp + inp_size);
150 Action action = obj->agent->computeAction(obs);
151 MLBRIDGE_DEBUG(std::cout << "action :: " << action << "\n");
152 return action;
153}
154
155void destroyEnvironment(Environment *env) { delete env; }
156
157void destroyONNXModelRunner(ONNXModelRunner *omr) { delete omr; }
This file defines the C APIs for ONNXModelRunner.
struct Environment Environment
struct ONNXModelRunner ONNXModelRunner
signed Action
This file defines the debug macros for the MLCompilerBridge.
#define MLBRIDGE_DEBUG(X)
Definition Debug.h:25
ONNXModelRunner * createSingleAgentOMR(char *agent_path)
ONNXModelRunner * createONNXModelRunner(Environment *env, int numAgents,...)
void env_setStepFunc(Environment *env, float *(*stepFunc)(Action action))
void destroyONNXModelRunner(ONNXModelRunner *omr)
Environment * createEnvironment()
void evaluate(ONNXModelRunner *omr)
void env_setResetFunc(Environment *env, float *(*resetFunc)())
void env_setNextAgent(Environment *env, char *agentName)
bool env_checkDone(Environment *env)
void env_setDone(Environment *env)
void env_resetDone(Environment *env)
void env_setNumFeatures(Environment *env, int numFeatures)
int singleAgentEvaluate(ONNXModelRunner *obj, float *inp, int inp_size)
void destroyEnvironment(Environment *env)
Agent class to support ML/RL model inference via ONNX.
Agent is a wrapper around the ONNXModel class, interfaces with the Environment class to support ML/RL...
Definition agent.h:30
unsigned computeAction(Observation &obs)
Runs the ONNX model on the input Observation and returns the output.
Definition agent.cpp:28
void setDone()
SetDone sets the termination condition to true.
Definition environment.h:69
bool checkDone()
CheckDone returns true if the termination condition is met at the end of the episode.
Definition environment.h:66
void setNextAgent(std::string name)
SetNextAgent sets the name of the next agent to interact with.
Definition environment.h:76
virtual Observation & step(Action action)=0
Step function takes an action as input and returns the observation corresponding to the next state.
virtual Observation & reset()=0
Reset function returns the initial observation.
std::string nextAgent
Definition environment.h:58
ONNXModelRunner is the main user facing class that interfaces with the Environment and Agent classes ...
MLBridge::Environment * env
std::map< std::string, Agent * > agents
void computeAction(Observation &obs)
llvm::SmallVector< float, 300 > Observation
Definition utils.h:15
float *(* stepFunc)(Action action)
Observation step(Action action)
void setNextAgent(char *agentName)
void setStepFunc(float *(*stepFunc)(Action action))
void setResetFunc(float *(*resetFunc)())
void setNumFeatures(int numFeatures)
std::string getNextAgent()
std::map< std::string, Agent * > agents
ONNXModelRunner(Environment *env, std::map< std::string, Agent * > &&agents)