Skip to content

Commit

Permalink
Added real_symmetric_eigh example (+ Alex's eigh)
Browse files Browse the repository at this point in the history
  • Loading branch information
paolot-gc committed Oct 17, 2023
1 parent 448e535 commit 85b4ebd
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 1 deletion.
68 changes: 68 additions & 0 deletions examples/real_symmetric_eigh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import sys

import jax
import jax.numpy as jnp
import numpy as np
from icecream import ic

from tessellate_ipu.linalg import ipu_hessenberg, ipu_tridiag_solve
from tessellate_ipu.linalg.tile_linalg_tridiagonal_eigh import ipu_eigh_tridiagonal

jax.config.FLAGS.jax_platform_name = "cpu"
jax.config.update("jax_enable_x64", False)

if len(sys.argv) != 3:
print(sys.argv[0] + " <size> <num eigenvectors>")
sys.exit(1)

seed = 42
np.random.seed(seed)

np.set_printoptions(precision=3, linewidth=120, suppress=True)


S = int(sys.argv[1])
N = int(sys.argv[2])

mat = np.random.rand(S, S).astype(np.float32)
mat = (mat + mat.T) / 2


def real_symmetric_eigh(M):

Q, M_tri_ = ipu_hessenberg(M)
M_tri = M_tri_.array

d, e = jnp.diag(M_tri), jnp.diag(M_tri, k=1)
eig = ipu_eigh_tridiagonal(d, e)[:N]

diag = jnp.tile(jnp.diag(M_tri, k=0).reshape(1, -1), (N, 1)) - eig[:, jnp.newaxis]

udiag = jnp.concatenate([jnp.diag(M_tri, k=1), jnp.array([0], dtype=jnp.float32)]).reshape(1, -1)
udiag = jnp.tile(udiag, (N, 1))

ldiag = jnp.concatenate([jnp.array([0], dtype=jnp.float32), jnp.diag(M_tri, k=-1)]).reshape(1, -1)
ldiag = jnp.tile(ldiag, (N, 1))

prng_key = jax.random.PRNGKey(seed)
x = jax.random.normal(prng_key, shape=(N, diag.shape[1]), dtype=jnp.float32)
x /= jnp.linalg.norm(x, axis=1)[:, jnp.newaxis]

def inverse_iteration(i, x):
x = ipu_tridiag_solve(diag, udiag, ldiag, x)
x /= jnp.linalg.norm(x, axis=1)[:, jnp.newaxis]
return x

x = jax.lax.fori_loop(0, 2, inverse_iteration, x)

return x @ Q.array.T, eig


x, eig = jax.jit(real_symmetric_eigh, backend="ipu")(mat)

# ic(x)

# ic(mat @ eigv - eig * eigv)
# ic(mat @ x.T - eig * x.T)

ic(np.max(np.abs(mat @ x.T - eig * x.T)))
89 changes: 89 additions & 0 deletions tessellate_ipu/core/vertex/tile_tridiagonal_eigh.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#include <poplar/Vertex.hpp>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <poplar/HalfFloat.hpp>
#include <poplar/Vertex.hpp>
#include "poplar/TileConstants.hpp"
#include <print.h>

using namespace poplar;

#ifdef __IPU__
// Use the IPU intrinsics
#include <ipu_memory_intrinsics>
#include <ipu_vector_math>
#define NAMESPACE ipu
#else
// Use the std functions
#include <cmath>
#define NAMESPACE std
#endif

class Sturm : public Vertex {
public:
Input<Vector<float>> alpha;
Input<Vector<float>> beta_sq;
Input<Vector<float>> pivmin;
Input<Vector<float>> alpha0_pertubation;
Input<Vector<float>> x;
Input<Vector<int>> id;
Input<Vector<float>> out_shape;
Input<Vector<float>> lower;
Input<Vector<float>> mid;
Input<Vector<float>> upper;

Output<Vector<float>> lower_out;
Output<Vector<float>> mid_out;
Output<Vector<float>> upper_out;

bool compute() {
int tile_id = id.data()[0];

//q = alpha[0] - x
//count = jnp.where(q < 0, ones, zeros)
//q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
float q = alpha[0] - x.data()[tile_id];
int count = q < 0;
if (alpha[0] == x.data()[tile_id]) {q = alpha0_pertubation.data()[tile_id]; }

//for i in range(1, n):
// q = alpha[i] - beta_sq[i - 1] / q - x
// count = jnp.where(q <= pivmin, count + 1, count)
// q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)*/

int n = x.size();
float x_tile_id = x.data()[tile_id];
float pivmin_tile_id = pivmin.data()[tile_id];
float minus_pivmin_tile_id = -pivmin_tile_id;

// main bulk: takes ~ 87k cycles for 1024 matrix => ~ 80 cycles per iteration.
// obs: we can precompute (alpha.data()[i]-x_tile_id) using all 6 threads in parallel.
for (unsigned int i = 1; i < n; i++){
//q = alpha[i] - beta_sq[i - 1] / q - x
//q = alpha.data()[i] - x.data()[tile_id] - beta_sq.data()[i - 1] / q ;
q = alpha.data()[i] - x_tile_id - beta_sq.data()[i - 1] / q ;

//count = jnp.where(q <= pivmin, count + 1, count)
//q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
if (q <= pivmin_tile_id){
count ++;
q = fmin(q, minus_pivmin_tile_id);
}
}

//lower = jnp.where(counts <= target_counts, mid, lower)
//upper = jnp.where(counts > target_counts, mid, upper)
//mid = 0.5 * (lower + upper)
int target_count = tile_id; // they are the same
if (count <= target_count) lower_out[0] = mid.data()[0];
else {lower_out[0] = lower.data()[0]; }

if (count > target_count) upper_out[0] = mid.data()[0];
else {upper_out[0] = upper.data()[0];}

mid_out[0] = (lower_out.data()[0] + upper_out.data()[0])/2;

return true;
}
};
143 changes: 143 additions & 0 deletions tessellate_ipu/linalg/tile_linalg_tridiagonal_eigh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import os.path as osp

import jax
import jax.numpy as jnp
import numpy as np
import scipy.linalg

from tessellate_ipu import create_ipu_tile_primitive, tile_map, tile_put_replicated, tile_put_sharded

jax.config.FLAGS.jax_platform_name = "cpu"


def ipu_eigh_tridiagonal(d, e, *, select="a", select_range=None, tol=None):
alpha, beta = jnp.asarray(d), jnp.asarray(e)
n = alpha.shape[0]
if n <= 1:
return jnp.real(alpha)

beta_abs = jnp.abs(beta)
beta_sq = jnp.square(beta)

# Estimate the largest and smallest eigenvalues of T using the Gershgorin circle theorem.
off_diag_abs_row_sum = jnp.concatenate([beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum)
lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum)

# Upper bound on 2-norm of T.
t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max))

