Skip to content

Commit

Permalink
fix issue in processing tensors and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
ejolly committed Oct 19, 2024
1 parent fd263db commit 24e062e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
22 changes: 19 additions & 3 deletions feat/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2547,7 +2547,13 @@ def __len__(self):

def __getitem__(self, idx):
# Return the sample at the given index
return self.tensor[idx, ...]
return {
"Image": self.tensor[idx, ...],
"Frame": idx,
"FileName": "tensor",
"Scale": 1.0,
"Padding": {"Left": 0, "Top": 0, "Right": 0, "Bottom": 0},
}


class VideoDataset(Dataset):
Expand All @@ -2560,7 +2566,8 @@ class VideoDataset(Dataset):
Dataset: dataset of [batch, channels, height, width] that can be passed to DataLoader
"""

def __init__(self, video_file, skip_frames=None, output_size=None):
def __init__(self, video_file, skip_frames=None, output_size=None, low_memory=True):
self.low_memory = low_memory
self.file_name = video_file
self.skip_frames = skip_frames
self.output_size = output_size
Expand All @@ -2569,14 +2576,23 @@ def __init__(self, video_file, skip_frames=None, output_size=None):
self.video_frames = np.arange(
0, self.metadata["num_frames"], 1 if skip_frames is None else skip_frames
)
if not self.low_memory:
self._container = av.open(self.file_name)
self._stream = self._container.streams.video[0]
self._frame_generator = self._container.decode(self._stream)

def __len__(self):
# Number of frames respective skip_frames
return len(self.video_frames)

def __getitem__(self, idx):
# Get the frame data and frame number respective skip_frames
frame_data, frame_idx = self.load_frame(idx)
if self.low_memory:
frame_data, frame_idx = self.load_frame(idx)
else:
frame = next(self._frame_generator)
frame_data = torch.from_numpy(frame.to_ndarray(format="rgb24"))
frame_idx = int(self.video_frames[idx])

# Swap frame dims to match output of read_image: [time, channels, height, width]
# Otherwise detectors face on tensor dimension mismatch
Expand Down
7 changes: 7 additions & 0 deletions feat/tests/test_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import os
from feat.utils.io import read_feat
import torch

EXPECTED_FEX_WIDTH = 691

Expand Down Expand Up @@ -263,3 +264,9 @@ def test_save_detect(self, single_face_mov, tmp_path):
np.allclose(df._get_numeric_data(), out_nosave._get_numeric_data())
assert all(df.columns == out_nosave.columns) and all(df.columns == out.columns)
assert df.shape == out_nosave.shape == out.shape

def test_detect_tensor(self, single_face_img_data):
tensor = torch.stack([single_face_img_data] * 3)
tensor.shape
fex = self.detector.detect(tensor, data_type="tensor")
assert fex.shape == (3, EXPECTED_FEX_WIDTH)

0 comments on commit 24e062e

Please sign in to comment.