diff --git a/configs/mcli/v1_5-mix-medium-mitch-ish.yaml b/configs/mcli/v1_5-mix-medium-mitch-ish.yaml index 8f8e5c493..167356153 100644 --- a/configs/mcli/v1_5-mix-medium-mitch-ish.yaml +++ b/configs/mcli/v1_5-mix-medium-mitch-ish.yaml @@ -6,7 +6,8 @@ gpu_type: a100_40gb integrations: - integration_type: git_repo git_repo: allenai/LLM - git_branch: main # make sure to update this! + # git_branch: mitchish + git_commit: 148ca062e7f1f7667d7fc0f4346e97467e66ce87 pip_install: -e . ssh_clone: true command: |- @@ -28,6 +29,18 @@ command: |- --nproc_per_node 8 \ scripts/train.py configs/v1_5-mix-medium-mitch-ish-s3.yaml \ --run_name=v1_5-mix-mitch-ish \ - --wandb.name=v1_5-mix-mitch-ish-mcli \ + --wandb.name=v1_5-mix-mitch-ish-mcli-final \ --global_train_batch_size=2160 \ - --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish}' + --time_limit=169200 + +# We added these flags in order to get a final checkpoint where we decayed the LR down to 0. +# --eval_interval=100 \ +# --save_interval=500 \ +# --load_path=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish/step556000 \ +# --remote_save_folder=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish-final \ +# --epoch=1 \ +# --optimizer.learning_rate=0.000023 \ +# --scheduler.t_warmup=556000 \ +# --scheduler.t_max=557000 \ +# --scheduler.alpha_f=0.001 \ +# --stop_at=557001 diff --git a/configs/v1_5-mix-medium-mitch-ish.yaml b/configs/v1_5-mix-medium-mitch-ish.yaml index 6518cbb00..59acb6bb5 100644 --- a/configs/v1_5-mix-medium-mitch-ish.yaml +++ b/configs/v1_5-mix-medium-mitch-ish.yaml @@ -49,7 +49,7 @@ optimizer: metrics_log_interval: 10 scheduler: - name: cosine_with_warmup + name: linear_with_warmup t_warmup: 5000 alpha_f: 0.1 diff --git a/olmo/model.py b/olmo/model.py index b635518cb..4fdc3a4b4 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -45,7 +45,6 @@ from .exceptions import OlmoConfigurationError from .initialization import ModuleType, init_weights from .torch_util import ensure_finite_ -from .util import pass_through_fn __all__ = [ "LayerNormBase", @@ -430,7 +429,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): self.__cache = cache assert config.d_model % config.n_heads == 0 - self._activation_checkpoint_fn = pass_through_fn + self._activation_checkpoint_fn = None # Dropout. self.dropout = Dropout(config.residual_dropout) @@ -492,7 +491,7 @@ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointin if strategy == ActivationCheckpointingStrategy.fine_grained: self._activation_checkpoint_fn = activation_checkpoint_function(self.config) else: - self._activation_checkpoint_fn = pass_through_fn + self._activation_checkpoint_fn = None @classmethod def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: @@ -673,12 +672,20 @@ def forward( # - for regular attn q, k, v: (batch_size, seq_len, d_model) # - for multi-query attn q: (batch_size, seq_len, d_model) # k, v: (batch_size, seq_len, d_model // n_heads) - q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(self.fused_dims, dim=-1) + if self._activation_checkpoint_fn is not None: + q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split( + self.fused_dims, dim=-1 + ) + else: + q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1) # Get attention scores. - att, cache = self._activation_checkpoint_fn( - self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache - ) + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) # Add attention scores. # shape: (B, T, C) @@ -687,9 +694,15 @@ def forward( # Add feed-forward projection. # shape: (batch_size, seq_len, d_model) og_x = x - x = self._activation_checkpoint_fn(self.ff_norm, x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) x = self.ff_proj(x) - x = self._activation_checkpoint_fn(self.act, x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) x = self.ff_out(x) x = self.dropout(x) x = og_x + x @@ -753,23 +766,35 @@ def forward( # - for multi-query attn q: (batch_size, seq_len, d_model) # k, v: (batch_size, seq_len, d_model // n_heads) # shape of ff: (batch_size, seq_len, hidden_size) - q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split( - self.fused_dims, dim=-1 - ) + if self._activation_checkpoint_fn is not None: + q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split( + self.fused_dims, dim=-1 + ) + else: + q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1) # Get attention scores. # shape: (B, T, C) - att, cache = self._activation_checkpoint_fn( - self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache - ) + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) # Apply output projections (and activation function) and sum the results. # We keep these projections separate because we found that we got better throughput this # way compared to fusing them. - return ( - x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att), - cache, - ) + if self._activation_checkpoint_fn is not None: + return ( + x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att), + cache, + ) + else: + return ( + x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att), + cache, + ) class OlmoLlamaBlock(OlmoBlock): @@ -874,9 +899,15 @@ def forward( # Add feed-forward projection. # shape: (batch_size, seq_len, d_model) og_x = x - x = self._activation_checkpoint_fn(self.ff_norm, x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) x = self.ff_proj(x) - x = self._activation_checkpoint_fn(self.act, x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) x = self.ff_out(x) x = self.dropout(x) x = og_x + x @@ -945,7 +976,7 @@ def forward( ) ): # shape: (batch_size, seq_len, d_model) - x, cache = self._activation_checkpoint_fn( + x, cache = self._activation_checkpoint_fn( # type: ignore block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache ) else: diff --git a/scripts/beaker/mitch-ish-7b.sh b/scripts/beaker/mitch-ish-7b.sh new file mode 100755 index 000000000..3fd81cade --- /dev/null +++ b/scripts/beaker/mitch-ish-7b.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +set -ex + +CONFIG_PATH=configs/v1_5-mix-medium-mitch-ish-s3.yaml +NUM_NODES=4 +ARGS='--activation_checkpointing=fine_grained wandb.name=v1_5-mix-mitch-ish-mcli-final --epoch=1 --optimizer.learning_rate=0.000023 --scheduler.t_warmup=556000 --scheduler.t_max=557000 --scheduler.alpha_f=0.001 --stop_at=557000' + +gantry run \ + --allow-dirty \ + --workspace ai2/llm-testing \ + --task-name mitchish-mcli-final \ + --description mitchish-mcli-final \ + --priority high \ + --beaker-image olmo-torch2-gantry \ + --cluster ai2/general-cirrascale-a100-80g-ib \ + --gpus 8 \ + --replicas "${NUM_NODES}" \ + --nfs \ + --mount /net/nfs.cirrascale/allennlp/petew/cache:/root/.cache \ + --env LOG_FILTER_TYPE=local_rank0_only \ + --env OMP_NUM_THREADS=8 \ + --env OLMO_TASK=model \ + --env-secret WANDB_API_KEY=WANDB_API_KEY \ + --env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \ + --env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \ + --shared-memory 10GiB \ + --venv base \ + --yes \ + -- /bin/bash -c "torchrun --nnodes ${NUM_NODES}:${NUM_NODES} --nproc-per-node 8 --rdzv_id=101 --rdzv_backend=c10d --rdzv_endpoint=\$BEAKER_LEADER_REPLICA_HOSTNAME:29400 scripts/train.py ${CONFIG_PATH} ${ARGS}" diff --git a/scripts/kempner/mitch-ish-7b.sh b/scripts/kempner/mitch-ish-7b.sh new file mode 100644 index 000000000..64624ec4e --- /dev/null +++ b/scripts/kempner/mitch-ish-7b.sh @@ -0,0 +1,52 @@ +#!/bin/bash +#SBATCH --job-name=v1.5-mix-medium-mitch-ish +#SBATCH --account=kempner_lab +#SBATCH --output=/n/holyscratch01/kempner_lab/Lab/logs-petew/%j.log +#SBATCH --nodes=8 # Total number of nodes +#SBATCH --ntasks-per-node=4 +#SBATCH --gpus-per-node=4 # Allocate one gpu per MPI rank +#SBATCH --cpus-per-task=16 +#SBATCH --time=167:00:00 +#SBATCH --mem=0 # All memory on the node +#SBATCH --partition=kempner_project + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export MPICH_GPU_SUPPORT_ENABLED=1 +export MIOPEN_USER_DB_PATH=/tmp/${USER}-miopen-cache-${SLURM_JOB_ID} +export MIOPEN_CUSTOM_CACHE_DIR=${MIOPEN_USER_DB_PATH} + +export PYTHONPATH=.:${PYTHONPATH} + +# Try playing with max_split_size_mb if you run into OOM errors. +# export PYTORCH_HIP_ALLOC_CONF=max_split_size_mb:512 + +export DATA_PATH=/n/home06/dgroeneveld/data/preprocessed/olmo-mix +export EVAL_DATA_PATH=/n/home06/dgroeneveld/data/eval-data +export CHECKPOINTS_PATH=/n/home06/dgroeneveld/checkpoints + +export PYTORCH_KERNEL_CACHE_PATH=/tmp/pytorch_kernel_cache/ +mkdir -p $PYTORCH_KERNEL_CACHE_PATH + +LOAD_PATH=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish/step556000-unsharded +# SAVE_PATH=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish-final-tulu + +srun \ + "--cpus-per-task=$SLURM_CPUS_PER_TASK" \ + --distribution=block:block \ + --kill-on-bad-exit \ + scripts/run_with_environment.sh \ + $HOME/miniconda3/envs/LLM/bin/python -u scripts/train.py configs/v1_5-mix-medium-mitch-ish-s3.yaml \ + "--run_name=kempner_${SLURM_JOB_ID}" \ + --wandb.name=v1_5-mix-mitch-ish-final-tulu \ + '--data.paths=[s3://ai2-llm/preprocessed/tulu-v2-sft-mixture/gpt-neox-20b-pii-special/data.npy,s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample-9B/gpt-neox-20b-pii-special/data.npy]' \ + --eval_interval=100 \ + --save_interval=500 \ + "--load_path=${LOAD_PATH}" \ + --restore_dataloader=false \ + --optimizer.learning_rate=0.000023 \ + --scheduler.t_warmup=556000 \ + --scheduler.alpha_f=0.001 \ + --scheduler.t_max=558223 \ + --stop_at=558223 \ + --time_limit=$((167 * 60 * 60)) \ + "--save_folder=/n/holyscratch01/kempner_lab/Lab/checkpoints/${SLURM_JOB_ID}/" diff --git a/scripts/lumi/mitch-ish-7b.sh b/scripts/lumi/mitch-ish-7b.sh new file mode 100644 index 000000000..b7deed4d6 --- /dev/null +++ b/scripts/lumi/mitch-ish-7b.sh @@ -0,0 +1,60 @@ +#!/bin/bash +#SBATCH --job-name=v1.5-mix-medium-mitch-ish +#SBATCH --account=project_462000229 +#SBATCH --output=/pfs/lustref1/flash/project_462000229/logs/%j.log +#SBATCH --nodes=256 # Total number of nodes +#SBATCH --ntasks-per-node=8 +#SBATCH --gpus-per-node=8 # Allocate one gpu per MPI rank +#SBATCH --cpus-per-task=6 +#SBATCH --time=48:00:00 +#SBATCH --time-min=24:00:00 +#SBATCH --mem=0 # All memory on the node +#SBATCH --partition=standard-g + +module load LUMI/22.08 partition/G + +# export OLMO_CONTAINER=llm-lumi_latest.sif +export OLMO_CONTAINER=llm-lumi-torch21_latest.sif + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export MPICH_GPU_SUPPORT_ENABLED=1 +export NCCL_SOCKET_IFNAME=hsn +export NCCL_NET_GDR_LEVEL=3 +export MIOPEN_USER_DB_PATH=/tmp/${USER}-miopen-cache-${SLURM_JOB_ID} +export MIOPEN_CUSTOM_CACHE_DIR=${MIOPEN_USER_DB_PATH} +export CXI_FORK_SAFE=1 +export CXI_FORK_SAFE_HP=1 +export FI_CXI_DISABLE_CQ_HUGETLB=1 + +# We need to set this to avoid "Cassini Event Queue overflow detected." errors. +export FI_CXI_DEFAULT_CQ_SIZE=131072 + +#export NCCL_DEBUG=INFO +export PYTHONPATH=.:${PYTHONPATH} +export ROCM_PATH=/opt/rocm +export SINGULARITYENV_LD_LIBRARY_PATH=/usr/local/lib:/opt/cray/libfabric/1.15.2.0/lib64 + +# Try playing with max_split_size_mb if you run into OOM errors. +#export PYTORCH_HIP_ALLOC_CONF=max_split_size_mb:128 + +export DATA_PATH=$FLASH_DIR/preprocessed/olmo-mix +export CHECKPOINTS_PATH=$FLASH_DIR/checkpoints +export EVAL_DATA_PATH=$SCRATCH_DIR/eval-data + +srun \ + --cpus-per-task=$SLURM_CPUS_PER_TASK \ + --distribution=block:block \ + --kill-on-bad-exit \ + scripts/run_with_environment.sh \ + singularity exec \ + -B"$PROJECT_DIR:$PROJECT_DIR" \ + -B"$FLASH_DIR:$FLASH_DIR" \ + -B"$SCRATCH_DIR:$SCRATCH_DIR" \ + -B /opt/cray:/opt/cray \ + -B /usr/lib64/libcxi.so.1:/usr/lib64/libcxi.so.1 \ + -B /usr/lib64/libjson-c.so.3:/usr/lib64/libjson-c.so.3 \ + $PROJECT_DIR/containers/$OLMO_CONTAINER \ + python scripts/train.py configs/v1_5-mix-medium-mitch-ish.yaml ${@} \ + --run_name=${SLURM_JOB_ID} \ + --global_train_batch_size=4096 \ + --max_duration=238418