Skip to content

Commit

Permalink
Merge pull request #121 from ttngu207/main
Browse files Browse the repository at this point in the history
add new table `LabeledVideo` to generate/store labeled video data after PoseEstimation
  • Loading branch information
kushalbakshi authored Aug 16, 2024
2 parents 2833f15 + 8fff7b0 commit 117694e
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 2 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.


## [0.3.1] - 2024-08-16

+ Add - add new table `LabeledVideo` to generate/store labeled video data after PoseEstimation

## [0.3.0] - 2024-08-08

+ Add - add support for inference (PoseEstimation) using pytorch model

## [0.2.14] - 2024-08-02

+ Fix - improve imports, avoid circular dependencies
Expand Down
106 changes: 105 additions & 1 deletion element_deeplabcut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from .readers import dlc_reader

schema = dj.schema()
logger = dj.logger

_linking_module = None


Expand Down Expand Up @@ -733,10 +735,16 @@ def make(self, key):
find_full_path(get_dlc_root_data_dir(), fp).as_posix()
for fp in video_relpaths
]
analyze_video_params = (PoseEstimationTask & key).fetch1(
pose_estimation_params = (PoseEstimationTask & key).fetch1(
"pose_estimation_params"
) or {}

# expect a nested dictionary with "analyze_videos" params
# if not, assume "pose_estimation_params" as a flat dictionary that include relevant "analyze_videos" params
analyze_video_params = (
pose_estimation_params.get("analyze_videos") or pose_estimation_params
)

@memoized_result(
uniqueness_dict={
**analyze_video_params,
Expand Down Expand Up @@ -867,6 +875,102 @@ def get_trajectory(cls, key: dict, body_parts: list = "all") -> pd.DataFrame:
return df


@schema
class LabeledVideo(dj.Computed):
definition = """
-> PoseEstimation
"""

class File(dj.Part):
definition = """
-> master
-> VideoRecording.File
---
labeled_video_path: varchar(255) # relative path to labeled video
"""

@property
def key_source(self):
return PoseEstimation & RecordingInfo

def make(self, key):
import deeplabcut

pose_estimation_params = (PoseEstimationTask & key).fetch1(
"pose_estimation_params"
) or {}

# expect a nested dictionary with "create_labeled_video" and "extract_outlier_frames" params
# if not, assume "pose_estimation_params" as a flat dictionary
create_labeled_video_params = (
pose_estimation_params.get("create_labeled_video") or pose_estimation_params
)

outputframerate = create_labeled_video_params.pop(
"outputframerate", 5
) # final labeled video FPS defaults to 5 Hz

dlc_model_ = (Model & key).fetch1()
fps, nframes = (RecordingInfo & key).fetch1("fps", "nframes")
output_dir = (PoseEstimationTask & key).fetch1("pose_estimation_output_dir")
output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)

project_path = find_full_path(
get_dlc_root_data_dir(), dlc_model_["project_path"]
)

try:
dlc_config = next(output_dir.glob("dj_dlc_config*.yaml"))
dlc_config = project_path / dlc_config.name
assert dlc_config.exists()
except (StopIteration, AssertionError):
dlc_config = next(project_path.glob("dj_dlc_config*.yaml"))
logger.warning(
f"No dj_dlc_config*.yaml file found in {output_dir} - this is unexpected.\nUsing {dlc_config}"
)

entries = []
for vkey in (VideoRecording.File & key).fetch("KEY"):
video_file = (VideoRecording.File & vkey).fetch1("file_path")
video_file = find_full_path(get_dlc_root_data_dir(), video_file)

# -- create labeled video --
create_labeled_video_kwargs = {
k: v
for k, v in create_labeled_video_params.items()
if k in inspect.signature(deeplabcut.create_labeled_video).parameters
}
create_labeled_video_kwargs.update(
dict(
config=dlc_config.as_posix(),
videos=[video_file.as_posix()],
shuffle=dlc_model_["shuffle"],
trainingsetindex=dlc_model_["trainingsetindex"],
modelprefix=dlc_model_["model_prefix"],
destfolder=output_dir.as_posix(),
Frames2plot=np.arange(0, nframes, int(fps / outputframerate)),
outputframerate=outputframerate,
)
)
deeplabcut.create_labeled_video(**create_labeled_video_kwargs)

labeled_video_path = next(
output_dir.glob(f"{video_file.stem}*_labeled.mp4")
)
entries.append(
{
**key,
**vkey,
"labeled_video_path": labeled_video_path.relative_to(
get_dlc_processed_data_dir()
).as_posix(),
}
)

self.insert1(key)
self.File.insert(entries)


def str_to_bool(value) -> bool:
"""Return whether the provided string represents true. Otherwise false.
Expand Down
2 changes: 1 addition & 1 deletion element_deeplabcut/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Package metadata
"""

__version__ = "0.3.0"
__version__ = "0.3.1"

0 comments on commit 117694e

Please sign in to comment.