Skip to content

Commit

Permalink
Initial working ipu_tridiag_solve()
Browse files Browse the repository at this point in the history
  • Loading branch information
paolot-gc committed Oct 11, 2023
1 parent 3cb79a8 commit c0e27ff
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 22 deletions.
27 changes: 19 additions & 8 deletions examples/tridiag_solver_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,32 @@
jax.config.update("jax_enable_x64", False)


N = int(sys.argv[1])
N = int(sys.argv[2])
M = int(sys.argv[1])
np.random.seed(42)

np.set_printoptions(precision=3, linewidth=120, suppress=True)


diag = np.arange(N).reshape(1, -1).astype(jnp.float32)
ldiag = np.random.rand(N - 1).reshape(1, -1).astype(jnp.float32)
rhs = np.random.rand(N).reshape(1, -1).astype(jnp.float32)
diag = np.random.rand(M, N).astype(jnp.float32)
udiag = np.random.rand(M, N).astype(jnp.float32)
rhs = np.random.rand(M, N).astype(jnp.float32)

x_ = jax.jit(ipu_tridiag_solve, backend="ipu")(diag, ldiag, rhs)
x_ = jax.jit(ipu_tridiag_solve, backend="ipu")(diag, udiag, np.roll(udiag, 1, axis=1), rhs)

x = np.array(x_.array)

T = spdiags([np.concatenate([np.array([0]), ldiag]), diag, np.concatenate([ldiag, [0]])], (1, 0, -1), N, N)
print(x.shape)

delta = T @ x - rhs
print(np.max(np.abs(delta)))
deltas = []
for i in range(M):
data = np.vstack(
[np.roll(udiag[i].flat, 1, axis=0), diag[i].flat, udiag[i].flat],
)
T = spdiags(data, (1, 0, -1), N, N).toarray()

delta = T @ x[i].reshape(N, 1) - rhs[i].reshape(N, 1)

deltas.append(delta)

print("Max abs delta:", np.max(np.abs(np.array(delta))))
14 changes: 7 additions & 7 deletions tessellate_ipu/core/vertex/tile_tridiagonal_solver_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,25 @@ class TridiagonalSolverVertex : public Vertex {
Input<Vector<float, poplar::VectorLayout::ONE_PTR, 8>> tls; // a
Input<Vector<float, poplar::VectorLayout::ONE_PTR, 8>> b; // d

Output<Vector<float, poplar::VectorLayout::ONE_PTR, 8>> tmp; // temporary

TridiagonalSolverVertex();

bool compute() {
Vector<float, poplar::VectorLayout::ONE_PTR, 8> b_;
for (int i=0; i<ts.size(); i++)
b_[i] = b[i];

int n = ts.size();

tmp[0] = b[0];
for (int i=1; i<n; i++){
float w;
w = tls[i] / ts[i-1]; // CHECK div-by-0 or OVFL
ts[i] -= w * tus[i-1];
b_[i] -= w * b_[i-1];
tmp[i] = b[i] - w * tmp[i-1];
}

ts[n-1] = b_[n-1] / ts[n-1];
for (int i=n-2; i>0; i--) {
ts[i] = (b_[i] - tus[i] * ts[i+1]) / ts[i]; // We put x into ts?
ts[n-1] = tmp[n-1] / ts[n-1];
for (int i=n-2; i>=0; i--) {
ts[i] = (tmp[i] - tus[i] * ts[i+1]) / ts[i]; // We put x into ts?
}

// Maybe we should compute the norm of the delta between x and ts?
Expand Down
22 changes: 15 additions & 7 deletions tessellate_ipu/linalg/tile_linalg_tridiagonal_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax

from tessellate_ipu import create_ipu_tile_primitive, tile_map, tile_put_replicated
from tessellate_ipu import create_ipu_tile_primitive, tile_map, tile_put_sharded

jax.config.FLAGS.jax_platform_name = "cpu"

Expand All @@ -17,22 +17,30 @@
"TridiagonalSolverVertex",
inputs=["ts", "tus", "tls", "b"],
outputs={"ts": 0},
tmp_space=3,
gp_filename=vertex_filename,
perf_estimate=100,
)


def ipu_tridiag_solve(diag: Array, ldiag: Array, rhs: Array):
def ipu_tridiag_solve(diag: Array, udiag: Array, ldiag: Array, rhs: Array):
"""
diag: main diagonal, (1,N)
udiag: upper diagonal, (1,N), the last element is not used
ldiag: lower diagonal, (1,N), the first element is not used, i.e. A[1,0] == ldiag[1]
rhs: right hand side, (1,N)
Note the logic is different from that of scipy.sparse.spdiags()
"""

tiles = [100]
tiles = list(range(diag.shape[0]))

ts = tile_put_replicated(diag, tiles=tiles)
ts = tile_put_sharded(diag, tiles=tiles)

tls = tile_put_replicated(ldiag, tiles=tiles)
tls = tile_put_sharded(ldiag, tiles=tiles)

tus = tls
tus = tile_put_sharded(udiag, tiles=tiles)

b = tile_put_replicated(rhs, tiles=tiles)
b = tile_put_sharded(rhs, tiles=tiles)

x_ = tile_map(tridiagonal_solver_p, ts, tus, tls, b)
return x_

0 comments on commit c0e27ff

Please sign in to comment.