Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX exp scripts #1

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/flax/language-modeling/gen_model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from transformers import RobertaConfig

# config = RobertaConfig.from_pretrained("roberta-base", vocab_size=50265)
config = RobertaConfig.from_pretrained("klue/roberta-small", vocab_size=50265)


config.save_pretrained("./roberta-base")
23 changes: 23 additions & 0 deletions examples/flax/language-modeling/launch_script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
export TRAINING_DIR="/fsx/erincho/examples/flax/language-modeling"
export ARTIFACTS_DIR="/fsx/erincho/examples/flax/language-modeling/roberta-base"
#export TRAINING_DIR="./"

python ${TRAINING_DIR}/run_mlm_flax.py \
--output_dir=$ARTIFACTS_DIR \
--model_type="roberta" \
--config_name=$ARTIFACTS_DIR \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="128" \
--weight_decay="0.01" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--learning_rate="3e-4" \
--warmup_steps="1000" \
--overwrite_output_dir \
--num_train_epochs="18" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--logging_steps="10" \
--save_steps="2500" \
--eval_steps="2500"
27 changes: 27 additions & 0 deletions examples/flax/language-modeling/roberta-base/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"architectures": [
"RobertaForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"bos_token_id": 0,
"classifier_dropout": null,
"eos_token_id": 2,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 514,
"model_type": "roberta",
"num_attention_heads": 12,
"num_hidden_layers": 6,
"pad_token_id": 1,
"position_embedding_type": "absolute",
"tokenizer_class": "BertTokenizer",
"transformers_version": "4.26.0.dev0",
"type_vocab_size": 1,
"use_cache": true,
"vocab_size": 50265
}
48 changes: 43 additions & 5 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


def compute_num_params(model):
"""Get num params."""
# https://github.com/google/jax/discussions/6153
return sum(x.size for x in jax.tree_leaves(model.params))

@dataclass
class TrainingArguments:
output_dir: str = field(
Expand Down Expand Up @@ -507,9 +512,11 @@ def main():
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")

if model_args.tokenizer_name:
if model_args.tokenizer_name or True:
logging.info("use default tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name,
# model_args.tokenizer_name,
"klue/roberta-small",
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
use_auth_token=True if model_args.use_auth_token else None,
Expand All @@ -527,7 +534,10 @@ def main():
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)

# Preprocessing the datasets.
# Only tokenize a subset of data as we don't need to run til converge
datasets["train"] = datasets["train"].select(range(10000))
datasets["validation"] = datasets["validation"].select(range(1000))

# First we tokenize all the texts.
if training_args.do_train:
column_names = datasets["train"].column_names
Expand Down Expand Up @@ -631,6 +641,7 @@ def group_texts(examples):
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
logger.info("local_device_count:", jax.local_device_count())

if model_args.model_name_or_path:
model = FlaxAutoModelForMaskedLM.from_pretrained(
Expand All @@ -647,16 +658,20 @@ def group_texts(examples):
dtype=getattr(jnp, model_args.dtype),
)

num_params = compute_num_params(model)

if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()

# Store some constant
num_epochs = int(training_args.num_train_epochs)
# GBS
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()

num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
logger.info("per_device_train_batch_size=%s\n, train_batch_size=%s\n, eval_batch_size=%s" % (training_args.per_device_train_batch_size, train_batch_size, eval_batch_size))

# Create learning rate schedule
warmup_fn = optax.linear_schedule(
Expand Down Expand Up @@ -773,6 +788,7 @@ def eval_step(params, batch):
# Replicate the train state on each device
state = jax_utils.replicate(state)

start = time.time()
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
Expand All @@ -792,11 +808,25 @@ def eval_step(params, batch):
# Gather the indexes for creating the batch and do a training step
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)

# (GBS, seq_lenth)
model_inputs = data_collator(samples, pad_to_multiple_of=16)
step_start = time.time()
# Model forward
# (# of GPU, per GPU BS, seq_length)
model_inputs = shard(model_inputs.data)
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)

