MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
PipeModelRunner.cpp
Go to the documentation of this file.
1//===- PipeModelRunner.cpp - Pipe based Model Runner ------------*- C++ -*-===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (Preliminary version adopted from InteractiveModelRunner.cpp of LLVM 17.X)
8//
9//===----------------------------------------------------------------------===//
17//===----------------------------------------------------------------------===//
18
22#include <cstddef>
23#include <cstring>
24#include <fstream>
25#include <iostream>
26#include <string>
27
28#define DEBUG_TYPE "pipe-model-runner"
29
30using namespace llvm;
31
32namespace MLBridge {
33PipeModelRunner::PipeModelRunner(StringRef OutboundName, StringRef InboundName,
34 SerDesKind SerDesType, LLVMContext *Ctx)
35 : MLModelRunner(Kind::Pipe, SerDesType, Ctx),
36 InEC(sys::fs::openFileForRead(InboundName, Inbound)) {
37 this->InboundName = InboundName.str();
38 if (InEC) {
39 int max_retries = 30, attempts = 0;
40 double wait_seconds = 0.2, backoff_exp = 1.2;
41
42 while (attempts < max_retries) {
43 InEC = sys::fs::openFileForRead(InboundName, Inbound);
44 if (InEC) {
45 attempts++;
46 std::cout << "Cannot open inbound file retrying! attempt: " << attempts
47 << std::endl;
48 std::this_thread::sleep_for(
49 std::chrono::duration<double>(wait_seconds));
50 wait_seconds *= backoff_exp;
51 } else {
52 break;
53 }
54 }
55 if (InEC) {
56 auto message = "Cannot open inbound file: " + InEC.message();
57 if (this->Ctx)
58 this->Ctx->emitError(message);
59 else
60 std::cerr << message << std::endl;
61 return;
62 }
63 }
64 {
65 OutStream = std::make_unique<raw_fd_ostream>(OutboundName, OutEC);
66 if (OutEC) {
67 auto message = "Cannot open outbound file: " + OutEC.message();
68 if (this->Ctx)
69 this->Ctx->emitError(message);
70 else
71 std::cerr << message << std::endl;
72 return;
73 }
74 }
75}
76
78 // close the file descriptors
79 sys::fs::file_t FDAsOSHandle = sys::fs::convertFDToNativeFile(Inbound);
80 sys::fs::closeFile(FDAsOSHandle);
81
82 OutStream->close();
83}
84
85std::string PipeModelRunner::readNBytes(size_t N) {
86 std::string OutputBuffer(N, '\0');
87 char *Buff = OutputBuffer.data();
88 size_t InsPoint = 0;
89 const size_t Limit = N;
90 while (InsPoint < Limit) {
91 auto ReadOrErr = ::sys::fs::readNativeFile(
92 sys::fs::convertFDToNativeFile(Inbound),
93 {Buff + InsPoint, OutputBuffer.size() - InsPoint});
94 if (ReadOrErr.takeError()) {
95 if (this->Ctx)
96 this->Ctx->emitError("Failed reading from inbound file");
97 else
98 std::cerr << "Failed reading from inbound file" << std::endl;
99 break;
100 }
101 InsPoint += *ReadOrErr;
102 }
103 return OutputBuffer;
104}
105
106void PipeModelRunner::send(void *data) {
107 // TODO: send data size first (a hack to send protbuf data properly)
108 auto dataString = reinterpret_cast<std::string *>(data);
109 size_t message_length = dataString->size();
110 const char *message_length_ptr =
111 reinterpret_cast<const char *>(&message_length);
112 MLBRIDGE_DEBUG(std::cout << "Message length: " << message_length << "\n");
113 MLBRIDGE_DEBUG(std::cout << "DataString.size(): " << dataString->size()
114 << "\n");
115 OutStream->write(message_length_ptr, sizeof(size_t));
116 OutStream->write(dataString->data(), dataString->size());
117 OutStream->flush();
118}
119
121 MLBRIDGE_DEBUG(std::cout << "In PipeModelRunner receive...\n");
122 auto hdr = readNBytes(8);
123 MLBRIDGE_DEBUG(std::cout << "Read header...\n");
124 size_t MessageLength = 0;
125 memcpy(&MessageLength, hdr.data(), sizeof(MessageLength));
126 // Read message
127 auto OutputBuffer = new std::string(readNBytes(MessageLength));
128 MLBRIDGE_DEBUG(std::cout << "OutputBuffer size: " << OutputBuffer->size()
129 << "\n";
130 std::cout << "OutputBuffer: " << *OutputBuffer << "\n");
131 return OutputBuffer;
132}
133
135 MLBRIDGE_DEBUG(std::cout << "In PipeModelRunner evaluateUntyped...\n");
136 auto *data = SerDes->getSerializedData();
137 send(data);
138 auto *reply = receive();
140 std::cout << "In PipeModelRunner::evaluateUntyped() received data...\n");
141 return SerDes->deserializeUntyped(reply);
142}
143
144} // namespace MLBridge
This file defines the debug macros for the MLCompilerBridge.
#define MLBRIDGE_DEBUG(X)
Definition Debug.h:25
The MLModelRunner class is the main interface for interacting with the ML models.
PipeModelRunner class supporting communication via OS pipes between the compiler and an external ML a...
MLModelRunner - The main interface for interacting with the ML models.
std::unique_ptr< BaseSerDes > SerDes
Kind
Type of the MLModelRunner.
llvm::LLVMContext * Ctx
std::string readNBytes(size_t N)
std::unique_ptr< llvm::raw_fd_ostream > OutStream
PipeModelRunner(llvm::StringRef OutboundName, llvm::StringRef InboundName, SerDesKind Kind, llvm::LLVMContext *Ctx=nullptr)
void * evaluateUntyped() override
Should be implemented by the derived class to call the model and get the result.
SerDesKind
This is the base class for SerDes.
Definition baseSerDes.h:46