Skip to content

Commit

Permalink
Fails with 'Unresolved relocation'
Browse files Browse the repository at this point in the history
  • Loading branch information
paolot-gc committed Oct 11, 2023
1 parent feb639f commit 3cb79a8
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 0 deletions.
31 changes: 31 additions & 0 deletions examples/tridiag_solver_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import sys

import jax
import jax.numpy as jnp
import numpy as np
from scipy.sparse import spdiags

from tessellate_ipu.linalg import ipu_tridiag_solve

jax.config.FLAGS.jax_platform_name = "cpu"
jax.config.update("jax_enable_x64", False)


N = int(sys.argv[1])
np.random.seed(42)

np.set_printoptions(precision=3, linewidth=120, suppress=True)


diag = np.arange(N).reshape(1, -1).astype(jnp.float32)
ldiag = np.random.rand(N - 1).reshape(1, -1).astype(jnp.float32)
rhs = np.random.rand(N).reshape(1, -1).astype(jnp.float32)

x_ = jax.jit(ipu_tridiag_solve, backend="ipu")(diag, ldiag, rhs)

x = np.array(x_.array)

T = spdiags([np.concatenate([np.array([0]), ldiag]), diag, np.concatenate([ldiag, [0]])], (1, 0, -1), N, N)

delta = T @ x - rhs
print(np.max(np.abs(delta)))
46 changes: 46 additions & 0 deletions tessellate_ipu/core/vertex/tile_tridiagonal_solver_vertex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#include <poplar/Vertex.hpp>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <poplar/HalfFloat.hpp>
#include <poplar/Vertex.hpp>
#include "poplar/TileConstants.hpp"
#include <print.h>

#include "intrinsics_utils.hpp"

using namespace poplar;

class TridiagonalSolverVertex : public Vertex {
public:
InOut<Vector<float, poplar::VectorLayout::SPAN, 8>> ts; // b contains x
Input<Vector<float, poplar::VectorLayout::ONE_PTR, 8>> tus; // c
Input<Vector<float, poplar::VectorLayout::ONE_PTR, 8>> tls; // a
Input<Vector<float, poplar::VectorLayout::ONE_PTR, 8>> b; // d

TridiagonalSolverVertex();

bool compute() {
Vector<float, poplar::VectorLayout::ONE_PTR, 8> b_;
for (int i=0; i<ts.size(); i++)
b_[i] = b[i];

int n = ts.size();

for (int i=1; i<n; i++){
float w;
w = tls[i] / ts[i-1]; // CHECK div-by-0 or OVFL
ts[i] -= w * tus[i-1];
b_[i] -= w * b_[i-1];
}

ts[n-1] = b_[n-1] / ts[n-1];
for (int i=n-2; i>0; i--) {
ts[i] = (b_[i] - tus[i] * ts[i+1]) / ts[i]; // We put x into ts?
}

// Maybe we should compute the norm of the delta between x and ts?
return true;
}
};
1 change: 1 addition & 0 deletions tessellate_ipu/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .tile_linalg_hessenberg import ipu_hessenberg
from .tile_linalg_jacobi import ipu_eigh
from .tile_linalg_qr import ipu_qr
from .tile_linalg_tridiagonal_solver import ipu_tridiag_solve
38 changes: 38 additions & 0 deletions tessellate_ipu/linalg/tile_linalg_tridiagonal_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os.path as osp
from typing import Any

import jax

from tessellate_ipu import create_ipu_tile_primitive, tile_map, tile_put_replicated

jax.config.FLAGS.jax_platform_name = "cpu"

Array = Any


vertex_filename = osp.join(osp.dirname(__file__), "../core", "vertex", "tile_tridiagonal_solver_vertex.cpp")

tridiagonal_solver_p = create_ipu_tile_primitive(
"tridiagonal_solver",
"TridiagonalSolverVertex",
inputs=["ts", "tus", "tls", "b"],
outputs={"ts": 0},
gp_filename=vertex_filename,
perf_estimate=100,
)


def ipu_tridiag_solve(diag: Array, ldiag: Array, rhs: Array):

tiles = [100]

ts = tile_put_replicated(diag, tiles=tiles)

tls = tile_put_replicated(ldiag, tiles=tiles)

tus = tls

b = tile_put_replicated(rhs, tiles=tiles)

x_ = tile_map(tridiagonal_solver_p, ts, tus, tls, b)
return x_

0 comments on commit 3cb79a8

Please sign in to comment.