# state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
# Use this line for benchmark
state, train_metric, dropout_rngs = jax.block_until_ready(p_train_step(state, model_inputs, dropout_rngs))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pmapped functions in JAX run asynchronously, so we need to call block_until_ready to make sure a particular computation has actually finished.


# calculate throughput
time_elapsed = time.time() - start
step_time = time.time() - step_start
sample_processed = len(samples) # same as GBS
Copy link
Owner Author

@hchings hchings Feb 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my understanding, the len(samples) here is already all the samples processed across all GPUs based on how JAX works (L817 shards all data across 8 GPUs, and then each process “sees” local input and output in parallelized functions). So we don't have to times dp_size as R/H script does:

sample_processed = input_ids.shape[0] * dp_size

throughput = sample_processed / step_time # block_until_ready?
tokens_per_gpu = model_inputs["input_ids"].shape[1] * model_inputs["input_ids"].shape[2]

train_metrics.append(train_metric)

cur_step = epoch * (num_train_samples // train_batch_size) + step
Expand All @@ -815,6 +845,14 @@ def eval_step(params, batch):

train_metrics = []

# log throughput
# Based on the formula in https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/
tflops_per_gpu = 8 * num_params * tokens_per_gpu / step_time / 1e12
logger.info("(%ds), Batch %d Loss: %s, Speed: %s samples/sec, TFLOPS/GPU: %s" % (
int(time_elapsed), step, train_metric['loss'],
throughput, tflops_per_gpu))


if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"])
Expand Down
26 changes: 26 additions & 0 deletions examples/flax/language-modeling/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer

# load dataset
print("loading dataset... ")
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")

print("done..")
# Instantiate tokenizer
tokenizer = ByteLevelBPETokenizer()

def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i: i + batch_size]["text"]

# Customized training
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
"<s>",
"<pad>",
"</s>",
"<unk>",
"<mask>",
])

# Save files to disk
tokenizer.save("./roberta-base/tokenizer.json")
16 changes: 16 additions & 0 deletions examples/flax/playground/multi-node/launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

num_nodes=${1:-2}

SMP_USER=${2:-"erincho"}
CONTAINER_NAME=${3:-"smp"}
SOURCE_CODE_USER=${4:-"$SMP_USER"}


set -ex
# build mpi command
smprun $SMP_USER -n $num_nodes -v --mpi-path /opt/amazon/openmpi/bin/mpirun --notify-exit \
-c $CONTAINER_NAME \
-d /fsx/${SOURCE_CODE_USER}/examples/flax/playground/multi-node/ \
-x NCCL_DEBUG=INFO -x NCCL_PROTO=simple \
/opt/conda/bin/python run.py \
17 changes: 17 additions & 0 deletions examples/flax/playground/multi-node/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import jax
import jax.numpy as jnp

# no need params if using openmpi
jax.distributed.initialize()

print("total devices: %s, devices per task: %s" % (jax.device_count(), jax.local_device_count()))

xs = jnp.ones(jax.local_device_count())

# Computes a reduction (sum) across all devices of x
# and broadcast the result, in y, to all devices.
# If x=[1] on all devices and we have 16 devices,
# the result is y=[16] on all devices.

y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(xs)
print(y)
68 changes: 68 additions & 0 deletions examples/flax/playground/pjit-pmap/exp_speed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# script.py
import jax
from jax.experimental.pjit import pjit
from jax.experimental.maps import Mesh
from jax.experimental import PartitionSpec as P
import numpy as np
import argparse
import timeit


parser = argparse.ArgumentParser()
parser.add_argument('-m', '--mode', type=str, choices=['pmap', 'pjit'], default='pmap')
args = parser.parse_args()


# Init data
x = np.random.randn(32, 1024).astype(np.float32)
W = np.random.randn(1024, 8).astype(np.float32)


