Skip to content

Commit

Permalink
Merge pull request #350 from allenai/mitchish
Browse files Browse the repository at this point in the history
Mitchish mosaic run on its own branch
  • Loading branch information
epwalsh authored Jan 5, 2024
2 parents df19554 + f3a73dd commit 5a735dd
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 26 deletions.
19 changes: 16 additions & 3 deletions configs/mcli/v1_5-mix-medium-mitch-ish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ gpu_type: a100_40gb
integrations:
- integration_type: git_repo
git_repo: allenai/LLM
git_branch: main # make sure to update this!
# git_branch: mitchish
git_commit: 148ca062e7f1f7667d7fc0f4346e97467e66ce87
pip_install: -e .
ssh_clone: true
command: |-
Expand All @@ -28,6 +29,18 @@ command: |-
--nproc_per_node 8 \
scripts/train.py configs/v1_5-mix-medium-mitch-ish-s3.yaml \
--run_name=v1_5-mix-mitch-ish \
--wandb.name=v1_5-mix-mitch-ish-mcli \
--wandb.name=v1_5-mix-mitch-ish-mcli-final \
--global_train_batch_size=2160 \
--load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish}'
--time_limit=169200
# We added these flags in order to get a final checkpoint where we decayed the LR down to 0.
# --eval_interval=100 \
# --save_interval=500 \
# --load_path=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish/step556000 \
# --remote_save_folder=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish-final \
# --epoch=1 \
# --optimizer.learning_rate=0.000023 \
# --scheduler.t_warmup=556000 \
# --scheduler.t_max=557000 \
# --scheduler.alpha_f=0.001 \
# --stop_at=557001
2 changes: 1 addition & 1 deletion configs/v1_5-mix-medium-mitch-ish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ optimizer:
metrics_log_interval: 10

scheduler:
name: cosine_with_warmup
name: linear_with_warmup
t_warmup: 5000
alpha_f: 0.1

Expand Down
75 changes: 53 additions & 22 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from .exceptions import OlmoConfigurationError
from .initialization import ModuleType, init_weights
from .torch_util import ensure_finite_
from .util import pass_through_fn

__all__ = [
"LayerNormBase",
Expand Down Expand Up @@ -430,7 +429,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
self.__cache = cache
assert config.d_model % config.n_heads == 0

self._activation_checkpoint_fn = pass_through_fn
self._activation_checkpoint_fn = None

# Dropout.
self.dropout = Dropout(config.residual_dropout)
Expand Down Expand Up @@ -492,7 +491,7 @@ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointin
if strategy == ActivationCheckpointingStrategy.fine_grained:
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
else:
self._activation_checkpoint_fn = pass_through_fn
self._activation_checkpoint_fn = None

@classmethod
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
Expand Down Expand Up @@ -673,12 +672,20 @@ def forward(
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(self.fused_dims, dim=-1)
if self._activation_checkpoint_fn is not None:
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
self.fused_dims, dim=-1
)
else:
q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)

