Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checkpoint unit test #35

Merged
merged 9 commits into from
Jan 30, 2025
11 changes: 7 additions & 4 deletions arctic_training/checkpoint/ds_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def checkpoint_tag(self) -> str:

@property
def client_state(self) -> Dict[str, Any]:
return {
state = {
"end_of_epoch": self.trainer.epoch_idx,
"torch_random_state": torch.get_rng_state(),
"torch_cuda_random_state": torch.cuda.get_rng_state(),
"np_random_state": np.random.get_state(),
"python_random_state": random.getstate(),
}
if self.device != torch.device("cpu"):
state["torch_cuda_random_state"] = torch.cuda.get_rng_state()
return state

def save(self, model) -> None:
model.save_checkpoint(
Expand All @@ -59,9 +61,10 @@ def load(self, model) -> None:
return
_, client_states = model.load_checkpoint(self.checkpoint_dir)

self.trainer.global_step = model.global_step
self.trainer.global_step = model.global_steps
self.trainer.epoch_idx = client_states["end_of_epoch"] + 1
torch.set_rng_state(client_states["torch_random_state"])
torch.cuda.set_rng_state(client_states["torch_cuda_random_state"])
np.random.set_state(client_states["np_random_state"])
random.setstate(client_states["python_random_state"])
if self.device != torch.device("cpu"):
torch.cuda.set_rng_state(client_states["torch_cuda_random_state"])
5 changes: 5 additions & 0 deletions arctic_training/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def main():


def run_script():
import deepspeed.comm as dist

from arctic_training.config.trainer import get_config

parser = argparse.ArgumentParser()
Expand All @@ -69,3 +71,6 @@ def run_script():
config = get_config(args.config)
trainer = config.trainer
trainer.train()
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
13 changes: 8 additions & 5 deletions arctic_training/config/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,14 @@ def validate_single_checkpoint_resume(self) -> Self:
return self


def get_config(config_file: Path) -> BaseConfig:
with open(config_file, "r") as f:
config_dict = yaml.safe_load(f)

config_dir = config_file.parent
def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig:
if isinstance(config_file_or_dict, dict):
config_dict = config_file_or_dict
config_dir = Path.cwd()
else:
with open(config_file_or_dict, "r") as f:
config_dict = yaml.safe_load(f)
config_dir = config_file_or_dict.parent

trainer_type = config_dict.get("type", TRAINER_DEFAULT)
config_dict["type"] = trainer_type
Expand Down
2 changes: 1 addition & 1 deletion arctic_training/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class DataFactory(ABC, CallbackMixin):
specify the DataFactory to use.
"""

config_type: Type[DataConfig]
config_type: Type[DataConfig] = DataConfig
"""
The type of the DataConfig object that this DataFactory uses. Any
DataFactory-specific options should be specified in this class.
Expand Down
10 changes: 6 additions & 4 deletions arctic_training/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def __init__(self, config: "TrainerConfig") -> None:
engine(self) for engine in self.config.checkpoint_engines
]

for engine in self.checkpoint_engines:
if engine.config.auto_resume:
engine.load(self.model)

def _set_seeds(self, seed: int) -> None:
logger.info(f"Setting random seeds to {seed}")
torch.manual_seed(seed)
Expand Down Expand Up @@ -241,10 +245,8 @@ def step(self, batch: Dict[str, torch.Tensor]) -> None:
self.config.exit_iteration > 0
and self.config.exit_iteration == self.global_step
):
logger.info(f"Hit exit iteration of {self.global_step}, forcing exit")
torch.distributed.barrier()
torch.distributed.destroy_process_group()
exit()
self.early_stop = True
logger.info(f"Hit exit iteration of {self.global_step}, ending training")

@callback_wrapper("epoch")
def epoch(self) -> None:
Expand Down
52 changes: 52 additions & 0 deletions tests/checkpoint/test_ds_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
from utils import models_are_equal

from arctic_training.config.trainer import get_config


@pytest.mark.cpu
def test_ds_engine(tmp_path):
config_dict = {
"type": "sft",
"exit_iteration": 2,
"model": {
"type": "random-weight-hf",
"name_or_path": "HuggingFaceTB/SmolLM-135M-Instruct",
"attn_implementation": "eager",
"dtype": "float32",
},
"data": {
"type": "noop",
"sources": [],
},
"optimizer": {
"type": "cpu-adam",
},
"scheduler": {
"type": "noop",
},
"checkpoint": {
"type": "deepspeed",
"auto_resume": True,
"output_dir": str(tmp_path / "checkpoints"),
"save_end_of_training": True,
},
}

config = get_config(config_dict)
trainer = config.trainer

# Force checkpoint to be saved despite no training happening
trainer.training_finished = True
trainer.checkpoint()

# Store original model for comparison later
original_model = trainer.model

config_dict["seed"] = 0 # Make sure newly initialized model is different
config = get_config(config_dict)
trainer = config.trainer

loaded_model = trainer.model
assert models_are_equal(original_model, loaded_model), "Models are not equal"
# TODO: Add assertion on optimizer state
51 changes: 51 additions & 0 deletions tests/checkpoint/test_hf_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
from utils import models_are_equal

from arctic_training.config.trainer import get_config


@pytest.mark.cpu
def test_hf_engine(tmp_path):
config_dict = {
"type": "sft",
"model": {
"type": "random-weight-hf",
"name_or_path": "HuggingFaceTB/SmolLM-135M-Instruct",
"attn_implementation": "eager",
"dtype": "float32",
},
"data": {
"type": "noop",
"sources": [],
},
"optimizer": {
"type": "cpu-adam",
},
"scheduler": {
"type": "noop",
},
"checkpoint": {
"type": "huggingface",
"output_dir": str(tmp_path / "checkpoints"),
"save_end_of_training": True,
},
}

