Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: EleutherAI/sparsify
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: main
Choose a base ref
...
head repository: CERC-AAI/sae
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: main
Choose a head ref
Can’t automatically merge. Don’t worry, you can still create the pull request.
  • 1 commit
  • 2 files changed
  • 1 contributor

Commits on Aug 30, 2024

  1. frontier adaptations

    george-adams1 committed Aug 30, 2024
    Copy the full SHA
    29d4639 View commit details
Showing with 50 additions and 69 deletions.
  1. +33 −69 sae/__main__.py
  2. +17 −0 sae/launch.sh
102 changes: 33 additions & 69 deletions sae/__main__.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,31 @@
import os
from contextlib import nullcontext, redirect_stdout
from dataclasses import dataclass
from multiprocessing import cpu_count

import torch
import torch.distributed as dist
from datasets import Dataset, load_dataset
from datasets import Dataset, load_from_disk
from simple_parsing import field, parse
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

from .data import chunk_and_tokenize, MemmapDataset
from .trainer import SaeTrainer, TrainConfig


@dataclass
class RunConfig(TrainConfig):
model: str = field(
default="EleutherAI/pythia-160m",
positional=True,
)
model: str = field(default="EleutherAI/pythia-160m")
"""Name of the model to train."""

dataset: str = field(
default="togethercomputer/RedPajama-Data-1T-Sample",
positional=True,
)
dataset: str = field(default="test.hf")
"""Path to the dataset to use for training."""

split: str = "train"
"""Dataset split to use for training."""
cache_dir: str = field(default=None)
"""Directory to use for caching the model."""

ctx_len: int = 2048
"""Context length to use for training."""

hf_token: str | None = None
"""Huggingface API token for downloading models."""

load_in_8bit: bool = False
"""Load the model in 8-bit mode."""

@@ -55,62 +45,39 @@ class RunConfig(TrainConfig):


def load_artifacts(args: RunConfig, rank: int) -> tuple[PreTrainedModel, Dataset | MemmapDataset]:
os.environ['HF_DATASETS_OFFLINE'] = "1"

if args.load_in_8bit:
dtype = torch.float16
elif torch.cuda.is_bf16_supported():
dtype = torch.bfloat16
else:
dtype = "auto"

model = AutoModel.from_pretrained(
print('loading model...')
model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map={"": f"cuda:{rank}"},
quantization_config=(
BitsAndBytesConfig(load_in_8bit=args.load_in_8bit)
if args.load_in_8bit
else None
),
torch_dtype=dtype,
token=args.hf_token,
cache_dir=args.cache_dir
)
print('model loaded')

print('loading datasets...')
dataset = load_from_disk(args.dataset)
dataset = dataset['train']
print('datasets loaded')

tokenizer = AutoTokenizer.from_pretrained(args.model)
dataset = chunk_and_tokenize(
dataset,
tokenizer,
max_seq_len=args.ctx_len,
num_proc=args.data_preprocessing_num_proc,
)

# For memmap-style datasets
if args.dataset.endswith(".bin"):
dataset = MemmapDataset(args.dataset, args.ctx_len, args.max_examples)
else:
# For Huggingface datasets
try:
dataset = load_dataset(
args.dataset,
split=args.split,
# TODO: Maybe set this to False by default? But RPJ requires it.
trust_remote_code=True,
)
except ValueError as e:
# Automatically use load_from_disk if appropriate
if "load_from_disk" in str(e):
dataset = Dataset.load_from_disk(args.dataset, keep_in_memory=False)
else:
raise e

assert isinstance(dataset, Dataset)
if "input_ids" not in dataset.column_names:
tokenizer = AutoTokenizer.from_pretrained(args.model, token=args.hf_token)
dataset = chunk_and_tokenize(
dataset,
tokenizer,
max_seq_len=args.ctx_len,
num_proc=args.data_preprocessing_num_proc,
)
else:
print("Dataset already tokenized; skipping tokenization.")

print(f"Shuffling dataset with seed {args.seed}")
dataset = dataset.shuffle(args.seed)

dataset = dataset.with_format("torch")
if limit := args.max_examples:
dataset = dataset.select(range(limit))
if args.max_examples:
dataset = dataset.select(range(args.max_examples))

return model, dataset

@@ -129,7 +96,6 @@ def run():

args = parse(RunConfig)

# Awkward hack to prevent other ranks from duplicating data preprocessing
if not ddp or rank == 0:
model, dataset = load_artifacts(args, rank)
if ddp:
@@ -138,16 +104,14 @@ def run():
model, dataset = load_artifacts(args, rank)
dataset = dataset.shard(dist.get_world_size(), rank)

# Prevent ranks other than 0 from printing
with nullcontext() if rank == 0 else redirect_stdout(None):
print(f"Training on '{args.dataset}' (split '{args.split}')")
print(f"Storing model weights in {model.dtype}")
print(f"Training on '{args.dataset}'")
print(f"Storing model weights in {model.dtype}")

trainer = SaeTrainer(args, dataset, model)
if args.resume:
trainer.load_state(args.run_name or "sae-ckpts")
trainer = SaeTrainer(args, dataset, model)
if args.resume:
trainer.load_state(args.run_name or "sae-ckpts")

trainer.fit()
trainer.fit()


if __name__ == "__main__":
17 changes: 17 additions & 0 deletions sae/launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
#SBATCH -A CSC605
#SBATCH -J sae-test
#SBATCH -o %x-%j.out
#SBATCH -t 00:30:00
#SBATCH -p batch
#SBATCH -N 1

module load rocm/6.1.3

cd /lustre/orion/csc605/scratch/george-adams/sae
source activate /lustre/orion/csc605/scratch/george-adams/conda_envs/transformers_env

export MASTER_IP=`ip -f inet addr show hsn0 | sed -En -e 's/.*inet ([0-9.]+).*/\1/p' | head -1`

torchrun --nproc_per_node=8 -m sae --model EleutherAI/pythia-160m --dataset /lustre/orion/csc605/scratch/george-adams/sae/test.hf \
--batch_size 16 --ctx_len 2048 --cache_dir /lustre/orion/csc605/scratch/george-adams/cache_dir_test