# Get attention scores.
att, cache = self._activation_checkpoint_fn(
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
if self._activation_checkpoint_fn is not None:
att, cache = self._activation_checkpoint_fn( # type: ignore
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Add attention scores.
# shape: (B, T, C)
Expand All @@ -687,9 +694,15 @@ def forward(
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
x = self._activation_checkpoint_fn(self.ff_norm, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
else:
x = self.ff_norm(x)
x = self.ff_proj(x)
x = self._activation_checkpoint_fn(self.act, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
else:
x = self.act(x)
x = self.ff_out(x)
x = self.dropout(x)
x = og_x + x
Expand Down Expand Up @@ -753,23 +766,35 @@ def forward(
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
# shape of ff: (batch_size, seq_len, hidden_size)
q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split(
self.fused_dims, dim=-1
)
if self._activation_checkpoint_fn is not None:
q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split(
self.fused_dims, dim=-1
)
else:
q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1)

# Get attention scores.
# shape: (B, T, C)
att, cache = self._activation_checkpoint_fn(
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
if self._activation_checkpoint_fn is not None:
att, cache = self._activation_checkpoint_fn( # type: ignore
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Apply output projections (and activation function) and sum the results.
# We keep these projections separate because we found that we got better throughput this
# way compared to fusing them.
return (
x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att),
cache,
)
if self._activation_checkpoint_fn is not None:
return (
x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att),
cache,
)
else:
return (
x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att),
cache,
)


class OlmoLlamaBlock(OlmoBlock):
Expand Down Expand Up @@ -874,9 +899,15 @@ def forward(
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
x = self._activation_checkpoint_fn(self.ff_norm, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
else:
x = self.ff_norm(x)
x = self.ff_proj(x)
x = self._activation_checkpoint_fn(self.act, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
else:
x = self.act(x)
x = self.ff_out(x)
x = self.dropout(x)
x = og_x + x
Expand Down Expand Up @@ -945,7 +976,7 @@ def forward(
)
):
# shape: (batch_size, seq_len, d_model)
x, cache = self._activation_checkpoint_fn(
x, cache = self._activation_checkpoint_fn( # type: ignore
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
Expand Down
30 changes: 30 additions & 0 deletions scripts/beaker/mitch-ish-7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env bash

set -ex

CONFIG_PATH=configs/v1_5-mix-medium-mitch-ish-s3.yaml
NUM_NODES=4
ARGS='--activation_checkpointing=fine_grained wandb.name=v1_5-mix-mitch-ish-mcli-final --epoch=1 --optimizer.learning_rate=0.000023 --scheduler.t_warmup=556000 --scheduler.t_max=557000 --scheduler.alpha_f=0.001 --stop_at=557000'

gantry run \
--allow-dirty \
--workspace ai2/llm-testing \
--task-name mitchish-mcli-final \
--description mitchish-mcli-final \
--priority high \
--beaker-image olmo-torch2-gantry \
--cluster ai2/general-cirrascale-a100-80g-ib \
--gpus 8 \
--replicas "${NUM_NODES}" \
--nfs \
--mount /net/nfs.cirrascale/allennlp/petew/cache:/root/.cache \
--env LOG_FILTER_TYPE=local_rank0_only \
--env OMP_NUM_THREADS=8 \
--env OLMO_TASK=model \
--env-secret WANDB_API_KEY=WANDB_API_KEY \
--env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \
--env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \
--shared-memory 10GiB \
--venv base \
--yes \
-- /bin/bash -c "torchrun --nnodes ${NUM_NODES}:${NUM_NODES} --nproc-per-node 8 --rdzv_id=101 --rdzv_backend=c10d --rdzv_endpoint=\$BEAKER_LEADER_REPLICA_HOSTNAME:29400 scripts/train.py ${CONFIG_PATH} ${ARGS}"
52 changes: 52 additions & 0 deletions scripts/kempner/mitch-ish-7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/bin/bash
#SBATCH --job-name=v1.5-mix-medium-mitch-ish
#SBATCH --account=kempner_lab
#SBATCH --output=/n/holyscratch01/kempner_lab/Lab/logs-petew/%j.log
#SBATCH --nodes=8 # Total number of nodes
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4 # Allocate one gpu per MPI rank
#SBATCH --cpus-per-task=16
#SBATCH --time=167:00:00
#SBATCH --mem=0 # All memory on the node
#SBATCH --partition=kempner_project

export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
export MPICH_GPU_SUPPORT_ENABLED=1
export MIOPEN_USER_DB_PATH=/tmp/${USER}-miopen-cache-${SLURM_JOB_ID}
export MIOPEN_CUSTOM_CACHE_DIR=${MIOPEN_USER_DB_PATH}

export PYTHONPATH=.:${PYTHONPATH}

# Try playing with max_split_size_mb if you run into OOM errors.
# export PYTORCH_HIP_ALLOC_CONF=max_split_size_mb:512

export DATA_PATH=/n/home06/dgroeneveld/data/preprocessed/olmo-mix
export EVAL_DATA_PATH=/n/home06/dgroeneveld/data/eval-data
export CHECKPOINTS_PATH=/n/home06/dgroeneveld/checkpoints

export PYTORCH_KERNEL_CACHE_PATH=/tmp/pytorch_kernel_cache/
mkdir -p $PYTORCH_KERNEL_CACHE_PATH

LOAD_PATH=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish/step556000-unsharded
# SAVE_PATH=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish-final-tulu

srun \
"--cpus-per-task=$SLURM_CPUS_PER_TASK" \
--distribution=block:block \
--kill-on-bad-exit \
scripts/run_with_environment.sh \
$HOME/miniconda3/envs/LLM/bin/python -u scripts/train.py configs/v1_5-mix-medium-mitch-ish-s3.yaml \
"--run_name=kempner_${SLURM_JOB_ID}" \
--wandb.name=v1_5-mix-mitch-ish-final-tulu \
'--data.paths=[s3://ai2-llm/preprocessed/tulu-v2-sft-mixture/gpt-neox-20b-pii-special/data.npy,s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample-9B/gpt-neox-20b-pii-special/data.npy]' \
--eval_interval=100 \
--save_interval=500 \
"--load_path=${LOAD_PATH}" \
--restore_dataloader=false \
--optimizer.learning_rate=0.000023 \
--scheduler.t_warmup=556000 \
--scheduler.alpha_f=0.001 \
--scheduler.t_max=558223 \
--stop_at=558223 \
--time_limit=$((167 * 60 * 60)) \
"--save_folder=/n/holyscratch01/kempner_lab/Lab/checkpoints/${SLURM_JOB_ID}/"
60 changes: 60 additions & 0 deletions scripts/lumi/mitch-ish-7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/bin/bash
#SBATCH --job-name=v1.5-mix-medium-mitch-ish
#SBATCH --account=project_462000229
#SBATCH --output=/pfs/lustref1/flash/project_462000229/logs/%j.log
#SBATCH --nodes=256 # Total number of nodes
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8 # Allocate one gpu per MPI rank
#SBATCH --cpus-per-task=6
#SBATCH --time=48:00:00
#SBATCH --time-min=24:00:00
#SBATCH --mem=0 # All memory on the node
#SBATCH --partition=standard-g

module load LUMI/22.08 partition/G

# export OLMO_CONTAINER=llm-lumi_latest.sif
export OLMO_CONTAINER=llm-lumi-torch21_latest.sif

export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
export MPICH_GPU_SUPPORT_ENABLED=1
export NCCL_SOCKET_IFNAME=hsn
export NCCL_NET_GDR_LEVEL=3
export MIOPEN_USER_DB_PATH=/tmp/${USER}-miopen-cache-${SLURM_JOB_ID}
export MIOPEN_CUSTOM_CACHE_DIR=${MIOPEN_USER_DB_PATH}
export CXI_FORK_SAFE=1
export CXI_FORK_SAFE_HP=1
export FI_CXI_DISABLE_CQ_HUGETLB=1

# We need to set this to avoid "Cassini Event Queue overflow detected." errors.
export FI_CXI_DEFAULT_CQ_SIZE=131072

#export NCCL_DEBUG=INFO
export PYTHONPATH=.:${PYTHONPATH}
export ROCM_PATH=/opt/rocm
export SINGULARITYENV_LD_LIBRARY_PATH=/usr/local/lib:/opt/cray/libfabric/1.15.2.0/lib64

# Try playing with max_split_size_mb if you run into OOM errors.
#export PYTORCH_HIP_ALLOC_CONF=max_split_size_mb:128

export DATA_PATH=$FLASH_DIR/preprocessed/olmo-mix
export CHECKPOINTS_PATH=$FLASH_DIR/checkpoints
export EVAL_DATA_PATH=$SCRATCH_DIR/eval-data

srun \
--cpus-per-task=$SLURM_CPUS_PER_TASK \
--distribution=block:block \
--kill-on-bad-exit \
scripts/run_with_environment.sh \
singularity exec \
-B"$PROJECT_DIR:$PROJECT_DIR" \
-B"$FLASH_DIR:$FLASH_DIR" \
-B"$SCRATCH_DIR:$SCRATCH_DIR" \
-B /opt/cray:/opt/cray \
-B /usr/lib64/libcxi.so.1:/usr/lib64/libcxi.so.1 \
-B /usr/lib64/libjson-c.so.3:/usr/lib64/libjson-c.so.3 \
$PROJECT_DIR/containers/$OLMO_CONTAINER \
python scripts/train.py configs/v1_5-mix-medium-mitch-ish.yaml ${@} \
--run_name=${SLURM_JOB_ID} \
--global_train_batch_size=4096 \
--max_duration=238418

0 comments on commit 5a735dd

Please sign in to comment.