Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX exp scripts #1

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

JAX exp scripts #1

wants to merge 2 commits into from

Conversation

hchings
Copy link
Owner

@hchings hchings commented Feb 7, 2023

tmp PR for easier code review


# state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
# Use this line for benchmark
state, train_metric, dropout_rngs = jax.block_until_ready(p_train_step(state, model_inputs, dropout_rngs))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pmapped functions in JAX run asynchronously, so we need to call block_until_ready to make sure a particular computation has actually finished.

# calculate throughput
time_elapsed = time.time() - start
step_time = time.time() - step_start
sample_processed = len(samples) # same as GBS
Copy link
Owner Author

@hchings hchings Feb 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my understanding, the len(samples) here is already all the samples processed across all GPUs based on how JAX works (L817 shards all data across 8 GPUs, and then each process “sees” local input and output in parallelized functions). So we don't have to times dp_size as R/H script does:

sample_processed = input_ids.shape[0] * dp_size

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant