-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enabling data-parallel multi-GPU training (#1188)
* First pass over multi-GPU * Multi-gpu test passes now locally (without metric calculations) * Introducing MultiLoader & add auto-initialization of Model * Enable multi-GPU with metric-calculations * Remove un-used to method in ModelOutput * fix test for cpu * use multigpu marker * automatically reparition if repartition is not provided * test rank * Add comment for follow up tasks * lint * fix test for cpu --------- Co-authored-by: edknv <[email protected]> Co-authored-by: edknv <[email protected]>
- Loading branch information
1 parent
c5afbd1
commit 145e592
Showing
12 changed files
with
269 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,5 +62,7 @@ markers = [ | |
"integration", | ||
"unit", | ||
"changed", | ||
"unchanged" | ||
"unchanged", | ||
"singlegpu", | ||
"multigpu" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import pytest | ||
import pytorch_lightning as pl | ||
|
||
import merlin.models.torch as mm | ||
|
||
|
||
# TODO: This test is not complete because Lightning launches separate processes | ||
# under the hood with the correct environment variables like `LOCAL_RANK`, but | ||
# the pytest stays in the main process and tests only the LOCAL_RANK=0 case. | ||
# Follow-up with proper test that ensures dataloader is working properly with | ||
# e.g. global_rank > 0. | ||
class TestMultiGPU: | ||
@pytest.mark.multigpu | ||
def test_multi_gpu(self, music_streaming_data): | ||
schema = music_streaming_data.schema | ||
data = music_streaming_data | ||
model = mm.Model( | ||
mm.TabularInputBlock(schema, init="defaults"), | ||
mm.MLPBlock([5]), | ||
mm.BinaryOutput(schema["click"]), | ||
) | ||
|
||
trainer = pl.Trainer(max_epochs=3, devices=2) | ||
multi_loader = mm.MultiLoader(data, batch_size=2) | ||
trainer.fit(model, multi_loader) | ||
|
||
# 100 rows total / 2 devices -> 50 rows per device | ||
# 50 rows / 2 per batch -> 25 steps per device | ||
assert trainer.num_training_batches == 25 | ||
|
||
assert trainer.global_rank == 0 # This should fail for node 1. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.