diff --git a/src/Sharding.jl b/src/Sharding.jl index 0574afc312..73bdd24f2e 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -1118,4 +1118,22 @@ Fine-grained control over the sharding propagation pipeline. enable_insert_explicit_collectives::Bool = false end +""" + disassemble_into_single_device_arrays(x) + +Disassembles a sharded array into a vector of single device shards. Each element in the +returned vector is a pair mapping the array slices to the single device shard. To transfer +the data to host, call `Array` on the single device shards. + +!!! note + + For distributed arrays, each process only returns the shards that are addressable to + that process. + +!!! warning + + Only supported for IFRT runtime. +""" +function disassemble_into_single_device_arrays end + end diff --git a/src/Types.jl b/src/Types.jl index a55bedcbd8..d989785007 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -434,6 +434,71 @@ function ConcreteIFRTArray{T,N}(x::AnyConcreteIFRTArray; kwargs...) where {T,N} ) end +function Sharding.disassemble_into_single_device_arrays( + x::ConcreteIFRTArray{T,N,<:Sharding.ShardInfo{<:Sharding.NoSharding}} +) where {T,N} + return [ntuple(Returns(Colon()), N) => x] +end + +function Sharding.disassemble_into_single_device_arrays( + x::ConcreteIFRTArray{T,N} +) where {T,N} + single_device_shards = XLA.IFRT.disassemble_into_single_device_arrays(x.data, true) + + padded_size = size(x) .+ get_padding(x) + + if x.sharding.sharding isa Sharding.HloSharding + (; hlo_sharding) = x.sharding.sharding + else + (; hlo_sharding) = Sharding.HloSharding(x.sharding.sharding, padded_size) + end + + all_devices = XLA.get_device.((XLA.client(x),), x.sharding.mesh.device_ids) + array_slices, _ = XLA.sharding_to_concrete_array_indices( + convert(XLA.CondensedOpSharding, hlo_sharding), + padded_size, + x.sharding.mesh.logical_device_ids, + ) + + mapping = [ + slice => ConcreteIFRTArray{T,N}( + XLA.IFRT.AsyncArray(shard, nothing), + map(length, slice), + Sharding.NoShardInfo(), + ) for + (slice, shard, device) in zip(array_slices, single_device_shards, all_devices) if + XLA.is_addressable(device) + ] + + has_padding(x) || return mapping + + mapping_unpadded = Vector{eltype(mapping)}(undef, length(mapping)) + for (i, (slice, shard)) in enumerate(mapping) + chop_ends = map(enumerate(slice)) do (i, idx_range) + last(idx_range) > size(x, i) && return last(idx_range) - size(x, i) + return 0 + end + + if all(iszero, chop_ends) + mapping_unpadded[i] = mapping[i] + else + new_slice = map(zip(slice, chop_ends)) do (idx_range, chop) + chop == 0 && return idx_range + return first(idx_range):(last(idx_range) - chop) + end + if !any(iszero ∘ length, new_slice) + mapping_unpadded[i] = + Tuple(new_slice) => shard[map(Base.OneTo ∘ length, new_slice)...] + end + end + end + + return [ + mapping_unpadded[i] for + i in 1:length(mapping_unpadded) if isassigned(mapping_unpadded, i) + ] +end + ## ConcreteRNG mutable struct ConcreteRNG{S<:AbstractConcreteArray} <: Random.AbstractRNG seed::S diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index 6da86d9c35..b3f989c7a3 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -236,7 +236,7 @@ function XLA.to_host(buffer::Array, data, reactant_sharding) client = XLA.client(buffer) all_devices = XLA.get_device.((client,), reactant_sharding.mesh.device_ids) - if any(XLA.is_addressable, all_devices) + if all(XLA.is_addressable, all_devices) # Take a fast path if all devices are addressable array_slices, _ = XLA.sharding_to_concrete_array_indices( convert(XLA.CondensedOpSharding, hlo_sharding),