MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
gRPCModelRunner.h
Go to the documentation of this file.
1//=== MLModelRunner/gRPCModelRunner.h -MLConfig class definition - C++ -*--===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
61// ===----------------------------------------------------------------------===//
62
63#ifndef GRPC_MODELRUNNER_H
64#define GRPC_MODELRUNNER_H
65
67
68#include <future>
69#include <google/protobuf/text_format.h>
70#include <grpcpp/grpcpp.h>
71#include <grpcpp/health_check_service_interface.h>
72#include <memory>
73#include <thread>
74
75namespace MLBridge {
79template <class Client, class Stub, class Request, class Response>
81public:
83 gRPCModelRunner(std::string server_address, grpc::Service *s,
84 llvm::LLVMContext *Ctx = nullptr)
86 server_address(server_address), request(nullptr), response(nullptr),
87 server_mode(true) {
88 RunService(s);
89 }
90
92 gRPCModelRunner(std::string server_address, Request *request,
93 Response *response, llvm::LLVMContext *Ctx = nullptr)
96 server_mode(false) {
97 SetStub();
98 }
99
100 void requestExit() override {
101 std::string input;
102 std::cin >> input;
103 if (input == "Terminate") {
104 this->exit_requested->set_value();
105 } else {
106 std::cout << "Problem while closing server\n";
107 }
108 }
109
110private:
112 bool isPortAvailable(std::string addr) {
113 int max_retries = 30, attempts = 0;
114 double wait_seconds = 0.2, backoff_exp = 1.2;
115
116 int idx = addr.find(":");
117 int port = stoi(addr.substr(idx + 1, addr.size() - idx - 1));
118
119 while (attempts < max_retries) {
120 std::string command = "lsof -i :" + std::to_string(port);
121 FILE *pipe = popen(command.c_str(), "r");
122 if (!pipe) {
123 std::cerr << "Error executing command: " << std::strerror(errno)
124 << std::endl;
125 return false;
126 }
127
128 char buffer[256];
129 std::string result = "";
130 while (!feof(pipe)) {
131 if (fgets(buffer, 256, pipe) != nullptr)
132 result += buffer;
133 }
134 pclose(pipe);
135
136 if (result.empty()) {
137 return true;
138 }
139 attempts++;
140 std::cout << "Port is unavailable retrying! attempt: " << attempts
141 << std::endl;
142 std::this_thread::sleep_for(std::chrono::duration<double>(wait_seconds));
143 wait_seconds *= backoff_exp;
144 }
145
146 std::cout << "Port is unavailable now!" << std::endl;
147 return false;
148 }
149
150 std::promise<void> *exit_requested;
151
154 void *evaluateUntyped() override {
155 assert(!server_mode &&
156 "evaluateUntyped not implemented for gRPCModelRunner; "
157 "Override gRPC method instead");
158 assert(request != nullptr && "Request cannot be null");
159
160 int max_retries = 30, attempts = 0;
161 double retries_wait_secs = 0.2;
162 int deadline_time = 10000;
163 int deadline_max_retries = 30, deadline_attpts = 0;
164 double retry_wait_backoff_exponent = 1.5;
165
166 // setting a deadline
167 auto deadline = std::chrono::system_clock::now() +
168 std::chrono::milliseconds(deadline_time);
169
170 while (attempts < max_retries && deadline_attpts < deadline_max_retries) {
171 grpc::ClientContext grpcCtx;
173 grpc::Status status;
174 grpcCtx.set_deadline(deadline);
175
176 status = stub_->getAdvice(&grpcCtx, *request, response);
177
178 if (status.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
179 deadline_attpts++;
180 int ext_deadline = 2 * deadline_time;
181 deadline_time = ext_deadline;
182 std::cout << "Deadline Exceeded for Request! sending the message again "
183 "with extended deadline : "
184 << deadline_time << "\n";
185 deadline = std::chrono::system_clock::now() +
186 std::chrono::milliseconds(deadline_time);
187 } else if (status.error_code() == grpc::StatusCode::UNAVAILABLE) {
188 attempts++;
189 std::cout << "Server is unavailable retrying! attempt: " << attempts
190 << "\n";
191 std::this_thread::sleep_for(
192 std::chrono::duration<double>(retries_wait_secs));
193 retries_wait_secs *= retry_wait_backoff_exponent;
194 } else {
195 request->Clear();
196 if (!status.ok()) {
197 if (Ctx)
198 Ctx->emitError("gRPC failed: " + status.error_message());
199 else
200 std::cerr << "gRPC failed: " << status.error_message() << std::endl;
201 }
202 // auto *action = new int(); // Hard wired for PosetRL case, should be
203 // fixed *action = response->action(); return action;
204 return SerDes->deserializeUntyped(response);
205 }
206 }
207
208 std::cout << "Server is unavailable now!!!\n";
209 return new int(-1);
210 }
211
212 Stub *stub_;
213 std::string server_address;
214 Request *request;
215 Response *response;
217
220 int RunService(grpc::Service *s) {
221 exit_requested = new std::promise<void>();
222 grpc::ServerBuilder builder;
223 // if (!this->isPortAvailable(server_address)) return -1;
224 builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
225 builder.RegisterService(s);
226 std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
227 std::cout << "Server Listening on " << server_address << std::endl;
228 auto serveFn = [&]() { server->Wait(); };
229 std::thread serving_thread(serveFn);
230 auto f = exit_requested->get_future();
231 this->requestExit();
232 f.wait();
233 server->Shutdown();
234 serving_thread.join();
235 std::cout << "Server Shutdowns Successfully" << std::endl;
236 return 0;
237 }
238
240 int SetStub() {
241 std::shared_ptr<grpc::Channel> channel =
242 grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials());
243 auto Stub_temp = Client::NewStub(channel);
244 stub_ = Stub_temp.release();
245 return 0;
246 }
247
248 Request *getRequest() { return (Request *)SerDes->getRequest(); }
249
250 Response *getResponse() { return (Response *)SerDes->getResponse(); }
251
252 void printMessage(const google::protobuf::Message *message) {
253 std::string s;
254 if (google::protobuf::TextFormat::PrintToString(*message, &s)) {
255 std::cout << "Your message: " << s << std::endl;
256 } else {
257 std::cerr << "Message not valid (partial content: "
258 << request->ShortDebugString() << ")\n";
259 }
260 }
261};
262} // namespace MLBridge
263
264#endif // GRPC_MODELRUNNER_H
The MLModelRunner class is the main interface for interacting with the ML models.
MLModelRunner - The main interface for interacting with the ML models.
std::unique_ptr< BaseSerDes > SerDes
Kind
Type of the MLModelRunner.
llvm::LLVMContext * Ctx
This class is used to create the grpc model runner object.
int RunService(grpc::Service *s)
This method is used to create the server and start listening.
int SetStub()
This method is used to create the stub. Used in client mode.
gRPCModelRunner(std::string server_address, grpc::Service *s, llvm::LLVMContext *Ctx=nullptr)
For server mode.
gRPCModelRunner(std::string server_address, Request *request, Response *response, llvm::LLVMContext *Ctx=nullptr)
For client mode.
std::promise< void > * exit_requested
void printMessage(const google::protobuf::Message *message)
void * evaluateUntyped() override
This method is used to send the request to the model and get the result.
bool isPortAvailable(std::string addr)
checks whether a port number is available or not
SerDesKind
This is the base class for SerDes.
Definition baseSerDes.h:46