Skip to content

Commit 7304ddc

Browse files
committed
wip
1 parent c8f891d commit 7304ddc

File tree

3 files changed

+32
-14
lines changed

3 files changed

+32
-14
lines changed

tessellate_ipu/core/tile_interpreter_vertex_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def make_ipu_vector1d_worker_offsets(
3939
Returns:
4040
(6,) number of data vectors per thread.
4141
"""
42+
print("SIZE:", size)
4243

4344
def make_offsets_fn(sizes):
4445
sizes = [0] + sizes

tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp

+12-6
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep
149149

150150
Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR, 8>>
151151
rotset; // (2,) rotation index p and q. p < q
152+
Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR>>
153+
pindex; // (1,) rotation index p and q. p < q
154+
Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR>>
155+
qindex; // (1,) rotation index p and q. p < q
152156
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> pcol; // (N,) p column
153157
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> qcol; // (N,) q column
154158

@@ -170,26 +174,28 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep
170174
JacobiUpdateFirstStep();
171175

172176
bool compute(unsigned wid) {
173-
const unsigned p = rotset[0];
174-
const unsigned q = rotset[1];
177+
const unsigned p = pindex[0];
178+
const unsigned q = qindex[0];
179+
180+
175181
const IndexType wstart = worker_offsets[wid];
176182
const IndexType wend = worker_offsets[wid + 1];
177183

178184
if (p <= q) {
179185
// Proper ordering of p and q already.
180186
jacob_update_first_step(pcol.data(), qcol.data(), pcol_updated.data(),
181187
qcol_updated.data(), cs.data(), p, q, wstart, wend);
182-
rotset_sorted[0] = rotset[0];
183-
rotset_sorted[1] = rotset[1];
188+
rotset_sorted[0] = p;
189+
rotset_sorted[1] = q;
184190
}
185191
else {
186192
// Swap p and q columns as q < p
187193
jacob_update_first_step(qcol.data(), pcol.data(), qcol_updated.data(),
188194
pcol_updated.data(), cs.data(), q, p, wstart, wend);
189195
// jacob_update_first_step(pcol.data(), qcol.data(), pcol_updated.data(),
190196
// qcol_updated.data(), cs.data(), q, p, wstart, wend);
191-
rotset_sorted[0] = rotset[1];
192-
rotset_sorted[1] = rotset[0];
197+
rotset_sorted[0] = q;
198+
rotset_sorted[1] = p;
193199
}
194200

195201
return true;

tessellate_ipu/linalg/tile_linalg_jacobi.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,16 @@ def get_jacobi_vertex_gp_filename() -> str:
4242
jacobi_update_first_step_p = create_ipu_tile_primitive(
4343
"jacobi_update_first_step",
4444
"JacobiUpdateFirstStep",
45-
inputs=["rotset", "pcol", "qcol"],
45+
inputs=["rotset", "pindex", "qindex", "pcol", "qcol"],
4646
outputs={
4747
"rotset_sorted": ShapedArray((2,), dtype=np.uint32),
4848
"cs": ShapedArray((2,), dtype=np.float32),
49-
"pcol_updated": 1,
50-
"qcol_updated": 2,
49+
"pcol_updated": 3,
50+
"qcol_updated": 4,
5151
},
5252
constants={
5353
"worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets(
54-
inavals[1].size, vector_size=2, wdtype=np.uint16
54+
inavals[3].size, vector_size=2, wdtype=np.uint16
5555
)
5656
},
5757
gp_filename=get_jacobi_vertex_gp_filename(),
@@ -178,16 +178,25 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
178178
rotset = jacobi_initial_rotation_set(N)
179179
rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles)
180180

181+
pindices_sharded = rotset_sharded[:, :1]
182+
qindices_sharded = rotset_sharded[:, 1:]
183+
181184
# All different size 2 partitions on columns.
182185
for _ in range(1, N):
183186
# On tile constant rotation set tensor building.
184-
with jax.named_scope("rotset"):
185-
rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles)
187+
# with jax.named_scope("rotset"):
188+
# # rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles)
189+
# pindices_sharded = tile_constant_sharded(np.copy(rotset[:, :1]), tiles=Atiles)
190+
# qindices_sharded = tile_constant_sharded(np.copy(rotset[:, 1:]), tiles=Atiles)
191+
192+
# pindices_sharded = rotset_sharded[:, :1]
193+
# qindices_sharded = rotset_sharded[:, 1:]
186194

187195
# Compute Schur decomposition + on-tile update of columns.
188196
rotset_sorted_sharded, cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore
189-
jacobi_update_first_step_p, rotset_sharded, Apcols, Aqcols, N=N
197+
jacobi_update_first_step_p, rotset_sharded, pindices_sharded, qindices_sharded, Apcols, Aqcols, N=N
190198
)
199+
# print(pindices_sharded.shape, qindices_sharded.shape)
191200
# Replicate Schur decomposition across all A tiles: (2*N//2) comms.
192201
with jax.named_scope("cs_replicated_sharded"):
193202
cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles)
@@ -197,6 +206,7 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
197206
cs_replicated, cs_Vtiles = tile_data_barrier(cs_replicated, cs_Vtiles)
198207

199208
# Second Jacobi update step.
209+
# Note: does not require sorting of pcols and qcols.
200210
cs_replicated, Apcols, Aqcols = tile_map( # type:ignore
201211
jacobi_update_second_step_p,
202212
cs_replicated,
@@ -221,10 +231,11 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
221231
with jax.named_scope("Apqcols_rotation"):
222232
# Apcols, Aqcols = tile_rotate_sorted_columns(Apcols, Aqcols, rotset)
223233
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
224-
225234
with jax.named_scope("Vpqcols_rotation"):
226235
# Vpcols, Vqcols = tile_rotate_sorted_columns(Vpcols, Vqcols, rotset)
227236
Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
237+
with jax.named_scope("pqindices_rotations"):
238+
pindices_sharded, qindices_sharded = tile_rotate_columns(pindices_sharded, qindices_sharded)
228239

229240
# Next rotation set.
230241
rotset = jacobi_next_rotation_set(rotset)

0 commit comments

Comments
 (0)