Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions torchrecipes/paved_path/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,29 @@
pip install -r requirements.txt
```

2. Train a model
2. Model training
* Train a model
```bash
python charnn/main.py
```
You will get output like below. The snapshot path can be used for inference or restore training
```
0: epoch 0 iter 100: train loss 3.01963
0: epoch 0 iter 200: train loss 2.69831
0: epoch 0 iter 0: test loss 2.67962
0: epoch 0 iter 100: test loss 2.69960
0: epoch 0 iter 200: test loss 2.70585
...
[2022-08-30 20:07:33,842][trainer][INFO] - Saving snapshot to /tmp/charnn/run-bc6565c7/snapshots/epoch-1
```
* Restore from a snapshot and train with more epochs
```bash
python charnn/main.py trainer.max_epochs=3 trainer.snapshot_path=/tmp/charnn/run-1f7abaed/snapshots/epoch-1
```

3. Generate text from a model
```bash
python charnn/main.py charnn.task="generate" charnn.phrase="hello world"
python charnn/main.py charnn.task="generate" charnn.phrase="hello world" trainer.snapshot_path=/tmp/charnn/run-1f7abaed/snapshots/epoch-1
```

4. [Optional] train a model with torchx
Expand Down
45 changes: 26 additions & 19 deletions torchrecipes/paved_path/charnn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@
import random
import socket
import uuid
from typing import Optional, Tuple
from typing import Tuple

import hydra
import torch
import torch.distributed as dist
import torchsnapshot

from char_dataset import CharDataset, get_dataset
from model import GPT, GPTConfig, OptimizerConfig
from omegaconf import DictConfig
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import random_split
from trainer import Checkpoint, load_checkpoint, Trainer, TrainerConfig
from trainer import Trainer, TrainerConfig
from utils import get_realpath, sample

logger = logging.getLogger(__name__)
Expand All @@ -43,7 +44,7 @@ def set_env() -> None:

def get_job_name() -> str:
uid = os.environ["TORCHELASTIC_RUN_ID"]
return f"test-job-{uid}"
return f"run-{uid}"


def get_device() -> torch.device:
Expand All @@ -57,13 +58,10 @@ def get_device() -> torch.device:


def get_ddp_model_and_optimizer(
gpt_config: GPTConfig, opt_config: OptimizerConfig, checkpoint: Optional[Checkpoint]
gpt_config: GPTConfig, opt_config: OptimizerConfig
) -> Tuple[torch.nn.Module, torch.optim.Optimizer]:
# Create new GPT Model on CPU
model = GPT(gpt_config)
# Load GPT model from checkpoint if present
if checkpoint:
model.load_state_dict(checkpoint.model_state)
device = get_device()
device_ids = None
if device.type == "cuda":
Expand All @@ -83,11 +81,9 @@ def get_model_and_optimizer(
type: str,
gpt_config: GPTConfig,
opt_config: OptimizerConfig,
checkpoint: Optional[Checkpoint],
) -> Tuple[torch.nn.Module, torch.optim.Optimizer]:
if type == "ddp":
return get_ddp_model_and_optimizer(gpt_config, opt_config, checkpoint)

return get_ddp_model_and_optimizer(gpt_config, opt_config)
raise RuntimeError(f"Unknown type: {type}. Allowed values: [ddp]")


Expand Down Expand Up @@ -149,23 +145,34 @@ def main(cfg: DictConfig) -> None:
)

train_cfg = cfg["trainer"]
train_cfg["work_dir"] = os.path.join(train_cfg.get("work_dir", ""), job_name)
tconf = TrainerConfig(
work_dir=train_cfg["work_dir"],
job_name=job_name,
max_epochs=train_cfg["max_epochs"],
batch_size=train_cfg["batch_size"],
data_loader_workers=train_cfg["data_loader_workers"],
enable_profile=train_cfg["enable_profile"],
log_dir=train_cfg.get("log_dir"),
checkpoint_path=train_cfg.get("checkpoint_path"),
# TODO: @stevenliu remove log_dir. infer it from work_dir
log_dir=os.path.join(train_cfg["work_dir"], "logs"),
)

checkpoint = load_checkpoint(tconf.checkpoint_path)
opt_conf = OptimizerConfig(
lr=cfg["opt"]["lr"], weight_decay=cfg["opt"]["weight_decay"]
)
model, optimizer = get_model_and_optimizer(
cfg["charnn"]["dist"], mconf, opt_conf, checkpoint
)
model, optimizer = get_model_and_optimizer(cfg["charnn"]["dist"], mconf, opt_conf)

# app_state will be saved or restored for checkpointing
progress = torchsnapshot.StateDict(current_epoch=0)
app_state = {
"model": model,
"optimizer": optimizer,
"progress": progress,
}

if train_cfg.get("snapshot_path", None):
snapshot = torchsnapshot.Snapshot(train_cfg["snapshot_path"])
print(f"Restoring snapshot from path: {train_cfg['snapshot_path']}")
snapshot.restore(app_state=app_state)

if cfg["charnn"]["task"] == "train":
trainer = Trainer(
Expand All @@ -175,9 +182,9 @@ def main(cfg: DictConfig) -> None:
test_dataset,
tconf,
device,
checkpoint.finished_epoch + 1 if checkpoint else 0,
progress["current_epoch"],
)
trainer.fit(cfg.get("max_iter", -1))
trainer.fit(app_state, max_iter=cfg.get("max_iter", -1))
elif cfg["charnn"]["task"] == "generate":
generate_seq(cfg, model, train_dataset.dataset)
else:
Expand Down
69 changes: 20 additions & 49 deletions torchrecipes/paved_path/charnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,73 +8,36 @@

import logging
import os
from collections import OrderedDict
from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional
import uuid
from dataclasses import dataclass
from typing import Dict, Optional

import torch
import torch.distributed as dist
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from torchsnapshot import Snapshot, Stateful

logger = logging.getLogger(__name__)


@dataclass
class TrainerConfig:
work_dir: str
job_name: str
max_epochs: int = 10
batch_size: int = 64
checkpoint_path: Optional[str] = None
data_loader_workers: int = 0
enable_profile: bool = False
log_dir: Optional[str] = None


@dataclass
class Checkpoint:
model_state: "OrderedDict[str, torch.Tensor]"
optimizer_state: Dict[str, Any]
finished_epoch: int


def get_raw_model(model: torch.nn.Module) -> torch.nn.Module:
return model.module if hasattr(model, "module") else model


def save_checkpoint(
checkpoint_path: Optional[str],
model: torch.nn.Module,
optimizer: optim.Optimizer,
epoch: int,
) -> None:
if checkpoint_path and dist.get_rank() == 0:
model = get_raw_model(model)
checkpoint = Checkpoint(
finished_epoch=epoch,
model_state=model.state_dict(),
optimizer_state=optimizer.state_dict(),
)
torch.save(asdict(checkpoint), checkpoint_path)


def load_checkpoint(checkpoint_path: Optional[str]) -> Optional[Checkpoint]:
if checkpoint_path and os.path.exists(checkpoint_path):
# Load checkpoint on CPU. For big models the sequence is:
# 1. Create model
# 2. Load model state from checkpoint
# 3. Move model to GPU
# 4. Create optimizer
# 5. Load optimizer state from checkpoint
checkpoint_data = torch.load(checkpoint_path, map_location="cpu")
return Checkpoint(**checkpoint_data)
else:
return None


class Trainer:
def __init__(
self,
Expand Down Expand Up @@ -174,7 +137,7 @@ def run_epoch(self, epoch: int, max_iter: int = -1) -> None:
)
if it % 100 == 0:
print(
f"{self.rank}: epoch {epoch + 1} iter {it}: train loss {train_batch_loss:.5f}"
f"{self.rank}: epoch {epoch} iter {it}: train loss {train_batch_loss:.5f}"
)
if max_iter > 0 and it >= max_iter:
break
Expand All @@ -187,20 +150,28 @@ def run_epoch(self, epoch: int, max_iter: int = -1) -> None:
self.tb_writer.add_scalar(f"test_loss_{epoch}", test_batch_loss, it)
if it % 100 == 0:
print(
f"{self.rank}: epoch {epoch + 1} iter {it}: test loss {test_batch_loss:.5f}"
f"{self.rank}: epoch {epoch} iter {it}: test loss {test_batch_loss:.5f}"
)
if max_iter > 0 and it >= max_iter:
break
save_checkpoint(
self.config.checkpoint_path, self.model, self.optimizer, epoch
)

finally:
if prof:
prof.stop()
if self.tb_writer:
self.tb_writer.flush()

def fit(self, max_iter: int = -1) -> None:
for epoch in range(self.start_epoch, self.config.max_epochs):
def fit(self, app_state: Dict[str, Stateful], max_iter: int = -1) -> None:
progress = app_state["progress"]
for epoch in range(progress["current_epoch"], self.config.max_epochs):
self.run_epoch(epoch, max_iter)
progress["current_epoch"] += 1

# save a snapshot per epoch
snapshot = Snapshot.take(
path=os.path.join(
self.config.work_dir, f"snapshots/epoch-{progress['current_epoch']}"
),
app_state=app_state,
)
logger.info(f"Saving snapshot to {snapshot.path}")
5 changes: 3 additions & 2 deletions torchrecipes/paved_path/charnn/trainer_config.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
experiment_name: charnn-test

opt:
lr: 0.0006
weight_decay: 0.1
dataset:
path: data/input.txt
max_iter: 200
trainer:
work_dir: "/tmp/charnn"
max_epochs: 1
lr: 0.0006
batch_size: 128
data_loader_workers: 1
enable_profile: False
log_dir: "/tmp/charnn-test"
snapshot_path: "" # specify your snapshot path to restore training state
model:
n_layer: 2 # 8
n_head: 2 # 8
Expand Down
2 changes: 2 additions & 0 deletions torchrecipes/paved_path/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ torch
tensorboard
hydra-core
fsspec
--pre
torchsnapshot-nightly