diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 4bb18cf6b48..d2fed552c4f 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -508,11 +508,20 @@ are combining pairs of images together. These can be used after the dataloader Developer tools ^^^^^^^^^^^^^^^ +.. autosummary:: + :toctree: generated/ + :template: class.rst + + v2.Transform + .. autosummary:: :toctree: generated/ :template: function.rst v2.functional.register_kernel + v2.query_size + v2.query_chw + v2.get_bounding_boxes V1 API Reference diff --git a/gallery/transforms/plot_custom_transforms.py b/gallery/transforms/plot_custom_transforms.py index 19bc955b934..d1bd9455bfb 100644 --- a/gallery/transforms/plot_custom_transforms.py +++ b/gallery/transforms/plot_custom_transforms.py @@ -12,6 +12,8 @@ """ # %% +from typing import Any, Dict, List + import torch from torchvision import tv_tensors from torchvision.transforms import v2 @@ -89,33 +91,110 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured # A key feature of the builtin Torchvision V2 transforms is that they can accept # arbitrary input structure and return the same structure as output (with # transformed entries). For example, transforms can accept a single image, or a -# tuple of ``(img, label)``, or an arbitrary nested dictionary as input: +# tuple of ``(img, label)``, or an arbitrary nested dictionary as input. Here's +# an example on the built-in transform :class:`~torchvision.transforms.v2.RandomHorizontalFlip`: structured_input = { "img": img, "annotations": (bboxes, label), - "something_that_will_be_ignored": (1, "hello") + "something that will be ignored": (1, "hello"), + "another tensor that is ignored": torch.arange(10), } structured_output = v2.RandomHorizontalFlip(p=1)(structured_input) assert isinstance(structured_output, dict) -assert structured_output["something_that_will_be_ignored"] == (1, "hello") +assert structured_output["something that will be ignored"] == (1, "hello") +assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all() +print(f"The input bboxes are:\n{structured_input['annotations'][0]}") +print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") + +# %% +# Basics: override the `transform()` method +# ----------------------------------------- +# +# In order to support arbitrary inputs in your custom transform, you will need +# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the +# `.transform()` method (not the `forward()` method!). Below is a basic example: + + +class MyCustomTransform(v2.Transform): + def transform(self, inpt: Any, params: Dict[str, Any]): + if type(inpt) == torch.Tensor: + print(f"I'm transforming an image of shape {inpt.shape}") + return inpt + 1 # dummy transformation + elif isinstance(inpt, tv_tensors.BoundingBoxes): + print(f"I'm transforming bounding boxes! {inpt.canvas_size = }") + return tv_tensors.wrap(inpt + 100, like=inpt) # dummy transformation + + +my_custom_transform = MyCustomTransform() +structured_output = my_custom_transform(structured_input) + +assert isinstance(structured_output, dict) +assert structured_output["something that will be ignored"] == (1, "hello") +assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all() +print(f"The input bboxes are:\n{structured_input['annotations'][0]}") print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") # %% -# If you want to reproduce this behavior in your own transform, we invite you to -# look at our `code -# `_ -# and adapt it to your needs. -# -# In brief, the core logic is to unpack the input into a flat list using `pytree -# `_, and -# then transform only the entries that can be transformed (the decision is made -# based on the **class** of the entries, as all TVTensors are -# tensor-subclasses) plus some custom logic that is out of score here - check the -# code for details. The (potentially transformed) entries are then repacked and -# returned, in the same structure as the input. -# -# We do not provide public dev-facing tools to achieve that at this time, but if -# this is something that would be valuable to you, please let us know by opening -# an issue on our `GitHub repo `_. +# An important thing to note is that when we call ``my_custom_transform`` on +# ``structured_input``, the input is flattened and then each individual part is +# passed to ``transform()``. That is, ``transform()``` receives the input image, +# then the bounding boxes, etc. Within ``transform()``, you can decide how to +# transform each input, based on their type. +# +# If you're curious why the other tensor (``torch.arange()``) didn't get passed +# to ``transform()``, see :ref:`this note ` for more +# details. +# +# Advanced: The ``make_params()`` method +# -------------------------------------- +# +# The ``make_params()`` method is called internally before calling +# ``transform()`` on each input. This is typically useful to generate random +# parameter values. In the example below, we use it to randomly apply the +# transformation with a probability of 0.5 + + +class MyRandomTransform(MyCustomTransform): + def __init__(self, p=0.5): + self.p = p + super().__init__() + + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + apply_transform = (torch.rand(size=(1,)) < self.p).item() + params = dict(apply_transform=apply_transform) + return params + + def transform(self, inpt: Any, params: Dict[str, Any]): + if not params["apply_transform"]: + print("Not transforming anything!") + return inpt + else: + return super().transform(inpt, params) + + +my_random_transform = MyRandomTransform() + +torch.manual_seed(0) +_ = my_random_transform(structured_input) # transforms +_ = my_random_transform(structured_input) # doesn't transform + +# %% +# +# .. note:: +# +# It's important for such random parameter generation to happen within +# ``make_params()`` and not within ``transform()``, so that for a given +# transform call, the same RNG applies to all the inputs in the same way. If +# we were to perform the RNG within ``transform()``, we would risk e.g. +# transforming the image while *not* transforming the bounding boxes. +# +# The ``make_params()`` method takes the list of all the inputs as parameter +# (each of the elements in this list will later be pased to ``transform()``). +# You can use ``flat_inputs`` to e.g. figure out the dimensions on the input, +# using :func:`~torchvision.transforms.v2.query_chw` or +# :func:`~torchvision.transforms.v2.query_size`. +# +# ``make_params()`` should return a dict (or actually, anything you want) that +# will then be passed to ``transform()``. diff --git a/references/segmentation/v2_extras.py b/references/segmentation/v2_extras.py index e1a8b53e02b..2d9eb3e661a 100644 --- a/references/segmentation/v2_extras.py +++ b/references/segmentation/v2_extras.py @@ -10,13 +10,13 @@ def __init__(self, size, fill=0): self.size = size self.fill = v2._utils._setup_fill_arg(fill) - def _get_params(self, sample): + def make_params(self, sample): _, height, width = v2._utils.query_chw(sample) padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] needs_padding = any(padding) return dict(padding=padding, needs_padding=needs_padding) - def _transform(self, inpt, params): + def transform(self, inpt, params): if not params["needs_padding"]: return inpt diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 3f2e5015863..85ef98cf7b8 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -159,7 +159,7 @@ def test__copy_paste(self, label_type): class TestFixedSizeCrop: - def test__get_params(self, mocker): + def test_make_params(self, mocker): crop_size = (7, 7) batch_shape = (10,) canvas_size = (11, 5) @@ -170,7 +170,7 @@ def test__get_params(self, mocker): make_image(size=canvas_size, color_space="RGB"), make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_shape[0]), ] - params = transform._get_params(flat_inputs) + params = transform.make_params(flat_inputs) assert params["needs_crop"] assert params["height"] <= crop_size[0] @@ -191,7 +191,7 @@ def test__transform_culling(self, mocker): is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool) mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", + "torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params", return_value=dict( needs_crop=True, top=0, @@ -229,7 +229,7 @@ def test__transform_bounding_boxes_clamping(self, mocker): canvas_size = (10, 10) mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", + "torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params", return_value=dict( needs_crop=True, top=0, diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index e16c0677c9f..fb49525ecfe 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1355,7 +1355,7 @@ def test_transform_bounding_boxes_correctness(self, format, center, seed): transform = transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center) torch.manual_seed(seed) - params = transform._get_params([bounding_boxes]) + params = transform.make_params([bounding_boxes]) torch.manual_seed(seed) actual = transform(bounding_boxes) @@ -1369,14 +1369,14 @@ def test_transform_bounding_boxes_correctness(self, format, center, seed): @pytest.mark.parametrize("scale", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["scale"]) @pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["shear"]) @pytest.mark.parametrize("seed", list(range(10))) - def test_transform_get_params_bounds(self, degrees, translate, scale, shear, seed): + def test_transformmake_params_bounds(self, degrees, translate, scale, shear, seed): image = make_image() height, width = F.get_size(image) transform = transforms.RandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear) torch.manual_seed(seed) - params = transform._get_params([image]) + params = transform.make_params([image]) if isinstance(degrees, (int, float)): assert -degrees <= params["angle"] <= degrees @@ -1783,7 +1783,7 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center) torch.manual_seed(seed) - params = transform._get_params([bounding_boxes]) + params = transform.make_params([bounding_boxes]) torch.manual_seed(seed) actual = transform(bounding_boxes) @@ -1795,11 +1795,11 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed @pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"]) @pytest.mark.parametrize("seed", list(range(10))) - def test_transform_get_params_bounds(self, degrees, seed): + def test_transformmake_params_bounds(self, degrees, seed): transform = transforms.RandomRotation(degrees=degrees) torch.manual_seed(seed) - params = transform._get_params([]) + params = transform.make_params([]) if isinstance(degrees, (int, float)): assert -degrees <= params["angle"] <= degrees @@ -1843,7 +1843,7 @@ def test_functional_image_fast_path_correctness(self, size, angle, expand): class TestContainerTransforms: class BuiltinTransform(transforms.Transform): - def _transform(self, inpt, params): + def transform(self, inpt, params): return inpt class PackedInputTransform(nn.Module): @@ -2996,7 +2996,7 @@ def test_transform_bounding_boxes_correctness(self, output_size, format, dtype, with freeze_rng_state(): torch.manual_seed(seed) - params = transform._get_params([bounding_boxes]) + params = transform.make_params([bounding_boxes]) assert not params.pop("needs_pad") del params["padding"] assert params.pop("needs_crop") @@ -3129,9 +3129,9 @@ def test_transform_image_correctness(self, param, value, dtype, device, seed): with freeze_rng_state(): torch.manual_seed(seed) - # This emulates the random apply check that happens before _get_params is called + # This emulates the random apply check that happens before make_params is called torch.rand(1) - params = transform._get_params([image]) + params = transform.make_params([image]) torch.manual_seed(seed) actual = transform(image) @@ -3159,7 +3159,7 @@ def test_transform_errors(self): transform = transforms.RandomErasing(value=[1, 2, 3, 4]) with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"): - transform._get_params([make_image()]) + transform.make_params([make_image()]) class TestGaussianBlur: @@ -3244,9 +3244,9 @@ def test_assertions(self): transforms.GaussianBlur(3, sigma={}) @pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0], (10, 12.0), [10]]) - def test__get_params(self, sigma): + def test_make_params(self, sigma): transform = transforms.GaussianBlur(3, sigma=sigma) - params = transform._get_params([]) + params = transform.make_params([]) if isinstance(sigma, float): assert params["sigma"][0] == params["sigma"][1] == sigma @@ -5251,7 +5251,7 @@ def test_transform_params_correctness(self, side_range, make_input, device): input = make_input() height, width = F.get_size(input) - params = transform._get_params([input]) + params = transform.make_params([input]) assert "padding" in params padding = params["padding"] @@ -5305,13 +5305,13 @@ def test_transform(self, make_input, device): check_transform(transforms.ScaleJitter(self.TARGET_SIZE), make_input(self.INPUT_SIZE, device=device)) - def test__get_params(self): + def test_make_params(self): input_size = self.INPUT_SIZE target_size = self.TARGET_SIZE scale_range = (0.5, 1.5) transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range) - params = transform._get_params([make_image(input_size)]) + params = transform.make_params([make_image(input_size)]) assert "size" in params size = params["size"] @@ -5544,7 +5544,7 @@ def split_on_pure_tensor(to_split): return pure_tensors[0] if pure_tensors else None, pure_tensors[1:], others class CopyCloneTransform(transforms.Transform): - def _transform(self, inpt, params): + def transform(self, inpt, params): return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy() @staticmethod @@ -5580,7 +5580,7 @@ def was_applied(output, inpt): class TestRandomIoUCrop: @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) - def test__get_params(self, device, options): + def test_make_params(self, device, options): orig_h, orig_w = size = (24, 32) image = make_image(size) bboxes = tv_tensors.BoundingBoxes( @@ -5596,7 +5596,7 @@ def test__get_params(self, device, options): n_samples = 5 for _ in range(n_samples): - params = transform._get_params(sample) + params = transform.make_params(sample) if options == [2.0]: assert len(params) == 0 @@ -5622,8 +5622,8 @@ def test__transform_empty_params(self, mocker): bboxes = tv_tensors.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4)) label = torch.tensor([1]) sample = [image, bboxes, label] - # Let's mock transform._get_params to control the output: - transform._get_params = mocker.MagicMock(return_value={}) + # Let's mock transform.make_params to control the output: + transform.make_params = mocker.MagicMock(return_value={}) output = transform(sample) torch.testing.assert_close(output, sample) @@ -5648,7 +5648,7 @@ def test__transform(self, mocker): is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool) params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) - transform._get_params = mocker.MagicMock(return_value=params) + transform.make_params = mocker.MagicMock(return_value=params) output = transform(sample) # check number of bboxes vs number of labels: @@ -5662,13 +5662,13 @@ def test__transform(self, mocker): class TestRandomShortestSize: @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)]) - def test__get_params(self, min_size, max_size): + def test_make_params(self, min_size, max_size): canvas_size = (3, 10) transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size, antialias=True) sample = make_image(canvas_size) - params = transform._get_params([sample]) + params = transform.make_params([sample]) assert "size" in params size = params["size"] @@ -5685,14 +5685,14 @@ def test__get_params(self, min_size, max_size): class TestRandomResize: - def test__get_params(self): + def test_make_params(self): min_size = 3 max_size = 6 transform = transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) for _ in range(10): - params = transform._get_params([]) + params = transform.make_params([]) assert isinstance(params["size"], list) and len(params["size"]) == 1 size = params["size"][0] @@ -6148,12 +6148,12 @@ def test_transform_image_correctness(self, quality, color_space, seed): @pytest.mark.parametrize("quality", [5, (10, 20)]) @pytest.mark.parametrize("seed", list(range(10))) - def test_transform_get_params_bounds(self, quality, seed): + def test_transformmake_params_bounds(self, quality, seed): transform = transforms.JPEG(quality=quality) with freeze_rng_state(): torch.manual_seed(seed) - params = transform._get_params([]) + params = transform.make_params([]) if isinstance(quality, int): assert params["quality"] == quality diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index b04e1fe5a2a..e7c501aabe0 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -35,7 +35,7 @@ def __init__( self.padding_mode = padding_mode - def _check_inputs(self, flat_inputs: List[Any]) -> None: + def check_inputs(self, flat_inputs: List[Any]) -> None: if not has_any( flat_inputs, PIL.Image.Image, @@ -53,7 +53,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel." ) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_size(flat_inputs) new_height = min(height, self.crop_height) new_width = min(width, self.crop_width) @@ -107,7 +107,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: needs_pad=needs_pad, ) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_crop"]: inpt = self._call_kernel( F.crop, diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index bab2c70812e..6ea6256b171 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -39,7 +39,7 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]] ) self.dims = dims - def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: + def transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: return inpt.as_subclass(torch.Tensor) @@ -61,7 +61,7 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i ) self.dims = dims - def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: + def transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: return inpt.as_subclass(torch.Tensor) diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 3532abb3759..025cd13a766 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -15,7 +15,7 @@ def __init__(self, num_categories: int = -1): super().__init__() self.num_categories = num_categories - def _transform(self, inpt: proto_tv_tensors.Label, params: Dict[str, Any]) -> proto_tv_tensors.OneHotLabel: + def transform(self, inpt: proto_tv_tensors.Label, params: Dict[str, Any]) -> proto_tv_tensors.OneHotLabel: num_categories = self.num_categories if num_categories == -1 and inpt.categories is not None: num_categories = len(inpt.categories) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index b1dd5083408..93d4ba45d65 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -96,7 +96,7 @@ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: An ) return super()._call_kernel(functional, inpt, *args, **kwargs) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: img_c, img_h, img_w = query_chw(flat_inputs) if self.value is not None and not (len(self.value) in (1, img_c)): @@ -134,7 +134,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(i=i, j=j, h=h, w=w, v=v) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["v"] is not None: inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace) @@ -181,7 +181,7 @@ def forward(self, *inputs): params = { "labels": labels, "batch_size": labels.shape[0], - **self._get_params( + **self.make_params( [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] ), } @@ -190,7 +190,7 @@ def forward(self, *inputs): # after an image or video. However, we need to handle them in _transform, so we make sure to set them to True needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True flat_outputs = [ - self._transform(inpt, params) if needs_transform else inpt + self.transform(inpt, params) if needs_transform else inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) ] @@ -243,10 +243,10 @@ class MixUp(_BaseMixUpCutMix): It can also be a callable that takes the same input as the transform, and returns the labels. """ - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: lam = params["lam"] if inpt is params["labels"]: @@ -292,7 +292,7 @@ class CutMix(_BaseMixUpCutMix): It can also be a callable that takes the same input as the transform, and returns the labels. """ - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: lam = float(self._dist.sample(())) # type: ignore[arg-type] H, W = query_size(flat_inputs) @@ -314,7 +314,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if inpt is params["labels"]: return self._mixup_label(inpt, lam=params["lam_adjusted"]) elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt): @@ -361,9 +361,9 @@ def __init__(self, quality: Union[int, Sequence[int]]): self.quality = quality - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item() return dict(quality=quality) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.jpeg, inpt, quality=params["quality"]) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 49b4a8d8b10..7a471e7c1f6 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -25,7 +25,7 @@ def __init__(self, num_output_channels: int = 1): super().__init__() self.num_output_channels = num_output_channels - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels) @@ -46,11 +46,11 @@ class RandomGrayscale(_RandomApplyTransform): def __init__(self, p: float = 0.1) -> None: super().__init__(p=p) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_input_channels, *_ = query_chw(flat_inputs) return dict(num_input_channels=num_input_channels) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) @@ -64,7 +64,7 @@ class RGB(Transform): def __init__(self): super().__init__() - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.grayscale_to_rgb, inpt) @@ -142,7 +142,7 @@ def _check_input( def _generate_value(left: float, right: float) -> float: return torch.empty(1).uniform_(left, right).item() - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: fn_idx = torch.randperm(4) b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1]) @@ -152,7 +152,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output = inpt brightness_factor = params["brightness_factor"] contrast_factor = params["contrast_factor"] @@ -173,11 +173,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomChannelPermutation(Transform): """Randomly permute the channels of an image or video""" - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) return dict(permutation=torch.randperm(num_channels)) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.permute_channels, inpt, params["permutation"]) @@ -220,7 +220,7 @@ def __init__( self.saturation = saturation self.p = p - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) params: Dict[str, Any] = { key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None @@ -235,7 +235,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None return params - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["brightness_factor"] is not None: inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"]) if params["contrast_factor"] is not None and params["contrast_before"]: @@ -264,7 +264,7 @@ class RandomEqualize(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomEqualize - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.equalize, inpt) @@ -281,7 +281,7 @@ class RandomInvert(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomInvert - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.invert, inpt) @@ -304,7 +304,7 @@ def __init__(self, bits: int, p: float = 0.5) -> None: super().__init__(p=p) self.bits = bits - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.posterize, inpt, bits=self.bits) @@ -332,7 +332,7 @@ def __init__(self, threshold: float, p: float = 0.5) -> None: super().__init__(p=p) self.threshold = threshold - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.solarize, inpt, threshold=self.threshold) @@ -349,7 +349,7 @@ class RandomAutocontrast(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomAutocontrast - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.autocontrast, inpt) @@ -372,5 +372,5 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: super().__init__(p=p) self.sharpness_factor = sharpness_factor - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor) diff --git a/torchvision/transforms/v2/_deprecated.py b/torchvision/transforms/v2/_deprecated.py index a664cb3fbbd..51a4f076e49 100644 --- a/torchvision/transforms/v2/_deprecated.py +++ b/torchvision/transforms/v2/_deprecated.py @@ -46,5 +46,5 @@ def __init__(self) -> None: ) super().__init__() - def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor: + def transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor: return _F.to_tensor(inpt) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 5d6b1841d7f..c2461418a42 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -44,7 +44,7 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.horizontal_flip, inpt) @@ -62,7 +62,7 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.vertical_flip, inpt) @@ -156,7 +156,7 @@ def __init__( self.max_size = max_size self.antialias = antialias - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel( F.resize, inpt, @@ -189,7 +189,7 @@ def __init__(self, size: Union[int, Sequence[int]]): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.center_crop, inpt, output_size=self.size) @@ -268,7 +268,7 @@ def __init__( self._log_ratio = torch.log(torch.tensor(self.ratio)) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_size(flat_inputs) area = height * width @@ -306,7 +306,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel( F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias ) @@ -363,10 +363,10 @@ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: An ) return super()._call_kernel(functional, inpt, *args, **kwargs) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.five_crop, inpt, self.size) - def _check_inputs(self, flat_inputs: List[Any]) -> None: + def check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask): raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") @@ -408,11 +408,11 @@ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: An ) return super()._call_kernel(functional, inpt, *args, **kwargs) - def _check_inputs(self, flat_inputs: List[Any]) -> None: + def check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask): raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip) @@ -483,7 +483,7 @@ def __init__( self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] @@ -535,7 +535,7 @@ def __init__( if side_range[0] < 1.0 or side_range[0] > side_range[1]: raise ValueError(f"Invalid side range provided {side_range}.") - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) @@ -551,7 +551,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(padding=padding) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel(F.pad, inpt, **params, fill=fill) @@ -618,11 +618,11 @@ def __init__( self.center = center - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() return dict(angle=angle) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel( F.rotate, @@ -716,7 +716,7 @@ def __init__( self.center = center - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_size(flat_inputs) angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() @@ -743,7 +743,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: shear = (shear_x, shear_y) return dict(angle=angle, translate=translate, scale=scale, shear=shear) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel( F.affine, @@ -839,7 +839,7 @@ def __init__( self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: padded_height, padded_width = query_size(flat_inputs) if self.padding is not None: @@ -897,7 +897,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: padding=padding, ) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = _get_fill(self._fill, type(inpt)) inpt = self._call_kernel(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) @@ -952,7 +952,7 @@ def __init__( self.fill = fill self._fill = _setup_fill_arg(fill) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_size(flat_inputs) distortion_scale = self.distortion_scale @@ -982,7 +982,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) return dict(coefficients=perspective_coeffs) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel( F.perspective, @@ -1051,7 +1051,7 @@ def __init__( self.fill = fill self._fill = _setup_fill_arg(fill) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: size = list(query_size(flat_inputs)) dx = torch.rand([1, 1] + size) * 2 - 1 @@ -1074,7 +1074,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 return dict(displacement=displacement) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel( F.elastic, @@ -1132,7 +1132,7 @@ def __init__( self.options = sampler_options self.trials = trials - def _check_inputs(self, flat_inputs: List[Any]) -> None: + def check_inputs(self, flat_inputs: List[Any]) -> None: if not ( has_all(flat_inputs, tv_tensors.BoundingBoxes) and has_any(flat_inputs, PIL.Image.Image, tv_tensors.Image, is_pure_tensor) @@ -1142,7 +1142,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: "and bounding boxes. Sample can also contain masks." ) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) bboxes = get_bounding_boxes(flat_inputs) @@ -1194,7 +1194,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if len(params) < 1: return inpt @@ -1262,7 +1262,7 @@ def __init__( self.interpolation = interpolation self.antialias = antialias - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) @@ -1272,7 +1272,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel( F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias ) @@ -1327,7 +1327,7 @@ def __init__( self.interpolation = interpolation self.antialias = antialias - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] @@ -1340,7 +1340,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel( F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias ) @@ -1406,11 +1406,11 @@ def __init__( self.interpolation = interpolation self.antialias = antialias - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: size = int(torch.randint(self.min_size, self.max_size, ())) return dict(size=[size]) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel( F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias ) diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index 01a356f46f5..1890b43115a 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -19,7 +19,7 @@ def __init__(self, format: Union[str, tv_tensors.BoundingBoxFormat]) -> None: super().__init__() self.format = format - def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes: + def transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes: return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value, arg-type] @@ -32,5 +32,5 @@ class ClampBoundingBoxes(Transform): _transformed_types = (tv_tensors.BoundingBoxes,) - def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes: + def transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes: return F.clamp_bounding_boxes(inpt) # type: ignore[return-value] diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 93198f0009d..d38a6ad8767 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -14,7 +14,7 @@ # TODO: do we want/need to expose this? class Identity(Transform): - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt @@ -34,7 +34,7 @@ def __init__(self, lambd: Callable[[Any], Any], *types: Type): self.lambd = lambd self.types = types or self._transformed_types - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(inpt, self.types): return self.lambd(inpt) else: @@ -99,11 +99,11 @@ def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tenso self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector - def _check_inputs(self, sample: Any) -> Any: + def check_inputs(self, sample: Any) -> Any: if has_any(sample, PIL.Image.Image): raise TypeError(f"{type(self).__name__}() does not support PIL images.") - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: shape = inpt.shape n = shape[-3] * shape[-2] * shape[-1] if n != self.transformation_matrix.shape[0]: @@ -157,11 +157,11 @@ def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = self.std = list(std) self.inplace = inplace - def _check_inputs(self, sample: Any) -> Any: + def check_inputs(self, sample: Any) -> Any: if has_any(sample, PIL.Image.Image): raise TypeError(f"{type(self).__name__}() does not support PIL images.") - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace) @@ -197,11 +197,11 @@ def __init__( if not 0.0 < self.sigma[0] <= self.sigma[1]: raise ValueError(f"sigma values should be positive and of the form (min, max). Got {self.sigma}") - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() return dict(sigma=[sigma, sigma]) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params) @@ -228,7 +228,7 @@ def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None: self.sigma = sigma self.clip = clip - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, sigma=self.sigma, clip=self.clip) @@ -272,7 +272,7 @@ def __init__( self.dtype = dtype self.scale = scale - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(self.dtype, torch.dtype): # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype # is a simple torch.dtype @@ -335,7 +335,7 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: super().__init__() self.dtype = dtype - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.to_dtype, inpt, dtype=self.dtype, scale=True) @@ -432,11 +432,11 @@ def forward(self, *inputs: Any) -> Any: ) params = dict(valid=valid, labels=labels) - flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs] + flat_outputs = [self.transform(inpt, params) for inpt in flat_inputs] return tree_unflatten(flat_outputs, spec) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: is_label = params["labels"] is not None and any(inpt is label for label in params["labels"]) is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)) diff --git a/torchvision/transforms/v2/_temporal.py b/torchvision/transforms/v2/_temporal.py index c59d5078d46..687b50188a8 100644 --- a/torchvision/transforms/v2/_temporal.py +++ b/torchvision/transforms/v2/_temporal.py @@ -22,5 +22,5 @@ def __init__(self, num_samples: int): super().__init__() self.num_samples = num_samples - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.uniform_temporal_subsample, inpt, self.num_samples) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index b7eced5a287..5f274589709 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -15,6 +15,11 @@ class Transform(nn.Module): + """Base class to implement your own v2 transforms. + + See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py` for + more details. + """ # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. @@ -24,31 +29,44 @@ def __init__(self) -> None: super().__init__() _log_api_usage_once(self) - def _check_inputs(self, flat_inputs: List[Any]) -> None: + def check_inputs(self, flat_inputs: List[Any]) -> None: pass - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + # When v2 was introduced, this method was private and called + # `_get_params()`. Now it's publicly exposed as `make_params()`. It cannot + # be exposed as `get_params()` because there is already a `get_params()` + # methods for v2 transforms: it's the v1's `get_params()` that we have to + # keep in order to guarantee 100% BC with v1. (It's defined in + # __init_subclass__ below). + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + """Method to override for custom transforms. + + See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`""" return dict() def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: kernel = _get_kernel(functional, type(inpt), allow_passthrough=True) return kernel(inpt, *args, **kwargs) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + """Method to override for custom transforms. + + See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`""" raise NotImplementedError def forward(self, *inputs: Any) -> Any: + """Do not override this! Use ``transform()`` instead.""" flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) - self._check_inputs(flat_inputs) + self.check_inputs(flat_inputs) needs_transform_list = self._needs_transform_list(flat_inputs) - params = self._get_params( + params = self.make_params( [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] ) flat_outputs = [ - self._transform(inpt, params) if needs_transform else inpt + self.transform(inpt, params) if needs_transform else inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) ] @@ -153,23 +171,23 @@ def __init__(self, p: float = 0.5) -> None: def forward(self, *inputs: Any) -> Any: # We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return # early afterwards in case the random check triggers. The same result could be achieved by calling - # `super().forward()` after the random check, but that would call `self._check_inputs` twice. + # `super().forward()` after the random check, but that would call `self.check_inputs` twice. inputs = inputs if len(inputs) > 1 else inputs[0] flat_inputs, spec = tree_flatten(inputs) - self._check_inputs(flat_inputs) + self.check_inputs(flat_inputs) if torch.rand(1) >= self.p: return inputs needs_transform_list = self._needs_transform_list(flat_inputs) - params = self._get_params( + params = self.make_params( [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] ) flat_outputs = [ - self._transform(inpt, params) if needs_transform else inpt + self.transform(inpt, params) if needs_transform else inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) ] diff --git a/torchvision/transforms/v2/_type_conversion.py b/torchvision/transforms/v2/_type_conversion.py index 7c7439b1d02..bf9f7185239 100644 --- a/torchvision/transforms/v2/_type_conversion.py +++ b/torchvision/transforms/v2/_type_conversion.py @@ -20,7 +20,7 @@ class PILToTensor(Transform): _transformed_types = (PIL.Image.Image,) - def _transform(self, inpt: PIL.Image.Image, params: Dict[str, Any]) -> torch.Tensor: + def transform(self, inpt: PIL.Image.Image, params: Dict[str, Any]) -> torch.Tensor: return F.pil_to_tensor(inpt) @@ -33,7 +33,7 @@ class ToImage(Transform): _transformed_types = (is_pure_tensor, PIL.Image.Image, np.ndarray) - def _transform( + def transform( self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] ) -> tv_tensors.Image: return F.to_image(inpt) @@ -66,7 +66,7 @@ def __init__(self, mode: Optional[str] = None) -> None: super().__init__() self.mode = mode - def _transform( + def transform( self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] ) -> PIL.Image.Image: return F.to_pil_image(inpt, mode=self.mode) @@ -80,5 +80,5 @@ class ToPureTensor(Transform): _transformed_types = (tv_tensors.TVTensor,) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: + def transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: return inpt.as_subclass(torch.Tensor) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index e7cde4c5c33..dd65ca4d9c9 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -151,6 +151,10 @@ def _parse_labels_getter(labels_getter: Union[str, Callable[[Any], Any], None]) def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes: + """Return the Bounding Boxes in the input. + + Assumes only one ``BoundingBoxes`` object is present. + """ # This assumes there is only one bbox per sample as per the general convention try: return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.BoundingBoxes)) @@ -159,6 +163,7 @@ def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes: def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: + """Return Channel, Height, and Width.""" chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs @@ -173,6 +178,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: def query_size(flat_inputs: List[Any]) -> Tuple[int, int]: + """Return Height and Width.""" sizes = { tuple(get_size(inpt)) for inpt in flat_inputs