Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 457107616
  • Loading branch information
jpuigcerver authored and copybara-github committed Jun 24, 2022
1 parent 6ed0f10 commit 219e5bd
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 30 deletions.
90 changes: 79 additions & 11 deletions vmoe/checkpoints/partitioned.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@

"""Functions for checkpointing partitioned models."""
import collections
import enum
import functools
import itertools
import os
from typing import Any, Iterator, Iterable, Mapping, Optional, Sequence, Tuple, Union

import jax
import jax.experimental.maps as maps
import jax.experimental.pjit as pjit
from jax.experimental import maps
from jax.experimental import pjit
import numpy as np
import vmoe.checkpoints.base
import vmoe.checkpoints.serialization
import vmoe.checkpoints.types
import vmoe.multihost_utils
import vmoe.utils
Expand All @@ -45,10 +47,17 @@
SliceNdArray = vmoe.checkpoints.types.SliceNdArray
ThreadPool = vmoe.checkpoints.base.ThreadPool

from_state_dict = vmoe.checkpoints.serialization.from_state_dict
to_state_dict = vmoe.checkpoints.serialization.to_state_dict
safe_map = vmoe.utils.safe_map
safe_zip = vmoe.utils.safe_zip


class Version(enum.Enum):
UNKNOWN = None
V1 = '20220622'


