From 319fe5b64703af415cd14966e733de3611248fbc Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 17 Jan 2024 11:14:55 -0800 Subject: [PATCH 1/9] allow specifying LR schedule in terms of tokens --- olmo/config.py | 6 +++++ olmo/train.py | 67 +++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index b4b0576f9..e530754b4 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -480,9 +480,15 @@ class SchedulerType(StrEnum): constant = "constant" +class SchedulerUnits(StrEnum): + steps = "steps" + tokens = "tokens" + + @dataclass class SchedulerConfig(BaseConfig): name: SchedulerType = SchedulerType.cosine_with_warmup + units: SchedulerUnits = SchedulerUnits.steps t_warmup: int = 100 t_max: Optional[int] = None alpha_f: float = 0.1 diff --git a/olmo/train.py b/olmo/train.py index 2207b0552..42250f919 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -26,6 +26,7 @@ from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer from .config import ( CheckpointType, + SchedulerUnits, ShardedCheckpointerType, SpeedMonitorConfig, TrainConfig, @@ -122,6 +123,14 @@ def dataset(self) -> IterableDataset: assert isinstance(self.train_loader.dataset, IterableDataset) return self.train_loader.dataset + @property + def tokens_per_batch(self) -> int: + return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length + + @property + def batches_per_epoch(self) -> int: + return self.dataset.total_size // self.cfg.global_train_batch_size + @property def max_epochs(self) -> int: if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"): @@ -138,20 +147,58 @@ def max_steps(self) -> int: # convert to float *first* to handle scientific notation max_tokens = int(float(self.cfg.max_duration[:-1].strip())) tokens_remaining = max_tokens - self.global_train_tokens_seen - tokens_per_batch = self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length - steps_remaining = tokens_remaining // tokens_per_batch + steps_remaining = tokens_remaining // self.tokens_per_batch return self.global_step + steps_remaining elif self.cfg.max_duration.endswith("ep"): max_epochs = int(self.cfg.max_duration[:-2].strip()) - examples_per_epoch = self.dataset.total_size - steps_per_epoch = examples_per_epoch // self.cfg.global_train_batch_size - return max_epochs * steps_per_epoch + return max_epochs * self.batches_per_epoch else: # convert to float *first* to handle scientific notation return int(float(self.cfg.max_duration)) else: raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}") + @property + def max_tokens(self) -> int: + if isinstance(self.cfg.max_duration, int): + return ( + self.global_train_tokens_seen + + min(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch + ) + elif isinstance(self.cfg.max_duration, str): + if self.cfg.max_duration.endswith("T"): + # convert to float *first* to handle scientific notation + return int(float(self.cfg.max_duration[:-1].strip())) + elif self.cfg.max_duration.endswith("ep"): + max_epochs = int(self.cfg.max_duration[:-2].strip()) + return max_epochs * self.batches_per_epoch * self.tokens_per_batch + else: + # convert to float *first* to handle scientific notation + return ( + self.global_train_tokens_seen + + min(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch + ) + else: + raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}") + + @property + def scheduler_current(self) -> int: + if self.cfg.scheduler.units == SchedulerUnits.steps: + return self.global_step + elif self.cfg.scheduler.units == SchedulerUnits.tokens: + return self.global_train_tokens_seen + else: + raise NotImplementedError(self.cfg.scheduler.units) + + @property + def scheduler_max(self) -> int: + if self.cfg.scheduler.units == SchedulerUnits.steps: + return self.max_steps + elif self.cfg.scheduler.units == SchedulerUnits.tokens: + return self.max_tokens + else: + raise NotImplementedError(self.cfg.scheduler.units) + def trainer_state_dict(self) -> Dict[str, Any]: return { "epoch": self.epoch, @@ -233,7 +280,7 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None: # Reset learning rate and weight decay to the values from the config, not the checkpoint. log.info("Resetting learning rate...") new_learning_rate = self.scheduler.get_lr( - self.cfg.optimizer.learning_rate, self.global_step, self.max_steps + self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max ) for group in self.optim.param_groups: group["lr"] = new_learning_rate @@ -572,12 +619,14 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> # TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group # we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of # the corresponding values from `self.cfg`. - group["lr"] = self.scheduler.get_lr(self.cfg.optimizer.learning_rate, self.global_step, self.max_steps) + group["lr"] = self.scheduler.get_lr( + self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max + ) group["max_grad_norm"] = self.scheduler.get_max_grad_norm( - self.cfg.max_grad_norm, self.global_step, self.max_steps + self.cfg.max_grad_norm, self.scheduler_current, self.scheduler_max ) group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm( - self.cfg.max_grad_norm_ratio, self.global_step, self.max_steps + self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max ) # Optimizer step. From 9d9e5a7885e9d2e202b60a6b6a5f85cb48353005 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 17 Jan 2024 11:26:28 -0800 Subject: [PATCH 2/9] Fix --- olmo/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/olmo/train.py b/olmo/train.py index 42250f919..d9710f453 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -146,7 +146,7 @@ def max_steps(self) -> int: if self.cfg.max_duration.endswith("T"): # convert to float *first* to handle scientific notation max_tokens = int(float(self.cfg.max_duration[:-1].strip())) - tokens_remaining = max_tokens - self.global_train_tokens_seen + tokens_remaining = max(max_tokens - self.global_train_tokens_seen, 0) steps_remaining = tokens_remaining // self.tokens_per_batch return self.global_step + steps_remaining elif self.cfg.max_duration.endswith("ep"): @@ -163,7 +163,7 @@ def max_tokens(self) -> int: if isinstance(self.cfg.max_duration, int): return ( self.global_train_tokens_seen - + min(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch + + max(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch ) elif isinstance(self.cfg.max_duration, str): if self.cfg.max_duration.endswith("T"): @@ -176,7 +176,7 @@ def max_tokens(self) -> int: # convert to float *first* to handle scientific notation return ( self.global_train_tokens_seen - + min(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch + + max(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch ) else: raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}") From a8f3f82390d1e37f8c7a035dba4e435a6a26f76e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 17 Jan 2024 15:55:10 -0800 Subject: [PATCH 3/9] add configs for test --- configs/mcli/small-test.yaml | 31 +++++++ configs/small-test-s3.yaml | 171 +++++++++++++++++++++++++++++++++++ 2 files changed, 202 insertions(+) create mode 100644 configs/mcli/small-test.yaml create mode 100644 configs/small-test-s3.yaml diff --git a/configs/mcli/small-test.yaml b/configs/mcli/small-test.yaml new file mode 100644 index 000000000..f4bed4267 --- /dev/null +++ b/configs/mcli/small-test.yaml @@ -0,0 +1,31 @@ +name: olmo-small-test +image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 +compute: + #cluster: r12z3 + cluster: r7z2 + gpus: 8 + gpu_type: a100_40gb +integrations: + - integration_type: git_repo + git_repo: allenai/LLM + git_branch: epwalsh/lr-schedule-tokens + pip_install: -e . + ssh_clone: true +command: |- + run_name=mitchish-small-test + + cd LLM + + pip freeze + + # Prepare environment including AWS config files for both S3 and R2 access. + mkdir -p /root/.cache/torch + + torchrun \ + --master_addr "$MASTER_ADDR" \ + --master_port "$MASTER_PORT" \ + --nnodes "$NUM_NODES" \ + --node_rank "$NODE_RANK" \ + --nproc_per_node 8 \ + scripts/train.py configs/small-test-s3.yaml \ + --run_name=${run_name} diff --git a/configs/small-test-s3.yaml b/configs/small-test-s3.yaml new file mode 100644 index 000000000..2c37d4805 --- /dev/null +++ b/configs/small-test-s3.yaml @@ -0,0 +1,171 @@ +run_name: small-test +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmo-small-test + +model: + d_model: 1024 + n_heads: 16 + n_layers: 12 + mlp_ratio: 8 + weight_tying: false + alibi: false + rope: true + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: sequential + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 0 + pad_token_id: 1 + init_device: meta + init_fn: mitchell + +compile: null + +optimizer: + name: adamw + learning_rate: 1.0e-3 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: linear_with_warmup + units: tokens + t_warmup: 1e9 + alpha_f: 0.1 + grad_clip_warmup_steps: 1e9 + grad_clip_warmup_factor: 10.0 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right + +save_folder: runs/${run_name} +remote_save_folder: s3://ai2-llm/checkpoints/small-test/${run_name} +save_overwrite: true +# Sharded checkpoints (best for restarts) +save_interval: 200 +save_num_checkpoints_to_keep: -1 +# Unsharded checkpoints (for final storage) +save_interval_unsharded: null # getting errors on LUMI right now +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 2e10T # 2T tokens +global_train_batch_size: 2048 +device_train_microbatch_size: 16 +time_limit: null + +precision: amp_bf16 + +fsdp: + wrapping_strategy: by_block + precision: mixed + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 20 + +eval_interval: ${save_interval} +eval_subset_num_batches: -1 +device_eval_batch_size: ${device_train_microbatch_size} +evaluators: [] + +data: + pad_direction: right + num_workers: 16 + drop_last: true + pin_memory: true + prefetch_factor: 1 + persistent_workers: true + timeout: 0 + paths: + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-000-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-000-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-001-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-001-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-002-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-002-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-003-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-003-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-004-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-004-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-005-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-005-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-006-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-006-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-006-00002.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-007-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-007-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-008-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-008-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-008-00002.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-009-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-009-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-010-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-010-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-010-00002.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-011-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-011-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-012-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-012-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-013-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-013-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-013-00002.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-014-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-014-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-014-00002.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-015-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-015-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-016-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-016-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-017-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-017-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-018-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-018-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-019-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-019-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-020-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-020-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-021-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-021-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-022-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-022-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-023-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-023-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-024-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-024-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-025-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-025-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-025-00002.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-026-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-026-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-027-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-027-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-027-00002.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-028-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-028-00001.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-028-00002.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-029-00000.npy + - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-029-00001.npy From 5b39657f51d89535425774bc904c1bb00901290a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 17 Jan 2024 16:12:30 -0800 Subject: [PATCH 4/9] allow float types --- configs/mcli/small-test.yaml | 12 ++++-------- olmo/config.py | 6 +++--- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/configs/mcli/small-test.yaml b/configs/mcli/small-test.yaml index f4bed4267..225089cf5 100644 --- a/configs/mcli/small-test.yaml +++ b/configs/mcli/small-test.yaml @@ -18,14 +18,10 @@ command: |- pip freeze - # Prepare environment including AWS config files for both S3 and R2 access. + # Prepare environment. mkdir -p /root/.cache/torch + export OMP_NUM_THREADS=8 + export LOG_FILTER_TYPE=local_rank0_only - torchrun \ - --master_addr "$MASTER_ADDR" \ - --master_port "$MASTER_PORT" \ - --nnodes "$NUM_NODES" \ - --node_rank "$NODE_RANK" \ - --nproc_per_node 8 \ - scripts/train.py configs/small-test-s3.yaml \ + torchrun --nproc_per_node 8 scripts/train.py configs/small-test-s3.yaml \ --run_name=${run_name} diff --git a/olmo/config.py b/olmo/config.py index e530754b4..f3768d18d 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -489,11 +489,11 @@ class SchedulerUnits(StrEnum): class SchedulerConfig(BaseConfig): name: SchedulerType = SchedulerType.cosine_with_warmup units: SchedulerUnits = SchedulerUnits.steps - t_warmup: int = 100 - t_max: Optional[int] = None + t_warmup: Union[int, float] = 100 + t_max: Optional[Union[int, float]] = None alpha_f: float = 0.1 - grad_clip_warmup_steps: Optional[int] = None + grad_clip_warmup_steps: Optional[Union[int, float]] = None """ The warmup period for which the max grad norm (or norm ratio) will be set to its warmup value of `max_grad_norm * grad_clip_warmup_factor`. From e2825f15337b5ad8acac94a1feea2f8bc81d1d73 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 17 Jan 2024 16:17:02 -0800 Subject: [PATCH 5/9] update config --- configs/small-test-s3.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/small-test-s3.yaml b/configs/small-test-s3.yaml index 2c37d4805..82bf57ee4 100644 --- a/configs/small-test-s3.yaml +++ b/configs/small-test-s3.yaml @@ -7,9 +7,9 @@ wandb: project: olmo-small-test model: - d_model: 1024 + d_model: 512 n_heads: 16 - n_layers: 12 + n_layers: 8 mlp_ratio: 8 weight_tying: false alibi: false From 409ff321d5124f762372872a605d18a8c7100eae Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 17 Jan 2024 16:25:15 -0800 Subject: [PATCH 6/9] update --- configs/small-test-s3.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/small-test-s3.yaml b/configs/small-test-s3.yaml index 82bf57ee4..d471c3361 100644 --- a/configs/small-test-s3.yaml +++ b/configs/small-test-s3.yaml @@ -10,7 +10,7 @@ model: d_model: 512 n_heads: 16 n_layers: 8 - mlp_ratio: 8 + mlp_ratio: 6 weight_tying: false alibi: false rope: true From 0ed8dd25abccadb57b55e0394bd1e32ff24ae5d8 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 17 Jan 2024 20:08:26 -0800 Subject: [PATCH 7/9] update --- configs/mcli/small-test.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/configs/mcli/small-test.yaml b/configs/mcli/small-test.yaml index 225089cf5..02d8f5e50 100644 --- a/configs/mcli/small-test.yaml +++ b/configs/mcli/small-test.yaml @@ -24,4 +24,6 @@ command: |- export LOG_FILTER_TYPE=local_rank0_only torchrun --nproc_per_node 8 scripts/train.py configs/small-test-s3.yaml \ - --run_name=${run_name} + --run_name=${run_name} \ + --load_path=s3://ai2-llm/checkpoints/small-test/${run_name}/step1000 \ + --global_train_batch_size=4096 From 9da6dbdf6dd52932d07d88b9f7041039726c85f5 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 17 Jan 2024 20:08:45 -0800 Subject: [PATCH 8/9] clean up --- configs/mcli/small-test.yaml | 29 ------ configs/small-test-s3.yaml | 171 ----------------------------------- 2 files changed, 200 deletions(-) delete mode 100644 configs/mcli/small-test.yaml delete mode 100644 configs/small-test-s3.yaml diff --git a/configs/mcli/small-test.yaml b/configs/mcli/small-test.yaml deleted file mode 100644 index 02d8f5e50..000000000 --- a/configs/mcli/small-test.yaml +++ /dev/null @@ -1,29 +0,0 @@ -name: olmo-small-test -image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 -compute: - #cluster: r12z3 - cluster: r7z2 - gpus: 8 - gpu_type: a100_40gb -integrations: - - integration_type: git_repo - git_repo: allenai/LLM - git_branch: epwalsh/lr-schedule-tokens - pip_install: -e . - ssh_clone: true -command: |- - run_name=mitchish-small-test - - cd LLM - - pip freeze - - # Prepare environment. - mkdir -p /root/.cache/torch - export OMP_NUM_THREADS=8 - export LOG_FILTER_TYPE=local_rank0_only - - torchrun --nproc_per_node 8 scripts/train.py configs/small-test-s3.yaml \ - --run_name=${run_name} \ - --load_path=s3://ai2-llm/checkpoints/small-test/${run_name}/step1000 \ - --global_train_batch_size=4096 diff --git a/configs/small-test-s3.yaml b/configs/small-test-s3.yaml deleted file mode 100644 index d471c3361..000000000 --- a/configs/small-test-s3.yaml +++ /dev/null @@ -1,171 +0,0 @@ -run_name: small-test -seed: 6198 -dry_run: false - -wandb: - name: ${run_name} - project: olmo-small-test - -model: - d_model: 512 - n_heads: 16 - n_layers: 8 - mlp_ratio: 6 - weight_tying: false - alibi: false - rope: true - flash_attention: true - attention_dropout: 0.0 - attention_layer_norm: false - multi_query_attention: false - include_bias: false - block_type: sequential - layer_norm_type: default - layer_norm_with_affine: false - bias_for_layer_norm: false - attention_layer_norm_with_affine: false - activation_type: swiglu - residual_dropout: 0.0 - embedding_dropout: 0.0 - max_sequence_length: 2048 - vocab_size: 50280 - embedding_size: 50304 - eos_token_id: 0 - pad_token_id: 1 - init_device: meta - init_fn: mitchell - -compile: null - -optimizer: - name: adamw - learning_rate: 1.0e-3 - weight_decay: 0.1 - betas: - - 0.9 - - 0.95 - metrics_log_interval: 10 - -scheduler: - name: linear_with_warmup - units: tokens - t_warmup: 1e9 - alpha_f: 0.1 - grad_clip_warmup_steps: 1e9 - grad_clip_warmup_factor: 10.0 - -tokenizer: - identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json - truncate_direction: right - -save_folder: runs/${run_name} -remote_save_folder: s3://ai2-llm/checkpoints/small-test/${run_name} -save_overwrite: true -# Sharded checkpoints (best for restarts) -save_interval: 200 -save_num_checkpoints_to_keep: -1 -# Unsharded checkpoints (for final storage) -save_interval_unsharded: null # getting errors on LUMI right now -save_num_unsharded_checkpoints_to_keep: -1 - -load_path: null - -max_duration: 2e10T # 2T tokens -global_train_batch_size: 2048 -device_train_microbatch_size: 16 -time_limit: null - -precision: amp_bf16 - -fsdp: - wrapping_strategy: by_block - precision: mixed - -max_grad_norm: 1.0 -max_grad_norm_ratio: null - -speed_monitor: - window_size: 20 - -eval_interval: ${save_interval} -eval_subset_num_batches: -1 -device_eval_batch_size: ${device_train_microbatch_size} -evaluators: [] - -data: - pad_direction: right - num_workers: 16 - drop_last: true - pin_memory: true - prefetch_factor: 1 - persistent_workers: true - timeout: 0 - paths: - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-000-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-000-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-001-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-001-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-002-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-002-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-003-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-003-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-004-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-004-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-005-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-005-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-006-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-006-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-006-00002.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-007-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-007-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-008-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-008-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-008-00002.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-009-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-009-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-010-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-010-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-010-00002.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-011-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-011-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-012-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-012-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-013-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-013-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-013-00002.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-014-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-014-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-014-00002.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-015-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-015-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-016-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-016-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-017-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-017-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-018-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-018-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-019-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-019-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-020-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-020-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-021-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-021-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-022-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-022-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-023-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-023-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-024-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-024-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-025-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-025-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-025-00002.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-026-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-026-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-027-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-027-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-027-00002.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-028-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-028-00001.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-028-00002.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-029-00000.npy - - s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-029-00001.npy From 9477cfa0a46241d9d772b6d2da3311a0e9ad5ba9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 17 Jan 2024 20:13:09 -0800 Subject: [PATCH 9/9] fixes for mypy --- olmo/optim.py | 30 ++++++++++++++++++++---------- scripts/train.py | 2 +- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/olmo/optim.py b/olmo/optim.py index 711e6d889..91d535e72 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -720,36 +720,46 @@ def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = Non sched_cfg = sched_cfg if sched_cfg is not None else cfg.scheduler if sched_cfg.name == SchedulerType.cosine_with_warmup: return CosWithWarmup( - grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps, + grad_clip_warmup_steps=None + if sched_cfg.grad_clip_warmup_steps is None + else int(sched_cfg.grad_clip_warmup_steps), grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, - warmup_steps=sched_cfg.t_warmup, + warmup_steps=int(sched_cfg.t_warmup), alpha_f=sched_cfg.alpha_f, - t_max=sched_cfg.t_max, + t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max), ) elif sched_cfg.name == SchedulerType.linear_with_warmup: return LinearWithWarmup( - grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps, + grad_clip_warmup_steps=None + if sched_cfg.grad_clip_warmup_steps is None + else int(sched_cfg.grad_clip_warmup_steps), grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, - warmup_steps=sched_cfg.t_warmup, + warmup_steps=int(sched_cfg.t_warmup), alpha_f=sched_cfg.alpha_f, - t_max=sched_cfg.t_max, + t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max), ) elif sched_cfg.name == SchedulerType.inverse_sqrt_with_warmup: return InvSqrtWithWarmup( - grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps, + grad_clip_warmup_steps=None + if sched_cfg.grad_clip_warmup_steps is None + else int(sched_cfg.grad_clip_warmup_steps), grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, - warmup_steps=sched_cfg.t_warmup, + warmup_steps=int(sched_cfg.t_warmup), ) elif sched_cfg.name == SchedulerType.max_scheduler: return MaxScheduler( - grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps, + grad_clip_warmup_steps=None + if sched_cfg.grad_clip_warmup_steps is None + else int(sched_cfg.grad_clip_warmup_steps), grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, sched1=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.cosine_with_warmup)), sched2=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.inverse_sqrt_with_warmup)), ) elif sched_cfg.name == SchedulerType.constant: return ConstantScheduler( - grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps, + grad_clip_warmup_steps=None + if sched_cfg.grad_clip_warmup_steps is None + else int(sched_cfg.grad_clip_warmup_steps), grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, ) else: diff --git a/scripts/train.py b/scripts/train.py index 710cf0255..de97e31be 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -212,7 +212,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: trainer.scheduler = BoltOnWarmupScheduler.wrap( trainer.scheduler, trainer.global_step, - trainer.global_step + cfg.scheduler.t_warmup, + int(trainer.global_step + cfg.scheduler.t_warmup), ) if cfg.force_save_unsharded: