Skip to content

feat: disassemble IFRT array into shards for serialization #1136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/Sharding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1118,4 +1118,22 @@ Fine-grained control over the sharding propagation pipeline.
enable_insert_explicit_collectives::Bool = false
end

"""
disassemble_into_single_device_arrays(x)

Disassembles a sharded array into a vector of single device shards. Each element in the
returned vector is a pair mapping the array slices to the single device shard. To transfer
the data to host, call `Array` on the single device shards.

!!! note

For distributed arrays, each process only returns the shards that are addressable to
that process.

!!! warning

Only supported for IFRT runtime.
"""
function disassemble_into_single_device_arrays end

end
65 changes: 65 additions & 0 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,71 @@ function ConcreteIFRTArray{T,N}(x::AnyConcreteIFRTArray; kwargs...) where {T,N}
)
end

function Sharding.disassemble_into_single_device_arrays(
x::ConcreteIFRTArray{T,N,<:Sharding.ShardInfo{<:Sharding.NoSharding}}
) where {T,N}
return [ntuple(Returns(Colon()), N) => x]
end

function Sharding.disassemble_into_single_device_arrays(
x::ConcreteIFRTArray{T,N}
) where {T,N}
single_device_shards = XLA.IFRT.disassemble_into_single_device_arrays(x.data, true)

padded_size = size(x) .+ get_padding(x)

if x.sharding.sharding isa Sharding.HloSharding
(; hlo_sharding) = x.sharding.sharding
else
(; hlo_sharding) = Sharding.HloSharding(x.sharding.sharding, padded_size)
end

all_devices = XLA.get_device.((XLA.client(x),), x.sharding.mesh.device_ids)
array_slices, _ = XLA.sharding_to_concrete_array_indices(
convert(XLA.CondensedOpSharding, hlo_sharding),
padded_size,
x.sharding.mesh.logical_device_ids,
)

mapping = [
slice => ConcreteIFRTArray{T,N}(
XLA.IFRT.AsyncArray(shard, nothing),
map(length, slice),
Sharding.NoShardInfo(),
) for
(slice, shard, device) in zip(array_slices, single_device_shards, all_devices) if
XLA.is_addressable(device)
]

has_padding(x) || return mapping

mapping_unpadded = Vector{eltype(mapping)}(undef, length(mapping))
for (i, (slice, shard)) in enumerate(mapping)
chop_ends = map(enumerate(slice)) do (i, idx_range)
last(idx_range) > size(x, i) && return last(idx_range) - size(x, i)
return 0
end

if all(iszero, chop_ends)
mapping_unpadded[i] = mapping[i]
else
new_slice = map(zip(slice, chop_ends)) do (idx_range, chop)
chop == 0 && return idx_range
return first(idx_range):(last(idx_range) - chop)
end
if !any(iszero ∘ length, new_slice)
mapping_unpadded[i] =
Tuple(new_slice) => shard[map(Base.OneTo ∘ length, new_slice)...]
end
end
end

return [
mapping_unpadded[i] for
i in 1:length(mapping_unpadded) if isassigned(mapping_unpadded, i)
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the typical case that we have one array per process, we'll use something like

adev = Sharding.disassemble_into_single_device_arrays(a)
acpu = Array(adev[1])

is that right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think you will have 1 array per process. If you have 4 gpus per node based on our current run configurations, you will have 4 arrays per process.

to get the arrays you can do [s => Array(x) for (s, x) in ret_val]. s is basically telling you the slice of this array as part of the global array

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right ok thanks. per device, not per process

end

## ConcreteRNG
mutable struct ConcreteRNG{S<:AbstractConcreteArray} <: Random.AbstractRNG
seed::S
Expand Down
2 changes: 1 addition & 1 deletion src/xla/IFRT/Array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ function XLA.to_host(buffer::Array, data, reactant_sharding)
client = XLA.client(buffer)
all_devices = XLA.get_device.((client,), reactant_sharding.mesh.device_ids)

if any(XLA.is_addressable, all_devices)
if all(XLA.is_addressable, all_devices)
# Take a fast path if all devices are addressable
array_slices, _ = XLA.sharding_to_concrete_array_indices(
convert(XLA.CondensedOpSharding, hlo_sharding),
Expand Down
Loading