Skip to content

Commit

Permalink
Add checkpoint unit test (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mwyatt authored Jan 30, 2025
1 parent 5e67280 commit 1516176
Show file tree
Hide file tree
Showing 11 changed files with 248 additions and 42 deletions.
11 changes: 7 additions & 4 deletions arctic_training/checkpoint/ds_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def checkpoint_tag(self) -> str:

@property
def client_state(self) -> Dict[str, Any]:
return {
state = {
"end_of_epoch": self.trainer.epoch_idx,
"torch_random_state": torch.get_rng_state(),
"torch_cuda_random_state": torch.cuda.get_rng_state(),
"np_random_state": np.random.get_state(),
"python_random_state": random.getstate(),
}
if self.device != torch.device("cpu"):
state["torch_cuda_random_state"] = torch.cuda.get_rng_state()
return state

def save(self, model) -> None:
model.save_checkpoint(
Expand All @@ -59,9 +61,10 @@ def load(self, model) -> None:
return
_, client_states = model.load_checkpoint(self.checkpoint_dir)

self.trainer.global_step = model.global_step
self.trainer.global_step = model.global_steps
self.trainer.epoch_idx = client_states["end_of_epoch"] + 1
torch.set_rng_state(client_states["torch_random_state"])
torch.cuda.set_rng_state(client_states["torch_cuda_random_state"])
np.random.set_state(client_states["np_random_state"])
random.setstate(client_states["python_random_state"])
if self.device != torch.device("cpu"):
torch.cuda.set_rng_state(client_states["torch_cuda_random_state"])
5 changes: 5 additions & 0 deletions arctic_training/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def main():


def run_script():
import deepspeed.comm as dist

from arctic_training.config.trainer import get_config

parser = argparse.ArgumentParser()
Expand All @@ -69,3 +71,6 @@ def run_script():
config = get_config(args.config)
trainer = config.trainer
trainer.train()
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
13 changes: 8 additions & 5 deletions arctic_training/config/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,14 @@ def validate_single_checkpoint_resume(self) -> Self:
return self


def get_config(config_file: Path) -> BaseConfig:
with open(config_file, "r") as f:
config_dict = yaml.safe_load(f)

config_dir = config_file.parent
def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig:
if isinstance(config_file_or_dict, dict):
config_dict = config_file_or_dict
config_dir = Path.cwd()
else:
with open(config_file_or_dict, "r") as f:
config_dict = yaml.safe_load(f)
config_dir = config_file_or_dict.parent

trainer_type = config_dict.get("type", TRAINER_DEFAULT)
config_dict["type"] = trainer_type
Expand Down
2 changes: 1 addition & 1 deletion arctic_training/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class DataFactory(ABC, CallbackMixin):
specify the DataFactory to use.
"""

config_type: Type[DataConfig]
config_type: Type[DataConfig] = DataConfig
"""
The type of the DataConfig object that this DataFactory uses. Any
DataFactory-specific options should be specified in this class.
Expand Down
10 changes: 6 additions & 4 deletions arctic_training/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def __init__(self, config: "TrainerConfig") -> None:
engine(self) for engine in self.config.checkpoint_engines
]

for engine in self.checkpoint_engines:
if engine.config.auto_resume:
engine.load(self.model)

def _set_seeds(self, seed: int) -> None:
logger.info(f"Setting random seeds to {seed}")
torch.manual_seed(seed)
Expand Down Expand Up @@ -241,10 +245,8 @@ def step(self, batch: Dict[str, torch.Tensor]) -> None:
self.config.exit_iteration > 0
and self.config.exit_iteration == self.global_step
):
logger.info(f"Hit exit iteration of {self.global_step}, forcing exit")
torch.distributed.barrier()
torch.distributed.destroy_process_group()
exit()
self.early_stop = True
logger.info(f"Hit exit iteration of {self.global_step}, ending training")

@callback_wrapper("epoch")
def epoch(self) -> None:
Expand Down
52 changes: 52 additions & 0 deletions tests/checkpoint/test_ds_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
from utils import models_are_equal

from arctic_training.config.trainer import get_config


@pytest.mark.cpu
def test_ds_engine(tmp_path):
config_dict = {
"type": "sft",
"exit_iteration": 2,
"model": {
"type": "random-weight-hf",
"name_or_path": "HuggingFaceTB/SmolLM-135M-Instruct",
"attn_implementation": "eager",
"dtype": "float32",
},
"data": {
"type": "noop",
"sources": [],
},
"optimizer": {
"type": "cpu-adam",
},
"scheduler": {
"type": "noop",
},
"checkpoint": {
"type": "deepspeed",
"auto_resume": True,
"output_dir": str(tmp_path / "checkpoints"),
"save_end_of_training": True,
},
}

config = get_config(config_dict)
trainer = config.trainer

# Force checkpoint to be saved despite no training happening
trainer.training_finished = True
trainer.checkpoint()

# Store original model for comparison later
original_model = trainer.model

config_dict["seed"] = 0 # Make sure newly initialized model is different
config = get_config(config_dict)
trainer = config.trainer

loaded_model = trainer.model
assert models_are_equal(original_model, loaded_model), "Models are not equal"
# TODO: Add assertion on optimizer state
51 changes: 51 additions & 0 deletions tests/checkpoint/test_hf_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
from utils import models_are_equal

from arctic_training.config.trainer import get_config


@pytest.mark.cpu
def test_hf_engine(tmp_path):
config_dict = {
"type": "sft",
"model": {
"type": "random-weight-hf",
"name_or_path": "HuggingFaceTB/SmolLM-135M-Instruct",
"attn_implementation": "eager",
"dtype": "float32",
},
"data": {
"type": "noop",
"sources": [],
},
"optimizer": {
"type": "cpu-adam",
},
"scheduler": {
"type": "noop",
},
"checkpoint": {
"type": "huggingface",
"output_dir": str(tmp_path / "checkpoints"),
"save_end_of_training": True,
},
}

config = get_config(config_dict)
trainer = config.trainer

# Force checkpoint to be saved despite no training happening
trainer.training_finished = True
trainer.checkpoint()

# Store original model for comparison later
original_model = trainer.model

config_dict["model"]["name_or_path"] = str(
trainer.checkpoint_engines[0].checkpoint_dir
)
config = get_config(config_dict)
trainer = config.trainer

loaded_model = trainer.model
assert models_are_equal(original_model, loaded_model), "Models are not equal"
9 changes: 9 additions & 0 deletions tests/checkpoint/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch


def models_are_equal(model_a: torch.nn.Module, model_b: torch.nn.Module) -> bool:
for param_a, param_b in zip(model_a.parameters(), model_b.parameters()):
if not param_a.data.eq(param_b.data).all():
return False

return True
65 changes: 65 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest


def pytest_configure(config):
# TODO: Make it so that cpu and gpu tests can be run with a single command.
# This requires some work with tearing down/setting up dist environments
# that have not been worked out yet.
if not config.option.markexpr:
config.option.markexpr = "cpu"
if "gpu" in config.option.markexpr and "cpu" in config.option.markexpr:
pytest.fail("Cannot run tests with both 'gpu' and 'cpu' marks")
if "cpu" in config.option.markexpr:
_setup_cpu_dist()
if "gpu" in config.option.markexpr:
_setup_gpu_dist()


def _setup_cpu_dist():
os.environ["DS_ACCELERATOR"] = "cpu"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_SIZE"] = "1"

from deepspeed.comm import init_distributed

init_distributed(auto_mpi_discovery=False)


def _setup_gpu_dist():
os.environ["DS_ACCELERATOR"] = "cuda"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_SIZE"] = "1"

from deepspeed.comm import init_distributed

init_distributed(auto_mpi_discovery=False)


# Eventually when we can run cpu + gpu tests, we will want to order the tests to
# avoid thrashing back and forth between the distirbuted environments.
"""
@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(config, items):
t scpu_tests = [item for item in items if "cpu" in item.keywords]
gpu_tests = [item for item in items if "gpu" in item.keywords]
# Reorder tests: all 'cpu' tests first, then 'gpu' tests
items[:] = cpu_tests + gpu_tests
"""


# Load helper functions automatically for all tests
@pytest.fixture(scope="session", autouse=True)
def helpers_code_path() -> None:
from . import test_helpers # noqa: F401
35 changes: 34 additions & 1 deletion tests/trainer/trainer_test_helpers.py → tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from transformers import PreTrainedModel

from arctic_training import register
from arctic_training.data.factory import DataFactory
from arctic_training.data.source import DataSource
from arctic_training.model.hf_factory import HFModelFactory
from arctic_training.optimizer import FusedAdamOptimizerFactory
from arctic_training.optimizer.adam_factory import FusedAdamOptimizerFactory
from arctic_training.optimizer.factory import OptimizerFactory
from arctic_training.scheduler.factory import SchedulerFactory


@register
Expand Down Expand Up @@ -52,3 +55,33 @@ def create_optimizer(self, model, optimizer_config):
lr=optimizer_config.learning_rate,
betas=optimizer_config.betas,
)


