Skip to content

Commit

Permalink
Improve (slightly!)tile_gather Popvision profile. (#41)
Browse files Browse the repository at this point in the history
Small improvement, could not fuse completely the poplar Tensor definition.
Some additional annotations to `ipu_eigh`
  • Loading branch information
balancap authored Sep 30, 2023
1 parent c040b7f commit d5b7c46
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
18 changes: 15 additions & 3 deletions tessellate_ipu/lib/tile_array_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ poplar::program::Program lowerTilePutShardedToPoplar(
auto output = createShardedVariable(
graph, input.elementType(), input[0].shape(), tile_array, debug_context);
// Copy data tensor into the output.
auto prog = poplar::program::Copy(input, output);
auto prog = poplar::program::Copy(input, output, false, debug_context);
outputs.push_back(output);
return prog;
}
Expand All @@ -112,7 +112,8 @@ poplar::program::Program lowerTilePutReplicatedToPoplar(
auto output = createShardedVariable(graph, input.elementType(), input.shape(),
tile_array, debug_context);
// Copy data tensor into the output.
auto prog = poplar::program::Copy(input_broadcasted, output, false);
auto prog =
poplar::program::Copy(input_broadcasted, output, false, debug_context);
outputs.push_back(output);
return prog;
}
Expand All @@ -128,10 +129,15 @@ poplar::program::Program lowerTileGatherToPoplar(
const auto& input = inputs[0];
const auto item_shape = input[0].shape();
const auto item_type = input.elementType();
const size_t num_tiles = params.tiles.size();

// Create the output tensor per gather index, then concat.
auto seq = poplar::program::Sequence();
// All output slices
std::vector<poplar::Tensor> output_slices;
// Slices requiring copying.
std::vector<poplar::Tensor> input_copy_slices;
std::vector<poplar::Tensor> output_copy_slices;
for (std::size_t idx = 0; idx < params.tiles.size(); ++idx) {
const auto gather_idx = params.indices[idx];
// Get the proper item at the gather index.
Expand All @@ -146,10 +152,16 @@ poplar::program::Program lowerTileGatherToPoplar(
auto output_item =
graph.addVariable(item_type, item_shape, debug_context);
graph.setTileMapping(output_item, output_tile);
seq.add(poplar::program::Copy(input_item, output_item));
input_copy_slices.push_back(input_item.expand({0}));
output_copy_slices.push_back(output_item.expand({0}));
output_slices.push_back(output_item.expand({0}));
}
}
// Copy input to output.
auto input_copy = poplar::concat(input_copy_slices);
auto output_copy = poplar::concat(output_copy_slices);
seq.add(poplar::program::Copy(input_copy, output_copy, false, debug_context));
// Full gather output tensor.
auto output = poplar::concat(output_slices);
outputs.push_back(output);
return seq;
Expand Down
19 changes: 12 additions & 7 deletions tessellate_ipu/linalg/tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,20 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
# Sorted rotation set: p < q indices.
rotset_sorted = jacobi_sort_rotation_set(rotset)
# On tile constant rotation set tensor building.
rotset_replicated = tile_constant_replicated(rotset_sorted, tiles=Atiles)
rotset_sharded = tile_constant_sharded(rotset_sorted, tiles=Atiles)
with jax.named_scope("rotset"):
rotset_replicated = tile_constant_replicated(rotset_sorted, tiles=Atiles)
rotset_sharded = tile_constant_sharded(rotset_sorted, tiles=Atiles)

# Compute Schur decomposition + on-tile update of columns.
cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore
jacobi_update_first_step_p, rotset_sharded, Apcols, Aqcols, N=N
)
# Replicate Schur decomposition across all A tiles: (2*N//2) comms.
cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles)
# Just copy Schur decomposition to associated V tiles.
cs_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles)
with jax.named_scope("cs_replicated_sharded"):
cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles)
# Just copy Schur decomposition to associated V tiles.
cs_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles)
cs_replicated, cs_Vtiles = tile_data_barrier(cs_replicated, cs_Vtiles)

# Second Jacobi update step.
cs_replicated, Apcols, Aqcols = tile_map( # type:ignore
Expand All @@ -177,8 +180,10 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)
# Move columns between tiles. 2*N commns per tile.
# NOTE: this inter-tile comm is keeping the p < q property on A and V columns.
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols, rotset)
Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols, rotset)
with jax.named_scope("Apqcols_rotation"):
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols, rotset)
with jax.named_scope("Vpqcols_rotation"):
Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols, rotset)
# Next rotation set.
rotset = jacobi_next_rotation_set(rotset)

Expand Down

0 comments on commit d5b7c46

Please sign in to comment.