Skip to content

Commit

Permalink
Merge pull request #110 from datajoint/staging
Browse files Browse the repository at this point in the history
Bugfix - store `snapshotindex` for the latest snapshot instead of the filename
  • Loading branch information
MilagrosMarin authored Jul 3, 2024
2 parents e6944a9 + 7949f61 commit 1cf8347
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 12 deletions.
9 changes: 8 additions & 1 deletion element_deeplabcut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,14 @@ def make(self, key):
task_mode, output_dir = (PoseEstimationTask & key).fetch1(
"task_mode", "pose_estimation_output_dir"
)

if not output_dir:
output_dir = PoseEstimationTask.infer_output_dir(
key, relative=True, mkdir=True
)
# update pose_estimation_output_dir
PoseEstimationTask.update1(
{**key, "pose_estimation_output_dir": output_dir.as_posix()}
)
output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)

# Triger PoseEstimation
Expand Down
53 changes: 42 additions & 11 deletions element_deeplabcut/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import datajoint as dj
import inspect
import importlib
import os
import re
from pathlib import Path
import yaml

from element_interface.utils import find_full_path, dict_to_uuid
from .readers import dlc_reader

Expand Down Expand Up @@ -241,7 +243,7 @@ class ModelTraining(dj.Computed):
# https://github.com/DeepLabCut/DeepLabCut/issues/70

def make(self, key):
from deeplabcut import train_network # isort:skip
import deeplabcut

try:
from deeplabcut.utils.auxiliaryfunctions import (
Expand Down Expand Up @@ -288,13 +290,39 @@ def make(self, key):
)
model_train_folder = project_path / model_folder / "train"

# update path of the init_weight
with open(model_train_folder / "pose_cfg.yaml", "r") as f:
pose_cfg = yaml.safe_load(f)
init_weights_path = Path(pose_cfg["init_weights"])

if (
"pose_estimation_tensorflow/models/pretrained"
in init_weights_path.as_posix()
):
# this is the res_net models, construct new path here
init_weights_path = (
Path(deeplabcut.__path__[0])
/ "pose_estimation_tensorflow/models/pretrained"
/ init_weights_path.name
)
else:
# this is existing snapshot weights, update path here
init_weights_path = model_train_folder / init_weights_path.name

edit_config(
model_train_folder / "pose_cfg.yaml",
{"project_path": project_path.as_posix()},
{
"project_path": project_path.as_posix(),
"init_weights": init_weights_path.as_posix(),
"dataset": Path(pose_cfg["dataset"]).as_posix(),
"metadataset": Path(pose_cfg["metadataset"]).as_posix(),
},
)

# ---- Trigger DLC model training job ----
train_network_input_args = list(inspect.signature(train_network).parameters)
train_network_input_args = list(
inspect.signature(deeplabcut.train_network).parameters
)
train_network_kwargs = {
k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v
for k, v in dlc_config.items()
Expand All @@ -304,25 +332,28 @@ def make(self, key):
train_network_kwargs[k] = int(train_network_kwargs[k])

try:
train_network(dlc_cfg_filepath, **train_network_kwargs)
deeplabcut.train_network(dlc_cfg_filepath, **train_network_kwargs)
except KeyboardInterrupt: # Instructions indicate to train until interrupt
print("DLC training stopped via Keyboard Interrupt")

snapshots = list(model_train_folder.glob("*index*"))
max_modified_time = 0
# DLC goes by snapshot magnitude when judging 'latest' for evaluation
# Here, we mean most recently generated
snapshots = sorted(model_train_folder.glob("snapshot*.index"))
max_modified_time = 0
for snapshot in snapshots:
modified_time = os.path.getmtime(snapshot)
modified_time = snapshot.stat().st_mtime
if modified_time > max_modified_time:
latest_snapshot = int(snapshot.stem[9:])
latest_snapshot_file = snapshot
latest_snapshot = int(re.search(r"(\d+)\.index", latest_snapshot_file.name).group(1))
max_modified_time = modified_time

# update snapshotindex in the config
dlc_config["snapshotindex"] = latest_snapshot
snapshotindex = snapshots.index(latest_snapshot_file)

dlc_config["snapshotindex"] = snapshotindex
edit_config(
dlc_cfg_filepath,
{"snapshotindex": latest_snapshot},
{"snapshotindex": snapshotindex},
)

self.insert1(
Expand Down

0 comments on commit 1cf8347

Please sign in to comment.