Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Oct 6, 2023
1 parent 491275a commit 6bc0668
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated,
* See: Gene H. Golub, Charles F. Van Loan, MATRIX COMPUTATIONS, 3rd edition,
* Johns Hopkins Chapter 8.
*/
class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep
// class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep
// : public MultiVertex {
class JacobiUpdateFirstStep
: public MultiVertex {
public:
using T = float;
Expand Down Expand Up @@ -312,8 +314,10 @@ void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated,
* See: Gene H. Golub, Charles F. Van Loan, MATRIX COMPUTATIONS, 3rd edition,
* Johns Hopkins Chapter 8.
*/
class [[poplar::constraint(
"elem(*vpcol) != elem(*vqcol)")]] JacobiUpdateEigenvectors
// class [[poplar::constraint(
// "elem(*vpcol) != elem(*vqcol)")]] JacobiUpdateEigenvectors
// : public MultiVertex {
class JacobiUpdateEigenvectors
: public MultiVertex {
public:
using T = float;
Expand Down
11 changes: 11 additions & 0 deletions tessellate_ipu/linalg/tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def tile_rotate_columns(pcols: TileShardedArray, qcols: TileShardedArray) -> Tup
qcols_indices_new = np.concatenate([qcols_indices[1:], pcols_indices[-1:]])
# Move columns around!
all_indices = np.concatenate([pcols_indices_new, qcols_indices_new])

pcols_updated = tile_gather(all_cols, pcols_indices_new.tolist(), pcols.tiles)
qcols_updated = tile_gather(all_cols, qcols_indices_new.tolist(), qcols.tiles)
return pcols_updated, qcols_updated

all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles)
return all_cols_updated[:halfN], all_cols_updated[halfN:]

Expand All @@ -211,6 +216,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu
halfN = Apcols.shape[0]

with jax.named_scope("jacobi_eigh"):
# with jax.named_scope("Apqcols_rotation"):
# Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
# with jax.named_scope("Vpqcols_rotation"):
# Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
# Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)

# Sharded constant with p/q indices to ignore in second update stage.
with jax.named_scope("rotset_index_ignored"):
rotset_index_ignored = tile_constant_sharded(np.arange(0, halfN, dtype=np.uint32), tiles=Atiles)
Expand Down

0 comments on commit 6bc0668

Please sign in to comment.