MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
agent.cpp
Go to the documentation of this file.
1//===- agent.cpp - RL Agent/Model for ONNX Runner --------------*- C++ -*-===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
17#include <algorithm>
18#include <cmath>
19#include <iterator>
20
21namespace MLBridge {
22
23Agent::Agent(std::string modelPath) {
24 // Create model object here
25 this->model = new ONNXModel(modelPath.c_str());
26}
27
29 // Call model on input
30 assert(input.size() > 0);
31 llvm::SmallVector<float, 100> model_input(input.begin(), input.end());
32 llvm::SmallVector<float, 100> model_output;
33
34 this->model->run(model_input, model_output);
35
36 // select action from model output
37 auto max = std::max_element(model_output.begin(),
38 model_output.end()); // [2, 4)
39 int argmaxVal = std::distance(model_output.begin(), max);
40
42 std::cout << "---------------MODEL OUTPUT VECTOR:----------------\n";
43 for (auto e
44 : model_output) { std::cout << e << " "; } std::cout
45 << "\nmax value and index are " << *max << " " << argmaxVal << "\n";);
46 return argmaxVal;
47}
48
49} // namespace MLBridge
This file defines the debug macros for the MLCompilerBridge.
#define MLBRIDGE_DEBUG(X)
Definition Debug.h:25
Agent class to support ML/RL model inference via ONNX.
ONNXModel * model
Definition agent.h:31
Agent(std::string model_path)
Definition agent.cpp:23
unsigned computeAction(Observation &obs)
Runs the ONNX model on the input Observation and returns the output.
Definition agent.cpp:28
void run(llvm::SmallVector< float, 100 > &input, llvm::SmallVector< float, 100 > &output)
Runs the ONNX model on the input and returns the output.
Definition onnx.cpp:51
llvm::SmallVector< float, 300 > Observation
Definition utils.h:15