Skip to content

Commit

Permalink
Merge pull request #120 from ttngu207/main
Browse files Browse the repository at this point in the history
feat(pose_estimation): use `memoized_results`
  • Loading branch information
kushalbakshi authored Aug 8, 2024
2 parents 34a6edb + c2eb090 commit 2833f15
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 14 deletions.
88 changes: 75 additions & 13 deletions element_deeplabcut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import pandas as pd
from pathlib import Path
from typing import Optional
from datetime import datetime
from element_interface.utils import find_full_path, find_root_directory
from datetime import datetime, timezone
from element_interface.utils import find_full_path, find_root_directory, memoized_result
from .readers import dlc_reader

schema = dj.schema()
Expand Down Expand Up @@ -705,7 +705,7 @@ class BodyPartPosition(dj.Part):
def make(self, key):
""".populate() method will launch training for each PoseEstimationTask"""
# ID model and directories
dlc_model = (Model & key).fetch1()
dlc_model_ = (Model & key).fetch1()
task_mode, output_dir = (PoseEstimationTask & key).fetch1(
"task_mode", "pose_estimation_output_dir"
)
Expand All @@ -719,31 +719,93 @@ def make(self, key):
)
output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)

# Triger PoseEstimation
# Trigger PoseEstimation
if task_mode == "trigger":
# Triggering dlc for pose estimation required:
# - project_path: full path to the directory containing the trained model
# - video_filepaths: full paths to the video files for inference
# - analyze_video_params: optional parameters to analyze video
project_path = find_full_path(
get_dlc_root_data_dir(), dlc_model["project_path"]
get_dlc_root_data_dir(), dlc_model_["project_path"]
)
video_relpaths = list((VideoRecording.File & key).fetch("file_path"))
video_filepaths = [
find_full_path(get_dlc_root_data_dir(), fp).as_posix()
for fp in (VideoRecording.File & key).fetch("file_path")
for fp in video_relpaths
]
analyze_video_params = (PoseEstimationTask & key).fetch1(
"pose_estimation_params"
) or {}

dlc_reader.do_pose_estimation(
key,
video_filepaths,
dlc_model,
project_path,
output_dir,
**analyze_video_params,
@memoized_result(
uniqueness_dict={
**analyze_video_params,
"project_path": dlc_model_["project_path"],
"shuffle": dlc_model_["shuffle"],
"trainingsetindex": dlc_model_["trainingsetindex"],
"video_filepaths": video_relpaths,
},
output_directory=output_dir,
)
def do_analyze_videos():
from deeplabcut.pose_estimation_tensorflow import analyze_videos

# ---- Build and save DLC configuration (yaml) file ----
dlc_config = dlc_model_["config_template"]
dlc_project_path = Path(project_path)
dlc_config["project_path"] = dlc_project_path.as_posix()

# ---- Special handling for "cropping" ----
# `analyze_videos` behavior:
# i) if is None, use the "cropping" from the config file
# ii) if defined, use the specified "cropping" values but not updating the config file
# new behavior: if defined as "False", overwrite "cropping" to False in config file
cropping = analyze_video_params.get("cropping", None)
if cropping is not None:
if cropping:
dlc_config["cropping"] = True
(
dlc_config["x1"],
dlc_config["x2"],
dlc_config["y1"],
dlc_config["y2"],
) = cropping
else: # cropping is False
dlc_config["cropping"] = False

# ---- Write config files ----
config_filename = f"dj_dlc_config_{datetime.now(tz=timezone.utc).strftime('%Y%m%d_%H%M%S')}.yaml"
# To output dir: Important for loading/parsing output in datajoint
_ = dlc_reader.save_yaml(
output_dir, dlc_config, filename=config_filename
)
# To project dir: Required by DLC to run the analyze_videos
if dlc_project_path != output_dir:
config_filepath = dlc_reader.save_yaml(
dlc_project_path,
dlc_config,
filename=config_filename,
)

# ---- Take valid parameters for analyze_videos ----
kwargs = {
k: v
for k, v in analyze_video_params.items()
if k in inspect.signature(analyze_videos).parameters
}

# ---- Trigger DLC prediction job ----
analyze_videos(
config=config_filepath,
videos=video_filepaths,
shuffle=dlc_model_["shuffle"],
trainingsetindex=dlc_model_["trainingsetindex"],
destfolder=output_dir,
modelprefix=dlc_model_["model_prefix"],
**kwargs,
)

do_analyze_videos()

dlc_result = dlc_reader.PoseEstimation(output_dir)
creation_time = datetime.fromtimestamp(dlc_result.creation_time).strftime(
Expand Down
7 changes: 7 additions & 0 deletions element_deeplabcut/readers/dlc_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,12 @@ def do_pose_estimation(
resulting in constant memory footprint.
"""
# this function should no longer be used, throw a deprecation warning
logger.warning(
"This function is deprecated and will be removed in a future release. "
+ "Its usage is now incorporated into model.PoseEstimation's `make` function"
)

from deeplabcut.pose_estimation_tensorflow import analyze_videos

# ---- Build and save DLC configuration (yaml) file ----
Expand All @@ -332,6 +338,7 @@ def do_pose_estimation(
dlc_config["project_path"] = dlc_project_path.as_posix()

# ---- Add current video to config ---
# FIXME: I don't think the code block below is necessary
for video_filepath in video_filepaths:
if video_filepath not in dlc_config["video_sets"]:
try:
Expand Down
2 changes: 1 addition & 1 deletion element_deeplabcut/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Package metadata
"""

__version__ = "0.2.14"
__version__ = "0.3.0"

0 comments on commit 2833f15

Please sign in to comment.