IR2Vec
Loading...
Searching...
No Matches
VectorSolverEigen.h
1//===- VectorSolverEigen.h - Solver flow using Eigen -----------*- C++ -*-===//
2//
3// Part of the IR2Vec Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef __VECTOR_SOLVER_H__
10#define __VECTOR_SOLVER_H__
11#define EIGEN_MPL2_ONLY
12
13#include "Eigen/LU"
14#include "Eigen/QR"
15#include "llvm/ADT/SmallVector.h"
16#include <vector>
17
18using namespace Eigen;
19using namespace llvm;
20
21typedef std::vector<std::vector<double>> matrix;
22
23MatrixXd calculate(MatrixXd A, MatrixXd B) {
24 if (A.determinant() != 0) {
25 return A.fullPivHouseholderQr().solve(B);
26 } else {
27 // To-Do: perturb probabilities
28 llvm_unreachable("inconsistent/infinitely many solutions");
29 }
30}
31
32MatrixXd formMatrix(std::vector<std::vector<double>> a, int r, int l) {
33 MatrixXd M(r, l);
34 for (int i = 0; i < r; i++)
35 M.row(i) = VectorXd::Map(&a[i][0], a[i].size());
36
37 return M;
38}
39
40matrix solve(matrix A, matrix B) {
41 int r = A.size();
42 int c = A[0].size();
43 MatrixXd mA(r, c);
44 mA = formMatrix(A, r, c);
45
46 r = B.size();
47 c = B[0].size();
48 MatrixXd mB(r, c);
49 mB = formMatrix(B, r, c);
50
51 r = A.size();
52 MatrixXd x(r, c);
53 x = calculate(mA, mB);
54 std::vector<std::vector<double>> raw_data;
55 // raw_data.resize(x.rows());
56 for (unsigned i = 0; i < x.rows(); i++) {
57 std::vector<double> tmp;
58 tmp.resize(x.cols());
59 VectorXd::Map(&tmp[0], x.cols()) = x.row(i);
60 raw_data.push_back(tmp);
61 }
62 return raw_data;
63}
64
65#endif