From 4b6f2ba7b8ec5849c67d26c5de4726fa6e056e4a Mon Sep 17 00:00:00 2001 From: Krishan Subudhi Date: Wed, 18 Aug 2021 16:33:25 +0530 Subject: [PATCH] backend fix and documentation changes --- pymarlin/core/module_interface.py | 18 ++++--- pymarlin/core/trainer.py | 51 +++++++++++++------ pymarlin/core/trainer_backend.py | 76 +++++++++++++++++++++++++---- pymarlin/utils/stats/basic_stats.py | 6 +++ pymarlin/utils/writer/base.py | 6 +++ website/package.json | 6 +-- 6 files changed, 128 insertions(+), 35 deletions(-) diff --git a/pymarlin/core/module_interface.py b/pymarlin/core/module_interface.py index 65af91b..fcf6bc8 100644 --- a/pymarlin/core/module_interface.py +++ b/pymarlin/core/module_interface.py @@ -78,14 +78,16 @@ def on_end_train(self, global_step:int): class ModuleInterface(torch.nn.Module, CallbackInterface): """Interface for PyTorch modules. - This interface contains model architecture in the form of a PyTorch - `nn.Module` together with optimizers and schedules, train and validation - step recipes and any callbacks. - - Note: The forward function is overridden. - - Note: Users are encouraged to override the `train_step` and `val_step` - methods. + This is where scientists and researchers write their training code. + This module can be thought of as the implementation of the training recipe. + ModuleInterface inherits from nn.Module and hence can be treated like any Pytorch module. + Scientists need to implement the abstract functions to create a training recipe. + + Note: The forward function is overridden and replaced with two functions train_step and val_step + to differentiate training and validation loop code. + + ModuleInterface also inherits from CallBackInterface. + Users can optionally override callbacks like on_end_val_epoch() to calculate metrics. """ @abstractmethod def get_optimizers_schedulers( diff --git a/pymarlin/core/trainer.py b/pymarlin/core/trainer.py index 6b925d3..3706c30 100644 --- a/pymarlin/core/trainer.py +++ b/pymarlin/core/trainer.py @@ -42,6 +42,27 @@ class TrainerArguments: """ Trainer Arguments class. + Args: + epochs(int): Defaults to 1 + use_gpu(bool): Defaults to True + train_batch_size(int): Global batch size for training. This value need not be changed for distributed training. + This is number of samples the model has seen before weight update. + Defaults to 1 + gpu_batch_size_limit(int): Maximum training batch size that each GPU can handle. Used for calculating gradient accumulation. Defaults to 512 + val_batch_size(int): Maximum training batch size that each GPU can handle. Defaults to 1 + max_train_steps_per_epoch(Optional[int]): maximum train global steps per epoch. Mostly used for sanity check. Defaults to None which trains the entire data loader + max_val_steps_per_epoch(Optional[int]): maximum validation global steps per epoch. Defaults to None which runs validation for entire dataloader + clip_grads(bool): Enables or disables gradient clipping. uses `torch.nn.utils.clip_grad_norm_`. Defaults to True + max_grad_norm(float): Maximum norm for gradient clipping. Defaults to 1.0 + reset_optimizers_schedulers(bool): Weather reseat optimizer and scheduler after loading checkpoint. Generally useful if scheduler is almost exhausted. Defaults to False + checkpointer_args(DefaultCheckpointerArguments): Defaults to DefaultCheckpointerArguments() + distributed_training_args(DistributedTrainingArguments): Instance of DistributedTrainingArguments. Defaults to None + + writers(List): List of writers. Can be a combination of `pymarlin.utils.writer.base.Writer` instances and strings. Only `stdout, aml and tensorboard` are supported in strings. Defaults to ["stdout", "aml", "tensorboard"] + stats_args(stats.StatInitArguments): Defaults to stats.StatInitArguments() + writer_args(WriterInitArguments): Defaults to WriterInitArguments() + disable_tqdm(bool): Disable tqdm style output. Generally disabled if output is piped to a file like in AzureML. Defaults to False + log_level(str): Defaults to "INFO" """ epochs: int = 1 use_gpu: bool = True @@ -89,19 +110,10 @@ def validate(self): class Trainer(AbstractTrainer): - """Orchestrates model training. - - Args: - module (ModuleInterface): Contains model definition, train and validation - definition, optimizer and scheduler, and optional callbacks. - args (TrainerArguments): Training hyperparameters. - - Optional keyword arguments: - trainer_backend (TrainerBackend): How the training will be carried out. - For example, the training is distributed and/or using AMP (automatic mixed precision). - This can also be specified in args using the backend keyword. - Defaults to singleprocess. Options are: sp (singleprocess), sp-amp, ddp, ddp-amp. - checkpointer (AbstractCheckpointer): Used to handle model checkpointing. + """The bridge between TrainerBackend and ModuleInterface: this module takes care of device management, + rank fetching, checkpointing, reloading. Also handles all the complicated calculations for mini batch size, + gradient accumulation, number of remaining epochs, initializing stats writers like tensor board, + restarting training from previous state etc. """ def __init__( @@ -111,8 +123,19 @@ def __init__( trainer_backend: Optional[trn.TrainerBackend] = None, checkpointer: Optional[AbstractCheckpointer] = None ): + """ - Initializes stats, writers, trainer_backend and other helper functions + Args: + module (pymarlin.ModuleInterface): Contains model definition, train and validation + definition, optimizer and scheduler, and optional callbacks. + args (pymarlin.core.trainer.TrainerArguments): Training hyperparameters. + + Optional keyword arguments: + trainer_backend (pymarlin.core.trainer.TrainerBackend): How the training will be carried out. + For example, the training is distributed and/or using AMP (automatic mixed precision). + This can also be specified in args using the backend keyword. + Defaults to singleprocess. Options are: sp (singleprocess), sp-amp, ddp, ddp-amp. + checkpointer (pymarlin.utils.checkpointer.checkpoint_utils.AbstractCheckpointer): Used to handle model checkpointing. """ self.module = module self.args = args diff --git a/pymarlin/core/trainer_backend.py b/pymarlin/core/trainer_backend.py index 2336185..9ac9016 100644 --- a/pymarlin/core/trainer_backend.py +++ b/pymarlin/core/trainer_backend.py @@ -1,6 +1,8 @@ """ Trainer Backend module: +Responsible for training/validating the ModuleInterface for one entire epoch. PyMarlin offers various useful backend implementations, such as SingleProcess, SingleProcessAmp, and DDPTrainerBackend. + Currently we support: 1. SingleProcess 2. SingleProcess Amp @@ -60,8 +62,21 @@ def build_trainer_backend(trainer_backend_name, *args, **kwargs): @dataclasses.dataclass class TrainerBackendArguments: """ - Trainer Backend Arguments dataclass. + Trainer Backend Arguments dataclass. attributes can also be passed as arguments + Args: + model (module_interface.ModuleInterface): PyMarlin module + device (Union[torch.device, str, int]): device + max_train_steps_per_epoch (Optional[int]): maximum train global steps per epoch. Mostly used for sanity check + max_val_steps_per_epoch (Optional[int]): maximum validation global steps per epoch. + distributed_training_args (DistributedTrainingArguments): DistributedTrainingArguments instance + optimizers (Iterable[torch.optim.Optimizer]): one or more optimizers. All optimizers are stepped at once at the end of each global step + schedulers (Optional[Iterable[torch.optim.lr_scheduler._LRScheduler]] = None # pylint (disable=protected-access): One or more schedulers. All schedulers are stepped at once at the end of each global step + gradient_accumulation (int = 1): gradient accumulation value + clip_grads (bool = True): Wnables or disables gradient clipping. uses `torch.nn.utils.clip_grad_norm_` + max_grad_norm (float = 1.0): Maximum norm for gradient clipping + disable_tqdm (bool = False): Disable tqdm style output. Generally disabled if output is piped to a file like in AzureML """ + model: module_interface.ModuleInterface device: Union[torch.device, str, int] max_train_steps_per_epoch: Optional[int] @@ -81,21 +96,34 @@ class TrainerBackendArguments: class TrainerBackend(ABC): """ - Trainer Backend abstract class. + Trainer Backend abstract class. This is responsible for running training and validation on an entire dataloader. """ def __init__(self): pass @abstractmethod def init(self, args: TrainerBackendArguments): + ''' + called before the start of training and validation. + TrainerBackend implementations should handle all initializaitons like module wrapping, optimizer modifications, distributed env initialization in this method. + + Args: + args (TrainerBackendArguments): instance of TrainerBackendArguments + ''' pass @abstractmethod - def train_dl(self, *args, **kwargs): + def train_dl(self, dataloader, callback: module_interface.CallbackInterface): + ''' + Train an entire dataloader + ''' pass @abstractmethod - def validate_dl(self, *args, **kwargs): + def validate_dl(self, dataloader): + ''' + Train an entire dataloader + ''' pass @abstractmethod @@ -118,10 +146,16 @@ def val_sampler(self): @abstractmethod def get_state(self): + ''' + Returns trainer backend state which is saved along with model checkpoint + ''' pass @abstractmethod def update_state(self, state): + ''' + Reload state saved in get_state() + ''' pass @@ -188,9 +222,6 @@ class SingleProcess(TrainerBackend): # pylint: disable=super-init-not-called def __init__(self): - """ - Single process trainer_backend - """ self.global_step_completed = 0 self.batches_completed = 0 self.distributed = False @@ -200,6 +231,10 @@ def stats(self): return stats.global_stats def init(self, args: TrainerBackendArguments): + ''' + Args: + args (TrainerBackendArguments): init arguments + ''' self.args = args self.model = self.args.model if not self.distributed: @@ -213,6 +248,14 @@ def get_global_steps_completed(self): return self.global_step_completed def train_dl(self, dataloader, callback: module_interface.CallbackInterface): + ''' + Iterates though the dataloader and trains the ModuleInterface passed in init() + Handles forward operation, backward operation , gradient accumulation, optimizer step, scheduler step, collecting outputs of train_step + + Args: + dataloader : can be a iterator which returns a batch of samples ar every iterations + callback (module_interface.CallbackInterface) : callback instance + ''' epoch_collector = OutputCollector() global_step_collector = OutputCollector() @@ -294,6 +337,13 @@ def _clip_gradients(self): ) def validate_dl(self, dataloader): + ''' + Iterates though the dataloader and runs evaluations on the ModuleInterface passed in init() + Handles forward operation and output collection + + Args: + dataloader: can be a iterator which returns a batch of samples ar every iterations + ''' collector = OutputCollector() for i, batch in enumerate(tqdm(dataloader, desc=f"Validation {self.args.distributed_training_args.global_rank}", disable=self.args.disable_tqdm)): if ( @@ -347,7 +397,7 @@ def update_state(self, state) -> None: Update the trainer_backend from a checkpointed state. Args: - state (dict) : Output of get_state() during checkpointing + state (dict): Output of get_state() during checkpointing """ if state: self.global_step_completed = state["global_step_completed"] @@ -571,6 +621,11 @@ class DDPTrainerBackend(AbstractTrainerBackendDecorator): """ # pylint: disable=super-init-not-called def __init__(self, trainer_backend, gather_frequency: Optional[int] = None): + ''' + Args: + gather_frequency (Optional[int]) : Unused. maximum samples for one all_gather operation. + There was issue with mismatched number of chunks across nodes which halted all_gather. + ''' self.trainer_backend = trainer_backend self.gather_frequency = gather_frequency self.trainer_backend.distributed = True @@ -672,10 +727,11 @@ def gather_tensors_on_cpu(self, x: torch.tensor): Gathered tensor on the cpu. """ n_samples = len(x) - self._set_gather_frequency(n_samples) + gather_frequency = n_samples gathered = [] - n_chunks = n_samples // self.gather_frequency + 1 + n_chunks = 1 # hardcoded to one as there was issue with mismatched number of chunks across nodes + #for NER type scenarios where token labels are flattened and filtered for i in range(n_chunks): # get chunk on cpu chunk_cpu = x[i * self.gather_frequency: (i + 1) * self.gather_frequency] diff --git a/pymarlin/utils/stats/basic_stats.py b/pymarlin/utils/stats/basic_stats.py index e2212d8..ea8c0f0 100644 --- a/pymarlin/utils/stats/basic_stats.py +++ b/pymarlin/utils/stats/basic_stats.py @@ -18,6 +18,12 @@ class StatInitArguments: """ Stats Arguments. + Args: + log_steps(int): Interval of logging. If log_steps is 50 and metric X is updated every step, then only 50th step value will be logged. Defaults to 1 + update_system_stats(bool): Logs system stats like CPU, RAM, GPU usage when enabled. Defaults to False + log_model_steps(int): Interval to log model weight norm and grad norm. Defaults to 1000 + exclude_list(str): Regular expression which when matched with parameter name, won't print weight norm and grad norms for that parameter. + Defaults to r"bias|LayerNorm|layer\.[3-9]|layer\.1(?!1)|layer\.2(?!3)" """ log_steps: int = 1 update_system_stats: bool = False diff --git a/pymarlin/utils/writer/base.py b/pymarlin/utils/writer/base.py index 996e7c2..16f6ba5 100644 --- a/pymarlin/utils/writer/base.py +++ b/pymarlin/utils/writer/base.py @@ -8,6 +8,12 @@ class WriterInitArguments: """ Writer Arguments. + Args: + tb_log_dir(str): Defaults to 'logs' + tb_logpath_parent_env(str): Defaults to None + tb_log_multi(bool): Defaults to False + tb_log_hist_steps(int): Defaults to 20000 + model_log_level(str): Defaults to 'INFO' """ tb_log_dir: str = 'logs' tb_logpath_parent_env: str = None diff --git a/website/package.json b/website/package.json index 98996c2..68a90e9 100644 --- a/website/package.json +++ b/website/package.json @@ -22,8 +22,8 @@ "file-loader": "^6.2.0", "react": "^17.0.1", "react-dom": "^17.0.1", - "url-loader": "^4.1.1", - "trim": "^0.0.3" + "trim": "^0.0.3", + "url-loader": "^4.1.1" }, "browserslist": { "production": [ @@ -37,4 +37,4 @@ "last 1 safari version" ] } -} \ No newline at end of file +}