Skip to content

Commit

Permalink
More docs
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Dec 9, 2024
1 parent e2519c8 commit 9c3308a
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,23 @@ def check_inputs(self, flat_inputs: List[Any]) -> None:
# keep in order to guarantee 100% BC with v1. (It's defined in
# __init_subclass__ below).
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
"""Method to override for custom transforms.
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`"""
return dict()

def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
kernel = _get_kernel(functional, type(inpt), allow_passthrough=True)
return kernel(inpt, *args, **kwargs)

def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
"""Method to override for custom transforms.
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`"""
raise NotImplementedError

def forward(self, *inputs: Any) -> Any:
"""Do not override this! Use ``transform()`` instead."""
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])

self.check_inputs(flat_inputs)
Expand Down

0 comments on commit 9c3308a

Please sign in to comment.