MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
SerDes.py
Go to the documentation of this file.
1# ------------------------------------------------------------------------------
2#
3# Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4# Exceptions. See the LICENSE file for license information.
5# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6#
7# ------------------------------------------------------------------------------
8
13
14import json
15from . import log_reader
16import ctypes
17import struct
18
19
20class NpEncoder(json.JSONEncoder):
21 def default(self, obj):
22 if isinstance(obj, ctypes.c_long):
23 return obj.value
24 if isinstance(obj, ctypes.c_double):
25 return obj.value
26 return super(NpEncoder, self).default(obj)
27
28
29
30class SerDes:
31
33 def __init__(self, data_format):
34 self.buffer = None
35 self.data_format = data_format
36 self.read_stream_iter = None
37
38 self.serMap = {
42 }
43 self.desMap = {
47 }
48
49
52 def deserializeData(self, rawdata):
53 return self.desMap[self.data_format](rawdata)
54
55
56 def deserializeJson(self, datastream):
57 hdr = datastream.read(8)
58 size = int.from_bytes(hdr, "little")
59 data = datastream.read(size)
60 return json.loads(data)
61
62
63 def deserializeBytes(self, datastream):
64 if self.read_stream_iter is None:
65 self.read_stream_iter = log_reader.read_stream2(
66 datastream
67 ) # try to make it indep
68 hdr = datastream.read(8)
69 context, observation_id, features, score = next(self.read_stream_iter)
70 return features
71
72 # Not implemented
73 def deserializeProtobuf(self, datastream):
74 raise NotImplementedError
75
76
77 def serializeData(self, data):
78 self.serMap[self.data_format](data)
79
80
81 def serializeJson(self, data):
82 msg = json.dumps({"out": data}, cls=NpEncoder).encode("utf-8")
83 hdr = len(msg).to_bytes(8, "little")
84 self.buffer = hdr + msg
85
86
87 def serializeBytes(self, data):
88 def _pack(data):
89 if isinstance(data, int):
90 return struct.pack("i", data)
91 elif isinstance(data, float):
92 return struct.pack("f", data)
93 elif isinstance(data, str) and len(data) == 1:
94 return struct.pack("c", data)
95 elif isinstance(data, ctypes.c_double):
96 return struct.pack("d", data.value)
97 elif isinstance(data, ctypes.c_long):
98 return struct.pack("l", data.value)
99 elif isinstance(data, list):
100 return b"".join([_pack(x) for x in data])
101
102 msg = _pack(data)
103 hdr = len(msg).to_bytes(8, "little")
104 self.buffer = hdr + msg
105
106 # Implemented outside for now
107 def serializeProtobuf(self, data):
108 self.buffer = data
109
110
113 out = self.buffer
114 self.buffer = None
115 return out
Class for serialization and deserialization in various formats for communication.
Definition SerDes.py:30
getOutputBuffer(self)
Returns value in buffer and empties it.
Definition SerDes.py:112
deserializeBytes(self, datastream)
Deserializes and returns bitstream data.
Definition SerDes.py:63
deserializeJson(self, datastream)
Deserializes and returns JSON data.
Definition SerDes.py:56
deserializeProtobuf(self, datastream)
Definition SerDes.py:73
deserializeData(self, rawdata)
Deserializes data for specified data format.
Definition SerDes.py:52
serializeBytes(self, data)
Serializes data to bitstream.
Definition SerDes.py:87
serializeData(self, data)
Serializes data and places it in a buffer.
Definition SerDes.py:77
__init__(self, data_format)
Contructor for SerDes object.
Definition SerDes.py:33
serializeJson(self, data)
Serializes data to JSON.
Definition SerDes.py:81