@register
class NoOpOptimizerFactory(OptimizerFactory):
name = "noop"

def create_optimizer(self, model, optimizer_config):
return None


@register
class NoOpDataFactory(DataFactory):
name = "noop"

def __call__(self):
return None, None

def tokenize_fn(self):
pass

def collate_fn(self):
pass


@register
class NoOpSchedulerFactory(SchedulerFactory):
name = "noop"

def create_scheduler(self, optimizer):
return None
37 changes: 10 additions & 27 deletions tests/trainer/test_sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import os
import subprocess
from pathlib import Path

import pytest
import yaml

from arctic_training.config.trainer import get_config


@pytest.mark.gpu
def test_sft_trainer(tmp_path):
config_dict = {
"type": "sft",
"code": str(Path(__file__).parent / "trainer_test_helpers.py"),
"exit_iteration": 2,
"micro_batch_size": 1,
"model": {
Expand All @@ -27,20 +24,16 @@ def test_sft_trainer(tmp_path):
with open(config_path, "w") as f:
f.write(yaml.dump(config_dict))

result = subprocess.run(
f"arctic_training {config_path}", shell=True, text=True, capture_output=True
)

if result.returncode != 0:
print(result.stderr)
pytest.fail("Training failed")
config = get_config(config_path)
trainer = config.trainer
trainer.train()
assert trainer.global_step > 0, "Training did not run"


@pytest.mark.cpu
def test_sft_trainer_cpu(tmp_path):
config_dict = {
"type": "sft",
"code": str(Path(__file__).parent / "trainer_test_helpers.py"),
"exit_iteration": 2,
"micro_batch_size": 1,
"model": {
Expand All @@ -66,17 +59,7 @@ def test_sft_trainer_cpu(tmp_path):
with open(config_path, "w") as f:
f.write(yaml.dump(config_dict))

env = os.environ.copy()
env["DS_ACCELERATOR"] = "cpu"

result = subprocess.run(
f"arctic_training {config_path}",
shell=True,
text=True,
capture_output=True,
env=env,
)

if result.returncode != 0:
print(result.stderr)
pytest.fail("Training failed")
config = get_config(config_path)
trainer = config.trainer
trainer.train()
assert trainer.global_step > 0, "Training did not run"

0 comments on commit 1516176

Please sign in to comment.