Skip to content

Commit

Permalink
add some train configs (allenai#21)
Browse files Browse the repository at this point in the history
* add some default configs

* clean up

* shh, mypy

* add gantry example to README

* skip GPU tests when no beaker token

* revert

* update

* ensure lm head tied to wte

* rename

* force wider terminal
  • Loading branch information
epwalsh authored Mar 7, 2023
1 parent 35d3325 commit edf21d7
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 18 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ test-image :
show-test-image :
@echo $(TEST_IMAGE)

.PHONY : show-gantry-image
show-gantry-image :
@echo $(GANTRY_IMAGE)

.PHONY : show-beaker-workspace
show-beaker-workspace :
@echo $(BEAKER_WORKSPACE)
Expand Down
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,23 @@ After cloning this repository, first install the latest [PyTorch](https://pytorc
```
pip install -e .[dev] --config-settings editable_mode=compat
```

## Running experiments

### Using [beaker-gantry](https://github.com/allenai/beaker-gantry)

Train a model on c4 with gantry:

```bash
gantry run \
--workspace ai2/llm-testing \
--env-secret WANDB_API_KEY=WANDB_API_KEY \
--venv base \
--nfs \
--priority preemptible \
--gpus 8 \
--beaker-image dolma-gantry \
--cluster 'ai2/*-cirrascale' \
--allow-dirty \
-- composer scripts/train.py configs/1.2b-c4.yaml
```
86 changes: 86 additions & 0 deletions configs/1.2b-c4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
run_name: 1.2b-c4-run-001
seed: 6198
dry_run: false

model:
d_model: 2048
n_heads: 16
n_layers: 24
mlp_ratio: 4
alibi: true
alibi_bias_max: 8.0
attention_dropout: 0.1
attention_layer_norm: true
residual_dropout: 0.1
embedding_dropout: 0.1
max_sequence_length: 1024
vocab_size: 50257
eos_token_id: 50256
pad_token_id: 50256
init_device: meta
init_std: 0.02

optimizer:
learning_rate: 2.0e-4
weight_decay: 0.01
betas:
- 0.9
- 0.95
eps: 1.0e-08

scheduler:
name: cosine_with_warmup
t_warmup: 100ba
alpha_f: 0.1

algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1.0

data:
paths:
- /net/nfs.cirrascale/allennlp/llm-data/c4/en/c4-train.*.npy
pad_direction: right
num_workers: 2
drop_last: true
pin_memory: true
prefetch_factor: 2
persistent_workers: true
timeout: 0

tokenizer:
identifier: gpt2
truncate_direction: right

max_duration: 24800ba # ~ 26B tokens

save_folder: /results
save_interval: 1000ba
save_num_checkpoints_to_keep: 2
save_overwrite: false

load_path: null
load_weights_only: false

global_train_batch_size: 512
device_train_batch_size: auto
device_train_microbatch_size: auto
device_train_grad_accum: auto
device_eval_batch_size: null

n_gpus: null

precision: null

fsdp_config:
sharding_strategy: FULL_SHARD
mixed_precision: DEFAULT
activation_checkpointing: false
activation_cpu_offload: false
verbose: false

wandb:
project: petew-benchmarks
entity: ai2-llm
name: ${run_name}
51 changes: 38 additions & 13 deletions dolma/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast

Expand Down Expand Up @@ -38,6 +38,19 @@ def __repr__(self) -> str:


class BaseConfig:
@classmethod
def new(cls: Type[C], overrides: Optional[List[str]] = None) -> C:
from omegaconf import OmegaConf
from omegaconf.errors import ConfigKeyError

conf = OmegaConf.structured(cls)
if overrides:
try:
conf = OmegaConf.merge(conf, OmegaConf.from_dotlist(overrides))
except ConfigKeyError as e:
raise DolmaConfigurationError(str(e))
return cast(C, OmegaConf.to_object(conf))

@classmethod
def load(cls: Type[C], path: PathOrStr, overrides: Optional[List[str]] = None) -> C:
"""Load from a YAML file."""
Expand Down Expand Up @@ -185,7 +198,7 @@ class PaddingDirection(StrEnum):

@dataclass
class DataConfig(BaseConfig):
paths: List[str]
paths: List[str] = field(default_factory=lambda: [])
pad_direction: PaddingDirection = PaddingDirection.right
num_workers: int = 0
drop_last: bool = True
Expand All @@ -194,6 +207,17 @@ class DataConfig(BaseConfig):
persistent_workers: bool = True
timeout: int = 0

def __post_init__(self):
from glob import glob

final_paths = []
for path in self.paths:
matching_paths = glob(path, recursive=True)
if not matching_paths:
raise FileNotFoundError(f"'{path}' did not match any files or directories")
final_paths.extend(matching_paths)
self.paths = final_paths


class TruncationDirection(StrEnum):
right = "right"
Expand All @@ -202,7 +226,7 @@ class TruncationDirection(StrEnum):

@dataclass
class TokenizerConfig(BaseConfig):
identifier: str
identifier: str = "gpt2"
truncate_direction: TruncationDirection = TruncationDirection.right


Expand All @@ -224,29 +248,30 @@ class TrainConfig(BaseConfig):
DOLMA training configuration.
"""

model: ModelConfig
optimizer: OptimizerConfig
scheduler: SchedulerConfig
data: DataConfig
tokenizer: TokenizerConfig
save_folder: str
run_name: Optional[str] = None
seed: int = 6198
dry_run: bool = False
model: ModelConfig = field(default_factory=ModelConfig)
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
algorithms: Optional[Dict[str, Dict[str, Any]]] = None
data: DataConfig = field(default_factory=DataConfig)
tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
save_folder: str = "./"
save_interval: Union[str, int] = "1ep"
save_num_checkpoints_to_keep: int = -1
save_overwrite: bool = False
load_path: Optional[str] = None
load_weights_only: bool = False
seed: int = 6198
run_name: Optional[str] = None
max_duration: Union[str, int] = "10ep"
global_train_batch_size: int = 512
device_train_batch_size: Union[str, int] = "auto"
device_train_microbatch_size: Union[str, int] = "auto"
device_train_grad_accum: Union[str, int] = "auto"
device_eval_batch_size: Optional[int] = None
n_gpus: Optional[int] = None
max_duration: Union[str, int] = "10ep"
precision: Optional[str] = None
fsdp_config: Optional[Dict[str, Any]] = None
dry_run: bool = False
wandb: Optional[WandbConfig] = None

@property
Expand Down
15 changes: 13 additions & 2 deletions dolma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
self.transformer.update(
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False, device=config.init_device)
if init_params and self.config.init_device != "meta":
self.apply(self.param_init_fn)
self.__num_fwd_flops = None
Expand Down Expand Up @@ -275,7 +274,7 @@ def forward(

# Get logits.
# shape: (batch_size, seq_len, vocab_size)
logits = self.lm_head(x) # type: ignore
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore

return DolmaGPTOutput(logits=cast(torch.FloatTensor, logits))

Expand Down Expand Up @@ -378,6 +377,18 @@ def param_init_fn(self, module):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)

def num_params(self, include_embedding: bool = True) -> int:
"""
Get the total number of parameters.
"""
params = (np for np in self.named_parameters())
if not include_embedding:
params = filter( # type: ignore
lambda np: ".wte." not in np[0] and ".wpe." not in np[0],
params,
)
return sum(p.numel() for _, p in params)

@property
def num_fwd_flops(self):
if self.__num_fwd_flops:
Expand Down
6 changes: 4 additions & 2 deletions dolma/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,18 @@ def set_env_variables():


def prepare_cli_environment():
rich.reconfigure(width=max(rich.get_console().width, 180))
install_excepthook()
filter_warnings()
set_env_variables()


def clean_opt(arg: str) -> str:
arg = arg.strip("-").replace("-", "_")
if "=" not in arg:
arg = f"{arg}=True"
return arg
name, val = arg.split("=", 1)
name = name.strip("-").replace("-", "_")
return f"{name}={val}"


def calculate_batch_size_info(
Expand Down
29 changes: 29 additions & 0 deletions scripts/init_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Run this to initialize a new training config to a file.
"""

import sys
from pathlib import Path
from typing import List

from dolma import TrainConfig
from dolma.exceptions import DolmaCliError
from dolma.util import clean_opt, echo, prepare_cli_environment


def main(save_path: Path, args_list: List[str]) -> None:
cfg = TrainConfig.new(overrides=args_list)
echo.info("Configuration:", cfg)
cfg.save(save_path)
echo.success(f"Config saved to {save_path}")


if __name__ == "__main__":
prepare_cli_environment()

try:
save_path, args_list = sys.argv[1], sys.argv[2:]
except IndexError:
raise DolmaCliError(f"Usage: {sys.argv[0]} [SAVE_PATH] [OPTIONS]")

main(Path(save_path), [clean_opt(s) for s in args_list])
25 changes: 24 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import os
import sys
from typing import Any, Dict

from dolma import SchedulerConfig, TrainConfig
from dolma.data import build_dataloader
Expand All @@ -44,6 +45,23 @@ def build_scheduler(cfg: SchedulerConfig):
raise DolmaConfigurationError(f"Not sure how to build scheduler: {cfg.name}")


def build_algorithm(name: str, kwargs: Dict[str, Any]):
from composer import algorithms

if name == "gradient_clipping":
return algorithms.GradientClipping(**kwargs)
# elif name == 'alibi':
# return algorithms.Alibi(**kwargs)
elif name == "fused_layernorm":
return algorithms.FusedLayerNorm(**kwargs)
elif name == "gated_linear_units":
return algorithms.GatedLinearUnits(**kwargs)
elif name == "low_precision_layernorm":
return algorithms.LowPrecisionLayerNorm(**kwargs)
else:
raise ValueError(f"Not sure how to build algorithm: {name}")


def main(cfg: TrainConfig) -> None:
from composer import Trainer
from composer.loggers import WandBLogger
Expand All @@ -70,6 +88,8 @@ def main(cfg: TrainConfig) -> None:

# Model.
model = ComposerDolmaGPT(cfg.model)
echo.info(f"Total number of parameters: {model.model.num_params():,d}")
echo.info(f"Number of non-embedding parameters: {model.model.num_params(include_embedding=False):,d}")

# Optimizer.
optimizer = model.model.configure_optimizer(**cfg.optimizer.asdict())
Expand All @@ -80,6 +100,9 @@ def main(cfg: TrainConfig) -> None:
# Dataset / data loader.
train_loader = build_dataloader(cfg, cfg.device_train_batch_size)

# Algorithms
algorithms = [build_algorithm(name, algorithm_cfg) for name, algorithm_cfg in (cfg.algorithms or {}).items()]

# Trainer.
trainer = Trainer(
run_name=cfg.run_name,
Expand All @@ -104,7 +127,7 @@ def main(cfg: TrainConfig) -> None:
load_weights_only=cfg.load_weights_only,
callbacks=[SpeedMonitorMFU()],
loggers=[WandBLogger(**cfg.wandb.asdict())] if cfg.wandb is not None else [],
# algorithms=algorithms,
algorithms=algorithms,
)

if not cfg.dry_run:
Expand Down

0 comments on commit edf21d7

Please sign in to comment.