-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added real_symmetric_eigh example (+ Alex's eigh)
- Loading branch information
Showing
4 changed files
with
301 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import sys | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from icecream import ic | ||
|
||
from tessellate_ipu.linalg import ipu_hessenberg, ipu_tridiag_solve | ||
from tessellate_ipu.linalg.tile_linalg_tridiagonal_eigh import ipu_eigh_tridiagonal | ||
|
||
jax.config.FLAGS.jax_platform_name = "cpu" | ||
jax.config.update("jax_enable_x64", False) | ||
|
||
if len(sys.argv) != 3: | ||
print(sys.argv[0] + " <size> <num eigenvectors>") | ||
sys.exit(1) | ||
|
||
seed = 42 | ||
np.random.seed(seed) | ||
|
||
np.set_printoptions(precision=3, linewidth=120, suppress=True) | ||
|
||
|
||
S = int(sys.argv[1]) | ||
N = int(sys.argv[2]) | ||
|
||
mat = np.random.rand(S, S).astype(np.float32) | ||
mat = (mat + mat.T) / 2 | ||
|
||
|
||
def real_symmetric_eigh(M): | ||
|
||
Q, M_tri_ = ipu_hessenberg(M) | ||
M_tri = M_tri_.array | ||
|
||
d, e = jnp.diag(M_tri), jnp.diag(M_tri, k=1) | ||
eig = ipu_eigh_tridiagonal(d, e)[:N] | ||
|
||
diag = jnp.tile(jnp.diag(M_tri, k=0).reshape(1, -1), (N, 1)) - eig[:, jnp.newaxis] | ||
|
||
udiag = jnp.concatenate([jnp.diag(M_tri, k=1), jnp.array([0], dtype=jnp.float32)]).reshape(1, -1) | ||
udiag = jnp.tile(udiag, (N, 1)) | ||
|
||
ldiag = jnp.concatenate([jnp.array([0], dtype=jnp.float32), jnp.diag(M_tri, k=-1)]).reshape(1, -1) | ||
ldiag = jnp.tile(ldiag, (N, 1)) | ||
|
||
prng_key = jax.random.PRNGKey(seed) | ||
x = jax.random.normal(prng_key, shape=(N, diag.shape[1]), dtype=jnp.float32) | ||
x /= jnp.linalg.norm(x, axis=1)[:, jnp.newaxis] | ||
|
||
def inverse_iteration(i, x): | ||
x = ipu_tridiag_solve(diag, udiag, ldiag, x) | ||
x /= jnp.linalg.norm(x, axis=1)[:, jnp.newaxis] | ||
return x | ||
|
||
x = jax.lax.fori_loop(0, 2, inverse_iteration, x) | ||
|
||
return x @ Q.array.T, eig | ||
|
||
|
||
x, eig = jax.jit(real_symmetric_eigh, backend="ipu")(mat) | ||
|
||
# ic(x) | ||
|
||
# ic(mat @ eigv - eig * eigv) | ||
# ic(mat @ x.T - eig * x.T) | ||
|
||
ic(np.max(np.abs(mat @ x.T - eig * x.T))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
#include <poplar/Vertex.hpp> | ||
#include <algorithm> | ||
#include <cassert> | ||
#include <cmath> | ||
#include <poplar/HalfFloat.hpp> | ||
#include <poplar/Vertex.hpp> | ||
#include "poplar/TileConstants.hpp" | ||
#include <print.h> | ||
|
||
using namespace poplar; | ||
|
||
#ifdef __IPU__ | ||
// Use the IPU intrinsics | ||
#include <ipu_memory_intrinsics> | ||
#include <ipu_vector_math> | ||
#define NAMESPACE ipu | ||
#else | ||
// Use the std functions | ||
#include <cmath> | ||
#define NAMESPACE std | ||
#endif | ||
|
||
class Sturm : public Vertex { | ||
public: | ||
Input<Vector<float>> alpha; | ||
Input<Vector<float>> beta_sq; | ||
Input<Vector<float>> pivmin; | ||
Input<Vector<float>> alpha0_pertubation; | ||
Input<Vector<float>> x; | ||
Input<Vector<int>> id; | ||
Input<Vector<float>> out_shape; | ||
Input<Vector<float>> lower; | ||
Input<Vector<float>> mid; | ||
Input<Vector<float>> upper; | ||
|
||
Output<Vector<float>> lower_out; | ||
Output<Vector<float>> mid_out; | ||
Output<Vector<float>> upper_out; | ||
|
||
bool compute() { | ||
int tile_id = id.data()[0]; | ||
|
||
//q = alpha[0] - x | ||
//count = jnp.where(q < 0, ones, zeros) | ||
//q = jnp.where(alpha[0] == x, alpha0_perturbation, q) | ||
float q = alpha[0] - x.data()[tile_id]; | ||
int count = q < 0; | ||
if (alpha[0] == x.data()[tile_id]) {q = alpha0_pertubation.data()[tile_id]; } | ||
|
||
//for i in range(1, n): | ||
// q = alpha[i] - beta_sq[i - 1] / q - x | ||
// count = jnp.where(q <= pivmin, count + 1, count) | ||
// q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)*/ | ||
|
||
int n = x.size(); | ||
float x_tile_id = x.data()[tile_id]; | ||
float pivmin_tile_id = pivmin.data()[tile_id]; | ||
float minus_pivmin_tile_id = -pivmin_tile_id; | ||
|
||
// main bulk: takes ~ 87k cycles for 1024 matrix => ~ 80 cycles per iteration. | ||
// obs: we can precompute (alpha.data()[i]-x_tile_id) using all 6 threads in parallel. | ||
for (unsigned int i = 1; i < n; i++){ | ||
//q = alpha[i] - beta_sq[i - 1] / q - x | ||
//q = alpha.data()[i] - x.data()[tile_id] - beta_sq.data()[i - 1] / q ; | ||
q = alpha.data()[i] - x_tile_id - beta_sq.data()[i - 1] / q ; | ||
|
||
//count = jnp.where(q <= pivmin, count + 1, count) | ||
//q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q) | ||
if (q <= pivmin_tile_id){ | ||
count ++; | ||
q = fmin(q, minus_pivmin_tile_id); | ||
} | ||
} | ||
|
||
//lower = jnp.where(counts <= target_counts, mid, lower) | ||
//upper = jnp.where(counts > target_counts, mid, upper) | ||
//mid = 0.5 * (lower + upper) | ||
int target_count = tile_id; // they are the same | ||
if (count <= target_count) lower_out[0] = mid.data()[0]; | ||
else {lower_out[0] = lower.data()[0]; } | ||
|
||
if (count > target_count) upper_out[0] = mid.data()[0]; | ||
else {upper_out[0] = upper.data()[0];} | ||
|
||
mid_out[0] = (lower_out.data()[0] + upper_out.data()[0])/2; | ||
|
||
return true; | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import os.path as osp | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import scipy.linalg | ||
|
||
from tessellate_ipu import create_ipu_tile_primitive, tile_map, tile_put_replicated, tile_put_sharded | ||
|
||
jax.config.FLAGS.jax_platform_name = "cpu" | ||
|
||
|
||
def ipu_eigh_tridiagonal(d, e, *, select="a", select_range=None, tol=None): | ||
alpha, beta = jnp.asarray(d), jnp.asarray(e) | ||
n = alpha.shape[0] | ||
if n <= 1: | ||
return jnp.real(alpha) | ||
|
||
beta_abs = jnp.abs(beta) | ||
beta_sq = jnp.square(beta) | ||
|
||
# Estimate the largest and smallest eigenvalues of T using the Gershgorin circle theorem. | ||
off_diag_abs_row_sum = jnp.concatenate([beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0) | ||
lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum) | ||
lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum) | ||
|
||
# Upper bound on 2-norm of T. | ||
t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max)) | ||
|
||
# Compute the smallest allowed pivot in the Sturm sequence to avoid | ||
# overflow. | ||
finfo = np.finfo(alpha.dtype) | ||
one = np.ones([], dtype=alpha.dtype) | ||
safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny) | ||
pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq)) | ||
alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0]) | ||
abs_tol = finfo.eps * t_norm | ||
if tol is not None: | ||
abs_tol = jnp.maximum(tol, abs_tol) | ||
|
||
# In the worst case, when the absolute tolerance is eps*lambda_est_max and | ||
# lambda_est_max = -lambda_est_min, we have to take as many bisection steps | ||
# as there are bits in the mantissa plus 1. | ||
# The proof is left as an exercise to the reader. | ||
max_it = finfo.nmant + 1 | ||
|
||
# Might be useful to only compute the "top k electrons//2" eigenvalues. | ||
target_counts = jnp.arange(n, dtype=jnp.int32) | ||
|
||
# Run binary search for all desired eigenvalues in parallel, starting from | ||
# the interval lightly wider than the estimated | ||
# [lambda_est_min, lambda_est_max]. | ||
fudge = 2.1 # We widen starting interval the Gershgorin interval a bit. | ||
norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm | ||
lower = lambda_est_min - norm_slack - 2 * fudge * pivmin | ||
upper = lambda_est_max + norm_slack + fudge * pivmin | ||
|
||
# Pre-broadcast the scalars used in the Sturm sequence for improved | ||
# performance. | ||
target_shape = jnp.shape(target_counts) | ||
lower = jnp.broadcast_to(lower, shape=target_shape) | ||
upper = jnp.broadcast_to(upper, shape=target_shape) | ||
mid = 0.5 * (upper + lower) | ||
pivmin = jnp.broadcast_to(pivmin, target_shape) | ||
alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape) | ||
|
||
vertex_filename = osp.join(osp.dirname(__file__), "../core", "vertex", "tile_tridiagonal_eigh.cpp") | ||
grad = create_ipu_tile_primitive( | ||
"Sturm", | ||
"Sturm", | ||
inputs=["alpha", "beta_sq", "pivmin", "alpha0_pertubation", "x", "id", "out_shape", "lower", "mid", "upper"], | ||
outputs={"lower_out": 7, "mid_out": 8, "upper_out": 9}, | ||
gp_filename=vertex_filename, | ||
perf_estimate=100, | ||
) | ||
|
||
x = mid | ||
n = x.shape[0] | ||
tiles = tuple(range(n)) | ||
_alpha = tile_put_replicated(jnp.array(alpha, dtype=jnp.float32), tiles) | ||
_beta_sq = tile_put_replicated(jnp.array(beta_sq, dtype=jnp.float32), tiles) | ||
_pivmin = tile_put_replicated(jnp.array(pivmin, dtype=jnp.float32), tiles) | ||
_alpha0_perturbation = tile_put_replicated(jnp.array(alpha0_perturbation, dtype=jnp.float32), tiles) | ||
_id = tile_put_sharded(jnp.arange(len(tiles)), tiles) | ||
_out_shape = tile_put_sharded(jnp.arange(len(tiles)).astype(np.float32), tiles) | ||
|
||
_lower = tile_put_sharded(lower, tiles) | ||
_mid = tile_put_sharded(mid, tiles) | ||
_upper = tile_put_sharded(upper, tiles) | ||
|
||
def body(j, args): | ||
i, lower, mid, upper = args | ||
_x = tile_put_replicated(jnp.array(mid.array, dtype=jnp.float32), tiles) | ||
lower, mid, upper = tile_map( | ||
grad, _alpha, _beta_sq, _pivmin, _alpha0_perturbation, _x, _id, _out_shape, lower, mid, upper | ||
) # type: ignore | ||
return i + 1, lower, mid, upper | ||
|
||
vals = (0, _lower, _mid, _upper) | ||
vals = jax.lax.fori_loop(0, max_it, body, vals) | ||
|
||
return vals[2].array | ||
|
||
|
||
if __name__ == "__main__": | ||
import jax | ||
import scipy | ||
|
||
np.random.seed(42) | ||
jax.config.FLAGS.jax_platform_name = "cpu" | ||
np.random.seed(42) | ||
|
||
dim = 1024 | ||
A = np.random.normal(0, 1, (dim, dim)) | ||
A = (A + A.T) / 2 | ||
print(A) | ||
|
||
D, Q = np.linalg.eigh(A) | ||
print(np.max(np.abs(Q @ np.diag(D) @ Q.T - A))) | ||
|
||
# compute eigh using hessenberg then eigh_tridiagonal. | ||
|
||
T, Q = scipy.linalg.hessenberg(A, calc_q=True) | ||
print(np.around(T, 2)) | ||
|
||
d, e = np.diag(T), np.diag(T, k=1) | ||
print(d, e) | ||
w, v = scipy.linalg.eigh_tridiagonal(d, e) | ||
print(w.shape, v.shape) | ||
|
||
assert np.allclose(T @ v - v @ np.diag(w), np.zeros((dim, dim))) | ||
Q_ = Q @ v | ||
assert np.allclose(Q_ @ np.diag(w) @ Q_.T, A) | ||
assert np.allclose(Q_ @ Q_.T, np.eye(dim)) | ||
assert np.allclose(w, D) | ||
print("PASSED! (scipy)") | ||
|
||
jitted_ipu_eigh_tridiagonal = jax.jit(ipu_eigh_tridiagonal, backend="ipu") | ||
|
||
_w = np.array(jitted_ipu_eigh_tridiagonal(d, e)) | ||
print(w.reshape(-1)[::128]) | ||
print(_w.reshape(-1)[::128]) | ||
print(np.max(np.abs(w - _w))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters