diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index 299cd67557..699a7221e4 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -110,7 +110,7 @@ def train_step(self, data: Union[dict, tuple, list], """ # Enable automatic mixed precision training context. with optim_wrapper.optim_context(self): - data = self.data_preprocessor(data, True) + data = self.data_preprocessor(data, True) # ! 数据前处理 (减均值除方差) losses = self._run_forward(data, mode='loss') # type: ignore parsed_losses, log_vars = self.parse_losses(losses) # type: ignore optim_wrapper.update_params(parsed_losses) diff --git a/mmengine/runner/base_loop.py b/mmengine/runner/base_loop.py index 5bae459a20..a97f6981be 100644 --- a/mmengine/runner/base_loop.py +++ b/mmengine/runner/base_loop.py @@ -23,7 +23,7 @@ def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None: # Determine whether or not different ranks use different seed. diff_rank_seed = runner._randomness_cfg.get( 'diff_rank_seed', False) - self.dataloader = runner.build_dataloader( + self.dataloader = runner.build_dataloader( # 构建dataloader dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed) else: self.dataloader = dataloader diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 5a678db7b9..f9e928d86e 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -102,7 +102,7 @@ def run(self) -> torch.nn.Module: and self._epoch >= self.val_begin and (self._epoch % self.val_interval == 0 or self._epoch == self._max_epochs)): - self.runner.val_loop.run() + self.runner.val_loop.run() # ! 验证 self.runner.call_hook('after_train') return self.runner.model diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 7d1f655aad..aaf26def87 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1364,10 +1364,10 @@ def build_dataloader(dataloader: Union[DataLoader, Dict], dataloader_cfg = copy.deepcopy(dataloader) - # build dataset + # build dataset 构建dataset dataset_cfg = dataloader_cfg.pop('dataset') if isinstance(dataset_cfg, dict): - dataset = DATASETS.build(dataset_cfg) + dataset = DATASETS.build(dataset_cfg) # 根据dataset_cfg的type构建对应的dataset类 if hasattr(dataset, 'full_init'): dataset.full_init() else: @@ -1473,7 +1473,7 @@ def build_dataloader(dataloader: Union[DataLoader, Dict], raise TypeError( 'collate_fn should be a dict or callable object, but got ' f'{collate_fn_cfg}') - data_loader = DataLoader( + data_loader = DataLoader( # 最终构建pytroch的DataLoader dataset=dataset, sampler=sampler if batch_sampler is None else None, batch_sampler=batch_sampler, @@ -1724,7 +1724,7 @@ def train(self) -> nn.Module: 'method. Please provide `train_dataloader`, `train_cfg`, ' '`optimizer` and `param_scheduler` arguments when ' 'initializing runner.') - + # ! 构建训练loop self._train_loop = self.build_train_loop( self._train_loop) # type: ignore @@ -1742,10 +1742,10 @@ def train(self) -> nn.Module: self._val_loop = self.build_val_loop( self._val_loop) # type: ignore # TODO: add a contextmanager to avoid calling `before_run` many times - self.call_hook('before_run') + self.call_hook('before_run') # 运行之前的hook # initialize the model weights - self._init_model_weights() + self._init_model_weights() # 初始化模型 # try to enable activation_checkpointing feature modules = self.cfg.get('activation_checkpointing', None) @@ -1773,9 +1773,10 @@ def train(self) -> nn.Module: # Maybe compile the model according to options in self.cfg.compile # This must be called **AFTER** model has been wrapped. self._maybe_compile('train_step') - + + # !开始训练模型 model = self.train_loop.run() # type: ignore - self.call_hook('after_run') + self.call_hook('after_run') # !运行之后的hook return model def val(self) -> dict: @@ -1874,7 +1875,7 @@ def register_hook( if 'priority' in hook: _priority = hook.pop('priority') - hook_obj = HOOKS.build(hook) + hook_obj = HOOKS.build(hook) # 构建hook类 else: hook_obj = hook @@ -1963,7 +1964,7 @@ def register_default_hooks( default_hooks[name] = hook for hook in default_hooks.values(): - self.register_hook(hook) + self.register_hook(hook) # 一个一个的注册 def register_custom_hooks(self, hooks: List[Union[Hook, Dict]]) -> None: """Register custom hooks into hook list.