Skip to content

Commit

Permalink
torchgpipe.balancing -> torchgpipe.balance
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed Oct 25, 2019
1 parent b88c09a commit e50f2da
Show file tree
Hide file tree
Showing 13 changed files with 30 additions and 30 deletions.
4 changes: 2 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ Inspecting GPipe Timeline
Automatic Balancing
~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchgpipe.balancing.balance_by_time(module, canary, partitions, device, timeout)
.. autofunction:: torchgpipe.balance.balance_by_time(module, canary, partitions, device, timeout)

.. autofunction:: torchgpipe.balancing.balance_by_size(module, canary, partitions, device)
.. autofunction:: torchgpipe.balance.balance_by_size(module, canary, partitions, device)
2 changes: 1 addition & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Improvements:
- Checkpointing deterministically handles randomness managed by PyTorch.

Breaking Changes:
- Moved ``torchgpipe_balancing`` module to :mod:`torchgpipe.balancing`.
- Moved ``torchgpipe_balancing`` module to :mod:`torchgpipe.balance`.

v0.0.4
~~~~~~
Expand Down
14 changes: 7 additions & 7 deletions docs/guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,20 @@ Automatic Balancing

It could be hard to determine the optimal balance of a model. In particular, if
you are still designing a model, the model architecture may change over time.
In this case, we highly recommend :mod:`torchgpipe.balancing` for automatic
In this case, we highly recommend :mod:`torchgpipe.balance` for automatic
balancing. This won't give you the optimal balance, but a good-enough balance.
Note that this is provided by `torchgpipe` package, and is not from the GPipe
paper.

There are two balancing tools, :func:`~torchgpipe.balancing.balance_by_time`
and :func:`~torchgpipe.balancing.balance_by_size`. Both are based on per-layer
There are two balancing tools, :func:`~torchgpipe.balance.balance_by_time` and
:func:`~torchgpipe.balance.balance_by_size`. Both are based on per-layer
profiling. Just like `PyTorch JIT`_, you need to feed a sample input into the
model. :func:`~torchgpipe.balancing.balance_by_time` traces elapsed time of
each layer, while :func:`~torchgpipe.balancing.balance_by_size` detects the
CUDA memory usage of each layer. Choose the balancing tool for your needs::
model. :func:`~torchgpipe.balance.balance_by_time` traces elapsed time of each
layer, while :func:`~torchgpipe.balance.balance_by_size` detects the CUDA
memory usage of each layer. Choose the balancing tool for your needs::

from torchgpipe import GPipe
from torchgpipe.balancing import balance_by_time
from torchgpipe.balance import balance_by_time

sample = torch.rand(128, 3, 224, 224)
balance = balance_by_time(model, sample, partitions=4)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

zip_safe=False,

packages=['torchgpipe', 'torchgpipe.balancing'],
packages=['torchgpipe', 'torchgpipe.balance'],
package_data={'torchgpipe': ['py.typed']},
py_modules=['torchgpipe_balancing'],

Expand Down
4 changes: 2 additions & 2 deletions tests/test_balancing.py → tests/test_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import nn

from torchgpipe.balancing import balance_by_size, balance_by_time, blockpartition
from torchgpipe.balance import balance_by_size, balance_by_time, blockpartition


def test_blockpartition():
Expand Down Expand Up @@ -101,5 +101,5 @@ def forward(self, x):


def test_deprecated_torchgpipe_balancing():
with pytest.raises(ImportError, match='torchgpipe.balancing'):
with pytest.raises(ImportError, match='torchgpipe.balance'):
__import__('torchgpipe_balancing')
8 changes: 4 additions & 4 deletions tests/test_gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,16 +468,16 @@ def test_named_children():
model.a


def test_recommend_balancing():
with pytest.raises(ValueError, match='balancing'):
def test_recommend_auto_balance():
with pytest.raises(ValueError, match='torchgpipe.balance'):
# balance is required
GPipe(nn.Sequential())

with pytest.raises(ValueError, match='balancing'):
with pytest.raises(ValueError, match='torchgpipe.balance'):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
GPipe(nn.Sequential(), [1])

with pytest.raises(ValueError, match='balancing'):
with pytest.raises(ValueError, match='torchgpipe.balance'):
# module and sum of balance have different length (module: 2, sum of balance: 1)
GPipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torchgpipe import GPipe
from torchgpipe.balancing import balance_by_time
from torchgpipe.balance import balance_by_time
sample = torch.rand(128, 3, 224, 224)
balance = balance_by_time(model, sample, partitions=4)
Expand All @@ -18,8 +18,8 @@
from torch import Tensor
import torch.nn as nn

from torchgpipe.balancing import utils
from torchgpipe.balancing.profile import profile_sizes, profile_times
from torchgpipe.balance import utils
from torchgpipe.balance.profile import profile_sizes, profile_times

__all__ = ['balance_by_time', 'balance_by_size']

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor
import torch.nn as nn

from torchgpipe.balancing import utils
from torchgpipe.balance import utils

__all__: List[str] = []

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor
import torch.nn as nn

from torchgpipe.balancing import blockpartition
from torchgpipe.balance import blockpartition

__all__: List[str] = []

Expand Down
12 changes: 6 additions & 6 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@
NamedModules = OrderedDict


def recommend_balancing(message: str) -> str:
"""Expands a message with recommendation to :mod:`torchgpipe.balancing`."""
def recommend_auto_balance(message: str) -> str:
"""Expands a message with recommendation to :mod:`torchgpipe.balance`."""
return '''{message}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend 'torchgpipe.balancing' for naive
frequently. In this case, we highly recommend 'torchgpipe.balance' for naive
automatic balancing:
from torchgpipe import GPipe
from torchgpipe.balancing import balance_by_time
from torchgpipe.balance import balance_by_time
sample = torch.rand(...)
balance = balance_by_time(model, sample, partitions=...)
Expand Down Expand Up @@ -204,7 +204,7 @@ def __init__(self,
super().__init__()

if balance is None:
raise ValueError(recommend_balancing('balance is required'))
raise ValueError(recommend_auto_balance('balance is required'))
if chunks <= 0:
raise ValueError('number of chunks must be positive integer')
if checkpoint not in ['always', 'except_last', 'never']:
Expand All @@ -227,7 +227,7 @@ def __init__(self,
try:
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_balancing(str(exc)))
raise ValueError(recommend_auto_balance(str(exc)))

self._copy_streams: List[List[AbstractStream]] = []

Expand Down
4 changes: 2 additions & 2 deletions torchgpipe_balancing.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# torchgpipe_balancing has moved to torchgpipe.balancing in v0.0.5.
raise ImportError('import torchgpipe.balancing instead')
# 'torchgpipe_balancing' has moved to 'torchgpipe.balance' in v0.0.5.
raise ImportError("import 'torchgpipe.balance' instead")

0 comments on commit e50f2da

Please sign in to comment.