Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddp backend fix and documentation changes #68

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions pymarlin/core/module_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,16 @@ def on_end_train(self, global_step:int):
class ModuleInterface(torch.nn.Module, CallbackInterface):
"""Interface for PyTorch modules.

This interface contains model architecture in the form of a PyTorch
`nn.Module` together with optimizers and schedules, train and validation
step recipes and any callbacks.

Note: The forward function is overridden.

Note: Users are encouraged to override the `train_step` and `val_step`
methods.
This is where scientists and researchers write their training code.
This module can be thought of as the implementation of the training recipe.
ModuleInterface inherits from nn.Module and hence can be treated like any Pytorch module.
Scientists need to implement the abstract functions to create a training recipe.

Note: The forward function is overridden and replaced with two functions train_step and val_step
to differentiate training and validation loop code.

ModuleInterface also inherits from CallBackInterface.
Users can optionally override callbacks like on_end_val_epoch() to calculate metrics.
"""
@abstractmethod
def get_optimizers_schedulers(
Expand Down
51 changes: 37 additions & 14 deletions pymarlin/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,27 @@
class TrainerArguments:
"""
Trainer Arguments class.
Args:
epochs(int): Defaults to 1
use_gpu(bool): Defaults to True
train_batch_size(int): Global batch size for training. This value need not be changed for distributed training.
This is number of samples the model has seen before weight update.
Defaults to 1
gpu_batch_size_limit(int): Maximum training batch size that each GPU can handle. Used for calculating gradient accumulation. Defaults to 512
val_batch_size(int): Maximum training batch size that each GPU can handle. Defaults to 1
max_train_steps_per_epoch(Optional[int]): maximum train global steps per epoch. Mostly used for sanity check. Defaults to None which trains the entire data loader
max_val_steps_per_epoch(Optional[int]): maximum validation global steps per epoch. Defaults to None which runs validation for entire dataloader
clip_grads(bool): Enables or disables gradient clipping. uses `torch.nn.utils.clip_grad_norm_`. Defaults to True
max_grad_norm(float): Maximum norm for gradient clipping. Defaults to 1.0
reset_optimizers_schedulers(bool): Weather reseat optimizer and scheduler after loading checkpoint. Generally useful if scheduler is almost exhausted. Defaults to False
checkpointer_args(DefaultCheckpointerArguments): Defaults to DefaultCheckpointerArguments()
distributed_training_args(DistributedTrainingArguments): Instance of DistributedTrainingArguments. Defaults to None

writers(List): List of writers. Can be a combination of `pymarlin.utils.writer.base.Writer` instances and strings. Only `stdout, aml and tensorboard` are supported in strings. Defaults to ["stdout", "aml", "tensorboard"]
stats_args(stats.StatInitArguments): Defaults to stats.StatInitArguments()
writer_args(WriterInitArguments): Defaults to WriterInitArguments()
disable_tqdm(bool): Disable tqdm style output. Generally disabled if output is piped to a file like in AzureML. Defaults to False
log_level(str): Defaults to "INFO"
"""
epochs: int = 1
use_gpu: bool = True
Expand Down Expand Up @@ -89,19 +110,10 @@ def validate(self):


class Trainer(AbstractTrainer):
"""Orchestrates model training.

Args:
module (ModuleInterface): Contains model definition, train and validation
definition, optimizer and scheduler, and optional callbacks.
args (TrainerArguments): Training hyperparameters.

Optional keyword arguments:
trainer_backend (TrainerBackend): How the training will be carried out.
For example, the training is distributed and/or using AMP (automatic mixed precision).
This can also be specified in args using the backend keyword.
Defaults to singleprocess. Options are: sp (singleprocess), sp-amp, ddp, ddp-amp.
checkpointer (AbstractCheckpointer): Used to handle model checkpointing.
"""The bridge between TrainerBackend and ModuleInterface: this module takes care of device management,
rank fetching, checkpointing, reloading. Also handles all the complicated calculations for mini batch size,
gradient accumulation, number of remaining epochs, initializing stats writers like tensor board,
restarting training from previous state etc.
"""

def __init__(
Expand All @@ -111,8 +123,19 @@ def __init__(
trainer_backend: Optional[trn.TrainerBackend] = None,
checkpointer: Optional[AbstractCheckpointer] = None
):

"""
Initializes stats, writers, trainer_backend and other helper functions
Args:
module (pymarlin.ModuleInterface): Contains model definition, train and validation
definition, optimizer and scheduler, and optional callbacks.
args (pymarlin.core.trainer.TrainerArguments): Training hyperparameters.

Optional keyword arguments:
trainer_backend (pymarlin.core.trainer.TrainerBackend): How the training will be carried out.
For example, the training is distributed and/or using AMP (automatic mixed precision).
This can also be specified in args using the backend keyword.
Defaults to singleprocess. Options are: sp (singleprocess), sp-amp, ddp, ddp-amp.
checkpointer (pymarlin.utils.checkpointer.checkpoint_utils.AbstractCheckpointer): Used to handle model checkpointing.
"""
self.module = module
self.args = args
Expand Down
76 changes: 66 additions & 10 deletions pymarlin/core/trainer_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Trainer Backend module:

Responsible for training/validating the ModuleInterface for one entire epoch. PyMarlin offers various useful backend implementations, such as SingleProcess, SingleProcessAmp, and DDPTrainerBackend.

Currently we support:
1. SingleProcess
2. SingleProcess Amp
Expand Down Expand Up @@ -60,8 +62,21 @@ def build_trainer_backend(trainer_backend_name, *args, **kwargs):
@dataclasses.dataclass
class TrainerBackendArguments:
"""
Trainer Backend Arguments dataclass.
Trainer Backend Arguments dataclass. attributes can also be passed as arguments
Args:
model (module_interface.ModuleInterface): PyMarlin module
device (Union[torch.device, str, int]): device
max_train_steps_per_epoch (Optional[int]): maximum train global steps per epoch. Mostly used for sanity check
max_val_steps_per_epoch (Optional[int]): maximum validation global steps per epoch.
distributed_training_args (DistributedTrainingArguments): DistributedTrainingArguments instance
optimizers (Iterable[torch.optim.Optimizer]): one or more optimizers. All optimizers are stepped at once at the end of each global step
schedulers (Optional[Iterable[torch.optim.lr_scheduler._LRScheduler]] = None # pylint (disable=protected-access): One or more schedulers. All schedulers are stepped at once at the end of each global step
gradient_accumulation (int = 1): gradient accumulation value
clip_grads (bool = True): Wnables or disables gradient clipping. uses `torch.nn.utils.clip_grad_norm_`
max_grad_norm (float = 1.0): Maximum norm for gradient clipping
disable_tqdm (bool = False): Disable tqdm style output. Generally disabled if output is piped to a file like in AzureML
"""

model: module_interface.ModuleInterface
device: Union[torch.device, str, int]
max_train_steps_per_epoch: Optional[int]
Expand All @@ -81,21 +96,34 @@ class TrainerBackendArguments:

class TrainerBackend(ABC):
"""
Trainer Backend abstract class.
Trainer Backend abstract class. This is responsible for running training and validation on an entire dataloader.
"""
def __init__(self):
pass

@abstractmethod
def init(self, args: TrainerBackendArguments):
'''
called before the start of training and validation.
TrainerBackend implementations should handle all initializaitons like module wrapping, optimizer modifications, distributed env initialization in this method.

Args:
args (TrainerBackendArguments): instance of TrainerBackendArguments
'''
pass

@abstractmethod
def train_dl(self, *args, **kwargs):
def train_dl(self, dataloader, callback: module_interface.CallbackInterface):
'''
Train an entire dataloader
'''
pass

@abstractmethod
def validate_dl(self, *args, **kwargs):
def validate_dl(self, dataloader):
'''
Train an entire dataloader
'''
pass

@abstractmethod
Expand All @@ -118,10 +146,16 @@ def val_sampler(self):

@abstractmethod
def get_state(self):
'''
Returns trainer backend state which is saved along with model checkpoint
'''
pass

@abstractmethod
def update_state(self, state):
'''
Reload state saved in get_state()
'''
pass


Expand Down Expand Up @@ -188,9 +222,6 @@ class SingleProcess(TrainerBackend):

# pylint: disable=super-init-not-called
def __init__(self):
"""
Single process trainer_backend
"""
self.global_step_completed = 0
self.batches_completed = 0
self.distributed = False
Expand All @@ -200,6 +231,10 @@ def stats(self):
return stats.global_stats

def init(self, args: TrainerBackendArguments):
'''
Args:
args (TrainerBackendArguments): init arguments
'''
self.args = args
self.model = self.args.model
if not self.distributed:
Expand All @@ -213,6 +248,14 @@ def get_global_steps_completed(self):
return self.global_step_completed

def train_dl(self, dataloader, callback: module_interface.CallbackInterface):
'''
Iterates though the dataloader and trains the ModuleInterface passed in init()
Handles forward operation, backward operation , gradient accumulation, optimizer step, scheduler step, collecting outputs of train_step

Args:
dataloader : can be a iterator which returns a batch of samples ar every iterations
callback (module_interface.CallbackInterface) : callback instance
'''

epoch_collector = OutputCollector()
global_step_collector = OutputCollector()
Expand Down Expand Up @@ -294,6 +337,13 @@ def _clip_gradients(self):
)

def validate_dl(self, dataloader):
'''
Iterates though the dataloader and runs evaluations on the ModuleInterface passed in init()
Handles forward operation and output collection

Args:
dataloader: can be a iterator which returns a batch of samples ar every iterations
'''
collector = OutputCollector()
for i, batch in enumerate(tqdm(dataloader, desc=f"Validation {self.args.distributed_training_args.global_rank}", disable=self.args.disable_tqdm)):
if (
Expand Down Expand Up @@ -347,7 +397,7 @@ def update_state(self, state) -> None:
Update the trainer_backend from a checkpointed state.

Args:
state (dict) : Output of get_state() during checkpointing
state (dict): Output of get_state() during checkpointing
"""
if state:
self.global_step_completed = state["global_step_completed"]
Expand Down Expand Up @@ -571,6 +621,11 @@ class DDPTrainerBackend(AbstractTrainerBackendDecorator):
"""
# pylint: disable=super-init-not-called
def __init__(self, trainer_backend, gather_frequency: Optional[int] = None):
'''
Args:
gather_frequency (Optional[int]) : Unused. maximum samples for one all_gather operation.
There was issue with mismatched number of chunks across nodes which halted all_gather.
'''
self.trainer_backend = trainer_backend
self.gather_frequency = gather_frequency
self.trainer_backend.distributed = True
Expand Down Expand Up @@ -672,10 +727,11 @@ def gather_tensors_on_cpu(self, x: torch.tensor):
Gathered tensor on the cpu.
"""
n_samples = len(x)
self._set_gather_frequency(n_samples)
gather_frequency = n_samples

gathered = []
n_chunks = n_samples // self.gather_frequency + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krishansubudhi For my understanding , why did we have to chunk before? I assumed it was to avoid exceeding GPU memory limit but it looks like we only move tensors to GPU in this loop and never out of it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding @aminsaied who also initially created the DDP Trainer backend and chunking logic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop first moves the tensors to GPU, then does all gather op, then moves the gathered tensors back to CPU. I believe this was at the request of @gshruti95 at the time for a specific workload that was being tested (keep me honest Shruti).

Copy link
Contributor

@gshruti95 gshruti95 Oct 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided to introduce chunking in case of potential memory or timeout issues when trying to all gather for pretraining workloads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this logic will have to change back then, chunking needs to be implemented correctly and not hard coded to 1, but can be for the time being (just will be slow I think).

n_chunks = 1 # hardcoded to one as there was issue with mismatched number of chunks across nodes
#for NER type scenarios where token labels are flattened and filtered
for i in range(n_chunks):
# get chunk on cpu
chunk_cpu = x[i * self.gather_frequency: (i + 1) * self.gather_frequency]
Expand Down
6 changes: 6 additions & 0 deletions pymarlin/utils/stats/basic_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
class StatInitArguments:
"""
Stats Arguments.
Args:
log_steps(int): Interval of logging. If log_steps is 50 and metric X is updated every step, then only 50th step value will be logged. Defaults to 1
update_system_stats(bool): Logs system stats like CPU, RAM, GPU usage when enabled. Defaults to False
log_model_steps(int): Interval to log model weight norm and grad norm. Defaults to 1000
exclude_list(str): Regular expression which when matched with parameter name, won't print weight norm and grad norms for that parameter.
Defaults to r"bias|LayerNorm|layer\.[3-9]|layer\.1(?!1)|layer\.2(?!3)"
"""
log_steps: int = 1
update_system_stats: bool = False
Expand Down
6 changes: 6 additions & 0 deletions pymarlin/utils/writer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
class WriterInitArguments:
"""
Writer Arguments.
Args:
tb_log_dir(str): Defaults to 'logs'
tb_logpath_parent_env(str): Defaults to None
tb_log_multi(bool): Defaults to False
tb_log_hist_steps(int): Defaults to 20000
model_log_level(str): Defaults to 'INFO'
"""
tb_log_dir: str = 'logs'
tb_logpath_parent_env: str = None
Expand Down
6 changes: 3 additions & 3 deletions website/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"file-loader": "^6.2.0",
"react": "^17.0.1",
"react-dom": "^17.0.1",
"url-loader": "^4.1.1",
"trim": "^0.0.3"
"trim": "^0.0.3",
"url-loader": "^4.1.1"
},
"browserslist": {
"production": [
Expand All @@ -37,4 +37,4 @@
"last 1 safari version"
]
}
}
}