# Compute the smallest allowed pivot in the Sturm sequence to avoid
# overflow.
finfo = np.finfo(alpha.dtype)
one = np.ones([], dtype=alpha.dtype)
safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq))
alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0])
abs_tol = finfo.eps * t_norm
if tol is not None:
abs_tol = jnp.maximum(tol, abs_tol)

# In the worst case, when the absolute tolerance is eps*lambda_est_max and
# lambda_est_max = -lambda_est_min, we have to take as many bisection steps
# as there are bits in the mantissa plus 1.
# The proof is left as an exercise to the reader.
max_it = finfo.nmant + 1

# Might be useful to only compute the "top k electrons//2" eigenvalues.
target_counts = jnp.arange(n, dtype=jnp.int32)

# Run binary search for all desired eigenvalues in parallel, starting from
# the interval lightly wider than the estimated
# [lambda_est_min, lambda_est_max].
fudge = 2.1 # We widen starting interval the Gershgorin interval a bit.
norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm
lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
upper = lambda_est_max + norm_slack + fudge * pivmin

# Pre-broadcast the scalars used in the Sturm sequence for improved
# performance.
target_shape = jnp.shape(target_counts)
lower = jnp.broadcast_to(lower, shape=target_shape)
upper = jnp.broadcast_to(upper, shape=target_shape)
mid = 0.5 * (upper + lower)
pivmin = jnp.broadcast_to(pivmin, target_shape)
alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape)

vertex_filename = osp.join(osp.dirname(__file__), "../core", "vertex", "tile_tridiagonal_eigh.cpp")
grad = create_ipu_tile_primitive(
"Sturm",
"Sturm",
inputs=["alpha", "beta_sq", "pivmin", "alpha0_pertubation", "x", "id", "out_shape", "lower", "mid", "upper"],
outputs={"lower_out": 7, "mid_out": 8, "upper_out": 9},
gp_filename=vertex_filename,
perf_estimate=100,
)

x = mid
n = x.shape[0]
tiles = tuple(range(n))
_alpha = tile_put_replicated(jnp.array(alpha, dtype=jnp.float32), tiles)
_beta_sq = tile_put_replicated(jnp.array(beta_sq, dtype=jnp.float32), tiles)
_pivmin = tile_put_replicated(jnp.array(pivmin, dtype=jnp.float32), tiles)
_alpha0_perturbation = tile_put_replicated(jnp.array(alpha0_perturbation, dtype=jnp.float32), tiles)
_id = tile_put_sharded(jnp.arange(len(tiles)), tiles)
_out_shape = tile_put_sharded(jnp.arange(len(tiles)).astype(np.float32), tiles)

_lower = tile_put_sharded(lower, tiles)
_mid = tile_put_sharded(mid, tiles)
_upper = tile_put_sharded(upper, tiles)

def body(j, args):
i, lower, mid, upper = args
_x = tile_put_replicated(jnp.array(mid.array, dtype=jnp.float32), tiles)
lower, mid, upper = tile_map(
grad, _alpha, _beta_sq, _pivmin, _alpha0_perturbation, _x, _id, _out_shape, lower, mid, upper
) # type: ignore
return i + 1, lower, mid, upper

vals = (0, _lower, _mid, _upper)
vals = jax.lax.fori_loop(0, max_it, body, vals)

return vals[2].array


if __name__ == "__main__":
import jax
import scipy

np.random.seed(42)
jax.config.FLAGS.jax_platform_name = "cpu"
np.random.seed(42)

dim = 1024
A = np.random.normal(0, 1, (dim, dim))
A = (A + A.T) / 2
print(A)

D, Q = np.linalg.eigh(A)
print(np.max(np.abs(Q @ np.diag(D) @ Q.T - A)))

# compute eigh using hessenberg then eigh_tridiagonal.

T, Q = scipy.linalg.hessenberg(A, calc_q=True)
print(np.around(T, 2))

d, e = np.diag(T), np.diag(T, k=1)
print(d, e)
w, v = scipy.linalg.eigh_tridiagonal(d, e)
print(w.shape, v.shape)

assert np.allclose(T @ v - v @ np.diag(w), np.zeros((dim, dim)))
Q_ = Q @ v
assert np.allclose(Q_ @ np.diag(w) @ Q_.T, A)
assert np.allclose(Q_ @ Q_.T, np.eye(dim))
assert np.allclose(w, D)
print("PASSED! (scipy)")

jitted_ipu_eigh_tridiagonal = jax.jit(ipu_eigh_tridiagonal, backend="ipu")

_w = np.array(jitted_ipu_eigh_tridiagonal(d, e))
print(w.reshape(-1)[::128])
print(_w.reshape(-1)[::128])
print(np.max(np.abs(w - _w)))
2 changes: 1 addition & 1 deletion tessellate_ipu/linalg/tile_linalg_tridiagonal_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ def ipu_tridiag_solve(diag: Array, udiag: Array, ldiag: Array, rhs: Array):
d, u, l, b = ipu_tridiag_solve_shard_inputs(diag, udiag, ldiag, rhs)

x, _, _, _ = tile_map(tridiagonal_solver_p, d, u, l, b) # type: ignore
return x
return x.array

0 comments on commit 85b4ebd

Please sign in to comment.