diff --git a/arctic_training/checkpoint/ds_engine.py b/arctic_training/checkpoint/ds_engine.py index 6d2f4ba..edd906e 100644 --- a/arctic_training/checkpoint/ds_engine.py +++ b/arctic_training/checkpoint/ds_engine.py @@ -39,13 +39,15 @@ def checkpoint_tag(self) -> str: @property def client_state(self) -> Dict[str, Any]: - return { + state = { "end_of_epoch": self.trainer.epoch_idx, "torch_random_state": torch.get_rng_state(), - "torch_cuda_random_state": torch.cuda.get_rng_state(), "np_random_state": np.random.get_state(), "python_random_state": random.getstate(), } + if self.device != torch.device("cpu"): + state["torch_cuda_random_state"] = torch.cuda.get_rng_state() + return state def save(self, model) -> None: model.save_checkpoint( @@ -59,9 +61,10 @@ def load(self, model) -> None: return _, client_states = model.load_checkpoint(self.checkpoint_dir) - self.trainer.global_step = model.global_step + self.trainer.global_step = model.global_steps self.trainer.epoch_idx = client_states["end_of_epoch"] + 1 torch.set_rng_state(client_states["torch_random_state"]) - torch.cuda.set_rng_state(client_states["torch_cuda_random_state"]) np.random.set_state(client_states["np_random_state"]) random.setstate(client_states["python_random_state"]) + if self.device != torch.device("cpu"): + torch.cuda.set_rng_state(client_states["torch_cuda_random_state"]) diff --git a/arctic_training/cli.py b/arctic_training/cli.py index dad3a04..abd1115 100644 --- a/arctic_training/cli.py +++ b/arctic_training/cli.py @@ -46,6 +46,8 @@ def main(): def run_script(): + import deepspeed.comm as dist + from arctic_training.config.trainer import get_config parser = argparse.ArgumentParser() @@ -69,3 +71,6 @@ def run_script(): config = get_config(args.config) trainer = config.trainer trainer.train() + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() diff --git a/arctic_training/config/trainer.py b/arctic_training/config/trainer.py index b16a522..63e897a 100644 --- a/arctic_training/config/trainer.py +++ b/arctic_training/config/trainer.py @@ -281,11 +281,14 @@ def validate_single_checkpoint_resume(self) -> Self: return self -def get_config(config_file: Path) -> BaseConfig: - with open(config_file, "r") as f: - config_dict = yaml.safe_load(f) - - config_dir = config_file.parent +def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig: + if isinstance(config_file_or_dict, dict): + config_dict = config_file_or_dict + config_dir = Path.cwd() + else: + with open(config_file_or_dict, "r") as f: + config_dict = yaml.safe_load(f) + config_dir = config_file_or_dict.parent trainer_type = config_dict.get("type", TRAINER_DEFAULT) config_dict["type"] = trainer_type diff --git a/arctic_training/data/factory.py b/arctic_training/data/factory.py index 241da41..3662dee 100644 --- a/arctic_training/data/factory.py +++ b/arctic_training/data/factory.py @@ -52,7 +52,7 @@ class DataFactory(ABC, CallbackMixin): specify the DataFactory to use. """ - config_type: Type[DataConfig] + config_type: Type[DataConfig] = DataConfig """ The type of the DataConfig object that this DataFactory uses. Any DataFactory-specific options should be specified in this class. diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index cd6e4a2..e1bfc4d 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -159,6 +159,10 @@ def __init__(self, config: "TrainerConfig") -> None: engine(self) for engine in self.config.checkpoint_engines ] + for engine in self.checkpoint_engines: + if engine.config.auto_resume: + engine.load(self.model) + def _set_seeds(self, seed: int) -> None: logger.info(f"Setting random seeds to {seed}") torch.manual_seed(seed) @@ -241,10 +245,8 @@ def step(self, batch: Dict[str, torch.Tensor]) -> None: self.config.exit_iteration > 0 and self.config.exit_iteration == self.global_step ): - logger.info(f"Hit exit iteration of {self.global_step}, forcing exit") - torch.distributed.barrier() - torch.distributed.destroy_process_group() - exit() + self.early_stop = True + logger.info(f"Hit exit iteration of {self.global_step}, ending training") @callback_wrapper("epoch") def epoch(self) -> None: diff --git a/tests/checkpoint/test_ds_engine.py b/tests/checkpoint/test_ds_engine.py new file mode 100644 index 0000000..2d5c03b --- /dev/null +++ b/tests/checkpoint/test_ds_engine.py @@ -0,0 +1,52 @@ +import pytest +from utils import models_are_equal + +from arctic_training.config.trainer import get_config + + +@pytest.mark.cpu +def test_ds_engine(tmp_path): + config_dict = { + "type": "sft", + "exit_iteration": 2, + "model": { + "type": "random-weight-hf", + "name_or_path": "HuggingFaceTB/SmolLM-135M-Instruct", + "attn_implementation": "eager", + "dtype": "float32", + }, + "data": { + "type": "noop", + "sources": [], + }, + "optimizer": { + "type": "cpu-adam", + }, + "scheduler": { + "type": "noop", + }, + "checkpoint": { + "type": "deepspeed", + "auto_resume": True, + "output_dir": str(tmp_path / "checkpoints"), + "save_end_of_training": True, + }, + } + + config = get_config(config_dict) + trainer = config.trainer + + # Force checkpoint to be saved despite no training happening + trainer.training_finished = True + trainer.checkpoint() + + # Store original model for comparison later + original_model = trainer.model + + config_dict["seed"] = 0 # Make sure newly initialized model is different + config = get_config(config_dict) + trainer = config.trainer + + loaded_model = trainer.model + assert models_are_equal(original_model, loaded_model), "Models are not equal" + # TODO: Add assertion on optimizer state diff --git a/tests/checkpoint/test_hf_engine.py b/tests/checkpoint/test_hf_engine.py new file mode 100644 index 0000000..af50cb6 --- /dev/null +++ b/tests/checkpoint/test_hf_engine.py @@ -0,0 +1,51 @@ +import pytest +from utils import models_are_equal + +from arctic_training.config.trainer import get_config + + +@pytest.mark.cpu +def test_hf_engine(tmp_path): + config_dict = { + "type": "sft", + "model": { + "type": "random-weight-hf", + "name_or_path": "HuggingFaceTB/SmolLM-135M-Instruct", + "attn_implementation": "eager", + "dtype": "float32", + }, + "data": { + "type": "noop", + "sources": [], + }, + "optimizer": { + "type": "cpu-adam", + }, + "scheduler": { + "type": "noop", + }, + "checkpoint": { + "type": "huggingface", + "output_dir": str(tmp_path / "checkpoints"), + "save_end_of_training": True, + }, + } + + config = get_config(config_dict) + trainer = config.trainer + + # Force checkpoint to be saved despite no training happening + trainer.training_finished = True + trainer.checkpoint() + + # Store original model for comparison later + original_model = trainer.model + + config_dict["model"]["name_or_path"] = str( + trainer.checkpoint_engines[0].checkpoint_dir + ) + config = get_config(config_dict) + trainer = config.trainer + + loaded_model = trainer.model + assert models_are_equal(original_model, loaded_model), "Models are not equal" diff --git a/tests/checkpoint/utils.py b/tests/checkpoint/utils.py new file mode 100644 index 0000000..8dc1a3e --- /dev/null +++ b/tests/checkpoint/utils.py @@ -0,0 +1,9 @@ +import torch + + +def models_are_equal(model_a: torch.nn.Module, model_b: torch.nn.Module) -> bool: + for param_a, param_b in zip(model_a.parameters(), model_b.parameters()): + if not param_a.data.eq(param_b.data).all(): + return False + + return True diff --git a/tests/conftest.py b/tests/conftest.py index 3e86bce..d8453a8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,3 +12,68 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import os + +import pytest + + +def pytest_configure(config): + # TODO: Make it so that cpu and gpu tests can be run with a single command. + # This requires some work with tearing down/setting up dist environments + # that have not been worked out yet. + if not config.option.markexpr: + config.option.markexpr = "cpu" + if "gpu" in config.option.markexpr and "cpu" in config.option.markexpr: + pytest.fail("Cannot run tests with both 'gpu' and 'cpu' marks") + if "cpu" in config.option.markexpr: + _setup_cpu_dist() + if "gpu" in config.option.markexpr: + _setup_gpu_dist() + + +def _setup_cpu_dist(): + os.environ["DS_ACCELERATOR"] = "cpu" + os.environ["LOCAL_RANK"] = "0" + os.environ["RANK"] = "0" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_SIZE"] = "1" + + from deepspeed.comm import init_distributed + + init_distributed(auto_mpi_discovery=False) + + +def _setup_gpu_dist(): + os.environ["DS_ACCELERATOR"] = "cuda" + os.environ["LOCAL_RANK"] = "0" + os.environ["RANK"] = "0" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_SIZE"] = "1" + + from deepspeed.comm import init_distributed + + init_distributed(auto_mpi_discovery=False) + + +# Eventually when we can run cpu + gpu tests, we will want to order the tests to +# avoid thrashing back and forth between the distirbuted environments. +""" +@pytest.hookimpl(trylast=True) +def pytest_collection_modifyitems(config, items): + t scpu_tests = [item for item in items if "cpu" in item.keywords] + gpu_tests = [item for item in items if "gpu" in item.keywords] + + # Reorder tests: all 'cpu' tests first, then 'gpu' tests + items[:] = cpu_tests + gpu_tests +""" + + +# Load helper functions automatically for all tests +@pytest.fixture(scope="session", autouse=True) +def helpers_code_path() -> None: + from . import test_helpers # noqa: F401 diff --git a/tests/trainer/trainer_test_helpers.py b/tests/test_helpers.py similarity index 69% rename from tests/trainer/trainer_test_helpers.py rename to tests/test_helpers.py index eb9677f..f76f34d 100644 --- a/tests/trainer/trainer_test_helpers.py +++ b/tests/test_helpers.py @@ -5,9 +5,12 @@ from transformers import PreTrainedModel from arctic_training import register +from arctic_training.data.factory import DataFactory from arctic_training.data.source import DataSource from arctic_training.model.hf_factory import HFModelFactory -from arctic_training.optimizer import FusedAdamOptimizerFactory +from arctic_training.optimizer.adam_factory import FusedAdamOptimizerFactory +from arctic_training.optimizer.factory import OptimizerFactory +from arctic_training.scheduler.factory import SchedulerFactory @register @@ -52,3 +55,33 @@ def create_optimizer(self, model, optimizer_config): lr=optimizer_config.learning_rate, betas=optimizer_config.betas, ) + + +@register +class NoOpOptimizerFactory(OptimizerFactory): + name = "noop" + + def create_optimizer(self, model, optimizer_config): + return None + + +@register +class NoOpDataFactory(DataFactory): + name = "noop" + + def __call__(self): + return None, None + + def tokenize_fn(self): + pass + + def collate_fn(self): + pass + + +@register +class NoOpSchedulerFactory(SchedulerFactory): + name = "noop" + + def create_scheduler(self, optimizer): + return None diff --git a/tests/trainer/test_sft_trainer.py b/tests/trainer/test_sft_trainer.py index 0f6adde..f613105 100644 --- a/tests/trainer/test_sft_trainer.py +++ b/tests/trainer/test_sft_trainer.py @@ -1,16 +1,13 @@ -import os -import subprocess -from pathlib import Path - import pytest import yaml +from arctic_training.config.trainer import get_config + @pytest.mark.gpu def test_sft_trainer(tmp_path): config_dict = { "type": "sft", - "code": str(Path(__file__).parent / "trainer_test_helpers.py"), "exit_iteration": 2, "micro_batch_size": 1, "model": { @@ -27,20 +24,16 @@ def test_sft_trainer(tmp_path): with open(config_path, "w") as f: f.write(yaml.dump(config_dict)) - result = subprocess.run( - f"arctic_training {config_path}", shell=True, text=True, capture_output=True - ) - - if result.returncode != 0: - print(result.stderr) - pytest.fail("Training failed") + config = get_config(config_path) + trainer = config.trainer + trainer.train() + assert trainer.global_step > 0, "Training did not run" @pytest.mark.cpu def test_sft_trainer_cpu(tmp_path): config_dict = { "type": "sft", - "code": str(Path(__file__).parent / "trainer_test_helpers.py"), "exit_iteration": 2, "micro_batch_size": 1, "model": { @@ -66,17 +59,7 @@ def test_sft_trainer_cpu(tmp_path): with open(config_path, "w") as f: f.write(yaml.dump(config_dict)) - env = os.environ.copy() - env["DS_ACCELERATOR"] = "cpu" - - result = subprocess.run( - f"arctic_training {config_path}", - shell=True, - text=True, - capture_output=True, - env=env, - ) - - if result.returncode != 0: - print(result.stderr) - pytest.fail("Training failed") + config = get_config(config_path) + trainer = config.trainer + trainer.train() + assert trainer.global_step > 0, "Training did not run"