@@ -42,16 +42,16 @@ def get_jacobi_vertex_gp_filename() -> str:
42
42
jacobi_update_first_step_p = create_ipu_tile_primitive (
43
43
"jacobi_update_first_step" ,
44
44
"JacobiUpdateFirstStep" ,
45
- inputs = ["rotset" , "pcol" , "qcol" ],
45
+ inputs = ["rotset" , "pindex" , "qindex" , " pcol" , "qcol" ],
46
46
outputs = {
47
47
"rotset_sorted" : ShapedArray ((2 ,), dtype = np .uint32 ),
48
48
"cs" : ShapedArray ((2 ,), dtype = np .float32 ),
49
- "pcol_updated" : 1 ,
50
- "qcol_updated" : 2 ,
49
+ "pcol_updated" : 3 ,
50
+ "qcol_updated" : 4 ,
51
51
},
52
52
constants = {
53
53
"worker_offsets" : lambda inavals , * _ : make_ipu_vector1d_worker_offsets (
54
- inavals [1 ].size , vector_size = 2 , wdtype = np .uint16
54
+ inavals [3 ].size , vector_size = 2 , wdtype = np .uint16
55
55
)
56
56
},
57
57
gp_filename = get_jacobi_vertex_gp_filename (),
@@ -178,16 +178,25 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
178
178
rotset = jacobi_initial_rotation_set (N )
179
179
rotset_sharded = tile_constant_sharded (rotset , tiles = Atiles )
180
180
181
+ pindices_sharded = rotset_sharded [:, :1 ]
182
+ qindices_sharded = rotset_sharded [:, 1 :]
183
+
181
184
# All different size 2 partitions on columns.
182
185
for _ in range (1 , N ):
183
186
# On tile constant rotation set tensor building.
184
- with jax .named_scope ("rotset" ):
185
- rotset_sharded = tile_constant_sharded (rotset , tiles = Atiles )
187
+ # with jax.named_scope("rotset"):
188
+ # # rotset_sharded = tile_constant_sharded(rotset, tiles=Atiles)
189
+ # pindices_sharded = tile_constant_sharded(np.copy(rotset[:, :1]), tiles=Atiles)
190
+ # qindices_sharded = tile_constant_sharded(np.copy(rotset[:, 1:]), tiles=Atiles)
191
+
192
+ # pindices_sharded = rotset_sharded[:, :1]
193
+ # qindices_sharded = rotset_sharded[:, 1:]
186
194
187
195
# Compute Schur decomposition + on-tile update of columns.
188
196
rotset_sorted_sharded , cs_per_tile , Apcols , Aqcols = tile_map ( # type:ignore
189
- jacobi_update_first_step_p , rotset_sharded , Apcols , Aqcols , N = N
197
+ jacobi_update_first_step_p , rotset_sharded , pindices_sharded , qindices_sharded , Apcols , Aqcols , N = N
190
198
)
199
+ # print(pindices_sharded.shape, qindices_sharded.shape)
191
200
# Replicate Schur decomposition across all A tiles: (2*N//2) comms.
192
201
with jax .named_scope ("cs_replicated_sharded" ):
193
202
cs_replicated = tile_put_replicated (cs_per_tile .array , tiles = Atiles )
@@ -197,6 +206,7 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
197
206
cs_replicated , cs_Vtiles = tile_data_barrier (cs_replicated , cs_Vtiles )
198
207
199
208
# Second Jacobi update step.
209
+ # Note: does not require sorting of pcols and qcols.
200
210
cs_replicated , Apcols , Aqcols = tile_map ( # type:ignore
201
211
jacobi_update_second_step_p ,
202
212
cs_replicated ,
@@ -221,10 +231,11 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
221
231
with jax .named_scope ("Apqcols_rotation" ):
222
232
# Apcols, Aqcols = tile_rotate_sorted_columns(Apcols, Aqcols, rotset)
223
233
Apcols , Aqcols = tile_rotate_columns (Apcols , Aqcols )
224
-
225
234
with jax .named_scope ("Vpqcols_rotation" ):
226
235
# Vpcols, Vqcols = tile_rotate_sorted_columns(Vpcols, Vqcols, rotset)
227
236
Vpcols , Vqcols = tile_rotate_columns (Vpcols , Vqcols )
237
+ with jax .named_scope ("pqindices_rotations" ):
238
+ pindices_sharded , qindices_sharded = tile_rotate_columns (pindices_sharded , qindices_sharded )
228
239
229
240
# Next rotation set.
230
241
rotset = jacobi_next_rotation_set (rotset )
0 commit comments