Skip to content

Commit

Permalink
Merge pull request #25 from YerevaNN/model_loading
Browse files Browse the repository at this point in the history
Model loading
  • Loading branch information
tigranfah authored Sep 24, 2024
2 parents 379be76 + 21970dc commit 7eb6c33
Show file tree
Hide file tree
Showing 17 changed files with 251 additions and 64 deletions.
13 changes: 8 additions & 5 deletions submitit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@

if __name__ == "__main__":
executor = submitit.AutoExecutor(folder="~/slurm_jobs/titan/job_%j")
n_gpus = 4
n_gpus = 8
executor.update_parameters(
name="titan", timeout_min=15,
name="titan", timeout_min=3 * 24 * 60,
gpus_per_node=n_gpus,
nodes=1, mem_gb=40, cpus_per_task=n_gpus * 2
nodes=1, mem_gb=80, cpus_per_task=n_gpus * 4,
slurm_additional_parameters={
"partition": "h100"
}
)

jobs = []
with executor.batch():
for _ in range(1):
# train_config = './train_configs/chemlactica_125m.toml'
train_config = './train_configs/chemlactica_125m.toml'
# train_config = './train_configs/chemlactica_1.3b.toml'
train_config = './train_configs/llama3_8b.toml'
# train_config = './train_configs/llama3_8b.toml'
# train_config = './train_configs/debug_model.toml'
function = submitit.helpers.CommandFunction([
'python3', '-m', 'torch.distributed.run',
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
lr_schedulers: List[torch.optim.lr_scheduler.LRScheduler],
states: Dict[str, Any],
job_config: JobConfig,
experiment_hash: str,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
Expand Down Expand Up @@ -235,7 +236,7 @@ def __init__(
for idx, lr_scheduler in enumerate(lr_schedulers):
self.states[f"lr_scheduler_{idx}"] = lr_scheduler

self.save_folder = os.path.join(job_config.job.dump_folder, ckpt_config.save_folder)
self.save_folder = os.path.join(job_config.job.dump_folder, os.path.join(ckpt_config.save_folder, experiment_hash))
self.load_folder = os.path.join(job_config.job.dump_folder, ckpt_config.load_folder)
self.interval_type = (
IntervalType.SECONDS
Expand Down
22 changes: 15 additions & 7 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,20 @@ def __init__(self):
help="Whether to enable checkpoint",
)
self.parser.add_argument(
"--checkpoint.folder",
"--checkpoint.load_folder",
type=str,
default="",
help="""
The folder to load the checkpoints.
When enable_checkpoint is set to true, checkpoints will loaded from {--job.dump_folder}/{--checkpoint.load_folder}.
""",
)
self.parser.add_argument(
"--checkpoint.save_folder",
type=str,
default="checkpoint",
help="""
The folder to store the checkpoints.
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
When enable_checkpoint is set to true, checkpoints will saved to {--job.dump_folder}/{--checkpoint.save_folder}.
""",
)
self.parser.add_argument(
Expand Down Expand Up @@ -643,13 +651,13 @@ def parse_args(self, args_list: list = sys.argv[1:]):
args, cmd_args = self.parse_args_from_command_line(args_list)
config_file = getattr(args, "job.config_file", None)
# build up a two level dict
args_dict = self._args_to_two_level_dict(args)
self.args_dict = self._args_to_two_level_dict(args)
if config_file is not None:
try:
with open(config_file, "rb") as f:
for k, v in tomllib.load(f).items():
# to prevent overwrite of non-specified keys
args_dict[k] |= v
self.args_dict[k] |= v
except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
logger.exception(
f"Error while loading the configuration file: {config_file}"
Expand All @@ -661,9 +669,9 @@ def parse_args(self, args_list: list = sys.argv[1:]):
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
for section, section_args in cmd_args_dict.items():
for k, v in section_args.items():
args_dict[section][k] = v
self.args_dict[section][k] = v

for k, v in args_dict.items():
for k, v in self.args_dict.items():
class_type = type(k.title(), (), v)
setattr(self, k, class_type())
self._validate_config()
Expand Down
18 changes: 12 additions & 6 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import pickle
from typing import Any, Dict, List, Optional
from pathlib import Path
import glob
import os

import numpy as np

Expand Down Expand Up @@ -33,7 +36,8 @@
_supported_datasets = {
"c4_test": "test/assets/c4_test",
"c4": "allenai/c4",
"chemlactica_train_mini": "test/assets/chemlactica_train_mini"
"chemlactica_train_mini": "test/assets/chemlactica_train_mini",
"chemlactica_train": "/nfs/dgx/raid/chem/data/rdkit_computed_rel+form/train_rdkit_computed_rel+form"
}

_supported_data_processing_styles = {
Expand Down Expand Up @@ -111,13 +115,16 @@ def __init__(
# c4 is huge, and requires both streaming and language selection
# (we default to en)
ds = load_dataset(dataset_path, name="en", split="train", streaming=True)
else:
elif dataset_name == "c4_test":
ds = load_dataset(dataset_path, split="train")

else:
dataset_files = glob.glob(os.path.join(dataset_path, "*.jsonl"))
ds = load_dataset("text", data_files=dataset_files, split="train", streaming=True)
try:
data_processing_fn = _supported_data_processing_styles[data_processing_style]
except KeyError as e:
raise ValueError(f"Unsupported data processing style: {data_processing_style}")
# data_processing_fn = lambda x, e: str(x)

# TODO: support shuffling and checkpointing
self.dataset_name = dataset_name
Expand Down Expand Up @@ -217,9 +224,8 @@ class DPAwareDataLoader(StatefulDataLoader, Stateful):
"""
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
"""

def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int, pin_memory: bool, num_workers: int):
super().__init__(hf_ds, batch_size)
super().__init__(hf_ds, batch_size, num_workers=num_workers)
self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"

Expand Down Expand Up @@ -251,7 +257,7 @@ def build_hf_data_loader(
rank,
infinite: bool = True,
pin_memory: bool = False,
num_workers: int = 0,
num_workers: int = 2,
special_mode = None,
context = "train",
):
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def init_logger(log_level):
# suppress verbose torch.profiler logging
os.environ["KINETO_LOG_LEVEL"] = "5"

# enable dataloading logging for logging the type of dataloading used
enable_dataloader_logging(log_level)


class LogLevel(Enum):
DEBUG = "DEBUG"
Expand All @@ -46,3 +49,6 @@ def from_string(cls, value: str):
def validate_log_level(value):
return LogLevel.from_string(value)


def enable_dataloader_logging(log_level):
logging.getLogger('datasets.iterable_dataset').setLevel(log_level)
8 changes: 7 additions & 1 deletion torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ def log_hparams(self, config):
if self.writer is not None:
self.writer.experiment['hparams'] = config

@property
def experiment_hash(self):
if self.writer is None:
return "default"
return self.writer._run.hash

def build_metric_logger(
job_config: JobConfig, parallel_dims: ParallelDims
):
Expand All @@ -127,7 +133,7 @@ def build_metric_logger(
"""
dump_dir = job_config.job.dump_folder
aim_config = job_config.metrics
save_aim_folder = aim_config.save_aim_folder
save_aim_folder = os.path.join(job_config.job.dump_folder, aim_config.save_aim_folder)
# since we don't have run id, use current minute as the identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, datetime_str)
Expand Down
5 changes: 2 additions & 3 deletions torchtitan/models/opt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,17 @@ class ModelArgs:
n_heads: int = 12
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
multiple_of: int = 256
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
rope_theta: float = 10000
dropout_p: float = 0.1

max_batch_size: int = 32
max_seq_len: int = 2048
# If `True`, then each transformer block init uses its layer ID, and if
# `False`, each uses the total number of transformer blocks
depth_init: bool = True
norm_type: str = "layersnorm"
norm_type: str = "layernorm_bias"


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/opt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def export_opt_weights(model: OPT, save_dir: str, token_embedding_size: int):
"""
write docs
"""
hf_model = OPTForCausalLM.from_pretrained(map_n_layers_to_model_name(model.n_layers))
hf_model = OPTForCausalLM.from_pretrained(map_n_layers_to_model_name(model.n_layers), tie_word_embeddings=False)
hf_model.resize_token_embeddings(new_num_tokens=token_embedding_size)
keys_mapping = get_hf_opt_state_dict_keys_mapping(model.n_layers)
state_dict = model.state_dict()
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/tokenizers/tokenizer/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
# copied and adjusted from https://github.com/facebookresearch/llama/blob/main/llama/tokenizer.py

from typing import List
import os

from torchtitan.logging import logger
from transformers import AutoTokenizer

os.environ["TOKENIZER_PARALLELISM"] = "true"


class CustomTokenizer:
"""
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def load_jsonl_line(jsonl_line):

def chemlactica_style_data_processing(sample_json, rng):
try:
sample_json = json.loads(sample_json["text"])
compound = delete_empty_tags(sample_json)
sample_json = generate_formatted_string(
compound, rng
)
except Exception as e:
print(e)
sample_json = {}
sample_json = ""
return sample_json


Expand Down
20 changes: 11 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import contextlib
import os
import time
import logging
from datetime import timedelta

import torch
Expand Down Expand Up @@ -186,6 +187,9 @@ def loss_fn(pred, labels):

train_state = TrainState()

metric_logger = build_metric_logger(job_config, parallel_dims)
metric_logger.log_hparams(job_config.args_dict)

# load initial checkpoint
checkpoint = CheckpointManager(
dataloader=data_loader,
Expand All @@ -194,6 +198,7 @@ def loss_fn(pred, labels):
lr_schedulers=lr_schedulers.schedulers,
states={"train_state": train_state},
job_config=job_config,
experiment_hash=metric_logger.experiment_hash
)

if job_config.model_download_export.to_titan:
Expand All @@ -218,11 +223,6 @@ def loss_fn(pred, labels):
logger.info("Created huggingface checkpoint")
return

metric_logger = build_metric_logger(job_config, parallel_dims)
args, cmd_args = job_config.parse_args_from_command_line(job_config.args_list)
job_config_dict = job_config._args_to_two_level_dict(args)
metric_logger.log_hparams(job_config_dict)

data_iterator = iter(data_loader)

train_context = get_train_context(
Expand Down Expand Up @@ -284,12 +284,14 @@ def loss_fn(pred, labels):
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

for m in model_parts:
torch.nn.utils.clip_grad_norm_(
m.parameters(), job_config.training.max_norm, foreach=True
)

if force_finish_train:
break
for m in model_parts:
torch.nn.utils.clip_grad_norm_(
m.parameters(), job_config.training.max_norm, foreach=True
)

# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model_parts)
Expand Down
31 changes: 17 additions & 14 deletions train_configs/chemlactica_1.3b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,49 +15,52 @@ save_memory_snapshot_folder = "memory_snapshot"
[metrics]
log_freq = 1
enable_color_printing = true
enable_tensorboard = true
save_tb_folder = "tb"
enable_aim = true
save_aim_folder = "aim"

[model]
name = "opt"
flavor = "1.3B"
# norm_type = "layernorm_bias" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
norm_type = "rmsnorm"
norm_type = "layernorm_bias" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
# test tokenizer.model, for debug purpose only
tokenizer_path = "./test/assets/test_tiktoken.model"
# tokenizer_path = "./test/assets/test_tiktoken.model"
tokenizer_path = "./torchtitan/tokenizers/chemlactica-125m"

[optimizer]
name = "AdamW"
lr = 8e-4
lr = 1.0e-4

[training]
batch_size = 10
batch_size = 13
gradient_accumulation_steps = 9
seq_len = 2048
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 50
steps = 18000
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
compile = true
# dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
# dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K)
dataset = "chemlactica_train"
data_process_style="chemlactica_style"

[experimental]
pipeline_parallel_degree = 1
enable_async_tensor_parallel = false

[checkpoint]
enable_checkpoint = true
create_seed_checkpoint = false
load_folder = "facebook/galactica-1.3b"
save_folder = "yerevann/chemlactica-1.3b"
interval_type = "steps"
interval = 100
interval = 2000
model_weights_only = false
export_dtype = "float32"
async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective' # ['none', 'selective', 'full']
mode = 'none' # ['none', 'selective', 'full']
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
Expand Down
Loading

0 comments on commit 7eb6c33

Please sign in to comment.