-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX exp scripts #1
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from transformers import RobertaConfig | ||
|
||
# config = RobertaConfig.from_pretrained("roberta-base", vocab_size=50265) | ||
config = RobertaConfig.from_pretrained("klue/roberta-small", vocab_size=50265) | ||
|
||
|
||
config.save_pretrained("./roberta-base") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
export TRAINING_DIR="/fsx/erincho/examples/flax/language-modeling" | ||
export ARTIFACTS_DIR="/fsx/erincho/examples/flax/language-modeling/roberta-base" | ||
#export TRAINING_DIR="./" | ||
|
||
python ${TRAINING_DIR}/run_mlm_flax.py \ | ||
--output_dir=$ARTIFACTS_DIR \ | ||
--model_type="roberta" \ | ||
--config_name=$ARTIFACTS_DIR \ | ||
--dataset_name="oscar" \ | ||
--dataset_config_name="unshuffled_deduplicated_no" \ | ||
--max_seq_length="128" \ | ||
--weight_decay="0.01" \ | ||
--per_device_train_batch_size="32" \ | ||
--per_device_eval_batch_size="32" \ | ||
--learning_rate="3e-4" \ | ||
--warmup_steps="1000" \ | ||
--overwrite_output_dir \ | ||
--num_train_epochs="18" \ | ||
--adam_beta1="0.9" \ | ||
--adam_beta2="0.98" \ | ||
--logging_steps="10" \ | ||
--save_steps="2500" \ | ||
--eval_steps="2500" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
{ | ||
"architectures": [ | ||
"RobertaForMaskedLM" | ||
], | ||
"attention_probs_dropout_prob": 0.1, | ||
"bos_token_id": 0, | ||
"classifier_dropout": null, | ||
"eos_token_id": 2, | ||
"gradient_checkpointing": false, | ||
"hidden_act": "gelu", | ||
"hidden_dropout_prob": 0.1, | ||
"hidden_size": 768, | ||
"initializer_range": 0.02, | ||
"intermediate_size": 3072, | ||
"layer_norm_eps": 1e-05, | ||
"max_position_embeddings": 514, | ||
"model_type": "roberta", | ||
"num_attention_heads": 12, | ||
"num_hidden_layers": 6, | ||
"pad_token_id": 1, | ||
"position_embedding_type": "absolute", | ||
"tokenizer_class": "BertTokenizer", | ||
"transformers_version": "4.26.0.dev0", | ||
"type_vocab_size": 1, | ||
"use_cache": true, | ||
"vocab_size": 50265 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,6 +66,11 @@ | |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) | ||
|
||
|
||
def compute_num_params(model): | ||
"""Get num params.""" | ||
# https://github.com/google/jax/discussions/6153 | ||
return sum(x.size for x in jax.tree_leaves(model.params)) | ||
|
||
@dataclass | ||
class TrainingArguments: | ||
output_dir: str = field( | ||
|
@@ -507,9 +512,11 @@ def main(): | |
config = CONFIG_MAPPING[model_args.model_type]() | ||
logger.warning("You are instantiating a new config instance from scratch.") | ||
|
||
if model_args.tokenizer_name: | ||
if model_args.tokenizer_name or True: | ||
logging.info("use default tokenizer") | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
model_args.tokenizer_name, | ||
# model_args.tokenizer_name, | ||
"klue/roberta-small", | ||
cache_dir=model_args.cache_dir, | ||
use_fast=model_args.use_fast_tokenizer, | ||
use_auth_token=True if model_args.use_auth_token else None, | ||
|
@@ -527,7 +534,10 @@ def main(): | |
"You can do it from another script, save it, and load it from here, using --tokenizer_name." | ||
) | ||
|
||
# Preprocessing the datasets. | ||
# Only tokenize a subset of data as we don't need to run til converge | ||
datasets["train"] = datasets["train"].select(range(10000)) | ||
datasets["validation"] = datasets["validation"].select(range(1000)) | ||
|
||
# First we tokenize all the texts. | ||
if training_args.do_train: | ||
column_names = datasets["train"].column_names | ||
|
@@ -631,6 +641,7 @@ def group_texts(examples): | |
# Initialize our training | ||
rng = jax.random.PRNGKey(training_args.seed) | ||
dropout_rngs = jax.random.split(rng, jax.local_device_count()) | ||
logger.info("local_device_count:", jax.local_device_count()) | ||
|
||
if model_args.model_name_or_path: | ||
model = FlaxAutoModelForMaskedLM.from_pretrained( | ||
|
@@ -647,16 +658,20 @@ def group_texts(examples): | |
dtype=getattr(jnp, model_args.dtype), | ||
) | ||
|
||
num_params = compute_num_params(model) | ||
|
||
if training_args.gradient_checkpointing: | ||
model.enable_gradient_checkpointing() | ||
|
||
# Store some constant | ||
num_epochs = int(training_args.num_train_epochs) | ||
# GBS | ||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() | ||
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) | ||
eval_batch_size = per_device_eval_batch_size * jax.device_count() | ||
|
||
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs | ||
logger.info("per_device_train_batch_size=%s\n, train_batch_size=%s\n, eval_batch_size=%s" % (training_args.per_device_train_batch_size, train_batch_size, eval_batch_size)) | ||
|
||
# Create learning rate schedule | ||
warmup_fn = optax.linear_schedule( | ||
|
@@ -773,6 +788,7 @@ def eval_step(params, batch): | |
# Replicate the train state on each device | ||
state = jax_utils.replicate(state) | ||
|
||
start = time.time() | ||
train_time = 0 | ||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) | ||
for epoch in epochs: | ||
|
@@ -792,11 +808,25 @@ def eval_step(params, batch): | |
# Gather the indexes for creating the batch and do a training step | ||
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): | ||
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] | ||
model_inputs = data_collator(samples, pad_to_multiple_of=16) | ||
|
||
# (GBS, seq_lenth) | ||
model_inputs = data_collator(samples, pad_to_multiple_of=16) | ||
step_start = time.time() | ||
# Model forward | ||
# (# of GPU, per GPU BS, seq_length) | ||
model_inputs = shard(model_inputs.data) | ||
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) | ||
|
||
# state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) | ||
# Use this line for benchmark | ||
state, train_metric, dropout_rngs = jax.block_until_ready(p_train_step(state, model_inputs, dropout_rngs)) | ||
|
||
# calculate throughput | ||
time_elapsed = time.time() - start | ||
step_time = time.time() - step_start | ||
sample_processed = len(samples) # same as GBS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To my understanding, the
|
||
throughput = sample_processed / step_time # block_until_ready? | ||
tokens_per_gpu = model_inputs["input_ids"].shape[1] * model_inputs["input_ids"].shape[2] | ||
|
||
train_metrics.append(train_metric) | ||
|
||
cur_step = epoch * (num_train_samples // train_batch_size) + step | ||
|
@@ -815,6 +845,14 @@ def eval_step(params, batch): | |
|
||
train_metrics = [] | ||
|
||
# log throughput | ||
# Based on the formula in https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/ | ||
tflops_per_gpu = 8 * num_params * tokens_per_gpu / step_time / 1e12 | ||
logger.info("(%ds), Batch %d Loss: %s, Speed: %s samples/sec, TFLOPS/GPU: %s" % ( | ||
int(time_elapsed), step, train_metric['loss'], | ||
throughput, tflops_per_gpu)) | ||
|
||
|
||
if cur_step % training_args.eval_steps == 0 and cur_step > 0: | ||
# ======================== Evaluating ============================== | ||
num_eval_samples = len(tokenized_datasets["validation"]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from datasets import load_dataset | ||
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer | ||
|
||
# load dataset | ||
print("loading dataset... ") | ||
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train") | ||
|
||
print("done..") | ||
# Instantiate tokenizer | ||
tokenizer = ByteLevelBPETokenizer() | ||
|
||
def batch_iterator(batch_size=1000): | ||
for i in range(0, len(dataset), batch_size): | ||
yield dataset[i: i + batch_size]["text"] | ||
|
||
# Customized training | ||
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[ | ||
"<s>", | ||
"<pad>", | ||
"</s>", | ||
"<unk>", | ||
"<mask>", | ||
]) | ||
|
||
# Save files to disk | ||
tokenizer.save("./roberta-base/tokenizer.json") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/bin/bash | ||
|
||
num_nodes=${1:-2} | ||
|
||
SMP_USER=${2:-"erincho"} | ||
CONTAINER_NAME=${3:-"smp"} | ||
SOURCE_CODE_USER=${4:-"$SMP_USER"} | ||
|
||
|
||
set -ex | ||
# build mpi command | ||
smprun $SMP_USER -n $num_nodes -v --mpi-path /opt/amazon/openmpi/bin/mpirun --notify-exit \ | ||
-c $CONTAINER_NAME \ | ||
-d /fsx/${SOURCE_CODE_USER}/examples/flax/playground/multi-node/ \ | ||
-x NCCL_DEBUG=INFO -x NCCL_PROTO=simple \ | ||
/opt/conda/bin/python run.py \ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
# no need params if using openmpi | ||
jax.distributed.initialize() | ||
|
||
print("total devices: %s, devices per task: %s" % (jax.device_count(), jax.local_device_count())) | ||
|
||
xs = jnp.ones(jax.local_device_count()) | ||
|
||
# Computes a reduction (sum) across all devices of x | ||
# and broadcast the result, in y, to all devices. | ||
# If x=[1] on all devices and we have 16 devices, | ||
# the result is y=[16] on all devices. | ||
|
||
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(xs) | ||
print(y) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# script.py | ||
import jax | ||
from jax.experimental.pjit import pjit | ||
from jax.experimental.maps import Mesh | ||
from jax.experimental import PartitionSpec as P | ||
import numpy as np | ||
import argparse | ||
import timeit | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-m', '--mode', type=str, choices=['pmap', 'pjit'], default='pmap') | ||
args = parser.parse_args() | ||
|
||
|
||
# Init data | ||
x = np.random.randn(32, 1024).astype(np.float32) | ||
W = np.random.randn(1024, 8).astype(np.float32) | ||
|
||
|
||
def step(x, W): | ||
return jax.lax.dot(x, W) | ||
|
||
|
||
# Compute pmap or pjit functions | ||
# Preload batch data and model parameters onto the devices as ShardedDeviceArrays | ||
if args.mode == 'pmap': | ||
p_step = jax.pmap(step, axis_name='batch') | ||
print("pmap mode. backend=%s, device_count=%s, local_device_count=%s" % (jax.lib.xla_bridge.get_backend(), | ||
jax.device_count(), | ||
jax.local_device_count())) | ||
x = np.reshape(x, (jax.local_device_count(), -1, x.shape[1])) | ||
print("x shape:", x.shape) | ||
|
||
# Gets correct device order that matches pmap | ||
devices = jax.lib.xla_bridge.get_backend().get_default_device_assignment(jax.device_count()) | ||
x = jax.device_put_sharded(list(x), devices) | ||
W = jax.device_put_replicated(W, devices) | ||
else: | ||
# ===== DP only ===== | ||
mesh = Mesh(np.asarray(jax.devices(), dtype=object).reshape(jax.local_device_count(), ), ['dp']) | ||
jax.experimental.maps.thread_resources.env = ( | ||
jax.experimental.maps.ResourceEnv(physical_mesh=mesh, loops=()) | ||
) | ||
p_step = pjit(step, in_axis_resources=(P('dp'), None), out_axis_resources=P('dp')) | ||
|
||
# Map batch and weights to devices | ||
p_init = pjit(lambda x, W: (x, W), in_axis_resources=(P('dp'), None), out_axis_resources=(P('dp'), None)) | ||
x, W = p_init(x, W) | ||
|
||
# ===== DP & MP ===== | ||
# mesh = Mesh(np.asarray(jax.devices(), dtype=object).reshape(jax.local_device_count(), 1), ['dp', 'mp']) | ||
# jax.experimental.maps.thread_resources.env = ( | ||
# jax.experimental.maps.ResourceEnv(physical_mesh=mesh, loops=()) | ||
# ) | ||
# p_step = pjit(step, in_axis_resources=(P('dp'), P('mp', None)), out_axis_resources=P('dp')) | ||
# | ||
# # Map batch and weights to devices | ||
# p_init = pjit(lambda x, W: (x, W), in_axis_resources=(P('dp'), P('mp', None)), out_axis_resources=(P('dp'), P('mp', None))) | ||
# x, W = p_init(x, W) | ||
|
||
# Warmup for initial compilation | ||
p_step(x, W).block_until_ready() | ||
|
||
# Time | ||
iterations = 1000 | ||
avg = timeit.timeit(lambda: p_step(x, W).block_until_ready(), number=iterations) / iterations | ||
print('Estimated Time:', avg, 'per itr') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
export TRAINING_DIR="/fsx/erincho/examples/pytorch/language-modeling" | ||
export ARTIFACTS_DIR="/fsx/erincho/examples/pytorch/language-modeling/roberta-base" | ||
#export TRAINING_DIR="./" | ||
|
||
export NUM_GPUS=1 | ||
export TOKENIZERS_PARALLELISM=0 | ||
export MODEL_DIR="./roberta-base" | ||
export MASTER_ADDR="compute-st-worker-60" | ||
#mkdir -p ${MODEL_DIR} | ||
|
||
python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} \ | ||
--rdzv_endpoint=$MASTER_ADDR:29400 \ | ||
--rdzv_id=100 \ | ||
--rdzv_backend=c10d ${TRAINING_DIR}/run_mlm_no_trainer.py \ | ||
--output_dir=$ARTIFACTS_DIR \ | ||
--model_type="roberta" \ | ||
--config_name=$ARTIFACTS_DIR \ | ||
--tokenizer_name="${MODEL_DIR}" \ | ||
--dataset_name="oscar" \ | ||
--dataset_config_name="unshuffled_deduplicated_no" \ | ||
--max_seq_length="128" \ | ||
--weight_decay="0.01" \ | ||
--per_device_train_batch_size="160" \ | ||
--per_device_eval_batch_size="160" \ | ||
--gradient_accumulation="4" \ | ||
--learning_rate="3e-4" \ | ||
--num_warmup_steps="1000" \ | ||
--num_train_epochs="18" \ | ||
--logging_steps="10" \ | ||
--seed=42 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pmapped functions in JAX run asynchronously, so we need to call block_until_ready to make sure a particular computation has actually finished.