Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve (slightly!) tile_gather Popvision profile. #41

Merged
merged 1 commit into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions tessellate_ipu/lib/tile_array_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ poplar::program::Program lowerTilePutShardedToPoplar(
auto output = createShardedVariable(
graph, input.elementType(), input[0].shape(), tile_array, debug_context);
// Copy data tensor into the output.
auto prog = poplar::program::Copy(input, output);
auto prog = poplar::program::Copy(input, output, false, debug_context);
outputs.push_back(output);
return prog;
}
Expand All @@ -112,7 +112,8 @@ poplar::program::Program lowerTilePutReplicatedToPoplar(
auto output = createShardedVariable(graph, input.elementType(), input.shape(),
tile_array, debug_context);
// Copy data tensor into the output.
auto prog = poplar::program::Copy(input_broadcasted, output, false);
auto prog =
poplar::program::Copy(input_broadcasted, output, false, debug_context);
outputs.push_back(output);
return prog;
}
Expand All @@ -128,10 +129,15 @@ poplar::program::Program lowerTileGatherToPoplar(
const auto& input = inputs[0];
const auto item_shape = input[0].shape();
const auto item_type = input.elementType();
const size_t num_tiles = params.tiles.size();

// Create the output tensor per gather index, then concat.
auto seq = poplar::program::Sequence();
// All output slices
std::vector<poplar::Tensor> output_slices;
// Slices requiring copying.
std::vector<poplar::Tensor> input_copy_slices;
std::vector<poplar::Tensor> output_copy_slices;
for (std::size_t idx = 0; idx < params.tiles.size(); ++idx) {
const auto gather_idx = params.indices[idx];
// Get the proper item at the gather index.
Expand All @@ -146,10 +152,16 @@ poplar::program::Program lowerTileGatherToPoplar(
auto output_item =
graph.addVariable(item_type, item_shape, debug_context);
graph.setTileMapping(output_item, output_tile);
seq.add(poplar::program::Copy(input_item, output_item));
input_copy_slices.push_back(input_item.expand({0}));
output_copy_slices.push_back(output_item.expand({0}));
output_slices.push_back(output_item.expand({0}));
}
}
// Copy input to output.
auto input_copy = poplar::concat(input_copy_slices);
auto output_copy = poplar::concat(output_copy_slices);
seq.add(poplar::program::Copy(input_copy, output_copy, false, debug_context));
// Full gather output tensor.
auto output = poplar::concat(output_slices);
outputs.push_back(output);
return seq;
Expand Down
19 changes: 12 additions & 7 deletions tessellate_ipu/linalg/tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,20 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
# Sorted rotation set: p < q indices.
rotset_sorted = jacobi_sort_rotation_set(rotset)
# On tile constant rotation set tensor building.
rotset_replicated = tile_constant_replicated(rotset_sorted, tiles=Atiles)
rotset_sharded = tile_constant_sharded(rotset_sorted, tiles=Atiles)
with jax.named_scope("rotset"):
rotset_replicated = tile_constant_replicated(rotset_sorted, tiles=Atiles)
rotset_sharded = tile_constant_sharded(rotset_sorted, tiles=Atiles)

# Compute Schur decomposition + on-tile update of columns.
cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore
jacobi_update_first_step_p, rotset_sharded, Apcols, Aqcols, N=N
)
# Replicate Schur decomposition across all A tiles: (2*N//2) comms.
cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles)
# Just copy Schur decomposition to associated V tiles.
cs_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles)
with jax.named_scope("cs_replicated_sharded"):
cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles)
# Just copy Schur decomposition to associated V tiles.
cs_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles)
cs_replicated, cs_Vtiles = tile_data_barrier(cs_replicated, cs_Vtiles)

# Second Jacobi update step.
cs_replicated, Apcols, Aqcols = tile_map( # type:ignore
Expand All @@ -177,8 +180,10 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)
# Move columns between tiles. 2*N commns per tile.
# NOTE: this inter-tile comm is keeping the p < q property on A and V columns.
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols, rotset)
Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols, rotset)
with jax.named_scope("Apqcols_rotation"):
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols, rotset)
with jax.named_scope("Vpqcols_rotation"):
Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols, rotset)
# Next rotation set.
rotset = jacobi_next_rotation_set(rotset)

Expand Down