MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
ONNXModelRunner.cpp
Go to the documentation of this file.
1//===- ONNXModelRunner.cpp - ONNX Runner ------------------------*- C++ -*-===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
16#include "SerDes/baseSerDes.h"
17
18using namespace llvm;
19namespace MLBridge {
20
22 std::map<std::string, Agent *> agents,
23 LLVMContext *Ctx)
24 : MLModelRunner(Kind::ONNX, Ctx), env(env), agents(agents) {}
25
26void ONNXModelRunner::addAgent(Agent *agent, std::string name) {
27 if (agents.find(name) == agents.end()) {
28 agents[name] = agent;
29 } else {
30 // throw error
31 std::cerr << "ERROR: Agent with the name " << name
32 << " already exists. Please give a different name!\n";
33 exit(1);
34 }
35}
36
37void ONNXModelRunner::computeAction(Observation &obs) {
38 while (true) {
39 Action action;
40 // current agent
41 auto current_agent = this->agents[this->env->getNextAgent()];
42 action = current_agent->computeAction(obs);
43 this->env->step(action);
44 if (this->env->checkDone()) {
45 std::cout << "Done🎉\n";
46 break;
47 }
48 }
49}
50
51void *ONNXModelRunner::evaluateUntyped() {
52 Observation &obs = env->reset();
53 computeAction(obs);
54 return new int(0);
55}
56
57} // namespace MLBridge
signed Action
ONNXModelRunner class supporting communication via ONNX C++ Runtime.
Supporting new SerDes:
Agent is a wrapper around the ONNXModel class, interfaces with the Environment class to support ML/RL...
Definition agent.h:30
bool checkDone()
CheckDone returns true if the termination condition is met at the end of the episode.
Definition environment.h:66
virtual Observation & step(Action action)=0
Step function takes an action as input and returns the observation corresponding to the next state.
virtual Observation & reset()=0
Reset function returns the initial observation.
std::string getNextAgent()
GetNextAgent returns the name/ID of the next agent to interact with.
Definition environment.h:73
llvm::SmallVector< float, 300 > Observation
Definition utils.h:15
std::map< std::string, Agent * > agents
ONNXModelRunner(Environment *env, std::map< std::string, Agent * > &&agents)