MLCompilerBridge
Tools for streamlining communication with ML models for compiler optimizations.
Loading...
Searching...
No Matches
jsonSerDes.cpp
Go to the documentation of this file.
1//===- jsonstreamSerDes.cpp - Serializer for JSON ---------------*- C++ -*-===//
2//
3// Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
15#include "SerDes/jsonSerDes.h"
18#include "SerDes/baseSerDes.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/JSON.h"
21#include <cstdint>
22#include <string>
23
24#define DEBUG_TYPE "json-serdes"
25
26using namespace llvm;
27
28namespace MLBridge {
30 auto tempJO = J;
31 auto data = json::Value(std::move(tempJO));
32 auto *ret = new std::string();
33 llvm::raw_string_ostream OS(*ret);
34 json::OStream(OS).value(data);
36 return ret;
37}
38
40 MLBRIDGE_DEBUG(std::cout << "In JsonSerDes deserializeUntyped...\n");
41 auto dataStr = static_cast<std::string *>(data);
42 MLBRIDGE_DEBUG(std::cout << "dataStr: " << *dataStr << "\n");
43 Expected<json::Value> valueOrErr = json::parse(*dataStr);
44 if (!valueOrErr) {
45 auto *ret = new std::string();
46 llvm::raw_string_ostream SOS(*ret);
47 SOS << "Error parsing JSON: " << valueOrErr.takeError() << "\n";
48 std::cerr << &ret << "\n";
49 exit(1);
50 }
51 json::Object *ret = valueOrErr->getAsObject();
52 auto val = ret->get("out");
53 MLBRIDGE_DEBUG(std::cout << "Got the final array...\n";
54 std::cout << "End JsonSerDes deserializeUntyped...\n");
55 return desJson(val);
56}
57
58void *JsonSerDes::desJson(json::Value *V) {
59 switch (V->kind()) {
60 case json::Value::Kind::Null:
61 return nullptr;
62 case json::Value::Kind::Number: {
63 if (auto x = V->getAsInteger()) {
64 IntegerType *ret = new IntegerType();
65 *ret = x.getValue();
66 this->MessageLength = sizeof(IntegerType);
67 return ret;
68 } else if (auto x = V->getAsNumber()) {
69 RealType *ret = new RealType();
70 *ret = x.getValue();
71 this->MessageLength = sizeof(RealType);
72 return ret;
73 } else {
74 std::cerr << "Error in desJson: Number is not int, or double\n";
75 exit(1);
76 }
77 }
78 case json::Value::Kind::String: {
79 std::string *ret = new std::string();
80 *ret = V->getAsString()->str();
81 this->MessageLength = ret->size() * sizeof(char);
82 return ret->data();
83 }
84 case json::Value::Kind::Boolean: {
85 bool *ret = new bool();
86 *ret = V->getAsBoolean().getValue();
87 this->MessageLength = sizeof(bool);
88 return ret;
89 }
90 case json::Value::Kind::Array: {
91 // iterate over array and find its type
92 // assume all elements are of same type and return vector of that type
93 // if not, return nullptr
94 auto arr = V->getAsArray();
95
96 auto it = arr->begin();
97 auto first = it;
98 switch (first->kind()) {
99 case json::Value::Kind::Number: {
100 if (auto x = first->getAsInteger()) {
101 std::vector<IntegerType> *ret = new std::vector<IntegerType>();
102 for (auto it : *arr) {
103 ret->push_back(it.getAsInteger().getValue());
104 }
105 this->MessageLength = ret->size() * sizeof(IntegerType);
106 return ret->data();
107 } else if (auto x = first->getAsNumber()) {
108 std::vector<RealType> *ret = new std::vector<RealType>();
109 for (auto it : *arr) {
110 ret->push_back(it.getAsNumber().getValue());
111 }
112 this->MessageLength = ret->size() * sizeof(RealType);
113 return ret->data();
114 } else {
115 std::cerr << "Error in desJson: Number is not int, or double\n";
116 exit(1);
117 }
118 }
119 case json::Value::Kind::String: {
120 std::vector<std::string> *ret = new std::vector<std::string>();
121 for (auto it : *arr) {
122 ret->push_back(it.getAsString()->str());
123 }
124 this->MessageLength = ret->size() * sizeof(std::string);
125 return ret->data();
126 }
127 case json::Value::Kind::Boolean: {
128 std::vector<uint8_t> *ret = new std::vector<uint8_t>();
129 for (auto it : *arr) {
130 ret->push_back(it.getAsBoolean().getValue());
131 }
132 this->MessageLength = ret->size() * sizeof(uint8_t);
133 return ret->data();
134 }
135 default: {
136 std::cerr << "Error in desJson: Array is not of supported type\n";
137 exit(1);
138 }
139 }
140 }
141 }
142 return nullptr;
143}
144} // namespace MLBridge
This file defines the bit widths of integral and floating point types supported by the MLCompilerBrid...
This file defines the debug macros for the MLCompilerBridge.
#define MLBRIDGE_DEBUG(X)
Definition Debug.h:25
Supporting new SerDes:
void * desJson(llvm::json::Value *V)
void * deserializeUntyped(void *data) override
void * getSerializedData() override
llvm::json::Object J
Definition jsonSerDes.h:52
void cleanDataStructures() override
Definition jsonSerDes.h:45
Json Serialization/Deserialization using LLVM's json library.
float RealType
Definition DataTypes.h:28
int IntegerType
Definition DataTypes.h:29