IR2Vec
Loading...
Searching...
No Matches
VectorSolver.h
1//===- VectorSolver.h - Hand-written Solver flow ---------------*- C++ -*-===//
2//
3// Part of the IR2Vec Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef __VECTOR_SOLVER_H__
10#define __VECTOR_SOLVER_H__
11
12#include <algorithm>
13#include <cmath>
14#include <iostream>
15#include <limits>
16#include <vector>
17using namespace std;
18typedef std::vector<std::vector<double>> matrix;
19// Function to swap rows in a matrix
20void swapRows(std::vector<double> &row1, std::vector<double> &row2) {
21 std::swap(row1, row2);
22}
23
24const double EPS = 1e-9;
25
26void gaussJordan(matrix a, int k, matrix &ans) {
27 int n = (int)a.size();
28 int m = (int)a[0].size() - k;
29
30 vector<int> where(m, -1);
31 for (int col = 0, row = 0; col < m && row < n; ++col) {
32 int sel = row;
33 for (int i = row; i < n; ++i)
34 if (abs(a[i][col]) > abs(a[sel][col]))
35 sel = i;
36 if (abs(a[sel][col]) < EPS)
37 continue;
38 for (int i = 0; i < m + k; ++i)
39 swap(a[sel][i], a[row][i]);
40 where[col] = row;
41
42 for (int i = 0; i < n; ++i)
43 if (i != row) {
44 double c = a[i][col] / a[row][col];
45 for (int j = col; j < m + k; ++j)
46 a[i][j] -= a[row][j] * c;
47 }
48 ++row;
49 }
50
51 ans.assign(m, vector<double>(k, 0));
52 for (int i = 0; i < m; ++i)
53 if (where[i] != -1)
54 for (int j = 0; j < k; ++j)
55 ans[i][j] = a[where[i]][m + j] / a[where[i]][i];
56
57 for (int i = 0; i < n; ++i) {
58 for (int j = 0; j < k; ++j) {
59 double sum = 0;
60 for (int l = 0; l < m; ++l)
61 sum += ans[l][j] * a[i][l];
62 if (abs(sum - a[i][m + j]) > EPS)
63 return;
64 }
65 }
66}
67matrix solve(matrix &A, matrix &B) {
68 int m = A.size();
69 int n = B[0].size();
70
71 // Check if dimensions are compatible (m rows in A, same m rows in B)
72 if (m != B.size()) {
73 throw std::invalid_argument(
74 "Matrix dimensions are not compatible for solving AX=B");
75 }
76
77 matrix augmented(m, std::vector<double>(m + n));
78 for (int i = 0; i < m; ++i) {
79 for (int j = 0; j < m; ++j) {
80 augmented[i][j] = A[i][j];
81 }
82 for (int j = 0; j < n; ++j) {
83 augmented[i][m + j] = B[i][j];
84 }
85 }
86 gaussJordan(augmented, B[0].size(), B);
87 matrix X(m, std::vector<double>(n));
88 for (int i = 0; i < m; ++i) {
89 for (int j = 0; j < n; ++j) {
90 X[i][j] = B[i][j];
91 }
92 }
93
94 return X;
95}
96
97#endif