Skip to content

Commit

Permalink
Compute det in inv
Browse files Browse the repository at this point in the history
  • Loading branch information
adamant-pwn committed Oct 28, 2024
1 parent 1604673 commit b925fb8
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 16 deletions.
8 changes: 5 additions & 3 deletions cp-algo/linalg/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,18 @@ namespace cp_algo::linalg {
return res;
}

std::optional<matrix> inv() const {
std::pair<base, matrix> inv() const {
assert(n() == m());
matrix b = *this | eye(n());
if(size(b.echelonize<reverse>(n())[0]) < n()) {
return std::nullopt;
return {0, {}};
}
base det = 1;
for(size_t i = 0; i < n(); i++) {
det *= b[i][i];
b[i] *= base(1) / b[i][i];
}
return b.submatrix(std::slice(0, n(), 1), std::slice(n(), n(), 1));
return {det, b.submatrix(std::slice(0, n(), 1), std::slice(n(), n(), 1))};
}

// Can also just run gauss on T() | eye(m)
Expand Down
5 changes: 2 additions & 3 deletions verify/linalg/adj.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ void solve() {
A[i][n] = cp_algo::random::rng();
A[n][i] = cp_algo::random::rng();
}
auto Ai = A.inv();
auto D = A.det();
auto [D, Ai] = A.inv();
for(int i: views::iota(0, n)) {
for(int j: views::iota(0, n)) {
if(D != 0) {
auto res = (*Ai)[n][n] * (*Ai)[i][j] - (*Ai)[i][n] * (*Ai)[n][j];
auto res = Ai[n][n] * Ai[i][j] - Ai[i][n] * Ai[n][j];
cout << res * D << " \n"[j + 1 == n];
} else {
cout << 0 << " \n"[j + 1 == n];
Expand Down
6 changes: 3 additions & 3 deletions verify/linalg/inv.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ void solve() {
cin >> n;
matrix<modint<mod>> a(n, n);
a.read();
auto ai = a.inv();
if(!ai) {
auto [d, ai] = a.inv();
if(d == 0) {
cout << -1 << "\n";
} else {
ai->print();
ai.print();
}
}

Expand Down
14 changes: 7 additions & 7 deletions verify/linalg/tutte.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ void solve() {
B[i][j] = T[pivots[i]][pivots[j]];
}
}
B = *B.inv();
auto [d, Bi] = B.inv();
vector<pair<int, int>> ans;
for(size_t i = 0; i < size(pivots); i++) {
for(size_t j = 0; j < size(pivots); j++) {
if(T[pivots[i]][pivots[j]] != 0 && B[i][j] != 0) {
if(T[pivots[i]][pivots[j]] != 0 && Bi[i][j] != 0) {
ans.emplace_back(pivots[i], pivots[j]);
B.eliminate<gauss_mode::reverse>(i, j);
B.eliminate<gauss_mode::reverse>(j, i);
B.normalize();
B[i] *= 0;
B[j] *= 0;
Bi.eliminate<gauss_mode::reverse>(i, j);
Bi.eliminate<gauss_mode::reverse>(j, i);
Bi.normalize();
Bi[i] *= 0;
Bi[j] *= 0;
}
}
}
Expand Down

0 comments on commit b925fb8

Please sign in to comment.