diff --git a/torchrecipes/paved_path/README.md b/torchrecipes/paved_path/README.md index 34a04b9..b09da37 100644 --- a/torchrecipes/paved_path/README.md +++ b/torchrecipes/paved_path/README.md @@ -8,14 +8,29 @@ pip install -r requirements.txt ``` -2. Train a model +2. Model training +* Train a model ```bash python charnn/main.py ``` +You will get output like below. The snapshot path can be used for inference or restore training +``` +0: epoch 0 iter 100: train loss 3.01963 +0: epoch 0 iter 200: train loss 2.69831 +0: epoch 0 iter 0: test loss 2.67962 +0: epoch 0 iter 100: test loss 2.69960 +0: epoch 0 iter 200: test loss 2.70585 +... +[2022-08-30 20:07:33,842][trainer][INFO] - Saving snapshot to /tmp/charnn/run-bc6565c7/snapshots/epoch-1 +``` +* Restore from a snapshot and train with more epochs +```bash +python charnn/main.py trainer.max_epochs=3 trainer.snapshot_path=/tmp/charnn/run-1f7abaed/snapshots/epoch-1 +``` 3. Generate text from a model ```bash -python charnn/main.py charnn.task="generate" charnn.phrase="hello world" +python charnn/main.py charnn.task="generate" charnn.phrase="hello world" trainer.snapshot_path=/tmp/charnn/run-1f7abaed/snapshots/epoch-1 ``` 4. [Optional] train a model with torchx diff --git a/torchrecipes/paved_path/charnn/main.py b/torchrecipes/paved_path/charnn/main.py index 051aff5..0652c68 100644 --- a/torchrecipes/paved_path/charnn/main.py +++ b/torchrecipes/paved_path/charnn/main.py @@ -9,18 +9,19 @@ import random import socket import uuid -from typing import Optional, Tuple +from typing import Tuple import hydra import torch import torch.distributed as dist +import torchsnapshot from char_dataset import CharDataset, get_dataset from model import GPT, GPTConfig, OptimizerConfig from omegaconf import DictConfig from torch.nn.parallel import DistributedDataParallel from torch.utils.data import random_split -from trainer import Checkpoint, load_checkpoint, Trainer, TrainerConfig +from trainer import Trainer, TrainerConfig from utils import get_realpath, sample logger = logging.getLogger(__name__) @@ -43,7 +44,7 @@ def set_env() -> None: def get_job_name() -> str: uid = os.environ["TORCHELASTIC_RUN_ID"] - return f"test-job-{uid}" + return f"run-{uid}" def get_device() -> torch.device: @@ -57,13 +58,10 @@ def get_device() -> torch.device: def get_ddp_model_and_optimizer( - gpt_config: GPTConfig, opt_config: OptimizerConfig, checkpoint: Optional[Checkpoint] + gpt_config: GPTConfig, opt_config: OptimizerConfig ) -> Tuple[torch.nn.Module, torch.optim.Optimizer]: # Create new GPT Model on CPU model = GPT(gpt_config) - # Load GPT model from checkpoint if present - if checkpoint: - model.load_state_dict(checkpoint.model_state) device = get_device() device_ids = None if device.type == "cuda": @@ -83,11 +81,9 @@ def get_model_and_optimizer( type: str, gpt_config: GPTConfig, opt_config: OptimizerConfig, - checkpoint: Optional[Checkpoint], ) -> Tuple[torch.nn.Module, torch.optim.Optimizer]: if type == "ddp": - return get_ddp_model_and_optimizer(gpt_config, opt_config, checkpoint) - + return get_ddp_model_and_optimizer(gpt_config, opt_config) raise RuntimeError(f"Unknown type: {type}. Allowed values: [ddp]") @@ -149,23 +145,34 @@ def main(cfg: DictConfig) -> None: ) train_cfg = cfg["trainer"] + train_cfg["work_dir"] = os.path.join(train_cfg.get("work_dir", ""), job_name) tconf = TrainerConfig( + work_dir=train_cfg["work_dir"], job_name=job_name, max_epochs=train_cfg["max_epochs"], batch_size=train_cfg["batch_size"], data_loader_workers=train_cfg["data_loader_workers"], enable_profile=train_cfg["enable_profile"], - log_dir=train_cfg.get("log_dir"), - checkpoint_path=train_cfg.get("checkpoint_path"), + # TODO: @stevenliu remove log_dir. infer it from work_dir + log_dir=os.path.join(train_cfg["work_dir"], "logs"), ) - - checkpoint = load_checkpoint(tconf.checkpoint_path) opt_conf = OptimizerConfig( lr=cfg["opt"]["lr"], weight_decay=cfg["opt"]["weight_decay"] ) - model, optimizer = get_model_and_optimizer( - cfg["charnn"]["dist"], mconf, opt_conf, checkpoint - ) + model, optimizer = get_model_and_optimizer(cfg["charnn"]["dist"], mconf, opt_conf) + + # app_state will be saved or restored for checkpointing + progress = torchsnapshot.StateDict(current_epoch=0) + app_state = { + "model": model, + "optimizer": optimizer, + "progress": progress, + } + + if train_cfg.get("snapshot_path", None): + snapshot = torchsnapshot.Snapshot(train_cfg["snapshot_path"]) + print(f"Restoring snapshot from path: {train_cfg['snapshot_path']}") + snapshot.restore(app_state=app_state) if cfg["charnn"]["task"] == "train": trainer = Trainer( @@ -175,9 +182,9 @@ def main(cfg: DictConfig) -> None: test_dataset, tconf, device, - checkpoint.finished_epoch + 1 if checkpoint else 0, + progress["current_epoch"], ) - trainer.fit(cfg.get("max_iter", -1)) + trainer.fit(app_state, max_iter=cfg.get("max_iter", -1)) elif cfg["charnn"]["task"] == "generate": generate_seq(cfg, model, train_dataset.dataset) else: diff --git a/torchrecipes/paved_path/charnn/trainer.py b/torchrecipes/paved_path/charnn/trainer.py index 3593d82..11861f9 100644 --- a/torchrecipes/paved_path/charnn/trainer.py +++ b/torchrecipes/paved_path/charnn/trainer.py @@ -8,73 +8,36 @@ import logging import os -from collections import OrderedDict -from dataclasses import asdict, dataclass -from typing import Any, Dict, Optional +import uuid +from dataclasses import dataclass +from typing import Dict, Optional import torch -import torch.distributed as dist import torch.optim as optim from torch.utils.data import Dataset from torch.utils.data.dataloader import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.utils.tensorboard import SummaryWriter +from torchsnapshot import Snapshot, Stateful logger = logging.getLogger(__name__) @dataclass class TrainerConfig: + work_dir: str job_name: str max_epochs: int = 10 batch_size: int = 64 - checkpoint_path: Optional[str] = None data_loader_workers: int = 0 enable_profile: bool = False log_dir: Optional[str] = None -@dataclass -class Checkpoint: - model_state: "OrderedDict[str, torch.Tensor]" - optimizer_state: Dict[str, Any] - finished_epoch: int - - def get_raw_model(model: torch.nn.Module) -> torch.nn.Module: return model.module if hasattr(model, "module") else model -def save_checkpoint( - checkpoint_path: Optional[str], - model: torch.nn.Module, - optimizer: optim.Optimizer, - epoch: int, -) -> None: - if checkpoint_path and dist.get_rank() == 0: - model = get_raw_model(model) - checkpoint = Checkpoint( - finished_epoch=epoch, - model_state=model.state_dict(), - optimizer_state=optimizer.state_dict(), - ) - torch.save(asdict(checkpoint), checkpoint_path) - - -def load_checkpoint(checkpoint_path: Optional[str]) -> Optional[Checkpoint]: - if checkpoint_path and os.path.exists(checkpoint_path): - # Load checkpoint on CPU. For big models the sequence is: - # 1. Create model - # 2. Load model state from checkpoint - # 3. Move model to GPU - # 4. Create optimizer - # 5. Load optimizer state from checkpoint - checkpoint_data = torch.load(checkpoint_path, map_location="cpu") - return Checkpoint(**checkpoint_data) - else: - return None - - class Trainer: def __init__( self, @@ -174,7 +137,7 @@ def run_epoch(self, epoch: int, max_iter: int = -1) -> None: ) if it % 100 == 0: print( - f"{self.rank}: epoch {epoch + 1} iter {it}: train loss {train_batch_loss:.5f}" + f"{self.rank}: epoch {epoch} iter {it}: train loss {train_batch_loss:.5f}" ) if max_iter > 0 and it >= max_iter: break @@ -187,13 +150,10 @@ def run_epoch(self, epoch: int, max_iter: int = -1) -> None: self.tb_writer.add_scalar(f"test_loss_{epoch}", test_batch_loss, it) if it % 100 == 0: print( - f"{self.rank}: epoch {epoch + 1} iter {it}: test loss {test_batch_loss:.5f}" + f"{self.rank}: epoch {epoch} iter {it}: test loss {test_batch_loss:.5f}" ) if max_iter > 0 and it >= max_iter: break - save_checkpoint( - self.config.checkpoint_path, self.model, self.optimizer, epoch - ) finally: if prof: @@ -201,6 +161,17 @@ def run_epoch(self, epoch: int, max_iter: int = -1) -> None: if self.tb_writer: self.tb_writer.flush() - def fit(self, max_iter: int = -1) -> None: - for epoch in range(self.start_epoch, self.config.max_epochs): + def fit(self, app_state: Dict[str, Stateful], max_iter: int = -1) -> None: + progress = app_state["progress"] + for epoch in range(progress["current_epoch"], self.config.max_epochs): self.run_epoch(epoch, max_iter) + progress["current_epoch"] += 1 + + # save a snapshot per epoch + snapshot = Snapshot.take( + path=os.path.join( + self.config.work_dir, f"snapshots/epoch-{progress['current_epoch']}" + ), + app_state=app_state, + ) + logger.info(f"Saving snapshot to {snapshot.path}") diff --git a/torchrecipes/paved_path/charnn/trainer_config.yaml b/torchrecipes/paved_path/charnn/trainer_config.yaml index 649edd8..a27779a 100644 --- a/torchrecipes/paved_path/charnn/trainer_config.yaml +++ b/torchrecipes/paved_path/charnn/trainer_config.yaml @@ -1,4 +1,4 @@ -experiment_name: charnn-test + opt: lr: 0.0006 weight_decay: 0.1 @@ -6,12 +6,13 @@ dataset: path: data/input.txt max_iter: 200 trainer: + work_dir: "/tmp/charnn" max_epochs: 1 lr: 0.0006 batch_size: 128 data_loader_workers: 1 enable_profile: False - log_dir: "/tmp/charnn-test" + snapshot_path: "" # specify your snapshot path to restore training state model: n_layer: 2 # 8 n_head: 2 # 8 diff --git a/torchrecipes/paved_path/requirements.txt b/torchrecipes/paved_path/requirements.txt index e5ef67c..3657918 100644 --- a/torchrecipes/paved_path/requirements.txt +++ b/torchrecipes/paved_path/requirements.txt @@ -2,3 +2,5 @@ torch tensorboard hydra-core fsspec +--pre +torchsnapshot-nightly \ No newline at end of file