Skip to content

Commit

Permalink
new save arg to .detect() to auto-append to file and save RAM
Browse files Browse the repository at this point in the history
  • Loading branch information
ejolly committed Oct 18, 2024
1 parent 693def0 commit e0db3d3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 7 deletions.
38 changes: 34 additions & 4 deletions feat/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from torchvision.transforms import Compose, Normalize
import sys
import warnings
from pathlib import Path

sys.modules["__main__"].__dict__["XGBClassifier"] = XGBClassifier
sys.modules["__main__"].__dict__["SVMClassifier"] = SVMClassifier
Expand Down Expand Up @@ -516,10 +517,11 @@ def detect(
face_detection_threshold=0.5,
skip_frames=None,
progress_bar=True,
save=None,
**kwargs,
):
"""
Detects FEX from one or more image files.
Detects FEX from one or more imagathe files.
Args:
inputs (list of str, torch.Tensor): Path to a list of paths to image files or torch.Tensor of images (B, C, H, W)
Expand All @@ -533,11 +535,14 @@ def detect(
skip_frames (int or None): number of frames to skip to speed up inference (video only); Default None
progress_bar (bool): Whether to show the tqdm progress bar. Default is True.
**kwargs: additional detector-specific kwargs
save (None or str or Path): if immediately append detections to a csv file at with the given name after processing each batch, which can be useful to interrupted/resuming jobs and saving memory/RAM
Returns:
pd.DataFrame: Concatenated results for all images in the batch
"""

save = Path(save) if save else None

if data_type.lower() == "image":
data_loader = DataLoader(
ImageDataset(
Expand Down Expand Up @@ -660,14 +665,39 @@ def detect(
- batch_data["Padding"]["Top"].detach().numpy()[j]
) / batch_data["Scale"].detach().numpy()[j]

batch_output.append(batch_results)
if save:
batch_results.to_csv(save, mode="a", index=False, header=batch_id == 0)
else:
batch_output.append(batch_results)
frame_counter += 1 * batch_size
batch_output = pd.concat(batch_output)
batch_output.reset_index(drop=True, inplace=True)

batch_output = (
Fex(
pd.read_csv(save),
au_columns=AU_LANDMARK_MAP["Feat"],
emotion_columns=FEAT_EMOTION_COLUMNS,
facebox_columns=FEAT_FACEBOX_COLUMNS,
landmark_columns=openface_2d_landmark_columns,
facepose_columns=FEAT_FACEPOSE_COLUMNS_6D,
identity_columns=FEAT_IDENTITY_COLUMNS[1:],
detector="Feat",
face_model=self.info["face_model"],
landmark_model=self.info["landmark_model"],
au_model=self.info["au_model"],
emotion_model=self.info["emotion_model"],
facepose_model=self.info["facepose_model"],
identity_model=self.info["identity_model"],
)
if save
else pd.concat(batch_output).reset_index(drop=True)
)
if data_type.lower() == "video":
batch_output["approx_time"] = [
dataset.calc_approx_frame_time(x)
for x in batch_output["frame"].to_numpy()
]
batch_output.compute_identities(threshold=face_identity_threshold, inplace=True)
# Overwrite with approx_time and identity columns
if save:
batch_output.to_csv(save, mode="w", index=False)
return batch_output
22 changes: 22 additions & 0 deletions feat/tests/test_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from feat.utils.io import get_test_data_path
import warnings
import os
from feat.utils.io import read_feat

EXPECTED_FEX_WIDTH = 691

Expand Down Expand Up @@ -243,3 +244,24 @@ def test_fast_detect_video(
assert not out.happiness.iloc[4:7].isnull().all().all()
# ending doesn't
assert out.happiness.iloc[7:].isnull().all().all()

def test_save_detect(self, single_face_mov, tmp_path):
"""Test appending to file during detection"""

out = self.detector.detect(
single_face_mov, skip_frames=12, data_type="video", save=tmp_path / "test.csv"
)
assert (tmp_path / "test.csv").exists()
df = read_feat(tmp_path / "test.csv")
assert df.equals(out)
out_nosave = self.detector.detect(
single_face_mov, skip_frames=12, data_type="video"
)
# We're not enforcing dtypes so the output Fex from .detect()
# uses Float32, but read_feat() uses Float64
# So instead we check the numeric values are close, columns match,
# and shapes, match
np.allclose(out._get_numeric_data(), out_nosave._get_numeric_data())
np.allclose(df._get_numeric_data(), out_nosave._get_numeric_data())
assert all(df.columns == out_nosave.columns) and all(df.columns == out.columns)
assert df.shape == out_nosave.shape == out.shape
9 changes: 6 additions & 3 deletions feat/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
openface_facepose_columns,
openface_gaze_columns,
openface_time_columns,
FEAT_FACEPOSE_COLUMNS_6D,
FEAT_IDENTITY_COLUMNS,
)


Expand Down Expand Up @@ -97,11 +99,12 @@ def read_feat(fexfile):
filename=fexfile,
au_columns=au_columns,
emotion_columns=FEAT_EMOTION_COLUMNS,
landmark_columns=openface_2d_landmark_columns,
facebox_columns=FEAT_FACEBOX_COLUMNS,
time_columns=FEAT_TIME_COLUMNS,
facepose_columns=["Pitch", "Roll", "Yaw"],
landmark_columns=openface_2d_landmark_columns,
facepose_columns=FEAT_FACEPOSE_COLUMNS_6D,
identity_columns=FEAT_IDENTITY_COLUMNS[1:],
detector="Feat",
time_columns=FEAT_TIME_COLUMNS,
)
return fex

Expand Down

0 comments on commit e0db3d3

Please sign in to comment.