def step(x, W):
return jax.lax.dot(x, W)


# Compute pmap or pjit functions
# Preload batch data and model parameters onto the devices as ShardedDeviceArrays
if args.mode == 'pmap':
p_step = jax.pmap(step, axis_name='batch')
print("pmap mode. backend=%s, device_count=%s, local_device_count=%s" % (jax.lib.xla_bridge.get_backend(),
jax.device_count(),
jax.local_device_count()))
x = np.reshape(x, (jax.local_device_count(), -1, x.shape[1]))
print("x shape:", x.shape)

# Gets correct device order that matches pmap
devices = jax.lib.xla_bridge.get_backend().get_default_device_assignment(jax.device_count())
x = jax.device_put_sharded(list(x), devices)
W = jax.device_put_replicated(W, devices)
else:
# ===== DP only =====
mesh = Mesh(np.asarray(jax.devices(), dtype=object).reshape(jax.local_device_count(), ), ['dp'])
jax.experimental.maps.thread_resources.env = (
jax.experimental.maps.ResourceEnv(physical_mesh=mesh, loops=())
)
p_step = pjit(step, in_axis_resources=(P('dp'), None), out_axis_resources=P('dp'))

# Map batch and weights to devices
p_init = pjit(lambda x, W: (x, W), in_axis_resources=(P('dp'), None), out_axis_resources=(P('dp'), None))
x, W = p_init(x, W)

# ===== DP & MP =====
# mesh = Mesh(np.asarray(jax.devices(), dtype=object).reshape(jax.local_device_count(), 1), ['dp', 'mp'])
# jax.experimental.maps.thread_resources.env = (
# jax.experimental.maps.ResourceEnv(physical_mesh=mesh, loops=())
# )
# p_step = pjit(step, in_axis_resources=(P('dp'), P('mp', None)), out_axis_resources=P('dp'))
#
# # Map batch and weights to devices
# p_init = pjit(lambda x, W: (x, W), in_axis_resources=(P('dp'), P('mp', None)), out_axis_resources=(P('dp'), P('mp', None)))
# x, W = p_init(x, W)

# Warmup for initial compilation
p_step(x, W).block_until_ready()

# Time
iterations = 1000
avg = timeit.timeit(lambda: p_step(x, W).block_until_ready(), number=iterations) / iterations
print('Estimated Time:', avg, 'per itr')
1 change: 1 addition & 0 deletions examples/flax/playground/pjit-pmap/exp_speed_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def print_model_size(params, name=''):
print_model_size(variables)

def step(x, variables):
# shaped array
return model.apply(variables, x)

# Compute pmap or pjit functions
Expand Down
30 changes: 30 additions & 0 deletions examples/pytorch/language-modeling/launch_script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
export TRAINING_DIR="/fsx/erincho/examples/pytorch/language-modeling"
export ARTIFACTS_DIR="/fsx/erincho/examples/pytorch/language-modeling/roberta-base"
#export TRAINING_DIR="./"

export NUM_GPUS=1
export TOKENIZERS_PARALLELISM=0
export MODEL_DIR="./roberta-base"
export MASTER_ADDR="compute-st-worker-60"
#mkdir -p ${MODEL_DIR}

python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} \
--rdzv_endpoint=$MASTER_ADDR:29400 \
--rdzv_id=100 \
--rdzv_backend=c10d ${TRAINING_DIR}/run_mlm_no_trainer.py \
--output_dir=$ARTIFACTS_DIR \
--model_type="roberta" \
--config_name=$ARTIFACTS_DIR \
--tokenizer_name="${MODEL_DIR}" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="128" \
--weight_decay="0.01" \
--per_device_train_batch_size="160" \
--per_device_eval_batch_size="160" \
--gradient_accumulation="4" \
--learning_rate="3e-4" \
--num_warmup_steps="1000" \
--num_train_epochs="18" \
--logging_steps="10" \
--seed=42
Loading