config = get_config(config_dict)
trainer = config.trainer

# Force checkpoint to be saved despite no training happening
trainer.training_finished = True
trainer.checkpoint()

# Store original model for comparison later
original_model = trainer.model

config_dict["model"]["name_or_path"] = str(
trainer.checkpoint_engines[0].checkpoint_dir
)
config = get_config(config_dict)
trainer = config.trainer

loaded_model = trainer.model
assert models_are_equal(original_model, loaded_model), "Models are not equal"
9 changes: 9 additions & 0 deletions tests/checkpoint/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch


def models_are_equal(model_a: torch.nn.Module, model_b: torch.nn.Module) -> bool:
for param_a, param_b in zip(model_a.parameters(), model_b.parameters()):
if not param_a.data.eq(param_b.data).all():
return False

return True
65 changes: 65 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest


def pytest_configure(config):
# TODO: Make it so that cpu and gpu tests can be run with a single command.
# This requires some work with tearing down/setting up dist environments
# that have not been worked out yet.
if not config.option.markexpr:
config.option.markexpr = "cpu"
if "gpu" in config.option.markexpr and "cpu" in config.option.markexpr:
pytest.fail("Cannot run tests with both 'gpu' and 'cpu' marks")
if "cpu" in config.option.markexpr:
_setup_cpu_dist()
if "gpu" in config.option.markexpr:
_setup_gpu_dist()


def _setup_cpu_dist():
os.environ["DS_ACCELERATOR"] = "cpu"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_SIZE"] = "1"

from deepspeed.comm import init_distributed

init_distributed(auto_mpi_discovery=False)


def _setup_gpu_dist():
os.environ["DS_ACCELERATOR"] = "cuda"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_SIZE"] = "1"

from deepspeed.comm import init_distributed

init_distributed(auto_mpi_discovery=False)


# Eventually when we can run cpu + gpu tests, we will want to order the tests to
# avoid thrashing back and forth between the distirbuted environments.
"""
@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(config, items):
t scpu_tests = [item for item in items if "cpu" in item.keywords]
gpu_tests = [item for item in items if "gpu" in item.keywords]

# Reorder tests: all 'cpu' tests first, then 'gpu' tests
items[:] = cpu_tests + gpu_tests
"""


# Load helper functions automatically for all tests
@pytest.fixture(scope="session", autouse=True)
def helpers_code_path() -> None:
from . import test_helpers # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from transformers import PreTrainedModel

from arctic_training import register
from arctic_training.data.factory import DataFactory
from arctic_training.data.source import DataSource
from arctic_training.model.hf_factory import HFModelFactory
from arctic_training.optimizer import FusedAdamOptimizerFactory
from arctic_training.optimizer.adam_factory import FusedAdamOptimizerFactory
from arctic_training.optimizer.factory import OptimizerFactory
from arctic_training.scheduler.factory import SchedulerFactory


@register
Expand Down Expand Up @@ -52,3 +55,33 @@ def create_optimizer(self, model, optimizer_config):
lr=optimizer_config.learning_rate,
betas=optimizer_config.betas,
)


@register
class NoOpOptimizerFactory(OptimizerFactory):
name = "noop"

def create_optimizer(self, model, optimizer_config):
return None


@register
class NoOpDataFactory(DataFactory):
name = "noop"

def __call__(self):
return None, None

def tokenize_fn(self):
pass

def collate_fn(self):
pass


@register
class NoOpSchedulerFactory(SchedulerFactory):
name = "noop"

def create_scheduler(self, optimizer):
return None
37 changes: 10 additions & 27 deletions tests/trainer/test_sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import os
import subprocess
from pathlib import Path

import pytest
import yaml

from arctic_training.config.trainer import get_config


@pytest.mark.gpu
def test_sft_trainer(tmp_path):
config_dict = {
"type": "sft",
"code": str(Path(__file__).parent / "trainer_test_helpers.py"),
"exit_iteration": 2,
"micro_batch_size": 1,
"model": {
Expand All @@ -27,20 +24,16 @@ def test_sft_trainer(tmp_path):
with open(config_path, "w") as f:
f.write(yaml.dump(config_dict))

result = subprocess.run(
f"arctic_training {config_path}", shell=True, text=True, capture_output=True
)

if result.returncode != 0:
print(result.stderr)
pytest.fail("Training failed")
config = get_config(config_path)
trainer = config.trainer
trainer.train()
assert trainer.global_step > 0, "Training did not run"


@pytest.mark.cpu
def test_sft_trainer_cpu(tmp_path):
config_dict = {
"type": "sft",
"code": str(Path(__file__).parent / "trainer_test_helpers.py"),
"exit_iteration": 2,
"micro_batch_size": 1,
"model": {
Expand All @@ -66,17 +59,7 @@ def test_sft_trainer_cpu(tmp_path):
with open(config_path, "w") as f:
f.write(yaml.dump(config_dict))

env = os.environ.copy()
env["DS_ACCELERATOR"] = "cpu"

result = subprocess.run(
f"arctic_training {config_path}",
shell=True,
text=True,
capture_output=True,
env=env,
)

if result.returncode != 0:
print(result.stderr)
pytest.fail("Training failed")
config = get_config(config_path)
trainer = config.trainer
trainer.train()
assert trainer.global_step > 0, "Training did not run"
Loading