diff --git a/element_deeplabcut/model.py b/element_deeplabcut/model.py index a57ecbb..e6935c9 100644 --- a/element_deeplabcut/model.py +++ b/element_deeplabcut/model.py @@ -708,7 +708,14 @@ def make(self, key): task_mode, output_dir = (PoseEstimationTask & key).fetch1( "task_mode", "pose_estimation_output_dir" ) - + if not output_dir: + output_dir = PoseEstimationTask.infer_output_dir( + key, relative=True, mkdir=True + ) + # update pose_estimation_output_dir + PoseEstimationTask.update1( + {**key, "pose_estimation_output_dir": output_dir.as_posix()} + ) output_dir = find_full_path(get_dlc_root_data_dir(), output_dir) # Triger PoseEstimation diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index b4f2765..87e88ad 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -7,8 +7,10 @@ import datajoint as dj import inspect import importlib -import os +import re from pathlib import Path +import yaml + from element_interface.utils import find_full_path, dict_to_uuid from .readers import dlc_reader @@ -241,7 +243,7 @@ class ModelTraining(dj.Computed): # https://github.com/DeepLabCut/DeepLabCut/issues/70 def make(self, key): - from deeplabcut import train_network # isort:skip + import deeplabcut try: from deeplabcut.utils.auxiliaryfunctions import ( @@ -288,13 +290,39 @@ def make(self, key): ) model_train_folder = project_path / model_folder / "train" + # update path of the init_weight + with open(model_train_folder / "pose_cfg.yaml", "r") as f: + pose_cfg = yaml.safe_load(f) + init_weights_path = Path(pose_cfg["init_weights"]) + + if ( + "pose_estimation_tensorflow/models/pretrained" + in init_weights_path.as_posix() + ): + # this is the res_net models, construct new path here + init_weights_path = ( + Path(deeplabcut.__path__[0]) + / "pose_estimation_tensorflow/models/pretrained" + / init_weights_path.name + ) + else: + # this is existing snapshot weights, update path here + init_weights_path = model_train_folder / init_weights_path.name + edit_config( model_train_folder / "pose_cfg.yaml", - {"project_path": project_path.as_posix()}, + { + "project_path": project_path.as_posix(), + "init_weights": init_weights_path.as_posix(), + "dataset": Path(pose_cfg["dataset"]).as_posix(), + "metadataset": Path(pose_cfg["metadataset"]).as_posix(), + }, ) # ---- Trigger DLC model training job ---- - train_network_input_args = list(inspect.signature(train_network).parameters) + train_network_input_args = list( + inspect.signature(deeplabcut.train_network).parameters + ) train_network_kwargs = { k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v for k, v in dlc_config.items() @@ -304,25 +332,28 @@ def make(self, key): train_network_kwargs[k] = int(train_network_kwargs[k]) try: - train_network(dlc_cfg_filepath, **train_network_kwargs) + deeplabcut.train_network(dlc_cfg_filepath, **train_network_kwargs) except KeyboardInterrupt: # Instructions indicate to train until interrupt print("DLC training stopped via Keyboard Interrupt") - snapshots = list(model_train_folder.glob("*index*")) - max_modified_time = 0 # DLC goes by snapshot magnitude when judging 'latest' for evaluation # Here, we mean most recently generated + snapshots = sorted(model_train_folder.glob("snapshot*.index")) + max_modified_time = 0 for snapshot in snapshots: - modified_time = os.path.getmtime(snapshot) + modified_time = snapshot.stat().st_mtime if modified_time > max_modified_time: - latest_snapshot = int(snapshot.stem[9:]) + latest_snapshot_file = snapshot + latest_snapshot = int(re.search(r"(\d+)\.index", latest_snapshot_file.name).group(1)) max_modified_time = modified_time # update snapshotindex in the config - dlc_config["snapshotindex"] = latest_snapshot + snapshotindex = snapshots.index(latest_snapshot_file) + + dlc_config["snapshotindex"] = snapshotindex edit_config( dlc_cfg_filepath, - {"snapshotindex": latest_snapshot}, + {"snapshotindex": snapshotindex}, ) self.insert1(