-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Deepspeed integration #4693
base: main
Are you sure you want to change the base?
Deepspeed integration #4693
Changes from 12 commits
e2ac4b5
619657e
a329fd2
00666c2
f0da3bf
d0e8a68
0a74573
eaf8aa5
498d3a2
a211b5e
3b30e21
fdd888b
ef544c9
083a6d0
0f8d5b7
4e4f7d7
f48ea19
b3328fc
966e296
2fdb7c0
95a9e5f
b152fe1
5b82534
4fb6604
e21fb1f
703843c
e7b8825
3fc1835
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from allennlp.training.deepspeed.trainer import DeepspeedTrainer | ||
jacobdanovitch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from allennlp.training.deepspeed.optimizers import ( | ||
FusedAdamOptimizer, | ||
DeepspeedCPUAdamOptimizer, | ||
FusedLambOptimizer | ||
) | ||
|
||
try: | ||
from allennlp.training.deepspeed.sparse_transformer_embedder import SparseTransformerEmbedder | ||
except ImportError: | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from typing import Union, Dict, Any, List, Tuple, Optional | ||
|
||
import logging | ||
import os | ||
import re | ||
import shutil | ||
import time | ||
|
||
from pathlib import Path | ||
|
||
import torch | ||
|
||
import allennlp | ||
from allennlp.nn import util as nn_util | ||
from allennlp.training import util as training_util, Checkpointer | ||
|
||
logger = logging.getLogger(__name__) | ||
_DeepspeedTrainer = "allennlp.training.deepspeed.trainer.DeepspeedTrainer" | ||
|
||
|
||
class DeepspeedCheckpointer(Checkpointer): | ||
dirkgr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# def maybe_save_checkpoint( | ||
# self, | ||
# trainer: _DeepspeedTrainer, | ||
# epoch: int, | ||
# batches_this_epoch: int | ||
# ) -> None: | ||
# 0/0 | ||
jacobdanovitch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def save_checkpoint( | ||
self, | ||
epoch: Union[int, str], | ||
trainer: _DeepspeedTrainer, | ||
is_best_so_far: bool = False, | ||
save_model_only=False, | ||
) -> None: | ||
if self._serialization_dir is None: | ||
return | ||
|
||
with trainer.get_checkpoint_state() as state: | ||
model_engine, model_state, training_states = state | ||
|
||
checkpoint_id = "deepspeed_epoch_{}".format(epoch) | ||
model_path = os.path.join(self._serialization_dir, "model_state_epoch_{}".format(epoch)) | ||
model_engine.save_checkpoint(self._serialization_dir, checkpoint_id) | ||
|
||
# TODO | ||
# Model will need a weight file to load; | ||
# not sure if ZeRO stage 2 will mess this up | ||
if not os.path.isfile(model_path): | ||
torch.save(model_state, model_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would be good to know. Have you tried the checkpointing logic? Does it work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The checkpointing works for saving; it's able to go through the training process E2E, doing the checkpointing and so on. I'm just not sure how model-parallel affects this part, if it's saving the entire model state or just the state local to that device. I imagine that this could be validated in a test case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That seems important. I don't want to make a release claiming that this works and then it doesn't in a fairly common use case. |
||
if save_model_only: | ||
return | ||
|
||
training_path = os.path.join( | ||
self._serialization_dir, "training_state_epoch_{}.th".format(epoch) | ||
) | ||
if not os.path.isfile(training_path): | ||
torch.save({**training_states, "epoch": epoch}, training_path) | ||
|
||
# The main checkpointing logic is now done, this is just shuffling files around, to keep | ||
# track of best weights, and to remove old checkpoints, if desired. | ||
if is_best_so_far: | ||
logger.info( | ||
"Best validation performance so far. Copying weights to '%s/best.th'.", | ||
self._serialization_dir, | ||
) | ||
shutil.copyfile(model_path, os.path.join(self._serialization_dir, "best.th")) | ||
|
||
engine_dir = os.path.join(self._serialization_dir, "best_deepspeed") | ||
shutil.rmtree(engine_dir, ignore_errors=True) # in case no previous checkpoints | ||
shutil.copytree(os.path.join(self._serialization_dir, checkpoint_id), engine_dir) | ||
|
||
if ( | ||
self._num_serialized_models_to_keep is not None | ||
and self._num_serialized_models_to_keep >= 0 | ||
): | ||
self._serialized_paths.append((time.time(), model_path, training_path)) | ||
if len(self._serialized_paths) > self._num_serialized_models_to_keep: | ||
paths_to_remove = self._serialized_paths.pop(0) | ||
# Check to see if we should keep this checkpoint, if it has been longer | ||
# then self._keep_serialized_model_every_num_seconds since the last | ||
# kept checkpoint. | ||
remove_path = True | ||
if self._keep_serialized_model_every_num_seconds is not None: | ||
save_time = paths_to_remove[0] | ||
time_since_checkpoint_kept = ( | ||
save_time - self._last_permanent_saved_checkpoint_time | ||
) | ||
if ( | ||
time_since_checkpoint_kept | ||
> self._keep_serialized_model_every_num_seconds | ||
): | ||
# We want to keep this checkpoint. | ||
remove_path = False | ||
self._last_permanent_saved_checkpoint_time = save_time | ||
if remove_path: | ||
for fname in paths_to_remove[1:]: | ||
if os.path.isfile(fname): | ||
os.remove(fname) | ||
|
||
def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]: | ||
latest = super().find_latest_checkpoint() | ||
if not latest: | ||
return None | ||
|
||
model_path, training_state_path = latest | ||
|
||
checkpoints = (self._serialization_dir and Path(self._serialization_dir).glob('deepspeed_epoch_*')) or [] | ||
checkpoints = sorted(c for c in checkpoints if c.is_dir()) | ||
if not checkpoints: | ||
return None | ||
|
||
engine_path = checkpoints[-1] | ||
return engine_path, model_path, training_state_path | ||
|
||
def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't make sure, but a lot of these functions look identical to the ones from the regular checkpointer. Can you derive from that one and just override the methods that have differences? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is derived from the regular checkpointer. I might be able to clean more of this up depending on the above points; if I didn't have to re-load the torch weights and could delegate almost entirely to deepspeed, it would simplify things quite a lot. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do you find out whether you can do this? I suspect that deep speed will work best the more stuff we delegate to it. |
||
latest_checkpoint = self.find_latest_checkpoint() | ||
|
||
if latest_checkpoint is None: | ||
# No checkpoint to restore, start at 0 | ||
return {}, {}, {} | ||
|
||
checkpoint_id, model_path, training_state_path = latest_checkpoint | ||
|
||
model_state = torch.load(model_path, map_location=nn_util.device_mapping(-1)) | ||
training_state = torch.load(training_state_path, map_location=nn_util.device_mapping(-1)) | ||
return checkpoint_id, model_state, training_state | ||
|
||
def best_model_state(self) -> Dict[str, Any]: | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import Dict, Any | ||
from enum import IntEnum | ||
from allennlp.common import FromParams | ||
from dataclasses import dataclass, asdict | ||
|
||
|
||
@dataclass | ||
class DeepspeedFP16Config(FromParams): | ||
enabled: bool = True | ||
loss_scale: float = 0. | ||
initial_scale_power: int = 32 | ||
loss_scale_window: int = 1000 | ||
hysteresis: int = 2 | ||
min_loss_scale: float = 1. | ||
|
||
@dataclass | ||
class DeepspeedAMPConfig(FromParams): | ||
enabled: bool = False | ||
opt_level: str = "O1" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought AMP was dead and we now use things built directly into PyTorch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah but it's a required install for deepspeed and you can use it there, so I thought I would keep it in for compatibility. It can be removed if need be. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm. Surely in the next DeepSpeed version they will make it use PyTorch-native AMP. But if we need it for now, that's cool. |
||
|
||
@dataclass | ||
class DeepspeedOptimizerConfig(FromParams): | ||
type: str | ||
params: Dict[str, Any] | ||
|
||
@dataclass | ||
class DeepspeedLRSchedulerConfig(FromParams): | ||
type: str | ||
params: Dict[str, Any] | ||
|
||
class DeepspeedZeROStage(IntEnum): | ||
DISABLED = 0 | ||
OPTIMIZER = 1 | ||
GRADIENT = 2 | ||
|
||
@dataclass | ||
class DeepspeedZeROConfig(FromParams): | ||
stage: DeepspeedZeROStage = DeepspeedZeROStage.GRADIENT | ||
allgather_partitions: bool = True | ||
allgather_bucket_size: int = 500000000 | ||
overlap_comm: bool = False | ||
reduce_scatter: bool = True | ||
reduce_bucket_size: int = 500000000 | ||
contiguous_gradients: bool = False | ||
cpu_offload: bool = False | ||
|
||
|
||
@dataclass | ||
class DeepspeedConfig(FromParams): | ||
zero_optimization: DeepspeedZeROConfig | ||
fp16: DeepspeedFP16Config | ||
amp: DeepspeedAMPConfig = DeepspeedAMPConfig() | ||
optimizer: DeepspeedOptimizerConfig = None | ||
scheduler: DeepspeedLRSchedulerConfig = None | ||
|
||
zero_allow_untested_optimizer: bool = True | ||
wall_clock_breakdown: bool = False | ||
|
||
def to_dict(self): | ||
return asdict(self) | ||
|
||
|
||
@dataclass | ||
class DeepspeedArgs(FromParams): | ||
local_rank: int | ||
deepspeed: bool = True | ||
deepspeed_mpi: bool = False | ||
deepspeed_config: str = None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import List, Tuple, Dict, Any | ||
|
||
import torch | ||
|
||
from apex.optimizers.fused_adam import FusedAdam | ||
from deepspeed.ops.adam import DeepSpeedCPUAdam | ||
from deepspeed.ops.lamb import FusedLamb | ||
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam | ||
|
||
from allennlp.training.optimizers import Optimizer, make_parameter_groups | ||
|
||
@Optimizer.register("fused_adam") | ||
class FusedAdamOptimizer(Optimizer, FusedAdam): | ||
def __init__( | ||
self, | ||
model_parameters: List[Tuple[str, torch.nn.Parameter]], | ||
parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None, | ||
lr: float = 0.001, | ||
betas: Tuple[float, float] = (0.9, 0.999), | ||
eps: float = 1e-08, | ||
weight_decay: float = 0.0, | ||
amsgrad: bool = False, | ||
bias_correction: bool =True, | ||
adam_w_mode: bool = True, | ||
set_grad_none: bool = True, | ||
): | ||
super().__init__( | ||
params=make_parameter_groups(model_parameters, parameter_groups), | ||
lr=lr, | ||
betas=betas, | ||
eps=eps, | ||
weight_decay=weight_decay, | ||
amsgrad=amsgrad, | ||
bias_correction=bias_correction, | ||
adam_w_mode=adam_w_mode, | ||
set_grad_none=set_grad_none, | ||
) | ||
|
||
# This does not currently work | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not? If it doesn't work and is not necessary, can we remove it? |
||
@Optimizer.register("cpu_adam") | ||
class DeepspeedCPUAdamOptimizer(Optimizer, DeepSpeedCPUAdam): | ||
def __init__( | ||
self, | ||
model_parameters: List[Tuple[str, torch.nn.Parameter]], | ||
parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None, | ||
lr: float = 0.001, | ||
betas: Tuple[float, float] = (0.9, 0.999), | ||
eps: float = 1e-08, | ||
weight_decay: float = 0.0, | ||
amsgrad: bool = False, | ||
): | ||
super().__init__( | ||
model_params=make_parameter_groups(model_parameters, parameter_groups), | ||
lr=lr, | ||
betas=betas, | ||
eps=eps, | ||
weight_decay=weight_decay, | ||
amsgrad=amsgrad | ||
) | ||
|
||
@Optimizer.register("fused_lamb") | ||
class FusedLambOptimizer(Optimizer, FusedLamb): | ||
def __init__( | ||
self, | ||
model_parameters: List[Tuple[str, torch.nn.Parameter]], | ||
parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None, | ||
lr: float = 0.001, | ||
betas: Tuple[float, float] = (0.9, 0.999), | ||
eps: float = 1e-08, | ||
eps_inside_sqrt: bool = False, | ||
weight_decay: float = 0.0, | ||
amsgrad: bool = False, | ||
max_grad_norm: float = 0., | ||
max_coeff: float = 10.0, | ||
min_coeff: float = 0.01 | ||
): | ||
super().__init__( | ||
params=make_parameter_groups(model_parameters, parameter_groups), | ||
lr=lr, | ||
betas=betas, | ||
eps=eps, | ||
weight_decay=weight_decay, | ||
amsgrad=amsgrad, | ||
max_grad_norm=max_grad_norm, | ||
max_coeff=max_coeff, | ||
min_coeff=min_coeff, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder | ||
from allennlp.modules.token_embedders.pretrained_transformer_embedder import PretrainedTransformerEmbedder | ||
|
||
from deepspeed.ops.sparse_attention.sparse_attention_utils import SparseAttentionUtils | ||
|
||
# Doesn't work yet | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just a question of scope. I think it would be a very neat little addition for not much cost, but I've been having trouble getting it installed properly (doing all of this on a SLURM cluster is leading to most of these problems, esp. not being able to use Docker). Fwiw, someone in the issue thread was expressing interest in this. So overall it's just a question of if you'd like me to give it a shot to get it working now, or leave it for the future. Happy either way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be a great addition, but let's not make this PR bigger than it already is. It would be a nice follow-up though, especially if those classes end up being the secret to those massive promises that they have on their website. You can make PRs that branch off of other PRs, just in case you didn't know :-P |
||
@TokenEmbedder.register('sparse_transformer') | ||
class SparseTransformerEmbedder(PretrainedTransformerEmbedder): | ||
class __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.transformer_model = SparseAttentionUtils.replace_model_self_attention_with_sparse_self_attention(self.transformer_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this leftover debug code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depends on how we want to include this. Based on my experience, I wouldn't recommend making deepspeed a required dependency. If we're doing the
pip install allennlp[deepspeed]
thing, this could be replaced/updated (not sure offhand how that gets handled but I can look for some examples).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't mind doing the work making it optional, then let's make it optional.