Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Oct 2, 2023
1 parent c8f891d commit 7304ddc
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
1 change: 1 addition & 0 deletions tessellate_ipu/core/tile_interpreter_vertex_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def make_ipu_vector1d_worker_offsets(
Returns:
(6,) number of data vectors per thread.
"""
print("SIZE:", size)

def make_offsets_fn(sizes):
sizes = [0] + sizes
Expand Down
18 changes: 12 additions & 6 deletions tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep

Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR, 8>>
rotset; // (2,) rotation index p and q. p < q
Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR>>
pindex; // (1,) rotation index p and q. p < q
Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR>>
qindex; // (1,) rotation index p and q. p < q
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> pcol; // (N,) p column
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> qcol; // (N,) q column

Expand All @@ -170,26 +174,28 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep
JacobiUpdateFirstStep();

bool compute(unsigned wid) {
const unsigned p = rotset[0];
const unsigned q = rotset[1];
const unsigned p = pindex[0];
const unsigned q = qindex[0];


const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];

if (p <= q) {
// Proper ordering of p and q already.
jacob_update_first_step(pcol.data(), qcol.data(), pcol_updated.data(),
qcol_updated.data(), cs.data(), p, q, wstart, wend);
rotset_sorted[0] = rotset[0];
rotset_sorted[1] = rotset[1];
rotset_sorted[0] = p;
rotset_sorted[1] = q;
}
else {
// Swap p and q columns as q < p
jacob_update_first_step(qcol.data(), pcol.data(), qcol_updated.data(),
pcol_updated.data(), cs.data(), q, p, wstart, wend);
// jacob_update_first_step(pcol.data(), qcol.data(), pcol_updated.data(),
// qcol_updated.data(), cs.data(), q, p, wstart, wend);
rotset_sorted[0] = rotset[1];
rotset_sorted[1] = rotset[0];
rotset_sorted[0] = q;
rotset_sorted[1] = p;
}

return true;
Expand Down
27 changes: 19 additions & 8 deletions tessellate_ipu/linalg/tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def get_jacobi_vertex_gp_filename() -> str:
jacobi_update_first_step_p = create_ipu_tile_primitive(
"jacobi_update_first_step",
"JacobiUpdateFirstStep",
inputs=["rotset", "pcol", "qcol"],
inputs=["rotset", "pindex", "qindex", "pcol", "qcol"],
outputs={
"rotset_sorted": ShapedArray((2,), dtype=np.uint32),
"cs": ShapedArray((2,), dtype=np.float32),
"pcol_updated": 1,
"qcol_updated": 2,
"pcol_updated": 3,
"qcol_updated": 4,
},
constants={
"worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets(
inavals[1].size, vector_size=2, wdtype=np.uint16
inavals[3].size, vector_size=2, wdtype=np.uint16
)
},
gp_filename=get_jacobi_vertex_gp_filename(),
Expand Down Expand Up @@ -178,16 +178,25 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
rotset = jacobi_initial_rotation_set(N)
rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles)

pindices_sharded = rotset_sharded[:, :1]
qindices_sharded = rotset_sharded[:, 1:]

# All different size 2 partitions on columns.
for _ in range(1, N):
# On tile constant rotation set tensor building.
with jax.named_scope("rotset"):
rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles)
# with jax.named_scope("rotset"):
# # rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles)
# pindices_sharded = tile_constant_sharded(np.copy(rotset[:, :1]), tiles=Atiles)
# qindices_sharded = tile_constant_sharded(np.copy(rotset[:, 1:]), tiles=Atiles)

# pindices_sharded = rotset_sharded[:, :1]
# qindices_sharded = rotset_sharded[:, 1:]

# Compute Schur decomposition + on-tile update of columns.
rotset_sorted_sharded, cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore
jacobi_update_first_step_p, rotset_sharded, Apcols, Aqcols, N=N
jacobi_update_first_step_p, rotset_sharded, pindices_sharded, qindices_sharded, Apcols, Aqcols, N=N
)
# print(pindices_sharded.shape, qindices_sharded.shape)
# Replicate Schur decomposition across all A tiles: (2*N//2) comms.
with jax.named_scope("cs_replicated_sharded"):
cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles)
Expand All @@ -197,6 +206,7 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
cs_replicated, cs_Vtiles = tile_data_barrier(cs_replicated, cs_Vtiles)

# Second Jacobi update step.
# Note: does not require sorting of pcols and qcols.
cs_replicated, Apcols, Aqcols = tile_map( # type:ignore
jacobi_update_second_step_p,
cs_replicated,
Expand All @@ -221,10 +231,11 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
with jax.named_scope("Apqcols_rotation"):
# Apcols, Aqcols = tile_rotate_sorted_columns(Apcols, Aqcols, rotset)
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)

with jax.named_scope("Vpqcols_rotation"):
# Vpcols, Vqcols = tile_rotate_sorted_columns(Vpcols, Vqcols, rotset)
Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
with jax.named_scope("pqindices_rotations"):
pindices_sharded, qindices_sharded = tile_rotate_columns(pindices_sharded, qindices_sharded)

# Next rotation set.
rotset = jacobi_next_rotation_set(rotset)
Expand Down

0 comments on commit 7304ddc

Please sign in to comment.