def restore_checkpoint(*,
prefix: str,
tree: Optional[PyTree],
Expand Down Expand Up @@ -77,12 +86,55 @@ def restore_checkpoint(*,
raise ValueError("You must pass a non-empty mesh. If you didn't pass any, "
"check that you called restore_checkpoint from a "
"maps.mesh context.")
index = vmoe.checkpoints.base.restore_checkpoint(prefix + '.index', {
'shard_count': 0,
'index': tree if tree is not None else axis_resources,
})
# Restore index as a state dict.
index = vmoe.checkpoints.base.restore_checkpoint(prefix + '.index')
version = Version(index.get('version', Version.UNKNOWN))
shard_count = index['shard_count']
index = index['index']
if version == Version.UNKNOWN:
if tree is None and axis_resources is None:
raise ValueError(
'You must specify the tree and/or axis_resources arguments when '
f'restoring checkpoints from {Version.UNKNOWN}.')

# The index has the same structure as the input tree and the axis_resources.
# Obtain such structure before calling _restore_checkpoint.
index = from_state_dict(target=tree or axis_resources, state=index)
return _restore_checkpoint_from_index(
prefix=prefix,
shard_count=shard_count,
index=index,
axis_resources=axis_resources,
mesh=mesh,
thread_pool=thread_pool)
if version == Version.V1:
if axis_resources is None:
axis_resources_state_dict = None
else:
axis_resources_state_dict = to_state_dict(axis_resources)
state_dict = _restore_checkpoint_from_index(
prefix=prefix,
shard_count=shard_count,
index=index,
axis_resources=axis_resources_state_dict,
mesh=mesh,
thread_pool=thread_pool)
if (tree or axis_resources) is not None:
return from_state_dict(target=tree or axis_resources, state=state_dict)
else:
return state_dict
raise ValueError(f'Unsupported checkpoint version: {version!r}')


def _restore_checkpoint_from_index(
*,
prefix: str,
shard_count: int,
index: PyTree,
axis_resources: Optional[PyTree],
mesh: Mesh,
thread_pool: Optional[ThreadPool] = None) -> PyTree:
"""Restores a PyTree of partitioned arrays from an index."""
# axis_resources indicates how the data to be loaded will be partitioned.
# If no axis_resources is given, assume that we don't want any partitioning.
# This implies that all devices will store a copy of all the parameters, thus
Expand Down Expand Up @@ -158,7 +210,8 @@ def save_checkpoint(*,
num_shards: int = 0,
overwrite: bool = True,
makedirs: bool = True,
thread_pool: Optional[ThreadPool] = None) -> AsyncResult:
thread_pool: Optional[ThreadPool] = None,
version: Version = Version.V1) -> AsyncResult:
"""Saves a PyTree of partitioned arrays into a sharded checkpoint.
Args:
Expand All @@ -177,19 +230,32 @@ def save_checkpoint(*,
If False, the existence of the base dir is assumed.
thread_pool: ThreadPool used to write the checkpoint files asynchronously.
If None, a new pool will be created.
version: Write checkpoints using this version. DO NOT CHANGE UNLESS YOU KNOW
WHAT YOU ARE DOING.
Returns:
An AsyncResult object.
"""
# We convert the tree (and axis_resource) to a pure nested dictionary here.
# Before Version.V1, this was done implicitly in the subsequent calls to the
# function base.save_checkpoint(). However, the leaves of the PyTree were
# given by jax.tree_leaves(tree), which may give a different order than
# jax.tree_leaves(to_state_dict(tree)). When this happened, this prevented us
# to restore the checkpoint without passing any tree/axis_resources. Yet, this
# is a useful feature if we want to restore the checkpoint as a pure Python
# dictionary (i.e. a Flax state dict), when we don't know the original
# structure of the PyTree in the checkpoint.
if version is not Version.UNKNOWN:
tree = to_state_dict(tree)
axis_resources = to_state_dict(axis_resources)
if mesh is None:
mesh = _get_current_mesh()
if mesh.empty:
raise ValueError("You must pass a non-empty mesh. If you didn't pass any, "
"check that you called save_checkpoint from a maps.mesh "
"context.")
filepath_map = _make_save_checkpoint_filepath_map(prefix, tree,
axis_resources, mesh,
num_shards)
filepath_map = _make_save_checkpoint_filepath_map(
prefix, tree, axis_resources, mesh, num_shards, version)
if makedirs:
# Process 0 creates the workdir if it doesn't exist. All processes wait
# until it's done.
Expand Down Expand Up @@ -309,7 +375,7 @@ def _intersect_slice_nd(

def _make_save_checkpoint_filepath_map(
prefix: str, tree: PyTree, axis_resources: PyTree, mesh: Mesh,
num_shards: int = 0):
num_shards: int = 0, version: Version = Version.V1):
"""Makes a dictionary of filepaths mapping to the content that must be serialized."""
filepath_map = {} # Result.
tree_leaves, struct = jax.tree_flatten(tree)
Expand Down Expand Up @@ -373,6 +439,8 @@ def _make_save_checkpoint_filepath_map(
'shard_count': shard_count,
'index': struct.unflatten(index_leaves),
}
if version is not Version.UNKNOWN:
filepath_map[prefix + '.index']['version'] = version.value
# Assign the LazyArrayChunks objects to the corresponding shard filepaths.
shard_fpath_fn = functools.partial(
vmoe.checkpoints.base.add_shard_suffix,
Expand Down
62 changes: 43 additions & 19 deletions vmoe/checkpoints/partitioned_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Slice = partitioned.Slice
SliceNd = partitioned.SliceNd
SliceNdArray = partitioned.SliceNdArray
Version = partitioned.Version


class MakeSliceNdArrayTest(absltest.TestCase):
Expand Down Expand Up @@ -265,9 +266,15 @@ def test_intersect_slice_nd(self, ckpt_slice_nd, global_slice_nd,


class RestoreAndSaveCheckpointTest(parameterized.TestCase):
# Note: when restoring a checkpoint from an UNKNOWN version, we require now
# that either the tree or the axis_resources are given. This is because the
# order of the leaves in jax.tree_structure(foo) might different from that in
# jax.tree_structure(serialization.to_state_dict(foo)), which can cause an
# exception or loading the wrong values in the state dict.
# See comment in partitioned.save_checkpoint().

@parameterized.named_parameters(
('process_0_of_2', 0, None,
('process_0_of_2_ver_v1', 0, None,
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
Expand All @@ -277,8 +284,8 @@ class RestoreAndSaveCheckpointTest(parameterized.TestCase):
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3]]),
('process_1_of_2', 1, None,
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3]], Version.V1),
('process_1_of_2_ver_v1', 1, None,
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
Expand All @@ -288,32 +295,32 @@ class RestoreAndSaveCheckpointTest(parameterized.TestCase):
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3]]),
('process_0_of_2_axis_resources',
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3]], Version.V1),
('process_0_of_2_axis_resources_ver_v1',
0, {'x': None, 'y': None, 'z': PartitionSpec('a')},
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]]),
('process_1_of_2_axis_resources',
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]], Version.V1),
('process_1_of_2_axis_resources_ver_unknown',
1, {'x': None, 'y': None, 'z': PartitionSpec('a')},
[[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3]]),
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3]], Version.UNKNOWN),
)
@mock.patch.object(partitioned.jax, 'process_count', return_value=2)
def test_restore_checkpoint(self, process_index, axis_resources, expected_z,
_):
version, _):
devices = np.asarray(
[_make_device(process_index=i // 2, id=i) for i in range(4)])
mesh = partitioned.Mesh(devices.reshape((2, 2)), ('a', 'b'))
prefix = self.create_tempfile().full_path
def side_effect(filepath, *unused_args, **unused_kwargs):
return {
prefix + '.index': self._get_expected_index(),
prefix + '.index': self._get_expected_index(version),
prefix + '.data-00000-of-00004': self._get_expected_shard_content(0),
prefix + '.data-00001-of-00004': self._get_expected_shard_content(1),
prefix + '.data-00002-of-00004': self._get_expected_shard_content(2),
Expand Down Expand Up @@ -344,23 +351,36 @@ def side_effect(filepath, *unused_args, **unused_kwargs):
[2, 2, 2, 2, 2]])
np.testing.assert_array_almost_equal(restored['z'], expected_z)

def test_restore_checkpoint_no_tree_nor_axis_resources_unknown_ver(self):
devices = np.asarray([_make_device(process_index=0, id=0)])
mesh = partitioned.Mesh(devices, ('a',))
prefix = self.create_tempfile().full_path
index = self._get_expected_index(version=Version.UNKNOWN)
with mock.patch.object(partitioned.vmoe.checkpoints.base,
'restore_checkpoint', return_value=index):
with self.assertRaisesRegex(
ValueError,
'You must specify the tree and/or axis_resources arguments when'):
partitioned.restore_checkpoint(
prefix=prefix, tree=None, axis_resources=None, mesh=mesh)

def test_restore_checkpoint_empty_mesh(self):
prefix = self.create_tempfile().full_path
with self.assertRaisesRegex(ValueError, 'You must pass a non-empty mesh'):
partitioned.restore_checkpoint(
prefix=prefix, tree=None, axis_resources=None)

@parameterized.named_parameters(
('process_0', 0, 2, 0),
('process_1', 1, 1, 2),
('process_2', 2, 1, 1),
('process_3', 3, 1, 3),
('process_0_ver_unk', 0, 2, 0, Version.UNKNOWN),
('process_1_ver_unk', 1, 1, 2, Version.UNKNOWN),
('process_2_ver_v1', 2, 1, 1, Version.V1),
('process_3_ver_v1', 3, 1, 3, Version.V1),
)
@mock.patch.object(partitioned.jax, 'process_count', return_value=4)
@mock.patch.object(
partitioned.vmoe.multihost_utils, 'sync_devices', return_value=None)
def test_save_checkpoint(self, process_index, num_written_files, shard,
unused_1, unused_2):
version, unused_1, unused_2):
devices = np.asarray(
[_make_device(process_index=i, id=i) for i in range(4)]).reshape((2, 2))
mesh = partitioned.Mesh(devices, ('a', 'b'))
Expand Down Expand Up @@ -388,14 +408,15 @@ def test_save_checkpoint(self, process_index, num_written_files, shard,
with mock.patch.object(jax._src.lib.xla_bridge, 'process_index',
return_value=process_index):
async_result = partitioned.save_checkpoint(
prefix=prefix, tree=tree, axis_resources=axis_resources, mesh=mesh)
prefix=prefix, tree=tree, axis_resources=axis_resources, mesh=mesh,
version=version)
written_files = async_result.get()
# Check that the process writes the expected number of files.
self.assertLen(written_files, num_written_files)
# If the process writes the index, load the index and check its icontent.
if num_written_files == 2:
index_content = base.restore_checkpoint(prefix + '.index')
expected_index_content = self._get_expected_index()
expected_index_content = self._get_expected_index(version)
chex.assert_trees_all_equal_comparator(
lambda x, y: x == y,
lambda x, y: f'IndexInfos do not match:\n{x}\n{y}',
Expand Down Expand Up @@ -427,8 +448,8 @@ def _compare_array_chunks(self, a, b):
return False
return all(map(lambda x, y: (x == y).all, a, b))

def _get_expected_index(self):
return {
def _get_expected_index(self, version: Version):
index = {
'shard_count': 4,
'index': {
'x': partitioned.IndexInfo(
Expand All @@ -449,6 +470,9 @@ def _get_expected_index(self):
shards=(0, 2, 1, 3)),
},
}
if version != Version.UNKNOWN:
index['version'] = version.value
return index

def _get_expected_shard_content(self, shard):
"""Returns the ArrayChunks data stored in each shard."""
Expand Down

0 comments on commit 219e5bd

Please sign in to comment.