diff --git a/pytorch_pfn_extras/handler/_logic.py b/pytorch_pfn_extras/handler/_logic.py index 8851af046..ba77f79e6 100644 --- a/pytorch_pfn_extras/handler/_logic.py +++ b/pytorch_pfn_extras/handler/_logic.py @@ -206,7 +206,16 @@ def consume_options(self, options: Dict[str, Any]) -> None: "torch.cuda.amp.GradScaler object" ) - def _forward(self, model: torch.nn.Module, batch: Any) -> Any: + def forward(self, model: torch.nn.Module, batch: Any) -> Any: + """Get the result of inputting the sampled data batch into the model. + + Args: + model (torch.nn.Module): Model to input data. + batch (Any): Mini-batch sampled from data loader running in Trainer + + Returns: + Any: Output of the model. loss is assumed to be output. + """ if isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple return model(batch) @@ -309,7 +318,7 @@ def train_step( """ with self._autocast.autocast(): optimizers[self.model_name].zero_grad() - outs = self._forward(models[self.model_name], batch) + outs = self.forward(models[self.model_name], batch) to_back_outs = _normalize_outputs(outs) self._backward(to_back_outs) return outs @@ -369,7 +378,7 @@ def eval_step( """ model = models[self.model_name] with self._autocast.autocast(): - outs = self._forward(model, batch) + outs = self.forward(model, batch) return outs @@ -530,7 +539,7 @@ def train_step( def clousure() -> ClousureModelOutput: with self._autocast.autocast(): optimizers[self.model_name].zero_grad() - outs = self._forward(models[self.model_name], batch) + outs = self.forward(models[self.model_name], batch) to_back_outs = _normalize_outputs(outs) if len(to_back_outs) > 1: raise RuntimeError(