MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
onnx.cpp
Go to the documentation of this file.
1//===- onnx.cpp - ONNX Interface with CPP Runtime --------------*- C++ -*-===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
16#include "onnxruntime_cxx_api.h"
17
18#include <algorithm>
19#include <assert.h>
20#include <cmath>
21#include <iostream>
22#include <numeric>
23
24ONNXModel::ONNXModel(const char *model_path) : model_path(model_path) {
25 env = new Ort::Env(ORT_LOGGING_LEVEL_WARNING, "test");
26 session = new Ort::Session(*env, model_path, Ort::SessionOptions{nullptr});
27}
28
29Ort::Value ONNXModel::getInputValue(llvm::SmallVector<float, 100> &input,
30 int inputIdx) {
31 auto typeInfo = session->GetInputTypeInfo(inputIdx);
32 auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo();
33 auto inputDims = tensorInfo.GetShape();
34 std::replace_if(
35 inputDims.begin(), inputDims.end(), [](int64_t &i) { return i < 0; }, 1);
36
37 size_t inputTensorSize = std::accumulate(inputDims.begin(), inputDims.end(),
38 1, std::multiplies<int>());
39 assert(inputTensorSize == input.size());
40
41 auto memory_info =
42 Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
43 auto inputTmp = Ort::Value::CreateTensor<float>(
44 memory_info, input.data(), inputTensorSize, inputDims.data(),
45 inputDims.size());
46 auto inputTensor = &inputTmp;
47 assert(inputTensor->IsTensor());
48 return inputTmp;
49}
50
51void ONNXModel::run(llvm::SmallVector<float, 100> &input,
52 llvm::SmallVector<float, 100> &output) {
53 Ort::AllocatorWithDefaultOptions allocator;
54
55 int input_count = session->GetInputCount();
56 llvm::SmallVector<llvm::SmallVector<float, 100>, 10> inputList;
57 inputList.push_back(input);
58 llvm::SmallVector<float, 100> dummy_input;
59 dummy_input.insert(dummy_input.end(), 1);
60 for (int i = 1; i < input_count; i++) {
61 inputList.push_back(dummy_input);
62 }
63
64 llvm::SmallVector<std::string, 10> inputNameList;
65 for (int i = 0; i < input_count; i++) {
66 auto inputName = session->GetInputNameAllocated(i, allocator);
67 auto inputNameStr = inputName.get();
68 inputNameList.push_back(inputNameStr);
69 }
70
71 std::vector<Ort::Value> input_final;
72 llvm::SmallVector<const char *, 10> inputNameStr_final;
73
74 for (int i = 0; i < input_count; i++) {
75 input_final.insert(input_final.end(), getInputValue(inputList[i], i));
76 inputNameStr_final.insert(inputNameStr_final.end(),
77 inputNameList[i].c_str());
78 }
79
80 auto outputName = session->GetOutputNameAllocated(0, allocator);
81 auto outputNameStr = outputName.get();
82
83 auto outputTensors =
84 session->Run(Ort::RunOptions{nullptr}, inputNameStr_final.data(),
85 input_final.data(), input_count, &outputNameStr, 1);
86
87 assert(outputTensors.size() == 1 && outputTensors.front().IsTensor());
88
89 auto outputDims =
90 outputTensors.front().GetTensorTypeAndShapeInfo().GetShape()[1];
91
92 auto outVal = outputTensors.front().GetTensorMutableData<float>();
93
94 output = llvm::SmallVector<float, 100>(outVal, outVal + outputDims);
95 std::replace_if(
96 output.begin(), output.end(), [](double x) { return std::isnan(x); },
97 -1.17549e+038);
98}
const char * model_path
Definition onnx.h:31
Ort::Session * session
Definition onnx.h:32
Ort::Value getInputValue(llvm::SmallVector< float, 100 > &input, int inputIdx)
Definition onnx.cpp:29
Ort::Env * env
Definition onnx.h:30
ONNXModel(const char *model_path)
Definition onnx.cpp:24
void run(llvm::SmallVector< float, 100 > &input, llvm::SmallVector< float, 100 > &output)
Runs the ONNX model on the input and returns the output.
Definition onnx.cpp:51
This file defines the ONNXModel class, which is a wrapper around the ONNX C++ interface.