From 97ee2fa46e20a8962715f834a0bafd7edf61f693 Mon Sep 17 00:00:00 2001 From: tocean Date: Fri, 5 Jan 2024 03:23:52 +0000 Subject: [PATCH 01/18] integrate with fsdp --- examples/mnist_fsdp.py | 207 ++++++++++++++++++++++ msamp/fsdp/__init__.py | 9 + msamp/fsdp/flat_param.py | 62 +++++++ msamp/fsdp/fully_sharded_data_parallel.py | 120 +++++++++++++ msamp/nn/functional.py | 23 ++- msamp/optim/__init__.py | 6 +- msamp/optim/adam.py | 31 +++- msamp/optim/adamw.py | 90 +++++++++- 8 files changed, 539 insertions(+), 9 deletions(-) create mode 100644 examples/mnist_fsdp.py create mode 100644 msamp/fsdp/__init__.py create mode 100644 msamp/fsdp/flat_param.py create mode 100644 msamp/fsdp/fully_sharded_data_parallel.py diff --git a/examples/mnist_fsdp.py b/examples/mnist_fsdp.py new file mode 100644 index 00000000..893fea43 --- /dev/null +++ b/examples/mnist_fsdp.py @@ -0,0 +1,207 @@ +# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py + +import os +import argparse +import functools + +import torch +import torch.nn as nn +import torch.multiprocessing as mp +import torch.nn.functional as F +import torch.distributed as dist + +from torch.optim.lr_scheduler import StepLR +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data.distributed import DistributedSampler + + +from torchvision import datasets, transforms + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + +def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None): + model.train() + ddp_loss = torch.zeros(2).to(rank) + if sampler: + sampler.set_epoch(epoch) + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(rank), target.to(rank) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target, reduction='sum') + loss.backward() + + optimizer.step() + ddp_loss[0] += loss.item() + ddp_loss[1] += len(data) + + #break + dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) + if rank == 0: + print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1])) + +def test(model, rank, world_size, test_loader): + model.eval() + correct = 0 + ddp_loss = torch.zeros(3).to(rank) + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(rank), target.to(rank) + output = model(data) + ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item() + ddp_loss[2] += len(data) + + dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) + + if rank == 0: + test_loss = ddp_loss[0] / ddp_loss[2] + print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( + test_loss, int(ddp_loss[1]), int(ddp_loss[2]), + 100. * ddp_loss[1] / ddp_loss[2])) + + +def fsdp_main(rank, world_size, args): + setup(rank, world_size) + + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + if rank == 0: + dataset1 = datasets.MNIST('./data', train=True, download=True, + transform=transform) + dist.barrier() + if rank != 0: + dataset1 = datasets.MNIST('./data', train=True, + transform=transform) + + dataset2 = datasets.MNIST('./data', train=False, + transform=transform) + + sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True) + sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size) + + train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} + test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} + cuda_kwargs = {'num_workers': 2, + 'pin_memory': True, + 'shuffle': False} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + my_auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=20000 + ) + torch.cuda.set_device(rank) + + init_start_event = torch.cuda.Event(enable_timing=True) + init_end_event = torch.cuda.Event(enable_timing=True) + + model = Net().to(rank) + + if args.msamp: + from msamp.nn import LinearReplacer + from msamp.common.dtype import Dtypes + from msamp.fsdp import FP8FullyShardedDataParallel + model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3) + model = FP8FullyShardedDataParallel(model, use_orig_params=True) + else: + model = FSDP(model, use_orig_params=True) + + if rank == 0: + print(f'FSDP model:') + print(f'{model}') + + if args.msamp: + from msamp.optim import FSDPAdam + optimizer = FSDPAdam(model.parameters(), lr=args.lr) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + init_start_event.record() + for epoch in range(1, args.epochs + 1): + train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1) + test(model, rank, world_size, test_loader) + scheduler.step() + + init_end_event.record() + + if args.save_model: + # use a barrier to make sure training is done on all ranks + dist.barrier() + states = model.state_dict() + if rank == 0: + torch.save(states, "mnist_cnn.pt") + + cleanup() + +if __name__ == '__main__': + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=14, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=3e-4, metavar='LR', + help='learning rate (default: 3e-4)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + parser.add_argument('--msamp', action='store_true', default=False, + help='whether use MS-AMP') + args = parser.parse_args() + + torch.manual_seed(args.seed) + + WORLD_SIZE = torch.cuda.device_count() + mp.spawn(fsdp_main, + args=(WORLD_SIZE, args), + nprocs=WORLD_SIZE, + join=True) \ No newline at end of file diff --git a/msamp/fsdp/__init__.py b/msamp/fsdp/__init__.py new file mode 100644 index 00000000..f81d6826 --- /dev/null +++ b/msamp/fsdp/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Expose the interface of MS-AMP fsdp package.""" + +from msamp.fsdp.flat_param import FP8FlatParamHandle +from msamp.fsdp.fully_sharded_data_parallel import FP8FullyShardedDataParallel + +__all__ = ['FP8FlatParamHandle', 'FP8FullyShardedDataParallel'] \ No newline at end of file diff --git a/msamp/fsdp/flat_param.py b/msamp/fsdp/flat_param.py new file mode 100644 index 00000000..374a7009 --- /dev/null +++ b/msamp/fsdp/flat_param.py @@ -0,0 +1,62 @@ + +from typing import ( + Optional, + Sequence +) + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.distributed.fsdp.flat_param import FlatParamHandle + +class FP8FlatParamHandle(FlatParamHandle): + def _init_flat_param( + self, + params: Sequence[Optional[nn.Parameter]], + module: nn.Module, + use_orig_params: bool, + ) -> None: + super()._init_flat_param(params, module, use_orig_params) + + metas = [] + paddeds = [] + original_shapes = [] + scaling_metas = [] + + for param in self.flat_param._params: + if hasattr(param, '_fp8') and param._fp8: + metas.append(param._meta) + paddeds.append(param._padded) + original_shapes.append(param._original_shape) + scaling_metas.append(param._scaling_metas) + else: + metas.append(None) + paddeds.append(0) + original_shapes.append(None) + scaling_metas.append(None) + + self.flat_param._metas = metas + self.flat_param._paddeds = paddeds + self.flat_param._original_shapes = original_shapes + self.flat_param._scaling_metas = scaling_metas + + def _init_shard_metadata( + self, + numel_padded: int, + start: int, + end: int, + ) -> None: + super()._init_shard_metadata(numel_padded, start, end) + start_offset = 0 + end_offset = 0 + sharded_flat_param_numel = self.flat_param.numel() + for i, meta in enumerate(self.flat_param._metas): + start_offset += self.flat_param._numels[i-1] if i >=1 else 0 + end_offset += self.flat_param._numels[i] + if meta is not None: + start_rank = start_offset // sharded_flat_param_numel + end_rank = (end_offset-1) // sharded_flat_param_numel + ranks = list(range(start_rank, end_rank + 1)) + meta.group = dist.new_group(ranks=ranks) + +torch.distributed.fsdp._init_utils.FlatParamHandle = FP8FlatParamHandle \ No newline at end of file diff --git a/msamp/fsdp/fully_sharded_data_parallel.py b/msamp/fsdp/fully_sharded_data_parallel.py new file mode 100644 index 00000000..7cdbd3d0 --- /dev/null +++ b/msamp/fsdp/fully_sharded_data_parallel.py @@ -0,0 +1,120 @@ +import functools + +import torch +from torch.distributed.utils import _p_assert +from torch.distributed.fsdp._common_utils import FSDP_PREFIX, TrainingState +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.fsdp._runtime_utils import ( + _post_forward, + _post_forward_reshard, + _pre_forward, + _pre_forward_unshard, + _root_pre_forward, +) +from torch.distributed.fsdp._init_utils import _get_default_comm_hook +from torch.distributed.algorithms._comm_hooks import default_hooks + + +class FP8FullyShardedDataParallel(FullyShardedDataParallel): + def __init__(self, module, *args, **kwargs): + for _, submodule in module.named_modules(): + params_to_process = list(submodule.named_parameters(recurse=False)) + for param_name, param in params_to_process: + if not isinstance(param, torch.Tensor): + data = param.value.view(-1) + padded = 0 + if data.numel() % 4 != 0: + padded = 4 - data.numel() % 4 + data = torch.nn.functional.pad(data, (0, padded)) + + data = data.view(dtype=torch.float32) + new_param = torch.nn.Parameter(data) + new_param._fp8 = True + new_param._original_shape = param.shape + new_param._padded = padded + new_param._meta = param.meta + new_param._scaling_metas = param._scaling_metas + setattr(submodule, param_name, new_param) + + super().__init__(module, *args, **kwargs) + + self._communication_hook = self._get_fp8_comm_hook() + + def _get_fp8_comm_hook(self): + def _fp8_allreduce_hook(state, grad, output): + start = 0 + end = 0 + has_meta = False + for meta in self._flat_param._metas: + if meta is not None: + has_meta = True + break + if has_meta: + for i, meta in enumerate(self._flat_param._metas): + start += self._flat_param._numels[i - 1] if i >= 1 else 0 + end += self._flat_param._numels[i] + if meta is not None: + from msamp.common.dtype import Dtypes + from msamp.operators.dist_op import DistOp + dtype = Dtypes.get_dtype_from_qtype(meta.qtype) + DistOp.enable_fp8(meta.qtype) + torch.distributed.all_reduce(grad[start:end].view(dtype), group=state.process_group) + DistOp.disable_fp8() + else: + default_hooks.allreduce_hook( + state=state, + grad=grad[start:end], + ) + start = self.rank * output.numel() + end = (self.rank + 1) * output.numel() + output.copy_(grad[start:end]) + else: + _get_default_comm_hook()(state, grad, output) + + return _fp8_allreduce_hook + + def forward(self, *args, **kwargs): + with torch.autograd.profiler.record_function( + "FullyShardedDataParallel.forward" + ): + args, kwargs = _root_pre_forward(self, self, args, kwargs) + unused = None + unshard_fn = functools.partial(_pre_forward_unshard, self, self._handles) + reshard_fn = functools.partial(_post_forward_reshard, self, self._handles) + args, kwargs = _pre_forward( + self, self._handles, unshard_fn, self._fsdp_wrapped_module, args, kwargs + ) + + for handle in self._handles: + _p_assert( + handle.flat_param.device == self.compute_device, + "Expected `FlatParameter` to be on the compute device " + f"{self.compute_device} but got {handle.flat_param.device}", + ) + i = 0 + for _, submodule in self._fsdp_wrapped_module.named_modules(): + for _, param in submodule.named_parameters(recurse=False): + if self._flat_param._metas[i] is not None: + param._fp8 = True + param._scaling_metas = self._flat_param._scaling_metas[i] + param._meta = self._flat_param._metas[i] + param._padded = self._flat_param._paddeds[i] + param._original_shape = self._flat_param._original_shapes[i] + i += 1 + output = self._fsdp_wrapped_module(*args, **kwargs) + return _post_forward(self, self._handles, reshard_fn, self, unused, output) + + + def named_parameters(self, *args, **kwargs): + should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS + i = 0 + for param_name, param in super().named_parameters(*args, **kwargs): + if self._flat_param._metas[i] is not None: + param._meta = self._flat_param._metas[i] + param._grad_meta = self._flat_param._scaling_metas[i]['wgrad'] + i += 1 + if should_clean_name: + # Remove any instances of the FSDP-specific prefix; there can + # be multiple in the case of nested FSDP modules + param_name = param_name.replace(FSDP_PREFIX, "") + yield (param_name, param) \ No newline at end of file diff --git a/msamp/nn/functional.py b/msamp/nn/functional.py index 511281d5..9510ad47 100644 --- a/msamp/nn/functional.py +++ b/msamp/nn/functional.py @@ -10,6 +10,7 @@ from msamp.common.tensor import ScalingTensor from msamp.operators.gemm import Gemm from msamp.nn.state import model_state +from msamp.nn.parameter import ScalingParameter class _FP8GemmFunction(torch.autograd.Function): @@ -26,6 +27,18 @@ def forward(ctx, input, weight, metas, dtype_holder): dtype_holder (torch.Tensor): A tensor to hold the output dtype. The required_grad of this tensor should be if input.required_grad is False. """ + if hasattr(weight, '_fp8') and weight._fp8: + padded = weight._padded + original_shape = weight._original_shape + meta = weight._meta + + weight = weight.view(dtype=torch.uint8) + if padded != 0: + weight = weight[0: weight.numel() - padded] + weight = weight.view(original_shape) + weight = ScalingParameter(ScalingTensor(weight, meta)) + ctx._fp8 = True + ctx.metas = metas model_state.check_metas_in_flat(metas) input_meta = metas['input'] @@ -96,8 +109,12 @@ def backward(ctx, output_grad): use_split_accumulator=True, ) del old_wgrad - - if model_state.use_fp8_ddp: + if ctx._fp8: + wgrad = wgrad.cast(Dtypes.kfloat8_e4m3, meta=wgrad_meta, sync=True) + wgrad = wgrad.value.view(-1).view(dtype=torch.float32) + wgrad.meta = wgrad_meta + return input_grad, wgrad, None, None + elif model_state.use_fp8_ddp: wgrad.meta = wgrad_meta else: # wgrad above this line is torch.Tensor w/o tensor scaling @@ -149,7 +166,7 @@ def new_fn(input, weight, bias=None): if bias is not None and not isinstance(bias, torch.Tensor): raise TypeError(f'bias should be a torch.Tensor. current type: {type(bias)}') - if isinstance(weight, torch.Tensor): + if isinstance(weight, torch.Tensor) and not hasattr(weight, '_fp8'): return old_fn(input, weight, bias=bias) if not hasattr(weight, '_scaling_metas'): diff --git a/msamp/optim/__init__.py b/msamp/optim/__init__.py index 71d2e06f..4f01ea1a 100644 --- a/msamp/optim/__init__.py +++ b/msamp/optim/__init__.py @@ -5,7 +5,7 @@ from msamp.optim.optimizer import LBOptimizer from msamp.optim.adamw_base import LBAdamWBase -from msamp.optim.adamw import LBAdamW -from msamp.optim.adam import LBAdam, DSAdam +from msamp.optim.adamw import LBAdamW, FSDPAdamW +from msamp.optim.adam import LBAdam, DSAdam, FSDPAdam -__all__ = ['LBOptimizer', 'LBAdamWBase', 'LBAdamW', 'LBAdam', 'DSAdam'] +__all__ = ['LBOptimizer', 'LBAdamWBase', 'LBAdamW', 'LBAdam', 'DSAdam', 'FSDPAdamW', 'FSDPAdam'] diff --git a/msamp/optim/adam.py b/msamp/optim/adam.py index f2179e26..e706fbad 100644 --- a/msamp/optim/adam.py +++ b/msamp/optim/adam.py @@ -3,8 +3,7 @@ """MS-AMP adam module.""" -from msamp.optim import LBAdamW - +from msamp.optim import LBAdamW, FSDPAdamW class LBAdam(LBAdamW): """Implements Adam algorithm with weight decay fix.""" @@ -63,3 +62,31 @@ def __init__( ) self.use_adam = not adam_w_mode self.set_grad_none = set_grad_none + + +class FSDPAdam(FSDPAdamW): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + maximize: bool = False, + *args, + **kwargs + ): + """Constructor. See LBAdamW class docstring for details.""" + super().__init__( + params=params, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + maximize=maximize, + *args, + **kwargs + ) + self.use_adam = True \ No newline at end of file diff --git a/msamp/optim/adamw.py b/msamp/optim/adamw.py index f012adf1..34d120ba 100644 --- a/msamp/optim/adamw.py +++ b/msamp/optim/adamw.py @@ -11,7 +11,7 @@ from msamp.optim import LBAdamWBase from msamp.common.tensor import ScalingMeta, ScalingTensor -from msamp.common.dtype import Floating +from msamp.common.dtype import Floating, Dtypes import msamp_adamw @@ -228,3 +228,91 @@ def adamw_fn( # noqa: C901 if isinstance(params[i], ScalingTensor): params[i].copy_(param.cast(params[i].qtype, meta=params[i].meta)) + + +class FSDPAdamW(LBAdamW): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + *, + maximize: bool = False, + exp_avg_dtype=torch.uint8, + exp_avg_sq_dtype=torch.float16, + tensor_scale=True, + ): + super().__init__( + params, + lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=False, + maximize=maximize, + exp_avg_dtype=exp_avg_dtype, + exp_avg_sq_dtype=exp_avg_sq_dtype, + tensor_scale=tensor_scale + ) + + self.original_params = [] + self.master_weights = [] + + for group in self.param_groups: + params = [] + for param in group['params']: + if param is None or param.numel() == 0: + continue + if hasattr(param, '_meta') and param._meta is not None: + self.original_params.append(param) + dtype = Dtypes.qtype_to_dtype[param._meta.qtype] + param = ScalingTensor(param.view(dtype), param._meta) + master_weight = param.cast(Dtypes.kfloat16) + master_weight.requires_grad = True + self.master_weights.append(master_weight) + params.append(master_weight) + else: + self.original_params.append(param) + self.master_weights.append(None) + params.append(param) + + group['params'] = params + + + def zero_grad(self, set_to_none=False): + for param in self.original_params: + if set_to_none: + param.grad = None + else: + if param.grad is not None: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() + + def step(self): + # cast gradient to ScalingTensor + for i, param in enumerate(self.original_params): + if param.grad is None: + continue + + if hasattr(param, '_meta') and param._meta is not None: + grad_meta = param._grad_meta + dtype = Dtypes.qtype_to_dtype[grad_meta.qtype] + self.master_weights[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta) + param.grad = None + + # call step() to update master weight + super().step() + + # sync params and copy master weight to weight + for i, param in enumerate(self.original_params): + if hasattr(param, '_meta') and param._meta is not None and param.numel() > 0: + data = self.master_weights[i].float().cast(param._meta.qtype, param._meta, True).value.view(torch.float32) + param.data.copy_(data) From e6d4822bc6a91bd9e66c70956986ef004a6ad46f Mon Sep 17 00:00:00 2001 From: tocean Date: Mon, 8 Jan 2024 09:47:14 +0000 Subject: [PATCH 02/18] support auto wrap policy --- examples/mnist_fsdp.py | 17 +-- msamp/fsdp/__init__.py | 4 +- msamp/fsdp/_runtime_utils.py | 18 +++ msamp/fsdp/flat_param.py | 31 ++++- msamp/fsdp/fully_sharded_data_parallel.py | 152 ++++++---------------- msamp/fsdp/replacer.py | 38 ++++++ msamp/nn/functional.py | 8 +- msamp/optim/adam.py | 66 +++++++++- msamp/optim/adamw.py | 6 +- 9 files changed, 205 insertions(+), 135 deletions(-) create mode 100644 msamp/fsdp/_runtime_utils.py create mode 100644 msamp/fsdp/replacer.py diff --git a/examples/mnist_fsdp.py b/examples/mnist_fsdp.py index 893fea43..9979daee 100644 --- a/examples/mnist_fsdp.py +++ b/examples/mnist_fsdp.py @@ -9,15 +9,13 @@ import torch.multiprocessing as mp import torch.nn.functional as F import torch.distributed as dist - from torch.optim.lr_scheduler import StepLR from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data.distributed import DistributedSampler - - from torchvision import datasets, transforms + def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' @@ -25,9 +23,11 @@ def setup(rank, world_size): # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) + def cleanup(): dist.destroy_process_group() + class Net(nn.Module): def __init__(self): super(Net, self).__init__() @@ -53,6 +53,7 @@ def forward(self, x): output = F.log_softmax(x, dim=1) return output + def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None): model.train() ddp_loss = torch.zeros(2).to(rank) @@ -74,6 +75,7 @@ def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler if rank == 0: print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1])) + def test(model, rank, world_size, test_loader): model.eval() correct = 0 @@ -139,13 +141,12 @@ def fsdp_main(rank, world_size, args): model = Net().to(rank) if args.msamp: - from msamp.nn import LinearReplacer - from msamp.common.dtype import Dtypes + from msamp.fsdp import FsdpReplacer from msamp.fsdp import FP8FullyShardedDataParallel - model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3) - model = FP8FullyShardedDataParallel(model, use_orig_params=True) + model = FsdpReplacer.replace(model) + model = FP8FullyShardedDataParallel(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy) else: - model = FSDP(model, use_orig_params=True) + model = FSDP(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy) if rank == 0: print(f'FSDP model:') diff --git a/msamp/fsdp/__init__.py b/msamp/fsdp/__init__.py index f81d6826..b377bcbc 100644 --- a/msamp/fsdp/__init__.py +++ b/msamp/fsdp/__init__.py @@ -3,7 +3,7 @@ """Expose the interface of MS-AMP fsdp package.""" -from msamp.fsdp.flat_param import FP8FlatParamHandle +from msamp.fsdp.replacer import FsdpReplacer from msamp.fsdp.fully_sharded_data_parallel import FP8FullyShardedDataParallel -__all__ = ['FP8FlatParamHandle', 'FP8FullyShardedDataParallel'] \ No newline at end of file +__all__ = ['FsdpReplacer', 'FP8FullyShardedDataParallel'] \ No newline at end of file diff --git a/msamp/fsdp/_runtime_utils.py b/msamp/fsdp/_runtime_utils.py new file mode 100644 index 00000000..6b7c4acd --- /dev/null +++ b/msamp/fsdp/_runtime_utils.py @@ -0,0 +1,18 @@ +from typing import no_type_check + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel + +old_post_backward_hook = torch.distributed.fsdp._runtime_utils._post_backward_hook + +@no_type_check +@torch.no_grad() +def _post_backward_hook(state, handle, *unused): + if not isinstance(state, FullyShardedDataParallel): + return old_post_backward_hook(state, handle, *unused) + + old_communication_hook = state._communication_hook + state._communication_hook = state._get_fp8_comm_hook() + old_post_backward_hook(state, handle, *unused) + state._communication_hook = old_communication_hook + diff --git a/msamp/fsdp/flat_param.py b/msamp/fsdp/flat_param.py index 374a7009..9d39b1e0 100644 --- a/msamp/fsdp/flat_param.py +++ b/msamp/fsdp/flat_param.py @@ -1,8 +1,5 @@ -from typing import ( - Optional, - Sequence -) +from typing import Optional, Sequence import torch import torch.nn as nn @@ -24,7 +21,7 @@ def _init_flat_param( scaling_metas = [] for param in self.flat_param._params: - if hasattr(param, '_fp8') and param._fp8: + if hasattr(param, '_meta') and param._meta: metas.append(param._meta) paddeds.append(param._padded) original_shapes.append(param._original_shape) @@ -59,4 +56,26 @@ def _init_shard_metadata( ranks = list(range(start_rank, end_rank + 1)) meta.group = dist.new_group(ranks=ranks) -torch.distributed.fsdp._init_utils.FlatParamHandle = FP8FlatParamHandle \ No newline at end of file + + def _use_unsharded_views(self, as_params: bool) -> None: + super()._use_unsharded_views(as_params) + for i, param_info in enumerate(self.flat_param._param_infos): + if hasattr(param_info.module, param_info.param_name): + param = getattr(param_info.module, param_info.param_name) + + param._scaling_metas = self.flat_param._scaling_metas[i] + param._meta = self.flat_param._metas[i] + param._padded = self.flat_param._paddeds[i] + param._original_shape = self.flat_param._original_shapes[i] + + @torch.no_grad() + def _use_sharded_views(self) -> None: + super()._use_sharded_views() + for i, param_info in enumerate(self.flat_param._param_infos): + if hasattr(param_info.module, param_info.param_name): + param = getattr(param_info.module, param_info.param_name) + if self.flat_param._metas[i] is not None: + param._meta = self.flat_param._metas[i] + param._grad_meta = self.flat_param._scaling_metas[i]['wgrad'] + + diff --git a/msamp/fsdp/fully_sharded_data_parallel.py b/msamp/fsdp/fully_sharded_data_parallel.py index 7cdbd3d0..88d07a80 100644 --- a/msamp/fsdp/fully_sharded_data_parallel.py +++ b/msamp/fsdp/fully_sharded_data_parallel.py @@ -1,120 +1,54 @@ -import functools - import torch -from torch.distributed.utils import _p_assert -from torch.distributed.fsdp._common_utils import FSDP_PREFIX, TrainingState from torch.distributed.fsdp import FullyShardedDataParallel -from torch.distributed.fsdp._runtime_utils import ( - _post_forward, - _post_forward_reshard, - _pre_forward, - _pre_forward_unshard, - _root_pre_forward, -) -from torch.distributed.fsdp._init_utils import _get_default_comm_hook from torch.distributed.algorithms._comm_hooks import default_hooks +from torch.distributed.fsdp._init_utils import _get_default_comm_hook +from msamp.fsdp.flat_param import FP8FlatParamHandle +from msamp.fsdp._runtime_utils import _post_backward_hook + + +def _get_fp8_comm_hook(self): + def _fp8_allreduce_hook(state, grad, output): + start = 0 + end = 0 + has_meta = False + for meta in self._flat_param._metas: + if meta is not None: + has_meta = True + break + if has_meta: + for i, meta in enumerate(self._flat_param._metas): + start += self._flat_param._numels[i - 1] if i >= 1 else 0 + end += self._flat_param._numels[i] + if meta is not None: + from msamp.common.dtype import Dtypes + from msamp.operators.dist_op import DistOp + dtype = Dtypes.get_dtype_from_qtype(meta.qtype) + DistOp.enable_fp8(meta.qtype) + torch.distributed.all_reduce(grad[start:end].view(dtype), group=state.process_group) + DistOp.disable_fp8() + else: + default_hooks.allreduce_hook( + state=state, + grad=grad[start:end], + ) + start = self.rank*output.numel() + end = (self.rank+1)*output.numel() + output.copy_(grad[start:end]) + else: + _get_default_comm_hook()(state, grad, output) + + return _fp8_allreduce_hook class FP8FullyShardedDataParallel(FullyShardedDataParallel): def __init__(self, module, *args, **kwargs): - for _, submodule in module.named_modules(): - params_to_process = list(submodule.named_parameters(recurse=False)) - for param_name, param in params_to_process: - if not isinstance(param, torch.Tensor): - data = param.value.view(-1) - padded = 0 - if data.numel() % 4 != 0: - padded = 4 - data.numel() % 4 - data = torch.nn.functional.pad(data, (0, padded)) - - data = data.view(dtype=torch.float32) - new_param = torch.nn.Parameter(data) - new_param._fp8 = True - new_param._original_shape = param.shape - new_param._padded = padded - new_param._meta = param.meta - new_param._scaling_metas = param._scaling_metas - setattr(submodule, param_name, new_param) - super().__init__(module, *args, **kwargs) - self._communication_hook = self._get_fp8_comm_hook() - - def _get_fp8_comm_hook(self): - def _fp8_allreduce_hook(state, grad, output): - start = 0 - end = 0 - has_meta = False - for meta in self._flat_param._metas: - if meta is not None: - has_meta = True - break - if has_meta: - for i, meta in enumerate(self._flat_param._metas): - start += self._flat_param._numels[i - 1] if i >= 1 else 0 - end += self._flat_param._numels[i] - if meta is not None: - from msamp.common.dtype import Dtypes - from msamp.operators.dist_op import DistOp - dtype = Dtypes.get_dtype_from_qtype(meta.qtype) - DistOp.enable_fp8(meta.qtype) - torch.distributed.all_reduce(grad[start:end].view(dtype), group=state.process_group) - DistOp.disable_fp8() - else: - default_hooks.allreduce_hook( - state=state, - grad=grad[start:end], - ) - start = self.rank * output.numel() - end = (self.rank + 1) * output.numel() - output.copy_(grad[start:end]) - else: - _get_default_comm_hook()(state, grad, output) - - return _fp8_allreduce_hook - - def forward(self, *args, **kwargs): - with torch.autograd.profiler.record_function( - "FullyShardedDataParallel.forward" - ): - args, kwargs = _root_pre_forward(self, self, args, kwargs) - unused = None - unshard_fn = functools.partial(_pre_forward_unshard, self, self._handles) - reshard_fn = functools.partial(_post_forward_reshard, self, self._handles) - args, kwargs = _pre_forward( - self, self._handles, unshard_fn, self._fsdp_wrapped_module, args, kwargs - ) + @classmethod + def override(cls): + torch.distributed.fsdp._init_utils.FlatParamHandle = FP8FlatParamHandle + torch.distributed.fsdp._runtime_utils._post_backward_hook = _post_backward_hook + FullyShardedDataParallel._get_fp8_comm_hook = _get_fp8_comm_hook - for handle in self._handles: - _p_assert( - handle.flat_param.device == self.compute_device, - "Expected `FlatParameter` to be on the compute device " - f"{self.compute_device} but got {handle.flat_param.device}", - ) - i = 0 - for _, submodule in self._fsdp_wrapped_module.named_modules(): - for _, param in submodule.named_parameters(recurse=False): - if self._flat_param._metas[i] is not None: - param._fp8 = True - param._scaling_metas = self._flat_param._scaling_metas[i] - param._meta = self._flat_param._metas[i] - param._padded = self._flat_param._paddeds[i] - param._original_shape = self._flat_param._original_shapes[i] - i += 1 - output = self._fsdp_wrapped_module(*args, **kwargs) - return _post_forward(self, self._handles, reshard_fn, self, unused, output) - - def named_parameters(self, *args, **kwargs): - should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS - i = 0 - for param_name, param in super().named_parameters(*args, **kwargs): - if self._flat_param._metas[i] is not None: - param._meta = self._flat_param._metas[i] - param._grad_meta = self._flat_param._scaling_metas[i]['wgrad'] - i += 1 - if should_clean_name: - # Remove any instances of the FSDP-specific prefix; there can - # be multiple in the case of nested FSDP modules - param_name = param_name.replace(FSDP_PREFIX, "") - yield (param_name, param) \ No newline at end of file +FP8FullyShardedDataParallel.override() diff --git a/msamp/fsdp/replacer.py b/msamp/fsdp/replacer.py new file mode 100644 index 00000000..17e26e71 --- /dev/null +++ b/msamp/fsdp/replacer.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""MS-AMP fsdp.replacer module.""" + +import torch + +from msamp.common.dtype import Dtypes +from msamp.nn import LinearReplacer + + +class FsdpReplacer: + """A replacer to replace the FP8 weights with FP32 nn.Parameter and attributes.""" + + @classmethod + def replace(cls, model): + """Replace the weights with ScalingParameter in modules.""" + + model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3) + for _, submodule in model.named_modules(): + params_to_process = list(submodule.named_parameters(recurse=False)) + for param_name, param in params_to_process: + if not isinstance(param, torch.Tensor): + data = param.value.view(-1) + padded = 0 + if data.numel() % 4 != 0: + padded = 4 - data.numel() % 4 + data = torch.nn.functional.pad(data, (0, padded)) + + data = data.view(dtype=torch.float32) + new_param = torch.nn.Parameter(data) + new_param._original_shape = param.shape + new_param._padded = 0 + new_param._meta = param.meta + new_param._scaling_metas = param._scaling_metas + + setattr(submodule, param_name, new_param) + return model diff --git a/msamp/nn/functional.py b/msamp/nn/functional.py index 9510ad47..24694dea 100644 --- a/msamp/nn/functional.py +++ b/msamp/nn/functional.py @@ -27,7 +27,7 @@ def forward(ctx, input, weight, metas, dtype_holder): dtype_holder (torch.Tensor): A tensor to hold the output dtype. The required_grad of this tensor should be if input.required_grad is False. """ - if hasattr(weight, '_fp8') and weight._fp8: + if isinstance(weight, torch.Tensor) and hasattr(weight, '_meta'): padded = weight._padded original_shape = weight._original_shape meta = weight._meta @@ -37,7 +37,7 @@ def forward(ctx, input, weight, metas, dtype_holder): weight = weight[0: weight.numel() - padded] weight = weight.view(original_shape) weight = ScalingParameter(ScalingTensor(weight, meta)) - ctx._fp8 = True + ctx.return_wgrad = True ctx.metas = metas model_state.check_metas_in_flat(metas) @@ -109,7 +109,7 @@ def backward(ctx, output_grad): use_split_accumulator=True, ) del old_wgrad - if ctx._fp8: + if ctx.return_wgrad: wgrad = wgrad.cast(Dtypes.kfloat8_e4m3, meta=wgrad_meta, sync=True) wgrad = wgrad.value.view(-1).view(dtype=torch.float32) wgrad.meta = wgrad_meta @@ -166,7 +166,7 @@ def new_fn(input, weight, bias=None): if bias is not None and not isinstance(bias, torch.Tensor): raise TypeError(f'bias should be a torch.Tensor. current type: {type(bias)}') - if isinstance(weight, torch.Tensor) and not hasattr(weight, '_fp8'): + if isinstance(weight, torch.Tensor) and not hasattr(weight, '_meta'): return old_fn(input, weight, bias=bias) if not hasattr(weight, '_scaling_metas'): diff --git a/msamp/optim/adam.py b/msamp/optim/adam.py index e706fbad..4ab0495d 100644 --- a/msamp/optim/adam.py +++ b/msamp/optim/adam.py @@ -3,6 +3,10 @@ """MS-AMP adam module.""" +import torch + +from msamp.common.dtype import Dtypes +from msamp.common.tensor import ScalingTensor from msamp.optim import LBAdamW, FSDPAdamW class LBAdam(LBAdamW): @@ -64,7 +68,8 @@ def __init__( self.set_grad_none = set_grad_none -class FSDPAdam(FSDPAdamW): + +class FSDPAdam(LBAdam): def __init__( self, params, @@ -77,7 +82,6 @@ def __init__( *args, **kwargs ): - """Constructor. See LBAdamW class docstring for details.""" super().__init__( params=params, lr=lr, @@ -89,4 +93,60 @@ def __init__( *args, **kwargs ) - self.use_adam = True \ No newline at end of file + + self.original_params = [] + self.master_weights = [] + + for group in self.param_groups: + params = [] + for param in group['params']: + if param is None or param.numel() == 0: + continue + if hasattr(param, '_meta') and param._meta is not None: + self.original_params.append(param) + dtype = Dtypes.qtype_to_dtype[param._meta.qtype] + param = ScalingTensor(param.view(dtype), param._meta) + master_weight = param.cast(Dtypes.kfloat16) + master_weight.requires_grad = True + self.master_weights.append(master_weight) + params.append(master_weight) + else: + self.original_params.append(param) + self.master_weights.append(None) + params.append(param) + + group['params'] = params + + + def zero_grad(self, set_to_none=False): + for param in self.original_params: + if set_to_none: + param.grad = None + else: + if param.grad is not None: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() + + def step(self): + # cast gradient to ScalingTensor + for i, param in enumerate(self.original_params): + if param.grad is None: + continue + + if hasattr(param, '_meta') and param._meta is not None: + grad_meta = param._grad_meta + dtype = Dtypes.qtype_to_dtype[grad_meta.qtype] + self.master_weights[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta) + param.grad = None + + # call step() to update master weight + super().step() + + # sync params and copy master weight to weight + for i, param in enumerate(self.original_params): + if hasattr(param, '_meta') and param._meta is not None and param.numel() > 0: + data = self.master_weights[i].float().cast(param._meta.qtype, param._meta, True).value.view(torch.float32) + param.data.copy_(data) \ No newline at end of file diff --git a/msamp/optim/adamw.py b/msamp/optim/adamw.py index 34d120ba..d3447c00 100644 --- a/msamp/optim/adamw.py +++ b/msamp/optim/adamw.py @@ -230,7 +230,7 @@ def adamw_fn( # noqa: C901 params[i].copy_(param.cast(params[i].qtype, meta=params[i].meta)) -class FSDPAdamW(LBAdamW): +class FSDPAdamW(LBAdamWBase): def __init__( self, params, @@ -246,6 +246,7 @@ def __init__( exp_avg_sq_dtype=torch.float16, tensor_scale=True, ): + self.tensor_scale = tensor_scale super().__init__( params, lr=lr, @@ -256,8 +257,7 @@ def __init__( amsgrad=False, maximize=maximize, exp_avg_dtype=exp_avg_dtype, - exp_avg_sq_dtype=exp_avg_sq_dtype, - tensor_scale=tensor_scale + exp_avg_sq_dtype=exp_avg_sq_dtype ) self.original_params = [] From eb800e808814074fbeb3dc370b2d260e0b9b71dd Mon Sep 17 00:00:00 2001 From: tocean Date: Mon, 8 Jan 2024 09:50:25 +0000 Subject: [PATCH 03/18] FSDPAdam extends from FSDPAdamW --- msamp/optim/adam.py | 65 +++------------------------------------------ 1 file changed, 3 insertions(+), 62 deletions(-) diff --git a/msamp/optim/adam.py b/msamp/optim/adam.py index 4ab0495d..14082755 100644 --- a/msamp/optim/adam.py +++ b/msamp/optim/adam.py @@ -3,10 +3,6 @@ """MS-AMP adam module.""" -import torch - -from msamp.common.dtype import Dtypes -from msamp.common.tensor import ScalingTensor from msamp.optim import LBAdamW, FSDPAdamW class LBAdam(LBAdamW): @@ -69,7 +65,7 @@ def __init__( -class FSDPAdam(LBAdam): +class FSDPAdam(FSDPAdamW): def __init__( self, params, @@ -93,60 +89,5 @@ def __init__( *args, **kwargs ) - - self.original_params = [] - self.master_weights = [] - - for group in self.param_groups: - params = [] - for param in group['params']: - if param is None or param.numel() == 0: - continue - if hasattr(param, '_meta') and param._meta is not None: - self.original_params.append(param) - dtype = Dtypes.qtype_to_dtype[param._meta.qtype] - param = ScalingTensor(param.view(dtype), param._meta) - master_weight = param.cast(Dtypes.kfloat16) - master_weight.requires_grad = True - self.master_weights.append(master_weight) - params.append(master_weight) - else: - self.original_params.append(param) - self.master_weights.append(None) - params.append(param) - - group['params'] = params - - - def zero_grad(self, set_to_none=False): - for param in self.original_params: - if set_to_none: - param.grad = None - else: - if param.grad is not None: - if param.grad.grad_fn is not None: - param.grad.detach_() - else: - param.grad.requires_grad_(False) - param.grad.zero_() - - def step(self): - # cast gradient to ScalingTensor - for i, param in enumerate(self.original_params): - if param.grad is None: - continue - - if hasattr(param, '_meta') and param._meta is not None: - grad_meta = param._grad_meta - dtype = Dtypes.qtype_to_dtype[grad_meta.qtype] - self.master_weights[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta) - param.grad = None - - # call step() to update master weight - super().step() - - # sync params and copy master weight to weight - for i, param in enumerate(self.original_params): - if hasattr(param, '_meta') and param._meta is not None and param.numel() > 0: - data = self.master_weights[i].float().cast(param._meta.qtype, param._meta, True).value.view(torch.float32) - param.data.copy_(data) \ No newline at end of file + self.use_adam = True + \ No newline at end of file From 993bb7cd01d37c6f12a15211086bb2906e1ca4ff Mon Sep 17 00:00:00 2001 From: tocean Date: Mon, 8 Jan 2024 11:04:27 +0000 Subject: [PATCH 04/18] add document --- docs/getting-started/run-msamp.md | 6 ++++++ docs/user-tutorial/usage.md | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/docs/getting-started/run-msamp.md b/docs/getting-started/run-msamp.md index dd7d29e2..5465aaf7 100644 --- a/docs/getting-started/run-msamp.md +++ b/docs/getting-started/run-msamp.md @@ -20,6 +20,12 @@ python mnist.py --enable-msamp --opt-level=O2 torchrun --nproc_per_node=8 mnist_ddp.py --enable-msamp --opt-level=O2 ``` +### 3. Run mnist using FSDP + +```bash +python mnist_fsdp.py --msamp +``` + ## CIFAR10 ### 1. Run cifar10 using deepspeed diff --git a/docs/user-tutorial/usage.md b/docs/user-tutorial/usage.md index 58832da9..9e247b11 100644 --- a/docs/user-tutorial/usage.md +++ b/docs/user-tutorial/usage.md @@ -39,6 +39,30 @@ For enabling MS-AMP in DeepSpeed, add one line of code `from msamp import deepsp "O3" is designed for FP8 in ZeRO optimizer, so please make sure ZeRO is enabled when using "O3". "use_te" is designed for Transformer Engine, if you have already used Transformer Engine in your model, don't forget to set "use_te" to true. +## Usage in FSDP + +When using FSDP, enabling MS-AMP is very easy, just use `FsdpReplacer.replace` and `FP8FullyShardedDataParallel` to initialize model and `FSDPAdam` to initialize optimizer. + +Example: + +```python +# Your model +model = ... + +# Initialize model +from msamp.fsdp import FsdpReplacer +from msamp.fsdp import FP8FullyShardedDataParallel +my_auto_wrap_policy = ... +model = FsdpReplacer.replace(model) +model = FP8FullyShardedDataParallel(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy) + +# Initialize optimizer +from msamp.optim import FSDPAdam +optimizer = FSDPAdam(model.parameters(), lr=3e-04) +``` + +Please note that currenlty we only support `use_orig_params=True`. + ## Usage in Megatron-DeepSpeed and Megatron-LM For integrating MS-AMP with Megatron-DeepSpeed and Megatron-LM, you need to make some code changes. We provide a patch as a reference for the integration. Here is the instruction of integrating MS-AMP with Megatron-DeepSpeed/Megatron-LM and how to run [gpt-3](https://github.com/Azure/MS-AMP-Examples/tree/main/gpt3) with MS-AMP. From 515e6fc3deec9234fca2dde0adf0fdc29bfe4b8f Mon Sep 17 00:00:00 2001 From: tocean Date: Mon, 8 Jan 2024 11:19:34 +0000 Subject: [PATCH 05/18] add comment --- msamp/fsdp/__init__.py | 2 +- msamp/fsdp/_runtime_utils.py | 9 +++++++-- msamp/fsdp/flat_param.py | 12 +++++++++++- msamp/fsdp/fully_sharded_data_parallel.py | 13 +++++++++++-- 4 files changed, 30 insertions(+), 6 deletions(-) diff --git a/msamp/fsdp/__init__.py b/msamp/fsdp/__init__.py index b377bcbc..34ffb331 100644 --- a/msamp/fsdp/__init__.py +++ b/msamp/fsdp/__init__.py @@ -6,4 +6,4 @@ from msamp.fsdp.replacer import FsdpReplacer from msamp.fsdp.fully_sharded_data_parallel import FP8FullyShardedDataParallel -__all__ = ['FsdpReplacer', 'FP8FullyShardedDataParallel'] \ No newline at end of file +__all__ = ['FsdpReplacer', 'FP8FullyShardedDataParallel'] diff --git a/msamp/fsdp/_runtime_utils.py b/msamp/fsdp/_runtime_utils.py index 6b7c4acd..25cba74c 100644 --- a/msamp/fsdp/_runtime_utils.py +++ b/msamp/fsdp/_runtime_utils.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""MS-AMP fsdp._runtime_utils module.""" + from typing import no_type_check import torch @@ -7,7 +12,8 @@ @no_type_check @torch.no_grad() -def _post_backward_hook(state, handle, *unused): +def _fp8_post_backward_hook(state, handle, *unused): + """A post-backward communication hook which supports fp8.""" if not isinstance(state, FullyShardedDataParallel): return old_post_backward_hook(state, handle, *unused) @@ -15,4 +21,3 @@ def _post_backward_hook(state, handle, *unused): state._communication_hook = state._get_fp8_comm_hook() old_post_backward_hook(state, handle, *unused) state._communication_hook = old_communication_hook - diff --git a/msamp/fsdp/flat_param.py b/msamp/fsdp/flat_param.py index 9d39b1e0..219d21a5 100644 --- a/msamp/fsdp/flat_param.py +++ b/msamp/fsdp/flat_param.py @@ -1,4 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""MS-AMP fsdp.flat_param module.""" + from typing import Optional, Sequence import torch @@ -7,19 +12,21 @@ from torch.distributed.fsdp.flat_param import FlatParamHandle class FP8FlatParamHandle(FlatParamHandle): + """A handle for a flat parameter which may have fp32 and fp8.""" def _init_flat_param( self, params: Sequence[Optional[nn.Parameter]], module: nn.Module, use_orig_params: bool, ) -> None: + """Initialize the flat parameter and save fp8 related metadata.""" super()._init_flat_param(params, module, use_orig_params) metas = [] paddeds = [] original_shapes = [] scaling_metas = [] - + for param in self.flat_param._params: if hasattr(param, '_meta') and param._meta: metas.append(param._meta) @@ -43,6 +50,7 @@ def _init_shard_metadata( start: int, end: int, ) -> None: + """Initialize the shard metadata for the flat parameter and create a group for each fp8 parameter""" super()._init_shard_metadata(numel_padded, start, end) start_offset = 0 end_offset = 0 @@ -58,6 +66,7 @@ def _init_shard_metadata( def _use_unsharded_views(self, as_params: bool) -> None: + """Use unsharded views of the flat parameter and set fp8 related attritutes, which will be use in msamp.nn.functional.""" super()._use_unsharded_views(as_params) for i, param_info in enumerate(self.flat_param._param_infos): if hasattr(param_info.module, param_info.param_name): @@ -70,6 +79,7 @@ def _use_unsharded_views(self, as_params: bool) -> None: @torch.no_grad() def _use_sharded_views(self) -> None: + """Use sharded views of the flat parameter and set meta of scaling tensor, which will be used in optimizer.""" super()._use_sharded_views() for i, param_info in enumerate(self.flat_param._param_infos): if hasattr(param_info.module, param_info.param_name): diff --git a/msamp/fsdp/fully_sharded_data_parallel.py b/msamp/fsdp/fully_sharded_data_parallel.py index 88d07a80..0b31e7e0 100644 --- a/msamp/fsdp/fully_sharded_data_parallel.py +++ b/msamp/fsdp/fully_sharded_data_parallel.py @@ -1,13 +1,19 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""MS-AMP fsdp.fully_sharded_data_parallel module.""" + import torch from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.algorithms._comm_hooks import default_hooks from torch.distributed.fsdp._init_utils import _get_default_comm_hook from msamp.fsdp.flat_param import FP8FlatParamHandle -from msamp.fsdp._runtime_utils import _post_backward_hook +from msamp.fsdp._runtime_utils import _fp8_post_backward_hook def _get_fp8_comm_hook(self): + """Get the communication hook for fp8 gradient.""" def _fp8_allreduce_hook(state, grad, output): start = 0 end = 0 @@ -40,14 +46,17 @@ def _fp8_allreduce_hook(state, grad, output): return _fp8_allreduce_hook + class FP8FullyShardedDataParallel(FullyShardedDataParallel): + """A FullyShardedDataParallel with supports fp8.""" def __init__(self, module, *args, **kwargs): super().__init__(module, *args, **kwargs) @classmethod def override(cls): + """Override FlatParamHandle and _post_backward_hook with class/function which support fp8.""" torch.distributed.fsdp._init_utils.FlatParamHandle = FP8FlatParamHandle - torch.distributed.fsdp._runtime_utils._post_backward_hook = _post_backward_hook + torch.distributed.fsdp._runtime_utils._post_backward_hook = _fp8_post_backward_hook FullyShardedDataParallel._get_fp8_comm_hook = _get_fp8_comm_hook From b8da96b764cba71cabd6017025c6d75e19786acd Mon Sep 17 00:00:00 2001 From: tocean Date: Mon, 8 Jan 2024 11:43:29 +0000 Subject: [PATCH 06/18] fix linting --- examples/mnist_fsdp.py | 99 +++++++++++------------ msamp/fsdp/_runtime_utils.py | 1 + msamp/fsdp/flat_param.py | 20 ++--- msamp/fsdp/fully_sharded_data_parallel.py | 7 +- msamp/fsdp/replacer.py | 3 +- msamp/nn/functional.py | 4 +- msamp/optim/adam.py | 5 +- msamp/optim/adamw.py | 10 ++- 8 files changed, 76 insertions(+), 73 deletions(-) diff --git a/examples/mnist_fsdp.py b/examples/mnist_fsdp.py index 9979daee..3495494d 100644 --- a/examples/mnist_fsdp.py +++ b/examples/mnist_fsdp.py @@ -1,4 +1,10 @@ -# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""The fsdp mnist exampe using MS-AMP. + +It is adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py. +""" import os import argparse @@ -17,19 +23,23 @@ def setup(rank, world_size): + """Initialize the distributed environment.""" os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) + dist.init_process_group('nccl', rank=rank, world_size=world_size) def cleanup(): + """Destroy the distributed environment.""" dist.destroy_process_group() class Net(nn.Module): + """The neural network model for mnist.""" def __init__(self): + """Constructor.""" super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) @@ -39,6 +49,7 @@ def __init__(self): self.fc2 = nn.Linear(128, 10) def forward(self, x): + """Forward function.""" x = self.conv1(x) x = F.relu(x) x = self.conv2(x) @@ -55,6 +66,7 @@ def forward(self, x): def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None): + """Train the model with given data loader and optimizer.""" model.train() ddp_loss = torch.zeros(2).to(rank) if sampler: @@ -70,22 +82,22 @@ def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler ddp_loss[0] += loss.item() ddp_loss[1] += len(data) - #break dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) if rank == 0: print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1])) def test(model, rank, world_size, test_loader): + """Test the model on test data set.""" + model.eval() - correct = 0 ddp_loss = torch.zeros(3).to(rank) with torch.no_grad(): for data, target in test_loader: data, target = data.to(rank), target.to(rank) output = model(data) - ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item() ddp_loss[2] += len(data) @@ -93,46 +105,39 @@ def test(model, rank, world_size, test_loader): if rank == 0: test_loss = ddp_loss[0] / ddp_loss[2] - print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( - test_loss, int(ddp_loss[1]), int(ddp_loss[2]), - 100. * ddp_loss[1] / ddp_loss[2])) + print( + 'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( + test_loss, int(ddp_loss[1]), int(ddp_loss[2]), 100. * ddp_loss[1] / ddp_loss[2] + ) + ) def fsdp_main(rank, world_size, args): + """The main function for fsdp mnist example.""" setup(rank, world_size) - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) if rank == 0: - dataset1 = datasets.MNIST('./data', train=True, download=True, - transform=transform) + dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform) dist.barrier() if rank != 0: - dataset1 = datasets.MNIST('./data', train=True, - transform=transform) + dataset1 = datasets.MNIST('./data', train=True, transform=transform) - dataset2 = datasets.MNIST('./data', train=False, - transform=transform) + dataset2 = datasets.MNIST('./data', train=False, transform=transform) sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True) sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size) train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} - cuda_kwargs = {'num_workers': 2, - 'pin_memory': True, - 'shuffle': False} + cuda_kwargs = {'num_workers': 2, 'pin_memory': True, 'shuffle': False} train_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs) - train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) - my_auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=20000 - ) + my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=20000) torch.cuda.set_device(rank) init_start_event = torch.cuda.Event(enable_timing=True) @@ -149,8 +154,7 @@ def fsdp_main(rank, world_size, args): model = FSDP(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy) if rank == 0: - print(f'FSDP model:') - print(f'{model}') + print('FSDP model: {model}') if args.msamp: from msamp.optim import FSDPAdam @@ -172,37 +176,30 @@ def fsdp_main(rank, world_size, args): dist.barrier() states = model.state_dict() if rank == 0: - torch.save(states, "mnist_cnn.pt") + torch.save(states, 'mnist_cnn.pt') cleanup() + if __name__ == '__main__': # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=14, metavar='N', - help='number of epochs to train (default: 14)') - parser.add_argument('--lr', type=float, default=3e-4, metavar='LR', - help='learning rate (default: 3e-4)') - parser.add_argument('--gamma', type=float, default=0.7, metavar='M', - help='Learning rate step gamma (default: 0.7)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--save-model', action='store_true', default=False, - help='For Saving the current Model') - parser.add_argument('--msamp', action='store_true', default=False, - help='whether use MS-AMP') + parser.add_argument( + '--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)' + ) + parser.add_argument( + '--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)' + ) + parser.add_argument('--epochs', type=int, default=14, metavar='N', help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=3e-4, metavar='LR', help='learning rate (default: 3e-4)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') + parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') + parser.add_argument('--msamp', action='store_true', default=False, help='whether use MS-AMP') args = parser.parse_args() torch.manual_seed(args.seed) WORLD_SIZE = torch.cuda.device_count() - mp.spawn(fsdp_main, - args=(WORLD_SIZE, args), - nprocs=WORLD_SIZE, - join=True) \ No newline at end of file + mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) diff --git a/msamp/fsdp/_runtime_utils.py b/msamp/fsdp/_runtime_utils.py index 25cba74c..9c3092ac 100644 --- a/msamp/fsdp/_runtime_utils.py +++ b/msamp/fsdp/_runtime_utils.py @@ -10,6 +10,7 @@ old_post_backward_hook = torch.distributed.fsdp._runtime_utils._post_backward_hook + @no_type_check @torch.no_grad() def _fp8_post_backward_hook(state, handle, *unused): diff --git a/msamp/fsdp/flat_param.py b/msamp/fsdp/flat_param.py index 219d21a5..92037309 100644 --- a/msamp/fsdp/flat_param.py +++ b/msamp/fsdp/flat_param.py @@ -1,4 +1,3 @@ - # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. @@ -11,6 +10,7 @@ import torch.distributed as dist from torch.distributed.fsdp.flat_param import FlatParamHandle + class FP8FlatParamHandle(FlatParamHandle): """A handle for a flat parameter which may have fp32 and fp8.""" def _init_flat_param( @@ -50,33 +50,35 @@ def _init_shard_metadata( start: int, end: int, ) -> None: - """Initialize the shard metadata for the flat parameter and create a group for each fp8 parameter""" + """Initialize the shard metadata for the flat parameter and create a group for each fp8 parameter.""" super()._init_shard_metadata(numel_padded, start, end) start_offset = 0 end_offset = 0 sharded_flat_param_numel = self.flat_param.numel() for i, meta in enumerate(self.flat_param._metas): - start_offset += self.flat_param._numels[i-1] if i >=1 else 0 + start_offset += self.flat_param._numels[i - 1] if i >= 1 else 0 end_offset += self.flat_param._numels[i] if meta is not None: start_rank = start_offset // sharded_flat_param_numel - end_rank = (end_offset-1) // sharded_flat_param_numel + end_rank = (end_offset - 1) // sharded_flat_param_numel ranks = list(range(start_rank, end_rank + 1)) meta.group = dist.new_group(ranks=ranks) - def _use_unsharded_views(self, as_params: bool) -> None: - """Use unsharded views of the flat parameter and set fp8 related attritutes, which will be use in msamp.nn.functional.""" + """Use unsharded views of the flat parameter. + + It will also set fp8 related attritutes, which will be use in msamp.nn.functional. + """ super()._use_unsharded_views(as_params) for i, param_info in enumerate(self.flat_param._param_infos): if hasattr(param_info.module, param_info.param_name): param = getattr(param_info.module, param_info.param_name) - + param._scaling_metas = self.flat_param._scaling_metas[i] param._meta = self.flat_param._metas[i] param._padded = self.flat_param._paddeds[i] param._original_shape = self.flat_param._original_shapes[i] - + @torch.no_grad() def _use_sharded_views(self) -> None: """Use sharded views of the flat parameter and set meta of scaling tensor, which will be used in optimizer.""" @@ -87,5 +89,3 @@ def _use_sharded_views(self) -> None: if self.flat_param._metas[i] is not None: param._meta = self.flat_param._metas[i] param._grad_meta = self.flat_param._scaling_metas[i]['wgrad'] - - diff --git a/msamp/fsdp/fully_sharded_data_parallel.py b/msamp/fsdp/fully_sharded_data_parallel.py index 0b31e7e0..2fc9438a 100644 --- a/msamp/fsdp/fully_sharded_data_parallel.py +++ b/msamp/fsdp/fully_sharded_data_parallel.py @@ -38,18 +38,19 @@ def _fp8_allreduce_hook(state, grad, output): state=state, grad=grad[start:end], ) - start = self.rank*output.numel() - end = (self.rank+1)*output.numel() + start = self.rank * output.numel() + end = (self.rank + 1) * output.numel() output.copy_(grad[start:end]) else: _get_default_comm_hook()(state, grad, output) - + return _fp8_allreduce_hook class FP8FullyShardedDataParallel(FullyShardedDataParallel): """A FullyShardedDataParallel with supports fp8.""" def __init__(self, module, *args, **kwargs): + """Constructor.""" super().__init__(module, *args, **kwargs) @classmethod diff --git a/msamp/fsdp/replacer.py b/msamp/fsdp/replacer.py index 17e26e71..6d9c6400 100644 --- a/msamp/fsdp/replacer.py +++ b/msamp/fsdp/replacer.py @@ -11,7 +11,6 @@ class FsdpReplacer: """A replacer to replace the FP8 weights with FP32 nn.Parameter and attributes.""" - @classmethod def replace(cls, model): """Replace the weights with ScalingParameter in modules.""" @@ -24,7 +23,7 @@ def replace(cls, model): data = param.value.view(-1) padded = 0 if data.numel() % 4 != 0: - padded = 4 - data.numel() % 4 + padded = 4 - data.numel() % 4 data = torch.nn.functional.pad(data, (0, padded)) data = data.view(dtype=torch.float32) diff --git a/msamp/nn/functional.py b/msamp/nn/functional.py index 24694dea..baa02483 100644 --- a/msamp/nn/functional.py +++ b/msamp/nn/functional.py @@ -34,7 +34,7 @@ def forward(ctx, input, weight, metas, dtype_holder): weight = weight.view(dtype=torch.uint8) if padded != 0: - weight = weight[0: weight.numel() - padded] + weight = weight[0:weight.numel() - padded] weight = weight.view(original_shape) weight = ScalingParameter(ScalingTensor(weight, meta)) ctx.return_wgrad = True @@ -113,7 +113,7 @@ def backward(ctx, output_grad): wgrad = wgrad.cast(Dtypes.kfloat8_e4m3, meta=wgrad_meta, sync=True) wgrad = wgrad.value.view(-1).view(dtype=torch.float32) wgrad.meta = wgrad_meta - return input_grad, wgrad, None, None + return input_grad, wgrad, None, None elif model_state.use_fp8_ddp: wgrad.meta = wgrad_meta else: diff --git a/msamp/optim/adam.py b/msamp/optim/adam.py index 14082755..ea92daa6 100644 --- a/msamp/optim/adam.py +++ b/msamp/optim/adam.py @@ -5,6 +5,7 @@ from msamp.optim import LBAdamW, FSDPAdamW + class LBAdam(LBAdamW): """Implements Adam algorithm with weight decay fix.""" def __init__( @@ -64,8 +65,8 @@ def __init__( self.set_grad_none = set_grad_none - class FSDPAdam(FSDPAdamW): + """Implements Adam algorithm for FSDP.""" def __init__( self, params, @@ -78,6 +79,7 @@ def __init__( *args, **kwargs ): + """Constructor. See LBAdamW class docstring for details.""" super().__init__( params=params, lr=lr, @@ -90,4 +92,3 @@ def __init__( **kwargs ) self.use_adam = True - \ No newline at end of file diff --git a/msamp/optim/adamw.py b/msamp/optim/adamw.py index d3447c00..aef5443f 100644 --- a/msamp/optim/adamw.py +++ b/msamp/optim/adamw.py @@ -231,6 +231,7 @@ def adamw_fn( # noqa: C901 class FSDPAdamW(LBAdamWBase): + """Implements AdamW algorithm for FSDP.""" def __init__( self, params, @@ -246,6 +247,7 @@ def __init__( exp_avg_sq_dtype=torch.float16, tensor_scale=True, ): + """Constructor. See LBAdamW class docstring for details.""" self.tensor_scale = tensor_scale super().__init__( params, @@ -275,7 +277,7 @@ def __init__( master_weight = param.cast(Dtypes.kfloat16) master_weight.requires_grad = True self.master_weights.append(master_weight) - params.append(master_weight) + params.append(master_weight) else: self.original_params.append(param) self.master_weights.append(None) @@ -283,8 +285,8 @@ def __init__( group['params'] = params - def zero_grad(self, set_to_none=False): + """Zero gradients.""" for param in self.original_params: if set_to_none: param.grad = None @@ -297,6 +299,7 @@ def zero_grad(self, set_to_none=False): param.grad.zero_() def step(self): + """Performs a single optimization step.""" # cast gradient to ScalingTensor for i, param in enumerate(self.original_params): if param.grad is None: @@ -314,5 +317,6 @@ def step(self): # sync params and copy master weight to weight for i, param in enumerate(self.original_params): if hasattr(param, '_meta') and param._meta is not None and param.numel() > 0: - data = self.master_weights[i].float().cast(param._meta.qtype, param._meta, True).value.view(torch.float32) + data = self.master_weights[i].float().cast(param._meta.qtype, param._meta, True) \ + .value.view(torch.float32) param.data.copy_(data) From f00520201cfae49c82e0aae81420263f66f6480f Mon Sep 17 00:00:00 2001 From: tocean Date: Mon, 8 Jan 2024 11:49:53 +0000 Subject: [PATCH 07/18] fix typo --- docs/user-tutorial/usage.md | 2 +- examples/mnist_fsdp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/user-tutorial/usage.md b/docs/user-tutorial/usage.md index 9e247b11..4824fd25 100644 --- a/docs/user-tutorial/usage.md +++ b/docs/user-tutorial/usage.md @@ -61,7 +61,7 @@ from msamp.optim import FSDPAdam optimizer = FSDPAdam(model.parameters(), lr=3e-04) ``` -Please note that currenlty we only support `use_orig_params=True`. +Please note that currently we only support `use_orig_params=True`. ## Usage in Megatron-DeepSpeed and Megatron-LM diff --git a/examples/mnist_fsdp.py b/examples/mnist_fsdp.py index 3495494d..3e9fd1f9 100644 --- a/examples/mnist_fsdp.py +++ b/examples/mnist_fsdp.py @@ -154,7 +154,7 @@ def fsdp_main(rank, world_size, args): model = FSDP(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy) if rank == 0: - print('FSDP model: {model}') + print(f'FSDP model: {model}') if args.msamp: from msamp.optim import FSDPAdam From bcfd7e54068fde7063b2e97abbcd3083f9a66193 Mon Sep 17 00:00:00 2001 From: tocean Date: Tue, 9 Jan 2024 05:21:57 +0000 Subject: [PATCH 08/18] broadcast scaling_inv in optimizer.step --- msamp/fsdp/flat_param.py | 3 ++- msamp/optim/adamw.py | 34 ++++++++++++++++++---------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/msamp/fsdp/flat_param.py b/msamp/fsdp/flat_param.py index 92037309..0191eef1 100644 --- a/msamp/fsdp/flat_param.py +++ b/msamp/fsdp/flat_param.py @@ -28,7 +28,7 @@ def _init_flat_param( scaling_metas = [] for param in self.flat_param._params: - if hasattr(param, '_meta') and param._meta: + if hasattr(param, '_meta') and param._meta is not None: metas.append(param._meta) paddeds.append(param._padded) original_shapes.append(param._original_shape) @@ -63,6 +63,7 @@ def _init_shard_metadata( end_rank = (end_offset - 1) // sharded_flat_param_numel ranks = list(range(start_rank, end_rank + 1)) meta.group = dist.new_group(ranks=ranks) + meta.rank = ranks[0] def _use_unsharded_views(self, as_params: bool) -> None: """Use unsharded views of the flat parameter. diff --git a/msamp/optim/adamw.py b/msamp/optim/adamw.py index aef5443f..cc3c8f2b 100644 --- a/msamp/optim/adamw.py +++ b/msamp/optim/adamw.py @@ -8,6 +8,7 @@ import torch from torch import Tensor +import torch.distributed as dist from msamp.optim import LBAdamWBase from msamp.common.tensor import ScalingMeta, ScalingTensor @@ -268,10 +269,11 @@ def __init__( for group in self.param_groups: params = [] for param in group['params']: - if param is None or param.numel() == 0: + if param is None: continue - if hasattr(param, '_meta') and param._meta is not None: - self.original_params.append(param) + + self.original_params.append(param) + if hasattr(param, '_meta') and param._meta is not None and param.numel() > 0: dtype = Dtypes.qtype_to_dtype[param._meta.qtype] param = ScalingTensor(param.view(dtype), param._meta) master_weight = param.cast(Dtypes.kfloat16) @@ -279,7 +281,6 @@ def __init__( self.master_weights.append(master_weight) params.append(master_weight) else: - self.original_params.append(param) self.master_weights.append(None) params.append(param) @@ -300,23 +301,24 @@ def zero_grad(self, set_to_none=False): def step(self): """Performs a single optimization step.""" - # cast gradient to ScalingTensor - for i, param in enumerate(self.original_params): - if param.grad is None: - continue - - if hasattr(param, '_meta') and param._meta is not None: + # Set gradient of master weight. + for i, master_param in enumerate(self.master_weights): + if master_param is not None: + param = self.original_params[i] grad_meta = param._grad_meta dtype = Dtypes.qtype_to_dtype[grad_meta.qtype] - self.master_weights[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta) + master_param[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta) param.grad = None # call step() to update master weight super().step() - # sync params and copy master weight to weight + # Copy master weight to weight for i, param in enumerate(self.original_params): - if hasattr(param, '_meta') and param._meta is not None and param.numel() > 0: - data = self.master_weights[i].float().cast(param._meta.qtype, param._meta, True) \ - .value.view(torch.float32) - param.data.copy_(data) + if hasattr(param, '_meta') and param._meta is not None: + if param.numel() > 0: + data = self.master_weights[i].float().cast(param._meta.qtype, param._meta, True) \ + .value.view(torch.float32) + param.data.copy_(data) + # broadcast scale_inv + dist.broadcast(param._meta.scale_inv, src=param._meta.rank) From 103b7c72a6d132717f3516f2060e0b09d3a19fc6 Mon Sep 17 00:00:00 2001 From: tocean Date: Tue, 9 Jan 2024 06:00:23 +0000 Subject: [PATCH 09/18] fix bug --- msamp/optim/adamw.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/msamp/optim/adamw.py b/msamp/optim/adamw.py index cc3c8f2b..f65c41ba 100644 --- a/msamp/optim/adamw.py +++ b/msamp/optim/adamw.py @@ -302,12 +302,11 @@ def zero_grad(self, set_to_none=False): def step(self): """Performs a single optimization step.""" # Set gradient of master weight. - for i, master_param in enumerate(self.master_weights): - if master_param is not None: - param = self.original_params[i] + for i, param in enumerate(self.original_params): + if self.master_weights[i] is not None: grad_meta = param._grad_meta dtype = Dtypes.qtype_to_dtype[grad_meta.qtype] - master_param[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta) + self.master_weights[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta) param.grad = None # call step() to update master weight From 19a543408fb9c60e28a0ba9b357abaaa0447d07b Mon Sep 17 00:00:00 2001 From: tocean Date: Tue, 9 Jan 2024 06:56:29 +0000 Subject: [PATCH 10/18] add ut for fsdp --- msamp/fsdp/replacer.py | 2 +- msamp/nn/functional.py | 2 +- tests/fsdp/test_fsdp_distributed.py | 57 +++++++++++++++++++ tests/fsdp/test_fsdp_replacer.py | 48 ++++++++++++++++ tests/nn/test_distributed.py | 2 +- .../{test_replacer.py => test_te_replacer.py} | 0 6 files changed, 108 insertions(+), 3 deletions(-) create mode 100644 tests/fsdp/test_fsdp_distributed.py create mode 100644 tests/fsdp/test_fsdp_replacer.py rename tests/te/{test_replacer.py => test_te_replacer.py} (100%) diff --git a/msamp/fsdp/replacer.py b/msamp/fsdp/replacer.py index 6d9c6400..3bc856c8 100644 --- a/msamp/fsdp/replacer.py +++ b/msamp/fsdp/replacer.py @@ -29,7 +29,7 @@ def replace(cls, model): data = data.view(dtype=torch.float32) new_param = torch.nn.Parameter(data) new_param._original_shape = param.shape - new_param._padded = 0 + new_param._padded = padded new_param._meta = param.meta new_param._scaling_metas = param._scaling_metas diff --git a/msamp/nn/functional.py b/msamp/nn/functional.py index baa02483..23d023cb 100644 --- a/msamp/nn/functional.py +++ b/msamp/nn/functional.py @@ -109,7 +109,7 @@ def backward(ctx, output_grad): use_split_accumulator=True, ) del old_wgrad - if ctx.return_wgrad: + if hasattr(ctx, 'return_wgrad') and ctx.return_wgrad: wgrad = wgrad.cast(Dtypes.kfloat8_e4m3, meta=wgrad_meta, sync=True) wgrad = wgrad.value.view(-1).view(dtype=torch.float32) wgrad.meta = wgrad_meta diff --git a/tests/fsdp/test_fsdp_distributed.py b/tests/fsdp/test_fsdp_distributed.py new file mode 100644 index 00000000..cda048cc --- /dev/null +++ b/tests/fsdp/test_fsdp_distributed.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for msamp.fsdp package.""" + +import os + +import torch +import torch.distributed as dist +from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl + +from tests.helper import decorator +from msamp.fsdp import FsdpReplacer, FP8FullyShardedDataParallel + + +class FsdpDistributedTestCast(MultiProcessTestCase): + """Test functions in distributed module with FSDP.""" + def setUp(self): + """Hook method for setting up the test fixture before exercising it.""" + super().setUp() + torch.manual_seed(1000) + + self._spawn_processes() + + def tearDown(self): + """Hook method for deconstructing the test fixture after testing it.""" + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self): + """Return the number of processes.""" + return 2 + + @requires_nccl() + @skip_if_lt_x_gpu(2) + @decorator.cuda_test + def test_fp8_fsdp(self): + """Test forward and backward functionality in FP8 FSDP.""" + rank = self.rank + store = dist.FileStore(self.file_name, self.world_size) + torch.cuda.set_device(rank) + dist.init_process_group(backend='nccl', store=store, rank=self.rank, world_size=self.world_size) + model = torch.nn.Sequential(torch.nn.Linear(10000, 20000), torch.nn.Dropout(), torch.nn.Linear(20000, 10000)) + model = FsdpReplacer.replace(model) + model = FP8FullyShardedDataParallel(model, use_orig_params=True) + for _ in range(10): + input = torch.randn(128, 10000).cuda() + output = model(input) + loss = output.sum() + loss.backward() + for param in model.parameters(): + if param.numel() > 0: + assert param.grad is not None diff --git a/tests/fsdp/test_fsdp_replacer.py b/tests/fsdp/test_fsdp_replacer.py new file mode 100644 index 00000000..0e4405fb --- /dev/null +++ b/tests/fsdp/test_fsdp_replacer.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for msamp.fsdp.replacer module.""" + +import unittest +import copy +import math + +import torch +import torch.nn as nn +from tests.helper import decorator + +from msamp.fsdp import FsdpReplacer + + +class FsdpReplacerTestCase(unittest.TestCase): + """Test TeExtention overrider.""" + def setUp(self): + """Hook method for setting up the test fixture before exercising it.""" + torch.manual_seed(1000) + + def tearDown(self): + """Hook method for deconstructing the test fixture after testing it.""" + pass + + @decorator.cuda_test + def test_replace(self): + """Test replace function in FsdpReplacer.""" + model = nn.Linear(5, 10) + model1 = copy.deepcopy(model) + + model2 = FsdpReplacer.replace(model) + + params1 = list(model1.parameters()) + params2 = list(model2.parameters()) + + assert len(params1) == len(params2) + + param1 = params1[0] + param2 = params2[0] + assert isinstance(param1, torch.nn.Parameter) + assert isinstance(param2, torch.nn.Parameter) + assert param2.numel() == int(math.ceil(param1.numel() / 4)) + assert param2._original_shape == torch.Size([10, 5]) + assert param2._padded == 4 - (5 * 10) % 4 + assert param2._meta is not None + assert param2._scaling_metas is not None diff --git a/tests/nn/test_distributed.py b/tests/nn/test_distributed.py index 10fee252..9634476f 100644 --- a/tests/nn/test_distributed.py +++ b/tests/nn/test_distributed.py @@ -27,7 +27,7 @@ def forward(self, x): return self.fc1(x) -class DistributedTestCast(MultiProcessTestCase): +class DistributedTestCase(MultiProcessTestCase): """Test functions in distributed module.""" def setUp(self): """Hook method for setting up the test fixture before exercising it.""" diff --git a/tests/te/test_replacer.py b/tests/te/test_te_replacer.py similarity index 100% rename from tests/te/test_replacer.py rename to tests/te/test_te_replacer.py From 4415d825089bd33e0531c69a08882747958b95bc Mon Sep 17 00:00:00 2001 From: tocean Date: Tue, 9 Jan 2024 07:19:52 +0000 Subject: [PATCH 11/18] fix comments --- examples/cifar10_deepspeed.py | 2 +- examples/cifar10_deepspeed_te.py | 2 +- examples/mnist.py | 2 +- examples/mnist_ddp.py | 2 +- examples/mnist_fsdp.py | 4 +++- msamp/fsdp/_runtime_utils.py | 4 ++++ 6 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/cifar10_deepspeed.py b/examples/cifar10_deepspeed.py index db247153..5322e3e7 100644 --- a/examples/cifar10_deepspeed.py +++ b/examples/cifar10_deepspeed.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""The deepspeed cifar10 exampe using MS-AMP. It is adapted from official deepspeed example. +"""The deepspeed cifar10 example using MS-AMP. It is adapted from official deepspeed example. The only change is add "from msamp import deepspeed" and remove moe related code. """ diff --git a/examples/cifar10_deepspeed_te.py b/examples/cifar10_deepspeed_te.py index c51ff846..c29f0969 100644 --- a/examples/cifar10_deepspeed_te.py +++ b/examples/cifar10_deepspeed_te.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""The deepspeed cifar10 exampe using MS-AMP and TransformerEngine. It is adapted from official deepspeed example. +"""The deepspeed cifar10 example using MS-AMP and TransformerEngine. It is adapted from official deepspeed example. The model is adapted from VisionTransfomrer in timm and it uses te.TransformerLayer as encoder block. """ diff --git a/examples/mnist.py b/examples/mnist.py index b0bad1fa..5a3e5f2a 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""The mnist exampe using MS-AMP. It is adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py.""" +"""The mnist example using MS-AMP. It is adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py.""" from __future__ import print_function import argparse diff --git a/examples/mnist_ddp.py b/examples/mnist_ddp.py index 239c4943..0864b6bc 100644 --- a/examples/mnist_ddp.py +++ b/examples/mnist_ddp.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""The ddp mnist exampe using MS-AMP. It is adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py.""" +"""The ddp mnist example using MS-AMP. It is adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py.""" from __future__ import print_function import os diff --git a/examples/mnist_fsdp.py b/examples/mnist_fsdp.py index 3e9fd1f9..1e894e4d 100644 --- a/examples/mnist_fsdp.py +++ b/examples/mnist_fsdp.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""The fsdp mnist exampe using MS-AMP. +"""The fsdp mnist example using MS-AMP. It is adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py. """ @@ -73,7 +73,9 @@ def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler sampler.set_epoch(epoch) for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(rank), target.to(rank) + optimizer.zero_grad() + output = model(data) loss = F.nll_loss(output, target, reduction='sum') loss.backward() diff --git a/msamp/fsdp/_runtime_utils.py b/msamp/fsdp/_runtime_utils.py index 9c3092ac..f6bb674f 100644 --- a/msamp/fsdp/_runtime_utils.py +++ b/msamp/fsdp/_runtime_utils.py @@ -18,6 +18,10 @@ def _fp8_post_backward_hook(state, handle, *unused): if not isinstance(state, FullyShardedDataParallel): return old_post_backward_hook(state, handle, *unused) + accumulate_grad = hasattr(state._flat_param, "_saved_grad_shard") + if accumulate_grad and not torch.all(state._flat_param._saved_grad_shard == 0): + raise NotImplementedError("accumulate_grad is not supported for fp8") + old_communication_hook = state._communication_hook state._communication_hook = state._get_fp8_comm_hook() old_post_backward_hook(state, handle, *unused) From e2d2f03ffbd3403dd955371b9d0f01d790e64bb7 Mon Sep 17 00:00:00 2001 From: tocean Date: Tue, 9 Jan 2024 08:24:03 +0000 Subject: [PATCH 12/18] fix lint --- examples/mnist_ddp.py | 5 ++++- examples/mnist_fsdp.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/mnist_ddp.py b/examples/mnist_ddp.py index 0864b6bc..85a56fac 100644 --- a/examples/mnist_ddp.py +++ b/examples/mnist_ddp.py @@ -1,7 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""The ddp mnist example using MS-AMP. It is adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py.""" +"""The ddp mnist example using MS-AMP. + +It is adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py. +""" from __future__ import print_function import os diff --git a/examples/mnist_fsdp.py b/examples/mnist_fsdp.py index 1e894e4d..5faed7b5 100644 --- a/examples/mnist_fsdp.py +++ b/examples/mnist_fsdp.py @@ -73,9 +73,9 @@ def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler sampler.set_epoch(epoch) for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(rank), target.to(rank) - + optimizer.zero_grad() - + output = model(data) loss = F.nll_loss(output, target, reduction='sum') loss.backward() From 2379e32b1cf948197e98d7a31f7e3e17008b8442 Mon Sep 17 00:00:00 2001 From: tocean Date: Tue, 9 Jan 2024 08:27:02 +0000 Subject: [PATCH 13/18] fix lint --- msamp/fsdp/_runtime_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/msamp/fsdp/_runtime_utils.py b/msamp/fsdp/_runtime_utils.py index f6bb674f..de799c86 100644 --- a/msamp/fsdp/_runtime_utils.py +++ b/msamp/fsdp/_runtime_utils.py @@ -18,9 +18,9 @@ def _fp8_post_backward_hook(state, handle, *unused): if not isinstance(state, FullyShardedDataParallel): return old_post_backward_hook(state, handle, *unused) - accumulate_grad = hasattr(state._flat_param, "_saved_grad_shard") + accumulate_grad = hasattr(state._flat_param, '_saved_grad_shard') if accumulate_grad and not torch.all(state._flat_param._saved_grad_shard == 0): - raise NotImplementedError("accumulate_grad is not supported for fp8") + raise NotImplementedError('accumulate_grad is not supported for fp8') old_communication_hook = state._communication_hook state._communication_hook = state._get_fp8_comm_hook() From 6ffc5c17f08402cbd47a0e460e4dd7f8e89cd22a Mon Sep 17 00:00:00 2001 From: tocean Date: Tue, 9 Jan 2024 12:05:40 +0000 Subject: [PATCH 14/18] fix bug in optimizer --- msamp/fsdp/flat_param.py | 21 --------------------- msamp/optim/adamw.py | 18 ++++++++++++++---- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/msamp/fsdp/flat_param.py b/msamp/fsdp/flat_param.py index 0191eef1..56de6a3f 100644 --- a/msamp/fsdp/flat_param.py +++ b/msamp/fsdp/flat_param.py @@ -44,27 +44,6 @@ def _init_flat_param( self.flat_param._original_shapes = original_shapes self.flat_param._scaling_metas = scaling_metas - def _init_shard_metadata( - self, - numel_padded: int, - start: int, - end: int, - ) -> None: - """Initialize the shard metadata for the flat parameter and create a group for each fp8 parameter.""" - super()._init_shard_metadata(numel_padded, start, end) - start_offset = 0 - end_offset = 0 - sharded_flat_param_numel = self.flat_param.numel() - for i, meta in enumerate(self.flat_param._metas): - start_offset += self.flat_param._numels[i - 1] if i >= 1 else 0 - end_offset += self.flat_param._numels[i] - if meta is not None: - start_rank = start_offset // sharded_flat_param_numel - end_rank = (end_offset - 1) // sharded_flat_param_numel - ranks = list(range(start_rank, end_rank + 1)) - meta.group = dist.new_group(ranks=ranks) - meta.rank = ranks[0] - def _use_unsharded_views(self, as_params: bool) -> None: """Use unsharded views of the flat parameter. diff --git a/msamp/optim/adamw.py b/msamp/optim/adamw.py index f65c41ba..d71b41c7 100644 --- a/msamp/optim/adamw.py +++ b/msamp/optim/adamw.py @@ -315,9 +315,19 @@ def step(self): # Copy master weight to weight for i, param in enumerate(self.original_params): if hasattr(param, '_meta') and param._meta is not None: + hp_data = None + if param.numel() == 0: + param._meta.amax[0].zero_() + else: + hp_data = self.master_weights[i].float() + param._meta.amax[0] = hp_data.abs().max() + + dist.all_reduce(param._meta.amax[0], op=dist.ReduceOp.MAX) + param._meta.reset_scaling_factor() if param.numel() > 0: - data = self.master_weights[i].float().cast(param._meta.qtype, param._meta, True) \ - .value.view(torch.float32) + with ScalingMeta.in_time_scaling_context(False): + data = hp_data.cast(param._meta.qtype, param._meta, False) \ + .value.view(torch.float32) param.data.copy_(data) - # broadcast scale_inv - dist.broadcast(param._meta.scale_inv, src=param._meta.rank) + else: + param._meta.scale_inv.data.copy_(torch.reciprocal(param._meta.scale)) From 5ea7679a59fed8d03e38b70b19490c02fe8ff2e9 Mon Sep 17 00:00:00 2001 From: tocean Date: Wed, 10 Jan 2024 07:02:24 +0000 Subject: [PATCH 15/18] fix comments --- msamp/fsdp/flat_param.py | 1 - tests/fsdp/test_fsdp_replacer.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/msamp/fsdp/flat_param.py b/msamp/fsdp/flat_param.py index 56de6a3f..35807f7f 100644 --- a/msamp/fsdp/flat_param.py +++ b/msamp/fsdp/flat_param.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn -import torch.distributed as dist from torch.distributed.fsdp.flat_param import FlatParamHandle diff --git a/tests/fsdp/test_fsdp_replacer.py b/tests/fsdp/test_fsdp_replacer.py index 0e4405fb..d53b0514 100644 --- a/tests/fsdp/test_fsdp_replacer.py +++ b/tests/fsdp/test_fsdp_replacer.py @@ -9,8 +9,8 @@ import torch import torch.nn as nn -from tests.helper import decorator +from tests.helper import decorator from msamp.fsdp import FsdpReplacer From 8388adc070a4501de02a229a4ddebc91f1768bbb Mon Sep 17 00:00:00 2001 From: tocean Date: Wed, 10 Jan 2024 07:10:50 +0000 Subject: [PATCH 16/18] remove state type check in _fp8_post_backward_hook --- msamp/fsdp/_runtime_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/msamp/fsdp/_runtime_utils.py b/msamp/fsdp/_runtime_utils.py index de799c86..9eb25c73 100644 --- a/msamp/fsdp/_runtime_utils.py +++ b/msamp/fsdp/_runtime_utils.py @@ -15,9 +15,6 @@ @torch.no_grad() def _fp8_post_backward_hook(state, handle, *unused): """A post-backward communication hook which supports fp8.""" - if not isinstance(state, FullyShardedDataParallel): - return old_post_backward_hook(state, handle, *unused) - accumulate_grad = hasattr(state._flat_param, '_saved_grad_shard') if accumulate_grad and not torch.all(state._flat_param._saved_grad_shard == 0): raise NotImplementedError('accumulate_grad is not supported for fp8') From a8ff215851cd22c9cef2f303f965553bf9745949 Mon Sep 17 00:00:00 2001 From: tocean Date: Wed, 10 Jan 2024 08:22:01 +0000 Subject: [PATCH 17/18] fix ut error --- msamp/fsdp/_runtime_utils.py | 1 - tests/fsdp/test_fsdp_distributed.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/msamp/fsdp/_runtime_utils.py b/msamp/fsdp/_runtime_utils.py index 9eb25c73..566590fd 100644 --- a/msamp/fsdp/_runtime_utils.py +++ b/msamp/fsdp/_runtime_utils.py @@ -6,7 +6,6 @@ from typing import no_type_check import torch -from torch.distributed.fsdp import FullyShardedDataParallel old_post_backward_hook = torch.distributed.fsdp._runtime_utils._post_backward_hook diff --git a/tests/fsdp/test_fsdp_distributed.py b/tests/fsdp/test_fsdp_distributed.py index cda048cc..9379cb75 100644 --- a/tests/fsdp/test_fsdp_distributed.py +++ b/tests/fsdp/test_fsdp_distributed.py @@ -55,3 +55,4 @@ def test_fp8_fsdp(self): for param in model.parameters(): if param.numel() > 0: assert param.grad is not None + param.grad.zero_() From d3acfa24b10d0ac8835a228dd66e972a6f2d920c Mon Sep 17 00:00:00 2001 From: tocean Date: Thu, 11 Jan 2024 02:02:41 +0000 Subject: [PATCH 18/18] fix comments --- msamp/fsdp/_runtime_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/msamp/fsdp/_runtime_utils.py b/msamp/fsdp/_runtime_utils.py index 566590fd..0e8de20d 100644 --- a/msamp/fsdp/_runtime_utils.py +++ b/msamp/fsdp/_runtime_utils.py @@ -15,8 +15,8 @@ def _fp8_post_backward_hook(state, handle, *unused): """A post-backward communication hook which supports fp8.""" accumulate_grad = hasattr(state._flat_param, '_saved_grad_shard') - if accumulate_grad and not torch.all(state._flat_param._saved_grad_shard == 0): - raise NotImplementedError('accumulate_grad is not supported for fp8') + if accumulate_grad and torch.count_nonzero(state._flat_param._saved_grad_shard).item() > 0: + raise NotImplementedError('accumulate_grad is not supported yet for fp8') old_communication_hook = state._communication_hook state._communication_hook = state._get_fp8_comm_hook()