From c014ad25dd93e5e9627777632cf5021829e172a4 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 29 Jul 2024 21:28:55 +0400 Subject: [PATCH 1/3] Ignore invalid keys instead of raising an error --- kornia/augmentation/container/augment.py | 35 ++++++++++++++++-------- tests/augmentation/test_container.py | 4 ++- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index 26ce74cdc9..20dca73836 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -291,7 +291,7 @@ def inverse( # type: ignore[override] """ original_keys = None if len(args) == 1 and isinstance(args[0], dict): - original_keys, data_keys, args = self._preproc_dict_data(args[0]) + original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0]) # args here should already be `DataType` # NOTE: how to right type to: unpacked args <-> tuple of args to unpack @@ -324,7 +324,10 @@ def inverse( # type: ignore[override] outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys) # type: ignore if isinstance(original_keys, tuple): - return {k: v for v, k in zip(outputs, original_keys)} + result = {k: v for v, k in zip(outputs, original_keys)} + if invalid_data: + result.update(invalid_data) + return result if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] @@ -414,7 +417,7 @@ def forward( # type: ignore[override] # Unpack/handle dictionary args original_keys = None if len(args) == 1 and isinstance(args[0], dict): - original_keys, data_keys, args = self._preproc_dict_data(args[0]) + original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0]) self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys) @@ -455,7 +458,10 @@ def forward( # type: ignore[override] self._params = params if isinstance(original_keys, tuple): - return {k: v for v, k in zip(outputs, original_keys)} + result = {k: v for v, k in zip(outputs, original_keys)} + if invalid_data: + result.update(invalid_data) + return result if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] @@ -464,15 +470,17 @@ def forward( # type: ignore[override] def _preproc_dict_data( self, data: Dict[str, DataType] - ) -> Tuple[Tuple[str, ...], List[DataKey], Tuple[DataType, ...]]: + ) -> Tuple[Tuple[str, ...], List[DataKey], Tuple[DataType, ...], Optional[Dict[str, Any]]]: if self.data_keys is not None: raise ValueError("If you are using a dictionary as input, the data_keys should be None.") keys = tuple(data.keys()) - data_keys = self._read_datakeys_from_dict(keys) + data_keys, invalid_keys = self._read_datakeys_from_dict(keys) + invalid_data = {i: data.pop(i) for i in invalid_keys} if invalid_keys else None + keys = tuple(k for k in keys if k not in invalid_keys) if invalid_keys else keys data_unpacked = tuple(data.values()) - return keys, data_keys, data_unpacked + return keys, data_keys, data_unpacked, invalid_data def _read_datakeys_from_dict(self, keys: Sequence[str]) -> List[DataKey]: def retrieve_key(key: str) -> DataKey: @@ -487,12 +495,15 @@ def retrieve_key(key: str) -> DataKey: if key.upper().startswith(dk.name): return DataKey.get(dk.name) - allowed_dk = " | ".join(f"`{d.name}`" for d in DataKey) - raise ValueError( - f"Your input data dictionary keys should start with some of datakey values: {allowed_dk}. Got `{key}`" - ) + valid_data_keys = [] + invalid_keys = [] + for k in keys: + try: + valid_data_keys.append(DataKey.get(retrieve_key(k))) + except TypeError: + invalid_keys.append(k) - return [DataKey.get(retrieve_key(k)) for k in keys] + return valid_data_keys, invalid_keys def _preproc_mask(self, arg: MaskDataType) -> MaskDataType: if isinstance(arg, list): diff --git a/tests/augmentation/test_container.py b/tests/augmentation/test_container.py index 3d58a621f3..74f2dd1344 100644 --- a/tests/augmentation/test_container.py +++ b/tests/augmentation/test_container.py @@ -776,13 +776,14 @@ def test_dict_as_input_forward_and_inverse(self, random_apply, bbox_key, device, random_apply=random_apply, ) - data = {"input": inp, "mask": mask, bbox_key: bbox, "keypoints": keypoints} + data = {"input": inp, "mask": mask, bbox_key: bbox, "keypoints": keypoints, "invalid": 45} out = aug(data) assert out["input"].shape == inp.shape assert out["mask"].shape == mask.shape assert out[bbox_key].shape == bbox.shape assert out["keypoints"].shape == keypoints.shape assert set(out["mask"].unique().tolist()).issubset(set(mask.unique().tolist())) + assert out["invalid"] == 45 out_inv = aug.inverse(out) assert out_inv["input"].shape == inp.shape @@ -790,6 +791,7 @@ def test_dict_as_input_forward_and_inverse(self, random_apply, bbox_key, device, assert out_inv[bbox_key].shape == bbox.shape assert out_inv["keypoints"].shape == keypoints.shape assert set(out_inv["mask"].unique().tolist()).issubset(set(mask.unique().tolist())) + assert out_inv["invalid"] == 45 if random_apply is False: reproducibility_test(data, aug) From d4ab78d23ca5c603027986221725975ad4ac4199 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 29 Jul 2024 22:02:46 +0400 Subject: [PATCH 2/3] Add doc --- kornia/augmentation/container/augment.py | 5 +++-- tests/augmentation/test_container.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index 20dca73836..aaf3a5fd33 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -175,8 +175,9 @@ class AugmentationSequential(TransformMatrixMinIn, ImageSequential): ... ) >>> out = aug_list(input, mask, bbox) - How to use a dictionary as input with AugmentationSequential? The dictionary should starts with - one of the datakey availables. + How to use a dictionary as input with AugmentationSequential? The dictionary keys that start with + one of the available datakeys will be augmented accordingly. Otherwise, the dictionary item is passed + without any augmentation. >>> import kornia.augmentation as K >>> img = torch.randn(1, 3, 256, 256) diff --git a/tests/augmentation/test_container.py b/tests/augmentation/test_container.py index 74f2dd1344..f0dcff64b1 100644 --- a/tests/augmentation/test_container.py +++ b/tests/augmentation/test_container.py @@ -776,14 +776,14 @@ def test_dict_as_input_forward_and_inverse(self, random_apply, bbox_key, device, random_apply=random_apply, ) - data = {"input": inp, "mask": mask, bbox_key: bbox, "keypoints": keypoints, "invalid": 45} + data = {"input": inp, "mask": mask, bbox_key: bbox, "keypoints": keypoints, "id": 45} out = aug(data) assert out["input"].shape == inp.shape assert out["mask"].shape == mask.shape assert out[bbox_key].shape == bbox.shape assert out["keypoints"].shape == keypoints.shape assert set(out["mask"].unique().tolist()).issubset(set(mask.unique().tolist())) - assert out["invalid"] == 45 + assert out["id"] == 45 out_inv = aug.inverse(out) assert out_inv["input"].shape == inp.shape @@ -791,7 +791,7 @@ def test_dict_as_input_forward_and_inverse(self, random_apply, bbox_key, device, assert out_inv[bbox_key].shape == bbox.shape assert out_inv["keypoints"].shape == keypoints.shape assert set(out_inv["mask"].unique().tolist()).issubset(set(mask.unique().tolist())) - assert out_inv["invalid"] == 45 + assert out_inv["id"] == 45 if random_apply is False: reproducibility_test(data, aug) From 37494fb73d33bb5ebcabcba47a28c6486bd38d93 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 29 Jul 2024 22:16:19 +0400 Subject: [PATCH 3/3] Raise error --- kornia/augmentation/container/augment.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index aaf3a5fd33..914a1e29d4 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -483,7 +483,7 @@ def _preproc_dict_data( return keys, data_keys, data_unpacked, invalid_data - def _read_datakeys_from_dict(self, keys: Sequence[str]) -> List[DataKey]: + def _read_datakeys_from_dict(self, keys: Sequence[str]) -> Tuple[List[DataKey], Optional[List[str]]]: def retrieve_key(key: str) -> DataKey: """Try to retrieve the datakey value by matching `*`""" # Alias cases, like INPUT, will not be get by the enum iterator. @@ -496,12 +496,17 @@ def retrieve_key(key: str) -> DataKey: if key.upper().startswith(dk.name): return DataKey.get(dk.name) + allowed_dk = " | ".join(f"`{d.name}`" for d in DataKey) + raise ValueError( + f"Your input data dictionary keys should start with some of datakey values: {allowed_dk}. Got `{key}`" + ) + valid_data_keys = [] invalid_keys = [] for k in keys: try: valid_data_keys.append(DataKey.get(retrieve_key(k))) - except TypeError: + except ValueError: invalid_keys.append(k) return valid_data_keys, invalid_keys