-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add example for OPT model with distribution. #1727
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
support in the coming future. | ||
""" | ||
|
||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please group the imports at the top.
print(keras.version()) | ||
print(keras.backend.backend()) | ||
|
||
keras.mixed_precision.set_global_policy("mixed_float16") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment about using mixed precision.
count other items like optimizer states, as well as forward and backward path. | ||
""" | ||
# model_spec = 'opt_6.7b_en' | ||
# langauge_model = create_opt_model(model_spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be uncomented?
|
||
# Create a 2D mesh for model parallel, change the mesh shape to tune the | ||
# ratio of data/model parallelism | ||
_BATCH_DIM_NAME = "batch" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for leading underscores
generate function with XLA. The follow up runs will be much faster. | ||
""" | ||
prompt = "What is machine learning?" | ||
print(large_model.generate(prompt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a second prompt, possible with some time()
calls, to demonstrate the regular (post compilation) step time
|
||
|
||
""" | ||
## Introduction to KerasNLP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest focusing the example purely on the distribution aspects, so we can replace the KerasNLP intro with ~1 sentence. Meanwhile maybe we could flesh out the distribution part, e.g. include fine-tuning or other inference performance considerations.
@fchollet, this is the draft of the OPT model inferencing with Keras distribution API (I can add the finetune parts later). Do u still have the instructions to convert the py example to colab/MD file?
This one will require 8 V100 GPU to simulate, and I think we should have A100 to properly simulate a finetuning workflow