Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch.compile with ZeRO (Experimental) #4878

Merged
merged 49 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
91c01ab
add option to run torch.compile
tohtana Dec 15, 2023
49c7acd
improve compile helper
tohtana Dec 23, 2023
bbeb38c
fix compile wrapper to make modules() work
tohtana Dec 24, 2023
719cc79
add torch.compiler-disable to comm module
tohtana Dec 25, 2023
3ea9b44
move options for torch.compile to ds config
tohtana Dec 27, 2023
eb9d4e0
Merge branch 'master' into tohtana/compile-zero
tohtana Dec 28, 2023
fbae027
rename module and wrap decorator
tohtana Dec 28, 2023
4f8f86d
fix validation of compile config
tohtana Dec 28, 2023
d83963b
avoid reference to torch._dynamo when torch has no support
tohtana Dec 28, 2023
6920ab6
fix custom backend for test
tohtana Dec 28, 2023
c3429a6
fix validation
tohtana Dec 28, 2023
bfafb88
refactor config for torch.compile
tohtana Jan 10, 2024
3c13fd4
Merge branch 'master' into tohtana/compile-zero
tohtana Jan 10, 2024
9e63c95
Merge branch 'master' into tohtana/compile-zero
tohtana Jan 22, 2024
ff9c1ef
fix validation of compiler config
tohtana Jan 22, 2024
26b7f25
fix access to wrapped model
tohtana Jan 23, 2024
48d2453
add test for api to set torch compile options
tohtana Jan 23, 2024
d5584b0
rename util module
tohtana Jan 23, 2024
b9157ac
fix import
tohtana Jan 23, 2024
c19bf97
Merge branch 'master' into tohtana/compile-zero
tohtana Jan 23, 2024
93268f3
delay reduce-scatter for z3 leaf modules
tohtana Jan 24, 2024
a56ffec
Merge branch 'master' into tohtana/z3_moe_bwd
tohtana Jan 25, 2024
2a5e741
add comment to config class
tohtana Jan 25, 2024
ec91925
Merge branch 'master' into tohtana/compile-zero
tohtana Jan 25, 2024
ca5cff6
add api to get leaf modules
tohtana Jan 26, 2024
f615138
Merge branch 'master' into tohtana/compile-zero
tjruwase Jan 28, 2024
08770b8
Merge branch 'master' into tohtana/z3_moe_bwd
tjruwase Jan 29, 2024
3e5658b
add api to set a function to run torch.compile
tohtana Jan 30, 2024
5d9992e
Merge branch 'master' into tohtana/compile-zero
mrwyattii Jan 31, 2024
a3c0e5d
refactor compile config
tohtana Jan 31, 2024
95f4f34
lift is_compile_supported up to use as `deepspeed.is_compile_supporte…
tohtana Jan 31, 2024
1932b78
avoid overwriting backend fn in validator
tohtana Jan 31, 2024
ca85605
add tests combining compile and zero
tohtana Feb 1, 2024
19dd454
rename test modules
tohtana Feb 1, 2024
fccbd95
Merge branch 'master' into tohtana/z3_moe_bwd
tohtana Feb 1, 2024
da1f41d
Merge branch 'master' into tohtana/z3_moe_bwd
tohtana Feb 2, 2024
d8c0a14
use no zero + no compile as baseline for tests
tohtana Feb 2, 2024
ca419b4
disable memory_efficient_linear when torch.compile is enabled
tohtana Feb 2, 2024
5e3a070
pass only tensors to z3 hooks to prevent dynamo from displaying errors
tohtana Feb 2, 2024
0637968
Merge branch 'tohtana/z3_moe_bwd' into tohtana/compile-zero
tohtana Feb 2, 2024
7131d6e
Merge branch 'master' into tohtana/compile-zero
tohtana Feb 2, 2024
96c8647
fix exception used in test
tohtana Feb 2, 2024
0b3dae9
increse tolerance in tests
tohtana Feb 3, 2024
94cc97a
add check for bf16
tohtana Feb 3, 2024
eb27b9d
enable accelerator check for bf16
tohtana Feb 3, 2024
c1e10e5
Merge branch 'master' into tohtana/compile-zero
tohtana Feb 5, 2024
c2ba829
update DistributedTest to work with torch.compile tests
mrwyattii Feb 5, 2024
82f80d1
remove unused global
mrwyattii Feb 5, 2024
cb540d3
Merge branch 'master' into tohtana/compile-zero
tohtana Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from .runtime import zero
from .runtime import DeepSpeedOptimizer, ZeROOptimizer
from .runtime.compiler import is_compile_supported

from .pipe import PipelineModule

Expand Down
24 changes: 24 additions & 0 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .utils import *
from .backend import *
from .comm import *
from ..runtime import compiler
import os

DS_COMM_ALL_GATHER_OFF = False
Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
self.init_process_group(backend, timeout, init_method, rank, world_size)

@classmethod
@compiler.disable
def get_all_gather_function(self):
if hasattr(torch.distributed, "all_gather_into_tensor"):
return torch.distributed.all_gather_into_tensor
Expand All @@ -128,6 +130,7 @@ def get_all_gather_function(self):
return None

@classmethod
@compiler.disable
def get_reduce_scatter_function(self):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
return torch.distributed.reduce_scatter_tensor
Expand All @@ -150,14 +153,17 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size):
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'

@compiler.disable
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

@compiler.disable
def inference_all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

@compiler.disable
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
""" proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
Expand All @@ -168,13 +174,15 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group
op = self._reduce_op(op)
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)

@compiler.disable
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_OFF:
if int(os.getenv('RANK', '0')) == 0:
utils.logger.warning("REDUCE is OFF")
return Noop()
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)

@compiler.disable
def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_SCATTER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -187,6 +195,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_
group=group,
async_op=async_op)

@compiler.disable
def broadcast(self, tensor, src, group=None, async_op=False):
if DS_COMM_BROADCAST_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -195,6 +204,7 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@compiler.disable
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -203,13 +213,15 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False):
else:
return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)

@compiler.disable
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_all_gather_into_tensor():
return self.all_gather_function(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)

@compiler.disable
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -227,6 +239,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
""""""
assert len(output_tensors) == len(input_tensors), ""
Expand All @@ -250,6 +263,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_
else:
reqs[-1].wait()

@compiler.disable
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
Expand All @@ -263,6 +277,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
def all_to_all_single(self,
output,
input,
Expand All @@ -277,40 +292,49 @@ def all_to_all_single(self,
group=group,
async_op=async_op)

@compiler.disable
def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False):
return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)

@compiler.disable
def send(self, tensor, dst, group=None, tag=0):
return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
def recv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
def isend(self, tensor, dst, group=None, tag=0):
return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
def irecv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
return torch.distributed.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)

@compiler.disable
def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
return torch.distributed.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)

@compiler.disable
def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)

@compiler.disable
def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
if group is None:
group = torch.distributed.GroupMember.WORLD
Expand Down
166 changes: 166 additions & 0 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Union, Callable, Dict, Any
import importlib
import torch
from ..pydantic_v1 import validator
from .config_utils import DeepSpeedConfigModel

COMPILE_CONFIG = "compile"


def is_compile_supported():
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
return hasattr(torch, "compile")


def disable(func):
if is_compile_supported():
return torch.compiler.disable(func)
return func


def get_compile_config(param_dict):
if COMPILE_CONFIG in param_dict:
compile_config_dict = param_dict[COMPILE_CONFIG]
else:
compile_config_dict = {}
return CompileConfig(**compile_config_dict)


def get_backend_fn(backend: Union[str, Callable]) -> Union[str, Callable]:
if isinstance(backend, Callable):
return backend

elif isinstance(backend, str):
if backend in torch._dynamo.list_backends():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tohtana The default list_backends call will exclude debug and experimental backends, e.g. eager. I think it's better to use list_backends(exclude_tags=()) here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comment. I opened #5191.

return backend

# Get module name from backend name
module_name = '.'.join(backend.split('.')[:-1])
fn_name = backend.split('.')[-1]

try:
module = importlib.import_module(module_name)
backend_fn = getattr(module, fn_name)
except ImportError:
raise ValueError(
f"The backend {backend} is not in the list of available backends and could not be imported.")
return backend_fn

raise ValueError(f"backend for torch.compile must be a string or Callable: {backend}")


class CompileConfig(DeepSpeedConfigModel):
"""
[EXPERIMENTAL] This configuration enables users to activate `torch.compile` within DeepSpeed and customize its settings.
Please be aware that these features and API designs are experimental and subject to change.
"""

enabled: bool = False
"""
Enable torch.compile when True.
"""

backend: str = "inductor"
"""
Passed to `backend` argument of torch.compile.
If the given value is not in torch._dynamo.list_backends(),
DeepSpeed attempts to import and instantiate the module with the given name.
"""

kwargs: Dict[str, Any] = {}
"""
Passed to `kwargs` argument of torch.compile.
"""

@validator("enabled")
def validate_enabled(cls, field_value, values):
if field_value and not is_compile_supported():
raise ValueError("torch.compile is not supported on this version of PyTorch.")
return field_value


class CompiledModuleWrapper(torch.nn.Module):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
super().__init__()

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

modules = self.__dict__.get('_modules')
modules['wrapped'] = module
self.__dict__['wrapped'] = module
self._is_compiled = False
self._backend = get_backend_fn(compile_config.backend)
self._compile_kwargs = compile_config.kwargs
self._compiler_fn = None

def __getattr__(self, name):
return getattr(self.__dict__['wrapped'], name)

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.

Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
self._backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.

Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)

engine.set_compiler_fn(my_compiler_fn)
```
"""
self._compiler_fn = compiler_fn

def forward(self, *args, **kwargs) -> Any:
if not self.is_compiled:
if self._compiler_fn is None:
self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs)
else:
self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
self._is_compiled = True

return self.__dict__['wrapped'](*args, **kwargs)

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn
3 changes: 3 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config
from ..inference.config import WeightQuantConfig
from .compiler import get_compile_config

from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
Expand Down Expand Up @@ -899,6 +900,8 @@ def _initialize_params(self, param_dict):
self.weight_quantization_config = WeightQuantConfig(
**param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None

self.compile_config = get_compile_config(param_dict)

def _batch_assertion(self):

train_batch = self.train_batch_size
Expand Down
Loading
Loading