diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index fbcaabd98fff..e854a36a50dc 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -337,16 +337,19 @@ def _create_local_shards(self) -> Sequence[Shard]: for db in self._device_buffers: device = db.device() index, rid = global_indices_rid[device] - if db.aval is None: - db.aval = core.ShapedArray(db.shape, db.dtype) out.append(Shard(device, index, rid, db)) return out - @pxla.maybe_cached_property + @property def local_shards(self) -> Sequence[Shard]: + for s in self._local_shards: + # Ignore the type because mypy thinks data is None but local_shards + # cannot have data=None which is checked in `_create_local_shards`. + if s.data.aval is None: # type: ignore + s.data.aval = core.ShapedArray(s.data.shape, s.data.dtype) # type: ignore return self._local_shards - @pxla.maybe_cached_property + @property def global_shards(self) -> Sequence[Shard]: # Populating global_shards lazily (i.e. when requested) because populating # sthem eagerly leads to a performance regression when training on large