From d35f07e9e04d833b9ad8f9d3fa4b4246110658a7 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Sat, 4 May 2024 22:41:28 -0700 Subject: [PATCH] Fix serialization and logic for checking for embedded images --- sleap_io/io/slp.py | 22 ++++++++++++++++++---- sleap_io/io/video.py | 5 ++--- tests/io/test_slp.py | 40 +++++++++++++++++----------------------- 3 files changed, 37 insertions(+), 30 deletions(-) diff --git a/sleap_io/io/slp.py b/sleap_io/io/slp.py index 9d11d3ab..0ac3d252 100644 --- a/sleap_io/io/slp.py +++ b/sleap_io/io/slp.py @@ -25,9 +25,15 @@ read_hdf5_attrs, read_hdf5_dataset, ) -import imageio.v3 as iio from enum import IntEnum from pathlib import Path +import imageio.v3 as iio +import sys + +try: + import cv2 +except ImportError: + pass class InstanceType(IntEnum): @@ -214,9 +220,17 @@ def embed_video( if image_format == "hdf5": img_data = frame else: - img_data = iio.imwrite( - "", frame, extension="." + image_format - ).astype("int8") + if "cv2" in sys.modules: + img_data = np.squeeze( + cv2.imencode("." + image_format, frame)[1] + ).astype("int8") + else: + img_data = np.frombuffer( + iio.imwrite( + "", frame.squeeze(axis=-1), extension="." + image_format + ), + dtype="int8", + ) imgs_data.append(img_data) diff --git a/sleap_io/io/video.py b/sleap_io/io/video.py index 9323a696..fae09234 100644 --- a/sleap_io/io/video.py +++ b/sleap_io/io/video.py @@ -585,13 +585,12 @@ def has_embedded_images(self) -> bool: """Return True if the dataset contains embedded images.""" return self.image_format is not None and self.image_format != "hdf5" - def decode_embedded(self, img_string: np.ndarray, format: str) -> np.ndarray: + def decode_embedded(self, img_string: np.ndarray) -> np.ndarray: """Decode an embedded image string into a numpy array. Args: img_string: Binary string of the image as a `int8` numpy vector with the bytes as values corresponding to the format-encoded image. - format: Image format (e.g., "png" or "jpg"). Returns: The decoded image as a numpy array of shape `(height, width, channels)`. If @@ -604,7 +603,7 @@ def decode_embedded(self, img_string: np.ndarray, format: str) -> np.ndarray: if "cv2" in sys.modules: img = cv2.imdecode(img_string, cv2.IMREAD_UNCHANGED) else: - img = iio.imread(BytesIO(img_string), extension=f".{format}") + img = iio.imread(BytesIO(img_string), extension=f".{self.image_format}") if img.ndim == 2: img = np.expand_dims(img, axis=-1) diff --git a/tests/io/test_slp.py b/tests/io/test_slp.py index 8a90694e..766692d6 100644 --- a/tests/io/test_slp.py +++ b/tests/io/test_slp.py @@ -101,32 +101,26 @@ def test_read_videos_pkg(slp_minimal_pkg): def test_write_videos(slp_minimal_pkg, centered_pair, tmp_path): - def load_jsons(h5_path, dataset): - return [json.loads(x) for x in read_hdf5_dataset(h5_path, dataset)] - - def compare_jsons(jsons_ref, jsons_test): - for jsons_ref, jsons_test in zip(jsons_ref, jsons_test): - for k in jsons_ref["backend"]: - assert jsons_ref["backend"][k] == jsons_test["backend"][k] - - videos = read_videos(slp_minimal_pkg) - write_videos(tmp_path / "test_minimal_pkg.slp", videos) - json_fixture = load_jsons(slp_minimal_pkg, "videos_json") - json_test = load_jsons(tmp_path / "test_minimal_pkg.slp", "videos_json") - compare_jsons(json_fixture, json_test) - - videos = read_videos(centered_pair) - write_videos(tmp_path / "test_centered_pair.slp", videos) - json_fixture = load_jsons(centered_pair, "videos_json") - json_test = load_jsons(tmp_path / "test_centered_pair.slp", "videos_json") - compare_jsons(json_fixture, json_test) + def compare_videos(videos_ref, videos_test): + assert len(videos_ref) == len(videos_test) + for video_ref, video_test in zip(videos_ref, videos_test): + assert video_ref.shape == video_test.shape + assert (video_ref[0] == video_test[0]).all() + + videos_ref = read_videos(slp_minimal_pkg) + write_videos(tmp_path / "test_minimal_pkg.slp", videos_ref) + videos_test = read_videos(tmp_path / "test_minimal_pkg.slp") + compare_videos(videos_ref, videos_test) + + videos_ref = read_videos(centered_pair) + write_videos(tmp_path / "test_centered_pair.slp", videos_ref) + videos_test = read_videos(tmp_path / "test_centered_pair.slp") + compare_videos(videos_ref, videos_test) videos = read_videos(centered_pair) * 2 write_videos(tmp_path / "test_centered_pair_2vids.slp", videos) - json_test = read_hdf5_dataset( - tmp_path / "test_centered_pair_2vids.slp", "videos_json" - ) - assert len(json_test) == 2 + videos_test = read_videos(tmp_path / "test_centered_pair_2vids.slp") + compare_videos(videos, videos_test) def test_write_tracks(centered_pair, tmp_path):