Skip to content

Commit

Permalink
Enabling data-parallel multi-GPU training (#1188)
Browse files Browse the repository at this point in the history
* First pass over multi-GPU

* Multi-gpu test passes now locally (without metric calculations)

* Introducing MultiLoader & add auto-initialization of Model

* Enable multi-GPU with metric-calculations

* Remove un-used to method in ModelOutput

* fix test for cpu

* use multigpu marker

* automatically reparition if repartition is not provided

* test rank

* Add comment for follow up tasks

* lint

* fix test for cpu

---------

Co-authored-by: edknv <[email protected]>
Co-authored-by: edknv <[email protected]>
  • Loading branch information
3 people authored Jul 10, 2023
1 parent c5afbd1 commit 145e592
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cpu-horovod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ jobs:
if [[ "${{ github.ref }}" != 'refs/heads/main' ]]; then
extra_pytest_markers="and changed"
fi
EXTRA_PYTEST_MARKERS="$extra_pytest_markers" MERLIN_BRANCH="$merlin_branch" COMPARE_BRANCH=${{ github.base_ref }} tox -e horovod-cpu
PYTEST_MARKERS="$extra_pytest_markers" MERLIN_BRANCH="$merlin_branch" COMPARE_BRANCH=${{ github.base_ref }} tox -e horovod-cpu
45 changes: 44 additions & 1 deletion .github/workflows/gpu-multi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,47 @@ jobs:
if [[ "${{ github.ref }}" != 'refs/heads/main' ]]; then
extra_pytest_markers="and changed"
fi
cd ${{ github.workspace }}; EXTRA_PYTEST_MARKERS=$extra_pytest_markers MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e multi-gpu
cd ${{ github.workspace }}; PYTEST_MARKERS="multigpu $extra_pytest_markers" MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e gpu,horovod-gpu
check-changes-torch:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install GitPython
pip install . --no-deps
- name: Get changed backends
id: backend_check
run: |
echo "changed=$(python ci/get_changed_backends.py --backend torch --branch ${{github.base_ref}})" >> "$GITHUB_OUTPUT"
outputs:
needs_testing: ${{ steps.backend_check.outputs.changed }}

torch:
needs: check-changes-torch
if: ${{needs.check-changes-torch.outputs.needs_testing == 'true' || github.ref == 'refs/heads/main'}}
runs-on: 2GPU

steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Run tests
run: |
ref_type=${{ github.ref_type }}
branch=main
if [[ $ref_type == "tag"* ]]
then
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
if [[ "${{ github.ref }}" != 'refs/heads/main' ]]; then
extra_pytest_markers="and changed"
fi
cd ${{ github.workspace }}; PYTEST_MARKERS="multigpu $extra_pytest_markers" MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e gpu
2 changes: 1 addition & 1 deletion .github/workflows/gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
if [[ "${{ github.ref }}" != 'refs/heads/main' ]]; then
extra_pytest_markers="and changed"
fi
cd ${{ github.workspace }}; PYTEST_MARKERS="unit and not (examples or integration or notebook) $extra_pytest_markers" MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e gpu
cd ${{ github.workspace }}; PYTEST_MARKERS="unit and not (examples or integration or notebook) and (singlegpu or not multigpu) $extra_pytest_markers" MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e gpu
tests-examples:
runs-on: 1GPU
Expand Down
3 changes: 2 additions & 1 deletion merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables
from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys
from merlin.models.torch.inputs.tabular import TabularInputBlock
from merlin.models.torch.models.base import Model
from merlin.models.torch.models.base import Model, MultiLoader
from merlin.models.torch.models.ranking import DCNModel, DLRMModel
from merlin.models.torch.outputs.base import ModelOutput
from merlin.models.torch.outputs.classification import (
Expand All @@ -48,6 +48,7 @@
"DLRMBlock",
"MLPBlock",
"Model",
"MultiLoader",
"EmbeddingTable",
"EmbeddingTables",
"ParallelBlock",
Expand Down
165 changes: 155 additions & 10 deletions merlin/models/torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from typing import Dict, List, Optional, Sequence, Union

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning import LightningDataModule, LightningModule
from torch import nn

from merlin.dataloader.torch import Loader
Expand Down Expand Up @@ -53,16 +54,10 @@ class Model(LightningModule, Block):
... BinaryOutput(schema.select_by_tag(Tags.TARGET).first),
... )
... trainer = Trainer(max_epochs=1)
... with Loader(dataset, batch_size=16) as loader:
... model.initialize(loader)
... trainer.fit(model, loader)
... trainer.fit(model, Loader(dataset, batch_size=16))
"""

def __init__(
self,
*blocks: nn.Module,
optimizer=torch.optim.Adam,
):
def __init__(self, *blocks: nn.Module, optimizer=torch.optim.Adam, initialization="auto"):
super().__init__()

# Copied from BlockContainer.__init__
Expand All @@ -71,6 +66,7 @@ def __init__(
self.values.append(self.wrap_module(module))

self.optimizer = optimizer
self.initialization = initialization

def initialize(self, data: Union[Dataset, Loader, Batch]):
"""Initializes the model based on a given data set."""
Expand All @@ -96,7 +92,9 @@ def training_step(self, batch, batch_idx):

predictions = self(features, batch=Batch(features, targets))

loss_and_metrics = compute_loss(predictions, targets, self.model_outputs())
loss_and_metrics = compute_loss(
predictions, targets, self.model_outputs(), compute_metrics=True
)
for name, value in loss_and_metrics.items():
self.log(f"train_{name}", value)

Expand Down Expand Up @@ -137,6 +135,150 @@ def last(self) -> nn.Module:
"""Returns the last block in the model."""
return self.values[-1]

def setup(self, stage):
"""Initialize the model if `initialization="auto"`."""
if self.initialization == "auto":
loop = getattr(self.trainer, f"{stage}_loop")

data_instance = loop._data_source.instance
if isinstance(data_instance, MultiLoader):
self.initialize(data_instance.batch.to(None, device=self.device))
else:
dataloader = loop._data_source.dataloader()
if isinstance(dataloader, Loader):
self.initialize(dataloader)
else:
raise ValueError(
f"Can't auto-initialize from a non-merlin dataloader, got: {dataloader}",
"Please initialize the model manually with `model.initialize(batch)`",
)

def teardown(self, stage: str) -> None:
"""Teardown the data-loader after training."""
loop = getattr(self.trainer, f"{stage}_loop")
dataloader = loop._data_source.dataloader()
if isinstance(dataloader, Loader):
dataloader.stop()


class MultiLoader(LightningDataModule):
"""
Data Module for handling multiple types of data loaders. It facilitates the usage
of multiple datasets, as well as distributed training on multiple GPUs.
This class is particularly useful in scenarios where you have separate train,
validation and test datasets, and you want to use PyTorch Lightning's Trainer
which requires a single DataModule.
Parameters
----------
train : Union[Dataset, Loader]
Training dataset or data loader.
valid : Optional[Union[Dataset, Loader]], optional
Validation dataset or data loader, by default None
test : Optional[Union[Dataset, Loader]], optional
Test dataset or data loader, by default None
repartition : int, optional
Number of partitions to divide the dataset into, by default None
batch_size : int, optional
Number of data points per batch, by default 1024
Example usage for multi-GPU::
model = mm.Model(...)
train, valid = generate_data(...)
model.initialize(train)
trainer = pl.Trainer(max_epochs=5, devices=[0, 1])
trainer.fit(model, mm.MultiLoader(train, valid, batch_size=1024, repartition=4))
"""

def __init__(
self,
train: Union[Dataset, Loader],
valid: Optional[Union[Dataset, Loader]] = None,
test: Optional[Union[Dataset, Loader]] = None,
batch_size: int = 1024,
repartition: Optional[int] = None,
):
super().__init__()
self.repartition = repartition
self.train = train
self.batch_size = batch_size
self.batch = Batch.sample_from(train, batch_size=1, shuffle=False)
if valid:
self.val_dataloader = lambda: self._create_loader(valid, "valid")
if test:
self.test_dataloader = lambda: self._create_loader(test, "test")

def train_dataloader(self) -> Loader:
return self._create_loader(self.train, "train")

def _create_loader(self, data: Union[Dataset, Loader], name: str) -> Loader:
"""
Create a data loader with the right arguments.
Parameters
----------
data : Union[Dataset, Loader]
The input data, can be a dataset or data loader.
name : str
Name of the data loader.
Returns
-------
Loader
The created data loader.
"""

_dataset = data.dataset if isinstance(data, Loader) else data

has_world_size = "WORLD_SIZE" in os.environ

if self.repartition:
npartitions = self.repartition
elif has_world_size:
npartitions = int(os.environ["WORLD_SIZE"])
elif isinstance(data, Loader):
npartitions = data.global_size
else:
npartitions = None

if npartitions:
_dataset = _dataset.repartition(npartitions=npartitions)

if isinstance(data, Loader):
output = Loader(
_dataset,
batch_size=data.batch_size,
shuffle=data.shuffle,
drop_last=int(os.environ["WORLD_SIZE"]) > 1 if has_world_size else data.drop_last,
global_size=int(os.environ["WORLD_SIZE"]) if has_world_size else data.global_size,
global_rank=int(os.environ["LOCAL_RANK"]) if has_world_size else data.global_rank,
transforms=data.transforms,
)
else:
output = Loader(
_dataset,
batch_size=self.batch_size,
drop_last=int(os.environ["WORLD_SIZE"]) > 1 if has_world_size else False,
global_size=int(os.environ["WORLD_SIZE"]) if has_world_size else None,
global_rank=int(os.environ["LOCAL_RANK"]) if has_world_size else None,
)

setattr(self, f"loader_{name}", output)
return output

def teardown(self, stage):
"""
Stop all data loaders.
"""
for attr in dir(self):
if attr.startswith("loader"):
if hasattr(getattr(self, attr), "stop"):
getattr(self, attr).stop()
delattr(self, attr)


def compute_loss(
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
Expand Down Expand Up @@ -229,5 +371,8 @@ def compute_loss(

for metric in model_out.metrics:
metric_name = camelcase_to_snakecase(metric.__class__.__name__)
if not metric.device or metric.device != _predictions.device:
metric = metric.to(_predictions.device)

results[metric_name] = metric(_predictions, _targets)
return results
3 changes: 1 addition & 2 deletions merlin/models/torch/models/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ class DLRMModel(Model):
... output_block=mm.BinaryOutput(ColumnSchema("target")),
... )
>>> trainer = pl.Trainer()
>>> model.initialize(dataloader)
>>> trainer.fit(model, dataloader)
>>> trainer.fit(model, Loader(dataset, batch_size=32))
{dlrm_reference}
"""
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,7 @@ markers = [
"integration",
"unit",
"changed",
"unchanged"
"unchanged",
"singlegpu",
"multigpu"
]
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ markers =
horovod: mark as requiring horovod
changed: mark as requiring changed files
always: mark as always running
multigpu: Tests only run in multiple-GPU environments
singlegpu: Optional marker to run tests in single-GPU environments. Usually used when running in both single- and multi-GPU.
Empty file.
31 changes: 31 additions & 0 deletions tests/integration/torch/test_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import pytorch_lightning as pl

import merlin.models.torch as mm


# TODO: This test is not complete because Lightning launches separate processes
# under the hood with the correct environment variables like `LOCAL_RANK`, but
# the pytest stays in the main process and tests only the LOCAL_RANK=0 case.
# Follow-up with proper test that ensures dataloader is working properly with
# e.g. global_rank > 0.
class TestMultiGPU:
@pytest.mark.multigpu
def test_multi_gpu(self, music_streaming_data):
schema = music_streaming_data.schema
data = music_streaming_data
model = mm.Model(
mm.TabularInputBlock(schema, init="defaults"),
mm.MLPBlock([5]),
mm.BinaryOutput(schema["click"]),
)

trainer = pl.Trainer(max_epochs=3, devices=2)
multi_loader = mm.MultiLoader(data, batch_size=2)
trainer.fit(model, multi_loader)

# 100 rows total / 2 devices -> 50 rows per device
# 50 rows / 2 per batch -> 25 steps per device
assert trainer.num_training_batches == 25

assert trainer.global_rank == 0 # This should fail for node 1.
30 changes: 26 additions & 4 deletions tests/unit/torch/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,7 @@ def test_train_classification_with_lightning_trainer(self, music_streaming_data,
)

trainer = pl.Trainer(max_epochs=1, devices=1)

with Loader(music_streaming_data, batch_size=batch_size) as loader:
model.initialize(loader)
trainer.fit(model, loader)
trainer.fit(model, Loader(music_streaming_data, batch_size=batch_size))

assert trainer.logged_metrics["train_loss"] > 0.0
assert trainer.num_training_batches == 7 # 100 rows // 16 per batch + 1 for last batch
Expand All @@ -232,6 +229,31 @@ def test_train_classification_with_lightning_trainer(self, music_streaming_data,
_ = module_utils.module_test(model, batch)


class TestMultiLoader:
def test_train_dataset(self, music_streaming_data):
multi_loader = mm.MultiLoader(music_streaming_data)
assert multi_loader.train_dataloader() is multi_loader.loader_train

def test_train_loader(self, music_streaming_data):
multi_loader = mm.MultiLoader(Loader(music_streaming_data, 2))
assert multi_loader.train_dataloader() is multi_loader.loader_train

def test_valid_dataloader(self, music_streaming_data):
multi_loader = mm.MultiLoader(music_streaming_data, music_streaming_data)
assert multi_loader.val_dataloader() is multi_loader.loader_valid

def test_test_dataloader(self, music_streaming_data):
multi_loader = mm.MultiLoader(*([music_streaming_data] * 3))
assert multi_loader.test_dataloader() is multi_loader.loader_test

def test_teardown(self, music_streaming_data):
multi_loader = mm.MultiLoader(*([music_streaming_data] * 3))
multi_loader.teardown(None)
assert not hasattr(multi_loader, "loader_train")
assert not hasattr(multi_loader, "loader_valid")
assert not hasattr(multi_loader, "loader_test")


class TestComputeLoss:
def test_tensor_inputs(self):
predictions = torch.sigmoid(torch.randn(2, 1))
Expand Down
Loading

0 comments on commit 145e592

Please sign in to comment.