-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[checkpointio]support asyncio for all models #6152
base: main
Are you sure you want to change the base?
Conversation
ce61035
to
7fc7c60
Compare
772f164
to
c30bede
Compare
7b83bb5
to
8b1f649
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
b77ebe9
to
83dc8be
Compare
@@ -216,15 +244,31 @@ def save_sharded_model( | |||
|
|||
# Then collect the sharded parameters & buffers along tp_group. | |||
# Only devices with tp_rank == 0 are responsible for model saving. | |||
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) | |||
control_saving = self.tp_rank == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and sp_rank == 0?
# exception for fsdp, part[1] isn't param_id | ||
idx = parts[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert this now?
plugin = TorchDDPPlugin() | ||
booster = Booster(plugin=plugin) | ||
model = resnet18() | ||
criterion = lambda x: x.mean() | ||
optimizer = SGD((model.parameters()), lr=0.001) | ||
optimizer = SGD((model.parameters()), lr=0.001, momentum=0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep momentum=0 to test corner case
📌 Checklist before creating the PR
[doc/gemini/tensor/...]: A concise description
pip install pre-commit && pre-commit install
🚨 Issue number
📝 What does this PR do?
💥 Checklist before requesting a review
⭐️ Do you enjoy contributing to Colossal-AI?
Tell us more if you don't enjoy contributing to Colossal-AI.