Skip to content

Commit

Permalink
Enable torch.compile with ZeRO (Experimental) (#4878)
Browse files Browse the repository at this point in the history
This PR enables `torch.compile` with ZeRO stages 1/2/3. You need to add
`compile` section in your DeepSpeed config. The fields in the section
are passed to `torch.compile`.

```json
  "compile": {
    "disable": false,
    "backend": "inductor"
  }
```

To enable a custom backend, you can pass the fully qualified name of the
backend function. For example, if you have a backend class `my_backend`
in `my_backend.py` in the current directory, you can enable it by
`"backend": "my_backend.my_backend"`. You can find an example in [a unit
test](https://github.com/microsoft/DeepSpeed/blob/eb9d4e06e9596f391aea305a6a5c6ec70cc28b58/tests/unit/runtime/compile/test_config.py#L116).

Currently we validated the results with Megatron-DeepSpeed. See the
[example](https://github.com/microsoft/Megatron-DeepSpeed/tree/tohtana/enable_compile/examples_deepspeed/compile)
for the details.

NOTICE: This PR is a draft. We will need to validate the coverage and
accuracy with many more examples.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
3 people authored Feb 6, 2024
1 parent e212845 commit c3cfe96
Show file tree
Hide file tree
Showing 15 changed files with 784 additions and 125 deletions.
1 change: 1 addition & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from .runtime import zero
from .runtime import DeepSpeedOptimizer, ZeROOptimizer
from .runtime.compiler import is_compile_supported

from .pipe import PipelineModule

Expand Down
24 changes: 24 additions & 0 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .utils import *
from .backend import *
from .comm import *
from ..runtime import compiler
import os

DS_COMM_ALL_GATHER_OFF = False
Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
self.init_process_group(backend, timeout, init_method, rank, world_size)

@classmethod
@compiler.disable
def get_all_gather_function(self):
if hasattr(torch.distributed, "all_gather_into_tensor"):
return torch.distributed.all_gather_into_tensor
Expand All @@ -128,6 +130,7 @@ def get_all_gather_function(self):
return None

@classmethod
@compiler.disable
def get_reduce_scatter_function(self):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
return torch.distributed.reduce_scatter_tensor
Expand All @@ -150,14 +153,17 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size):
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'

@compiler.disable
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

@compiler.disable
def inference_all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

@compiler.disable
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
""" proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
Expand All @@ -168,13 +174,15 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group
op = self._reduce_op(op)
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)

@compiler.disable
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_OFF:
if int(os.getenv('RANK', '0')) == 0:
utils.logger.warning("REDUCE is OFF")
return Noop()
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)

@compiler.disable
def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_SCATTER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -187,6 +195,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_
group=group,
async_op=async_op)

@compiler.disable
def broadcast(self, tensor, src, group=None, async_op=False):
if DS_COMM_BROADCAST_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -195,6 +204,7 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@compiler.disable
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -203,13 +213,15 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False):
else:
return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)

@compiler.disable
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_all_gather_into_tensor():
return self.all_gather_function(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)

@compiler.disable
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -227,6 +239,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
""""""
assert len(output_tensors) == len(input_tensors), ""
Expand All @@ -250,6 +263,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_
else:
reqs[-1].wait()

@compiler.disable
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
Expand All @@ -263,6 +277,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
def all_to_all_single(self,
output,
input,
Expand All @@ -277,40 +292,49 @@ def all_to_all_single(self,
group=group,
async_op=async_op)

@compiler.disable
def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False):
return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)

@compiler.disable
def send(self, tensor, dst, group=None, tag=0):
return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
def recv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
def isend(self, tensor, dst, group=None, tag=0):
return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
def irecv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
return torch.distributed.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)

@compiler.disable
def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
return torch.distributed.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)

@compiler.disable
def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)

@compiler.disable
def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
if group is None:
group = torch.distributed.GroupMember.WORLD
Expand Down
166 changes: 166 additions & 0 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Union, Callable, Dict, Any
import importlib
import torch
from ..pydantic_v1 import validator
from .config_utils import DeepSpeedConfigModel

COMPILE_CONFIG = "compile"


def is_compile_supported():
return hasattr(torch, "compile")


def disable(func):
if is_compile_supported():
return torch.compiler.disable(func)
return func


def get_compile_config(param_dict):
if COMPILE_CONFIG in param_dict:
compile_config_dict = param_dict[COMPILE_CONFIG]
else:
compile_config_dict = {}
return CompileConfig(**compile_config_dict)


def get_backend_fn(backend: Union[str, Callable]) -> Union[str, Callable]:
if isinstance(backend, Callable):
return backend

elif isinstance(backend, str):
if backend in torch._dynamo.list_backends():
return backend

# Get module name from backend name
module_name = '.'.join(backend.split('.')[:-1])
fn_name = backend.split('.')[-1]

try:
module = importlib.import_module(module_name)
backend_fn = getattr(module, fn_name)
except ImportError:
raise ValueError(
f"The backend {backend} is not in the list of available backends and could not be imported.")
return backend_fn

raise ValueError(f"backend for torch.compile must be a string or Callable: {backend}")


class CompileConfig(DeepSpeedConfigModel):
"""
[EXPERIMENTAL] This configuration enables users to activate `torch.compile` within DeepSpeed and customize its settings.
Please be aware that these features and API designs are experimental and subject to change.
"""

enabled: bool = False
"""
Enable torch.compile when True.
"""

backend: str = "inductor"
"""
Passed to `backend` argument of torch.compile.
If the given value is not in torch._dynamo.list_backends(),
DeepSpeed attempts to import and instantiate the module with the given name.
"""

kwargs: Dict[str, Any] = {}
"""
Passed to `kwargs` argument of torch.compile.
"""

@validator("enabled")
def validate_enabled(cls, field_value, values):
if field_value and not is_compile_supported():
raise ValueError("torch.compile is not supported on this version of PyTorch.")
return field_value


class CompiledModuleWrapper(torch.nn.Module):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
super().__init__()

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

modules = self.__dict__.get('_modules')
modules['wrapped'] = module
self.__dict__['wrapped'] = module
self._is_compiled = False
self._backend = get_backend_fn(compile_config.backend)
self._compile_kwargs = compile_config.kwargs
self._compiler_fn = None

def __getattr__(self, name):
return getattr(self.__dict__['wrapped'], name)

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.
Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
self._backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.
Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.
Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)
engine.set_compiler_fn(my_compiler_fn)
```
"""
self._compiler_fn = compiler_fn

def forward(self, *args, **kwargs) -> Any:
if not self.is_compiled:
if self._compiler_fn is None:
self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs)
else:
self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
self._is_compiled = True

return self.__dict__['wrapped'](*args, **kwargs)

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn
3 changes: 3 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config
from ..inference.config import WeightQuantConfig
from .compiler import get_compile_config

from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
Expand Down Expand Up @@ -899,6 +900,8 @@ def _initialize_params(self, param_dict):
self.weight_quantization_config = WeightQuantConfig(
**param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None

self.compile_config = get_compile_config(param_dict)

def _batch_assertion(self):

train_batch = self.train_batch_size
Expand Down
Loading

0 comments on commit c3cfe96

Please sign in to comment.